From a044907ffce553a0394db3a1204f21e3691e54af Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 29 Aug 2023 16:12:11 +0100 Subject: [PATCH] Dilated convolutions (#657) * Add the dilation parameter. * Restore the basic optimizer example. * Dilation support in cudnn. * Use the dilation parameter in the cpu backend. * More dilation support. * No support for dilation in transposed convolutions. * Add dilation to a test. * Remove a print. * Helper function. --- candle-core/examples/basics.rs | 2 +- candle-core/examples/cpu_benchmarks.rs | 4 +- candle-core/examples/cuda_basics.rs | 6 +- candle-core/src/backprop.rs | 15 +- candle-core/src/conv.rs | 27 ++-- candle-core/src/cpu_backend.rs | 12 +- candle-core/src/cuda_backend.rs | 15 +- candle-core/src/cudnn.rs | 2 +- candle-core/src/device.rs | 20 +++ candle-core/src/op.rs | 3 + candle-core/tests/conv_tests.rs | 135 ++++++++++++++++-- .../examples/musicgen/encodec_model.rs | 2 + .../examples/stable-diffusion/resnet.rs | 2 + candle-examples/examples/whisper/model.rs | 2 + candle-examples/examples/yolo-v3/darknet.rs | 1 + candle-examples/examples/yolo-v8/model.rs | 1 + candle-kernels/src/conv.cu | 18 ++- candle-nn/src/conv.rs | 6 + candle-wasm-examples/whisper/src/model.rs | 2 + candle-wasm-examples/yolo/src/model.rs | 1 + 20 files changed, 231 insertions(+), 45 deletions(-) diff --git a/candle-core/examples/basics.rs b/candle-core/examples/basics.rs index 9d4734de..ad008177 100644 --- a/candle-core/examples/basics.rs +++ b/candle-core/examples/basics.rs @@ -11,7 +11,7 @@ fn main() -> Result<()> { let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?; let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?; let start = std::time::Instant::now(); - let res = inp.conv2d(&w, 0, 1, 1)?; + let res = inp.conv2d(&w, 0, 1, 1, 1)?; println!("{:?}", start.elapsed()); println!("{res:?}"); Ok(()) diff --git a/candle-core/examples/cpu_benchmarks.rs b/candle-core/examples/cpu_benchmarks.rs index 1ebd9b75..13175ac1 100644 --- a/candle-core/examples/cpu_benchmarks.rs +++ b/candle-core/examples/cpu_benchmarks.rs @@ -40,7 +40,7 @@ impl Benchmark for Conv1d { } fn run_one(d: &Self::PreProcessData) -> Result { - d.0.conv1d(&d.1, 0, 1, 1) + d.0.conv1d(&d.1, 0, 1, 1, 1) } const ITERS: usize = 5; @@ -59,7 +59,7 @@ impl Benchmark for Conv2d { } fn run_one(d: &Self::PreProcessData) -> Result { - d.0.conv2d(&d.1, 0, 1, 1) + d.0.conv2d(&d.1, 0, 1, 1, 1) } const ITERS: usize = 1; diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index cbdafd64..ad207461 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -11,11 +11,11 @@ fn main() -> Result<()> { let device = Device::new_cuda(0)?; let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?; let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?; - let out_t = in_t.conv2d(&k_t, 0, 1, 1)?; + let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?; println!("{out_t}"); let in_t = in_t.to_device(&Device::Cpu)?; let k_t = k_t.to_device(&Device::Cpu)?; - let out_t2 = in_t.conv2d(&k_t, 0, 1, 1)?; + let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?; let diff = (out_t.to_device(&Device::Cpu)? - out_t2)? .sqr()? .sum_all()?; @@ -23,7 +23,7 @@ fn main() -> Result<()> { let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?; let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?; - let res = t.conv2d(&w, 1, 1, 1)?; + let res = t.conv2d(&w, 1, 1, 1, 1)?; println!("{res:?}"); Ok(()) } diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 9ecdee4f..f4f90373 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -197,21 +197,28 @@ impl Tensor { kernel, padding, stride, + dilation, } => { // The output height for conv_transpose2d is: // (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1 let grad_h = grad.dim(2)?; let k_h = kernel.dim(2)?; - let out_size = (grad_h - 1) * stride + (k_h - 1) + 1 - 2 * padding; + let out_size = + (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding; let out_padding = arg.dim(2)? - out_size; - let grad_arg = - grad.conv_transpose2d(kernel, *padding, out_padding, *stride)?; + let grad_arg = grad.conv_transpose2d( + kernel, + *padding, + out_padding, + *stride, + *dilation, + )?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad_arg)?; let grad_kernel = arg .transpose(0, 1)? - .conv2d(&grad.transpose(0, 1)?, *padding, *stride, 1)? + .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)? .transpose(0, 1)?; let sum_grad = grads.or_insert(kernel)?; *sum_grad = sum_grad.add(&grad_kernel)?; diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index d9e0a9ab..1f3ef582 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -11,12 +11,12 @@ pub struct ParamsConv1D { pub(crate) k_size: usize, pub(crate) padding: usize, pub(crate) stride: usize, + pub(crate) dilation: usize, } impl ParamsConv1D { pub(crate) fn l_out(&self) -> usize { - let dilation = 1; - (self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1 + (self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1 } pub(crate) fn out_dims(&self) -> Vec { @@ -36,17 +36,16 @@ pub struct ParamsConv2D { pub(crate) c_in: usize, pub(crate) padding: usize, pub(crate) stride: usize, + pub(crate) dilation: usize, } impl ParamsConv2D { pub(crate) fn out_h(&self) -> usize { - let dilation = 1; - (self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1 + (self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1 } pub(crate) fn out_w(&self) -> usize { - let dilation = 1; - (self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1 + (self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1 } pub(crate) fn out_dims(&self) -> Vec { @@ -66,18 +65,17 @@ pub struct ParamsConvTranspose2D { pub(crate) padding: usize, pub(crate) output_padding: usize, pub(crate) stride: usize, + pub(crate) dilation: usize, } impl ParamsConvTranspose2D { pub(crate) fn out_h(&self) -> usize { - let dilation = 1; - (self.i_h - 1) * self.stride + dilation * (self.k_h - 1) + self.output_padding + 1 + (self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1 - 2 * self.padding } pub(crate) fn out_w(&self) -> usize { - let dilation = 1; - (self.i_w - 1) * self.stride + dilation * (self.k_w - 1) + self.output_padding + 1 + (self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1 - 2 * self.padding } @@ -96,6 +94,7 @@ impl Tensor { kernel, padding: params.padding, stride: params.stride, + dilation: params.dilation, }); let out_dims = params.out_dims(); Ok(crate::tensor::from_storage(storage, out_dims, op, false)) @@ -107,6 +106,7 @@ impl Tensor { kernel: &Self, padding: usize, stride: usize, + dilation: usize, groups: usize, ) -> Result { let (c_out, c_in_k, k_size) = kernel.dims3()?; @@ -130,6 +130,7 @@ impl Tensor { k_size, padding, stride, + dilation, }; if groups == 1 { self.conv1d_single_group(kernel, ¶ms) @@ -154,6 +155,7 @@ impl Tensor { kernel, padding: params.padding, stride: params.stride, + dilation: params.dilation, }); let out_dims = params.out_dims(); Ok(crate::tensor::from_storage(storage, out_dims, op, false)) @@ -165,6 +167,7 @@ impl Tensor { kernel: &Self, padding: usize, stride: usize, + dilation: usize, groups: usize, ) -> Result { let (b_size, c_in, i_h, i_w) = self.dims4()?; @@ -184,6 +187,7 @@ impl Tensor { c_in: c_in / groups, padding, stride, + dilation, }; if groups == 1 { self.conv2d_single_group(kernel, ¶ms) @@ -206,6 +210,7 @@ impl Tensor { padding: usize, output_padding: usize, stride: usize, + dilation: usize, ) -> Result { let (b_size, c_in, i_h, i_w) = self.dims4()?; let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?; @@ -223,6 +228,7 @@ impl Tensor { padding, output_padding, stride, + dilation, }; let storage = self.storage().conv_transpose2d( self.layout(), @@ -236,6 +242,7 @@ impl Tensor { padding: params.padding, output_padding: params.output_padding, stride: params.stride, + dilation: params.dilation, }); let out_dims = params.out_dims(); Ok(crate::tensor::from_storage(storage, out_dims, op, false)) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index f52d53b1..60fac0c9 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1064,7 +1064,7 @@ impl<'a> Map2 for Conv1D<'a> { let dst_idx = dst_idx + b_idx * p.c_out * l_out; for dst_l in 0..l_out { let dst_idx = dst_idx + dst_l; - let src_l = p.stride * dst_l + offset; + let src_l = (p.stride * dst_l + offset) * p.dilation; if src_l < p.padding || src_l >= p.padding + p.l_in { continue; } @@ -1141,14 +1141,14 @@ impl<'a> Map2 for Conv2D<'a> { let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w; for dst_h in 0..out_h { let dst_idx = dst_idx + dst_h * out_w; - let src_h = p.stride * dst_h + offset_h; + let src_h = (p.stride * dst_h + offset_h) * p.dilation; if src_h < p.padding || src_h >= p.i_h + p.padding { continue; } let src_h = src_h - p.padding; for dst_w in 0..out_w { let dst_idx = dst_idx + dst_w; - let src_w = p.stride * dst_w + offset_w; + let src_w = (p.stride * dst_w + offset_w) * p.dilation; if src_w < p.padding || src_w >= p.i_w + p.padding { continue; } @@ -1186,6 +1186,12 @@ impl<'a> Map2 for ConvTranspose2D<'a> { const OP: &'static str = "conv_transpose2d"; fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> { let p = self.0; + if p.dilation != 1 { + crate::bail!( + "dilation {} is not supported for conv-transpose2d", + p.dilation + ) + } let inp = &inp[inp_l.start_offset()..]; let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?; let k = &k[k_l.start_offset()..]; diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index ed696368..cd06e8d7 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -960,7 +960,9 @@ impl<'a> Map2 for Conv1D<'a> { crate::bail!("unexpected input shape for conv1d {dims:?}") }; let ds = dev.htod_copy(ds).w()?; - let params = (el, l_out, p.stride, p.padding, &ds, inp, k, &out); + let params = ( + el, l_out, p.stride, p.padding, p.dilation, &ds, inp, k, &out, + ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) @@ -998,7 +1000,9 @@ impl<'a> Map2 for Conv2D<'a> { crate::bail!("unexpected input shape for conv2d {dims:?}") }; let ds = dev.htod_copy(ds).w()?; - let params = (el, out_w, out_h, p.stride, p.padding, &ds, inp, k, &out); + let params = ( + el, out_w, out_h, p.stride, p.padding, p.dilation, &ds, inp, k, &out, + ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) @@ -1018,6 +1022,12 @@ impl<'a> Map2 for ConvTranspose2D<'a> { // Kernel shape: (c_in_k, c_out, h_k, w_k) // Input shape: (b_size, c_in, h_in, w_in) let p = &self.0; + if p.dilation != 1 { + crate::bail!( + "dilation {} is not supported for conv-transpose2d", + p.dilation + ) + } let (out_w, out_h) = (p.out_w(), p.out_h()); let dst_el = p.c_out * out_w * out_h * p.b_size; let inp = &inp.slice(inp_l.start_offset()..); @@ -1043,6 +1053,7 @@ impl<'a> Map2 for ConvTranspose2D<'a> { p.stride, p.padding, p.output_padding, + p.dilation, &ds, inp, k, diff --git a/candle-core/src/cudnn.rs b/candle-core/src/cudnn.rs index 3e943e51..235ad6e3 100644 --- a/candle-core/src/cudnn.rs +++ b/candle-core/src/cudnn.rs @@ -48,7 +48,7 @@ pub(crate) fn launch_conv2d< let conv = cudnn.create_conv2d::( /* pad */ [params.padding as i32, params.padding as i32], /* stride */ [params.stride as i32, params.stride as i32], - /* dilation */ [1, 1], + /* dilation */ [params.dilation as i32, params.dilation as i32], cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION, )?; let x_shape = [ diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 65232839..84716249 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -81,6 +81,26 @@ impl NdArray } } +impl NdArray + for &[[[[S; N4]; N3]; N2]; N1] +{ + fn shape(&self) -> Result { + Ok(Shape::from((N1, N2, N3, N4))) + } + + fn to_cpu_storage(&self) -> CpuStorage { + let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4); + for i1 in 0..N1 { + for i2 in 0..N2 { + for i3 in 0..N3 { + vec.extend(self[i1][i2][i3]) + } + } + } + S::to_cpu_storage_owned(vec) + } +} + impl Device { pub fn new_cuda(ordinal: usize) -> Result { Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index b18f868d..3fe52ebc 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -81,6 +81,7 @@ pub enum Op { kernel: Tensor, padding: usize, stride: usize, + dilation: usize, }, #[allow(dead_code)] @@ -89,6 +90,7 @@ pub enum Op { kernel: Tensor, padding: usize, stride: usize, + dilation: usize, }, #[allow(dead_code)] @@ -98,6 +100,7 @@ pub enum Op { padding: usize, output_padding: usize, stride: usize, + dilation: usize, }, AvgPool2D { diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 1c378e5e..05015995 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -32,13 +32,13 @@ fn conv1d(dev: &Device) -> Result<()> { dev, )? .reshape((2, 4, 3))?; - let res = t.conv1d(&w, 0, 1, 1)?; + let res = t.conv1d(&w, 0, 1, 1, 1)?; assert_eq!(res.dims(), [1, 2, 3]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069] ); - let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?; + let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?; assert_eq!(res.dims(), [1, 2, 5]); // Same as pytorch default padding: use zeros. assert_eq!( @@ -51,13 +51,13 @@ fn conv1d(dev: &Device) -> Result<()> { fn conv1d_small(dev: &Device) -> Result<()> { let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?; let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?; - let res = t.conv1d(&w, 0, 1, 1)?; + let res = t.conv1d(&w, 0, 1, 1, 1)?; assert_eq!(res.dims(), [1, 1, 2]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [0.4056, -0.8689] ); - let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?; + let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?; assert_eq!(res.dims(), [1, 1, 4]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, @@ -81,6 +81,10 @@ w_t = w.transpose(0, 1) res = torch.nn.functional.conv_transpose2d(t, w_t) print(res.shape) print(res) + +res = torch.nn.functional.conv2d(t, w, dilation=2) +print(res.shape) +print(res[0]) */ fn conv2d(dev: &Device) -> Result<()> { let t = Tensor::new( @@ -113,7 +117,7 @@ fn conv2d(dev: &Device) -> Result<()> { )?; let t = t.reshape((1, 4, 5, 5))?; let w = w.reshape((2, 4, 3, 3))?; - let res = t.conv2d(&w, 0, 1, 1)?; + let res = t.conv2d(&w, 0, 1, 1, 1)?; assert_eq!(res.dims(), [1, 2, 3, 3]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, @@ -122,7 +126,7 @@ fn conv2d(dev: &Device) -> Result<()> { 10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075 ] ); - let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?; + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; assert_eq!(res.dims(), [1, 2, 7, 7]); assert_eq!( test_utils::to_vec3_round(&res.i(0)?, 4)?, @@ -147,6 +151,13 @@ fn conv2d(dev: &Device) -> Result<()> { ] ] ); + // Dilations. + let res = t.conv2d(&w, 0, 1, 2, 1)?; + assert_eq!(res.dims(), [1, 2, 1, 1]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [2.45, -2.3504], + ); Ok(()) } @@ -182,13 +193,13 @@ fn conv2d_small(dev: &Device) -> Result<()> { let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?; let t = t.reshape((1, 2, 3, 3))?; let w = w.reshape((1, 2, 1, 1))?; - let res = t.conv2d(&w, 0, 1, 1)?; + let res = t.conv2d(&w, 0, 1, 1, 1)?; assert_eq!(res.dims(), [1, 1, 3, 3]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539] ); - let res = t.conv2d(&w, 2, 1, 1)?; + let res = t.conv2d(&w, 2, 1, 1, 1)?; assert_eq!(res.dims(), [1, 1, 7, 7]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, @@ -200,13 +211,13 @@ fn conv2d_small(dev: &Device) -> Result<()> { 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000 ] ); - let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?; + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; assert_eq!(res.dims(), [1, 1, 3, 3]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539], ); - let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1)?; + let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1)?; assert_eq!(res.dims(), [2, 2, 3, 3]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, @@ -230,7 +241,7 @@ fn conv2d_smaller(dev: &Device) -> Result<()> { let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?; let t = t.reshape((1, 1, 3, 3))?; let w = w.reshape((1, 1, 3, 3))?; - let res = t.conv2d(&w, 0, 1, 1)?; + let res = t.conv2d(&w, 0, 1, 1, 1)?; assert_eq!(res.dims(), [1, 1, 1, 1]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, @@ -261,7 +272,7 @@ fn conv2d_non_square(dev: &Device) -> Result<()> { let w = Tensor::new(&[-1.1351f32, 1.3841], dev)?; let t = t.reshape((1, 2, 4, 2))?; let w = w.reshape((1, 2, 1, 1))?; - let res = t.conv2d(&w, 0, 1, 1)?; + let res = t.conv2d(&w, 0, 1, 1, 1)?; assert_eq!(res.dims(), [1, 1, 4, 2]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, @@ -270,6 +281,36 @@ fn conv2d_non_square(dev: &Device) -> Result<()> { Ok(()) } +/* +import torch +torch.manual_seed(4242) + +t = torch.randn((1, 4, 5, 5), requires_grad=True) +w = torch.randn((2, 4, 3, 3), requires_grad=True) +print(t.flatten()) +print(w.flatten()) +res = torch.nn.functional.conv2d(t, w) +print(res.flatten()) +loss = (res ** 2).sum() +print(loss) +loss.backward() +print(t.grad.shape) +print(t.grad.flatten()) +print(w.grad.shape) +print(w.grad.flatten()) + +t.grad.zero_() +w.grad.zero_() +res = torch.nn.functional.conv2d(t, w, stride=2) +print(res.flatten()) +loss = (res ** 2).sum() +print(loss) +loss.backward() +print(t.grad.shape) +print(t.grad[0]) +print(w.grad.shape) +print(w.grad[0]) +*/ fn conv2d_grad(dev: &Device) -> Result<()> { use candle_core::Var; let t = Var::from_slice( @@ -302,7 +343,7 @@ fn conv2d_grad(dev: &Device) -> Result<()> { (2, 4, 3, 3), dev, )?; - let res = t.conv2d(&w, 0, 1, 1)?; + let res = t.conv2d(&w, 0, 1, 1, 1)?; let loss = res.sqr()?.sum_all()?; assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 741.12f32); let grads = loss.backward()?; @@ -335,6 +376,74 @@ fn conv2d_grad(dev: &Device) -> Result<()> { -34.24, -16.36, -12.31, 1.92, -33.62, -14.1, -49.23, -7.39, 11.5, -9.98, 9.66, 29.6 ] ); + + // Same as before but with stride. + let res = t.conv2d(&w, 0, 2, 1, 1)?; + let loss = res.sqr()?.sum_all()?; + assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 277.16f32); + let grads = loss.backward()?; + let grad_t = grads.get(&t).unwrap(); + let grad_w = grads.get(&w).unwrap(); + assert_eq!(grad_t.dims(), [1, 4, 5, 5]); + assert_eq!(grad_w.dims(), [2, 4, 3, 3]); + assert_eq!( + test_utils::to_vec3_round(&grad_t.i(0)?, 2)?, + [ + [ + [9.29, -7.03, 0.94, 3.49, -7.71], + [-1.8, -7.82, 8.9, 8.46, 7.43], + [-25.84, 22.09, -19.27, -0.22, 1.69], + [4.02, 18.53, -18.37, 2.3, -24.51], + [7.72, -9.68, -12.34, 5.6, -20.22] + ], + [ + [21.73, 3.39, -18.27, 3.86, -3.65], + [8.25, 3.73, 30.73, -8.61, -11.93], + [-72.15, -15.36, -17.53, -12.32, -1.61], + [-22.32, -7.79, -91.82, 6.44, -37.69], + [52.88, 14.44, 42.75, 9.88, 2.01] + ], + [ + [-8.98, 9.91, 6.75, -4.68, 15.38], + [4.93, -0.33, 9.94, -1.46, 14.78], + [13.62, -30.63, 3.96, -3.58, -4.48], + [-14.13, 1.19, -34.43, 3.08, -33.83], + [17.28, 12.94, 31.83, -3.35, 6.81] + ], + [ + [23.54, 6.98, -24.52, 0.52, 4.87], + [9.65, 6.18, 1.71, -25.23, -4.93], + [-54.99, -23.66, 3.19, -3.73, 18.58], + [-21.35, -10.39, -39.88, 28.73, -30.76], + [-9.13, 11.12, -14.0, -8.23, -11.25] + ] + ] + ); + assert_eq!( + test_utils::to_vec3_round(&grad_w.i(0)?, 2)?, + [ + [ + [28.34, -45.75, 7.32], + [0.72, -35.28, 19.23], + [-28.29, 20.89, -5.18] + ], + [ + [-16.04, -16.38, 32.12], + [57.5, 25.81, 11.96], + [-18.66, 8.48, -9.92] + ], + [ + [2.93, 1.57, -23.76], + [12.74, -26.2, -17.88], + [-14.98, -9.35, 12.2] + ], + [ + [-0.18, -6.82, 20.79], + [-2.54, 27.11, -10.11], + [-0.41, -3.18, -0.07] + ] + ] + ); Ok(()) } diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs index 86e3b6e9..53b252ed 100644 --- a/candle-examples/examples/musicgen/encodec_model.rs +++ b/candle-examples/examples/musicgen/encodec_model.rs @@ -278,6 +278,7 @@ impl EncodecConv1d { padding: 0, stride, groups: 1, + dilation: 1, }, vb.pp("conv"), )?, @@ -289,6 +290,7 @@ impl EncodecConv1d { padding: 0, stride, groups: 1, + dilation: 1, }, vb.pp("conv"), )?, diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-examples/examples/stable-diffusion/resnet.rs index 5f6a2558..4cfd386d 100644 --- a/candle-examples/examples/stable-diffusion/resnet.rs +++ b/candle-examples/examples/stable-diffusion/resnet.rs @@ -66,6 +66,7 @@ impl ResnetBlock2D { stride: 1, padding: 1, groups: 1, + dilation: 1, }; let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?; let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?; @@ -80,6 +81,7 @@ impl ResnetBlock2D { stride: 1, padding: 0, groups: 1, + dilation: 1, }; Some(conv2d( in_channels, diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index 9f0c9ef8..d6bea09a 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -281,11 +281,13 @@ impl AudioEncoder { padding: 1, stride: 1, groups: 1, + dilation: 1, }; let cfg2 = Conv1dConfig { padding: 1, stride: 2, groups: 1, + dilation: 1, }; let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; diff --git a/candle-examples/examples/yolo-v3/darknet.rs b/candle-examples/examples/yolo-v3/darknet.rs index de8fcf09..0c81bca8 100644 --- a/candle-examples/examples/yolo-v3/darknet.rs +++ b/candle-examples/examples/yolo-v3/darknet.rs @@ -132,6 +132,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl) stride, padding, groups: 1, + dilation: 1, }; let conv = if bias { conv2d(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))? diff --git a/candle-examples/examples/yolo-v8/model.rs b/candle-examples/examples/yolo-v8/model.rs index 98a0cb63..d7fe5c12 100644 --- a/candle-examples/examples/yolo-v8/model.rs +++ b/candle-examples/examples/yolo-v8/model.rs @@ -93,6 +93,7 @@ impl ConvBlock { padding, stride, groups: 1, + dilation: 1, }; let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?; let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?; diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index 5ccce317..c67a4300 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -8,6 +8,7 @@ __device__ void conv1d( const size_t l_out, const size_t stride, const size_t padding, + const size_t dilation, const size_t *info, const T *src, const T *kernel, @@ -36,7 +37,7 @@ __device__ void conv1d( const size_t src_idx0 = b_idx * src_s[0]; A d = 0; for (size_t offset = 0; offset < k_size; ++offset) { - size_t src_l = stride * dst_l + offset; + size_t src_l = (stride * dst_l + offset) * dilation; if (src_l < padding || src_l >= padding + l_in) { continue; } @@ -58,6 +59,7 @@ __device__ void conv2d( const size_t h_out, const size_t stride, const size_t padding, + const size_t dilation, const size_t *info, const T *src, const T *kernel, @@ -90,13 +92,13 @@ __device__ void conv2d( const size_t src_idx0 = b_idx * src_s[0]; A d = 0; for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { - size_t src_w = stride * dst_w + w_offset; + size_t src_w = (stride * dst_w + w_offset) * dilation; if (src_w < padding || src_w >= w_in + padding) { continue; } src_w -= padding; for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { - size_t src_h = stride * dst_h + h_offset; + size_t src_h = (stride * dst_h + h_offset) * dilation; if (src_h < padding || src_h >= h_in + padding) { continue; } @@ -120,6 +122,7 @@ __device__ void conv_transpose2d( const size_t stride, const size_t padding, const size_t out_padding, + const size_t dilation, const size_t *info, const T *src, const T *kernel, @@ -335,12 +338,13 @@ extern "C" __global__ void FN_NAME( \ const size_t num_dims, \ const size_t stride, \ const size_t padding, \ + const size_t dilation, \ const size_t *info, \ const TYPENAME *src, \ const TYPENAME *kernel, \ TYPENAME *dst \ ) { \ - conv1d(src_numel, num_dims, stride, padding, info, src, kernel, dst); \ + conv1d(src_numel, num_dims, stride, padding, dilation, info, src, kernel, dst); \ } \ #define CONV2D_OP(TYPENAME, TYPEACC, FN_NAME) \ @@ -350,12 +354,13 @@ extern "C" __global__ void FN_NAME( \ const size_t h_out, \ const size_t stride, \ const size_t padding, \ + const size_t dilation, \ const size_t *info, \ const TYPENAME *src, \ const TYPENAME *kernel, \ TYPENAME *dst \ ) { \ - conv2d(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \ + conv2d(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \ } \ #define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \ @@ -366,12 +371,13 @@ extern "C" __global__ void FN_NAME( \ const size_t stride, \ const size_t padding, \ const size_t out_padding, \ + const size_t dilation, \ const size_t *info, \ const TYPENAME *src, \ const TYPENAME *kernel, \ TYPENAME *dst \ ) { \ - conv_transpose2d(src_numel, w_out, h_out, stride, padding, out_padding, info, src, kernel, dst); \ + conv_transpose2d(src_numel, w_out, h_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \ } \ #define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \ diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index e43de8ef..dbf23aa5 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -5,6 +5,7 @@ use candle::{Result, Tensor}; pub struct Conv1dConfig { pub padding: usize, pub stride: usize, + pub dilation: usize, pub groups: usize, } @@ -13,6 +14,7 @@ impl Default for Conv1dConfig { Self { padding: 0, stride: 1, + dilation: 1, groups: 1, } } @@ -45,6 +47,7 @@ impl crate::Module for Conv1d { &self.weight, self.config.padding, self.config.stride, + self.config.dilation, self.config.groups, )?; match &self.bias { @@ -62,6 +65,7 @@ impl crate::Module for Conv1d { pub struct Conv2dConfig { pub padding: usize, pub stride: usize, + pub dilation: usize, pub groups: usize, } @@ -70,6 +74,7 @@ impl Default for Conv2dConfig { Self { padding: 0, stride: 1, + dilation: 1, groups: 1, } } @@ -103,6 +108,7 @@ impl crate::Module for Conv2d { &self.weight, self.config.padding, self.config.stride, + self.config.dilation, self.config.groups, )?; match &self.bias { diff --git a/candle-wasm-examples/whisper/src/model.rs b/candle-wasm-examples/whisper/src/model.rs index 72dbdcdd..239ceee5 100644 --- a/candle-wasm-examples/whisper/src/model.rs +++ b/candle-wasm-examples/whisper/src/model.rs @@ -269,11 +269,13 @@ impl AudioEncoder { padding: 1, stride: 1, groups: 1, + dilation: 1, }; let cfg2 = Conv1dConfig { padding: 1, stride: 2, groups: 1, + dilation: 1, }; let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; diff --git a/candle-wasm-examples/yolo/src/model.rs b/candle-wasm-examples/yolo/src/model.rs index a63c6e94..e0fa7ac4 100644 --- a/candle-wasm-examples/yolo/src/model.rs +++ b/candle-wasm-examples/yolo/src/model.rs @@ -97,6 +97,7 @@ impl ConvBlock { padding, stride, groups: 1, + dilation: 1, }; let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?; let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;