From 3b0d1e7d03469ed17d1eec931bd76c857b99ff3a Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 3 Nov 2023 11:18:25 +0100 Subject: [PATCH] Transposed conv1d in candle-nn. (#1252) --- candle-nn/src/conv.rs | 94 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index 7c0bf841..b1168405 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -70,6 +70,67 @@ impl crate::Module for Conv1d { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ConvTranspose1dConfig { + pub padding: usize, + pub output_padding: usize, + pub stride: usize, + pub dilation: usize, + // TODO: support groups. +} + +impl Default for ConvTranspose1dConfig { + fn default() -> Self { + Self { + padding: 0, + output_padding: 0, + stride: 1, + dilation: 1, + } + } +} + +#[derive(Clone, Debug)] +pub struct ConvTranspose1d { + weight: Tensor, + bias: Option, + config: ConvTranspose1dConfig, +} + +impl ConvTranspose1d { + pub fn new(weight: Tensor, bias: Option, config: ConvTranspose1dConfig) -> Self { + Self { + weight, + bias, + config, + } + } + + pub fn config(&self) -> &ConvTranspose1dConfig { + &self.config + } +} + +impl crate::Module for ConvTranspose1d { + fn forward(&self, x: &Tensor) -> Result { + let x = x.conv_transpose1d( + &self.weight, + self.config.padding, + self.config.output_padding, + self.config.stride, + self.config.dilation, + )?; + match &self.bias { + None => Ok(x), + Some(bias) => { + let b = bias.dims1()?; + let bias = bias.reshape((1, b, 1, 1))?; + Ok(x.broadcast_add(&bias)?) + } + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Conv2dConfig { pub padding: usize, @@ -241,6 +302,39 @@ pub fn conv1d( Ok(Conv1d::new(ws, Some(bs), cfg)) } +pub fn conv_transpose1d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: ConvTranspose1dConfig, + vb: crate::VarBuilder, +) -> Result { + let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt(); + let init = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?; + let bs = vb.get_with_hints(out_channels, "bias", init)?; + Ok(ConvTranspose1d::new(ws, Some(bs), cfg)) +} + +pub fn conv_transpose1d_no_bias( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: ConvTranspose1dConfig, + vb: crate::VarBuilder, +) -> Result { + let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt(); + let init = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?; + Ok(ConvTranspose1d::new(ws, None, cfg)) +} + pub fn conv2d( in_channels: usize, out_channels: usize,