Group support in conv-transpose2d.

This commit is contained in:
laurent
2023-09-09 18:11:41 +01:00
parent 31936c08fe
commit b4fe316fa1
6 changed files with 32 additions and 15 deletions

View File

@ -127,7 +127,7 @@ pub struct ConvTranspose2dConfig {
pub output_padding: usize,
pub stride: usize,
pub dilation: usize,
// TODO: support groups.
pub groups: usize,
}
impl Default for ConvTranspose2dConfig {
@ -137,6 +137,7 @@ impl Default for ConvTranspose2dConfig {
output_padding: 0,
stride: 1,
dilation: 1,
groups: 1,
}
}
}
@ -170,6 +171,7 @@ impl crate::Module for ConvTranspose2d {
self.config.output_padding,
self.config.stride,
self.config.dilation,
self.config.groups,
)?;
match &self.bias {
None => Ok(x),