From a17a7c42c1b95f4d710d95be86efcc2665eadb19 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 7 Sep 2023 06:47:28 +0200 Subject: [PATCH] Add a nn layer for conv-transpose2d. (#760) --- candle-nn/src/conv.rs | 51 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index dbf23aa5..f985cfd6 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -80,7 +80,6 @@ impl Default for Conv2dConfig { } } -#[allow(dead_code)] #[derive(Debug)] pub struct Conv2d { weight: Tensor, @@ -122,6 +121,56 @@ impl crate::Module for Conv2d { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ConvTranspose2dConfig { + pub padding: usize, + pub output_padding: usize, + pub stride: usize, + pub dilation: usize, + // TODO: support groups. +} + +#[derive(Debug)] +pub struct ConvTranspose2d { + weight: Tensor, + bias: Option, + config: ConvTranspose2dConfig, +} + +impl ConvTranspose2d { + pub fn new(weight: Tensor, bias: Option, config: ConvTranspose2dConfig) -> Self { + Self { + weight, + bias, + config, + } + } + + pub fn config(&self) -> &ConvTranspose2dConfig { + &self.config + } +} + +impl crate::Module for ConvTranspose2d { + fn forward(&self, x: &Tensor) -> Result { + let x = x.conv_transpose2d( + &self.weight, + self.config.padding, + self.config.output_padding, + self.config.stride, + self.config.dilation, + )?; + match &self.bias { + None => Ok(x), + Some(bias) => { + let b = bias.dims1()?; + let bias = bias.reshape((1, b, 1, 1))?; + Ok(x.broadcast_add(&bias)?) + } + } + } +} + pub fn conv1d( in_channels: usize, out_channels: usize,