Group support in conv-transpose2d.

This commit is contained in:
laurent
2023-09-09 18:11:41 +01:00
parent 31936c08fe
commit b4fe316fa1
6 changed files with 32 additions and 15 deletions

View File

@ -213,6 +213,7 @@ impl Tensor {
out_padding, out_padding,
*stride, *stride,
*dilation, *dilation,
1,
)?; )?;
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?; *sum_grad = sum_grad.add(&grad_arg)?;

View File

@ -60,12 +60,13 @@ pub struct ParamsConvTranspose2D {
pub(crate) i_w: usize, pub(crate) i_w: usize,
pub(crate) k_h: usize, pub(crate) k_h: usize,
pub(crate) k_w: usize, pub(crate) k_w: usize,
pub(crate) c_out: usize, pub(crate) c_out_per_group: usize,
pub(crate) c_in: usize, pub(crate) c_in: usize,
pub(crate) padding: usize, pub(crate) padding: usize,
pub(crate) output_padding: usize, pub(crate) output_padding: usize,
pub(crate) stride: usize, pub(crate) stride: usize,
pub(crate) dilation: usize, pub(crate) dilation: usize,
pub(crate) groups: usize,
} }
impl ParamsConvTranspose2D { impl ParamsConvTranspose2D {
@ -80,7 +81,8 @@ impl ParamsConvTranspose2D {
} }
pub(crate) fn out_dims(&self) -> Vec<usize> { pub(crate) fn out_dims(&self) -> Vec<usize> {
vec![self.b_size, self.c_out, self.out_h(), self.out_w()] let c_out = self.c_out_per_group * self.groups;
vec![self.b_size, c_out, self.out_h(), self.out_w()]
} }
} }
@ -211,24 +213,29 @@ impl Tensor {
output_padding: usize, output_padding: usize,
stride: usize, stride: usize,
dilation: usize, dilation: usize,
groups: usize,
) -> Result<Self> { ) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?; let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?; let (c_in_k, c_out_per_group, k_h, k_w) = kernel.dims4()?;
if c_in != c_in_k { if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})") crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
} }
if c_in % groups != 0 {
crate::bail!("in_channel {c_in} must be divisible by groups {groups}")
}
let params = ParamsConvTranspose2D { let params = ParamsConvTranspose2D {
b_size, b_size,
i_h, i_h,
i_w, i_w,
k_h, k_h,
k_w, k_w,
c_out, c_out_per_group,
c_in, c_in,
padding, padding,
output_padding, output_padding,
stride, stride,
dilation, dilation,
groups,
}; };
let storage = self.storage().conv_transpose2d( let storage = self.storage().conv_transpose2d(
self.layout(), self.layout(),
@ -243,6 +250,7 @@ impl Tensor {
output_padding: params.output_padding, output_padding: params.output_padding,
stride: params.stride, stride: params.stride,
dilation: params.dilation, dilation: params.dilation,
groups: params.groups,
}); });
let out_dims = params.out_dims(); let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false)) Ok(crate::tensor::from_storage(storage, out_dims, op, false))

View File

@ -1190,8 +1190,9 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
let (out_h, out_w) = (p.out_h(), p.out_w()); let (out_h, out_w) = (p.out_h(), p.out_w());
// Output shape: [b_size, c_out, out_h, out_w]. // Output shape: [b_size, c_out, out_h, out_w].
let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; let c_out = p.groups * p.c_out_per_group;
let dst_s0 = p.c_out * out_h * out_w; let dst = vec![T::zero(); p.b_size * c_out * out_h * out_w];
let dst_s0 = c_out * out_h * out_w;
let dst_s1 = out_h * out_w; let dst_s1 = out_h * out_w;
let dst_s2 = out_w; let dst_s2 = out_w;
let dst_s3 = 1; let dst_s3 = 1;
@ -1214,12 +1215,16 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
} }
} }
let c_in_per_group = p.c_in / p.groups;
for k_y in 0..p.k_h { for k_y in 0..p.k_h {
for k_x in 0..p.k_w { for k_x in 0..p.k_w {
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| { (0..c_out).into_par_iter().for_each(|dst_c_idx| {
let k_cont = (0..p.c_in) let (group_idx, dst_c_idx_in_group) =
(c_out / p.c_out_per_group, c_out % p.c_out_per_group);
let k_cont = (0..c_in_per_group)
.map(|c_in_idx| { .map(|c_in_idx| {
k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3] let c_in_idx = group_idx * c_in_per_group + c_in_idx;
k[c_in_idx * k_s0 + dst_c_idx_in_group * k_s1 + k_y * k_s2 + k_x * k_s3]
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
for b_idx in 0..p.b_size { for b_idx in 0..p.b_size {
@ -1245,7 +1250,7 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
inp_cont.as_ptr(), inp_cont.as_ptr(),
k_cont.as_ptr(), k_cont.as_ptr(),
&mut d, &mut d,
p.c_in, c_in_per_group,
) )
} }
let dst_p = dst.as_ptr(); let dst_p = dst.as_ptr();

View File

@ -102,6 +102,7 @@ pub enum Op {
output_padding: usize, output_padding: usize,
stride: usize, stride: usize,
dilation: usize, dilation: usize,
groups: usize,
}, },
AvgPool2D { AvgPool2D {

View File

@ -130,7 +130,7 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075 10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
] ]
); );
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7, 7]); assert_eq!(res.dims(), [1, 2, 7, 7]);
assert_eq!( assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?, test_utils::to_vec3_round(&res.i(0)?, 4)?,
@ -164,7 +164,7 @@ fn conv2d(dev: &Device) -> Result<()> {
); );
// Transpose and dilations. // Transpose and dilations.
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?; let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2, 1)?;
assert_eq!(res.dims(), [1, 2, 9, 9]); assert_eq!(res.dims(), [1, 2, 9, 9]);
assert_eq!( assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?, test_utils::to_vec3_round(&res.i(0)?, 4)?,
@ -246,13 +246,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
] ]
); );
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]); assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!( assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?, test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539], [0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
); );
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1)?; let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [2, 2, 3, 3]); assert_eq!(res.dims(), [2, 2, 3, 3]);
assert_eq!( assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?, test_utils::to_vec1_round(&res.flatten_all()?, 4)?,

View File

@ -127,7 +127,7 @@ pub struct ConvTranspose2dConfig {
pub output_padding: usize, pub output_padding: usize,
pub stride: usize, pub stride: usize,
pub dilation: usize, pub dilation: usize,
// TODO: support groups. pub groups: usize,
} }
impl Default for ConvTranspose2dConfig { impl Default for ConvTranspose2dConfig {
@ -137,6 +137,7 @@ impl Default for ConvTranspose2dConfig {
output_padding: 0, output_padding: 0,
stride: 1, stride: 1,
dilation: 1, dilation: 1,
groups: 1,
} }
} }
} }
@ -170,6 +171,7 @@ impl crate::Module for ConvTranspose2d {
self.config.output_padding, self.config.output_padding,
self.config.stride, self.config.stride,
self.config.dilation, self.config.dilation,
self.config.groups,
)?; )?;
match &self.bias { match &self.bias {
None => Ok(x), None => Ok(x),