mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Group support in conv-transpose2d.
This commit is contained in:
@ -60,12 +60,13 @@ pub struct ParamsConvTranspose2D {
|
||||
pub(crate) i_w: usize,
|
||||
pub(crate) k_h: usize,
|
||||
pub(crate) k_w: usize,
|
||||
pub(crate) c_out: usize,
|
||||
pub(crate) c_out_per_group: usize,
|
||||
pub(crate) c_in: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) output_padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
pub(crate) groups: usize,
|
||||
}
|
||||
|
||||
impl ParamsConvTranspose2D {
|
||||
@ -80,7 +81,8 @@ impl ParamsConvTranspose2D {
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
|
||||
let c_out = self.c_out_per_group * self.groups;
|
||||
vec![self.b_size, c_out, self.out_h(), self.out_w()]
|
||||
}
|
||||
}
|
||||
|
||||
@ -211,24 +213,29 @@ impl Tensor {
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
|
||||
let (c_in_k, c_out_per_group, k_h, k_w) = kernel.dims4()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
if c_in % groups != 0 {
|
||||
crate::bail!("in_channel {c_in} must be divisible by groups {groups}")
|
||||
}
|
||||
let params = ParamsConvTranspose2D {
|
||||
b_size,
|
||||
i_h,
|
||||
i_w,
|
||||
k_h,
|
||||
k_w,
|
||||
c_out,
|
||||
c_out_per_group,
|
||||
c_in,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
};
|
||||
let storage = self.storage().conv_transpose2d(
|
||||
self.layout(),
|
||||
@ -243,6 +250,7 @@ impl Tensor {
|
||||
output_padding: params.output_padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
groups: params.groups,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
|
Reference in New Issue
Block a user