Group support in conv-transpose2d.

This commit is contained in:
laurent
2023-09-09 18:11:41 +01:00
parent 31936c08fe
commit b4fe316fa1
6 changed files with 32 additions and 15 deletions

View File

@ -60,12 +60,13 @@ pub struct ParamsConvTranspose2D {
pub(crate) i_w: usize,
pub(crate) k_h: usize,
pub(crate) k_w: usize,
pub(crate) c_out: usize,
pub(crate) c_out_per_group: usize,
pub(crate) c_in: usize,
pub(crate) padding: usize,
pub(crate) output_padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
pub(crate) groups: usize,
}
impl ParamsConvTranspose2D {
@ -80,7 +81,8 @@ impl ParamsConvTranspose2D {
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
let c_out = self.c_out_per_group * self.groups;
vec![self.b_size, c_out, self.out_h(), self.out_w()]
}
}
@ -211,24 +213,29 @@ impl Tensor {
output_padding: usize,
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
let (c_in_k, c_out_per_group, k_h, k_w) = kernel.dims4()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
if c_in % groups != 0 {
crate::bail!("in_channel {c_in} must be divisible by groups {groups}")
}
let params = ParamsConvTranspose2D {
b_size,
i_h,
i_w,
k_h,
k_w,
c_out,
c_out_per_group,
c_in,
padding,
output_padding,
stride,
dilation,
groups,
};
let storage = self.storage().conv_transpose2d(
self.layout(),
@ -243,6 +250,7 @@ impl Tensor {
output_padding: params.output_padding,
stride: params.stride,
dilation: params.dilation,
groups: params.groups,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))