Group support in conv-transpose2d.

This commit is contained in:
laurent
2023-09-09 18:11:41 +01:00
parent 31936c08fe
commit b4fe316fa1
6 changed files with 32 additions and 15 deletions

View File

@ -130,7 +130,7 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7, 7]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
@ -164,7 +164,7 @@ fn conv2d(dev: &Device) -> Result<()> {
);
// Transpose and dilations.
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2, 1)?;
assert_eq!(res.dims(), [1, 2, 9, 9]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
@ -246,13 +246,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
);
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1)?;
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [2, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,