diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index d2099df7..b86d51dd 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -213,6 +213,7 @@ impl Tensor { out_padding, *stride, *dilation, + 1, )?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad_arg)?; diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 1f3ef582..056edcd4 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -60,12 +60,13 @@ pub struct ParamsConvTranspose2D { pub(crate) i_w: usize, pub(crate) k_h: usize, pub(crate) k_w: usize, - pub(crate) c_out: usize, + pub(crate) c_out_per_group: usize, pub(crate) c_in: usize, pub(crate) padding: usize, pub(crate) output_padding: usize, pub(crate) stride: usize, pub(crate) dilation: usize, + pub(crate) groups: usize, } impl ParamsConvTranspose2D { @@ -80,7 +81,8 @@ impl ParamsConvTranspose2D { } pub(crate) fn out_dims(&self) -> Vec { - vec![self.b_size, self.c_out, self.out_h(), self.out_w()] + let c_out = self.c_out_per_group * self.groups; + vec![self.b_size, c_out, self.out_h(), self.out_w()] } } @@ -211,24 +213,29 @@ impl Tensor { output_padding: usize, stride: usize, dilation: usize, + groups: usize, ) -> Result { let (b_size, c_in, i_h, i_w) = self.dims4()?; - let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?; + let (c_in_k, c_out_per_group, k_h, k_w) = kernel.dims4()?; if c_in != c_in_k { crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})") } + if c_in % groups != 0 { + crate::bail!("in_channel {c_in} must be divisible by groups {groups}") + } let params = ParamsConvTranspose2D { b_size, i_h, i_w, k_h, k_w, - c_out, + c_out_per_group, c_in, padding, output_padding, stride, dilation, + groups, }; let storage = self.storage().conv_transpose2d( self.layout(), @@ -243,6 +250,7 @@ impl Tensor { output_padding: params.output_padding, stride: params.stride, dilation: params.dilation, + groups: params.groups, }); let out_dims = params.out_dims(); Ok(crate::tensor::from_storage(storage, out_dims, op, false)) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 01ccfde7..11250e82 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1190,8 +1190,9 @@ impl<'a> Map2 for ConvTranspose2D<'a> { let (out_h, out_w) = (p.out_h(), p.out_w()); // Output shape: [b_size, c_out, out_h, out_w]. - let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; - let dst_s0 = p.c_out * out_h * out_w; + let c_out = p.groups * p.c_out_per_group; + let dst = vec![T::zero(); p.b_size * c_out * out_h * out_w]; + let dst_s0 = c_out * out_h * out_w; let dst_s1 = out_h * out_w; let dst_s2 = out_w; let dst_s3 = 1; @@ -1214,12 +1215,16 @@ impl<'a> Map2 for ConvTranspose2D<'a> { } } + let c_in_per_group = p.c_in / p.groups; for k_y in 0..p.k_h { for k_x in 0..p.k_w { - (0..p.c_out).into_par_iter().for_each(|dst_c_idx| { - let k_cont = (0..p.c_in) + (0..c_out).into_par_iter().for_each(|dst_c_idx| { + let (group_idx, dst_c_idx_in_group) = + (c_out / p.c_out_per_group, c_out % p.c_out_per_group); + let k_cont = (0..c_in_per_group) .map(|c_in_idx| { - k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3] + let c_in_idx = group_idx * c_in_per_group + c_in_idx; + k[c_in_idx * k_s0 + dst_c_idx_in_group * k_s1 + k_y * k_s2 + k_x * k_s3] }) .collect::>(); for b_idx in 0..p.b_size { @@ -1245,7 +1250,7 @@ impl<'a> Map2 for ConvTranspose2D<'a> { inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, - p.c_in, + c_in_per_group, ) } let dst_p = dst.as_ptr(); diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 9382b217..206d536f 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -102,6 +102,7 @@ pub enum Op { output_padding: usize, stride: usize, dilation: usize, + groups: usize, }, AvgPool2D { diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 937ddf67..720e51e1 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -130,7 +130,7 @@ fn conv2d(dev: &Device) -> Result<()> { 10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075 ] ); - let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?; assert_eq!(res.dims(), [1, 2, 7, 7]); assert_eq!( test_utils::to_vec3_round(&res.i(0)?, 4)?, @@ -164,7 +164,7 @@ fn conv2d(dev: &Device) -> Result<()> { ); // Transpose and dilations. - let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?; + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2, 1)?; assert_eq!(res.dims(), [1, 2, 9, 9]); assert_eq!( test_utils::to_vec3_round(&res.i(0)?, 4)?, @@ -246,13 +246,13 @@ fn conv2d_small(dev: &Device) -> Result<()> { 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000 ] ); - let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?; assert_eq!(res.dims(), [1, 1, 3, 3]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539], ); - let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1)?; + let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1, 1)?; assert_eq!(res.dims(), [2, 2, 3, 3]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index cfe86bfa..9becdb73 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -127,7 +127,7 @@ pub struct ConvTranspose2dConfig { pub output_padding: usize, pub stride: usize, pub dilation: usize, - // TODO: support groups. + pub groups: usize, } impl Default for ConvTranspose2dConfig { @@ -137,6 +137,7 @@ impl Default for ConvTranspose2dConfig { output_padding: 0, stride: 1, dilation: 1, + groups: 1, } } } @@ -170,6 +171,7 @@ impl crate::Module for ConvTranspose2d { self.config.output_padding, self.config.stride, self.config.dilation, + self.config.groups, )?; match &self.bias { None => Ok(x),