From b5bb5e056d838ad23a95e3feaf464bdca2b677cd Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 8 Aug 2023 07:04:32 +0200 Subject: [PATCH] Add more conv2d support. (#340) * Add more conv2d support. * Conv2d cpu work. * Conv2d output shape. --- candle-core/src/backend.rs | 8 +++++++ candle-core/src/conv.rs | 29 ++++++++++++++++++++++++++ candle-core/src/cpu_backend.rs | 25 ++++++++++++++++++++++ candle-core/src/cuda_backend.rs | 10 +++++++++ candle-core/src/dummy_cuda_backend.rs | 10 +++++++++ candle-core/src/storage.rs | 27 ++++++++++++++++++++++++ candle-core/src/tensor.rs | 30 +++++++++++++++++++++++++-- 7 files changed, 137 insertions(+), 2 deletions(-) diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index c3f8aa3c..a8e5ac52 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -37,6 +37,14 @@ pub trait BackendStorage: Sized { _params: &crate::conv::ParamsConv1D, ) -> Result; + fn conv2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConv2D, + ) -> Result; + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result; fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result; diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 4cf9d0ad..30799459 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -25,3 +25,32 @@ impl ParamsConv1D { } } } + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParamsConv2D { + pub(crate) b_size: usize, + pub(crate) i_h: usize, + pub(crate) i_w: usize, + pub(crate) k_h: usize, + pub(crate) k_w: usize, + pub(crate) c_out: usize, + pub(crate) c_in: usize, + pub(crate) padding: usize, + pub(crate) stride: usize, +} + +impl ParamsConv2D { + pub(crate) fn out_h(&self) -> usize { + let dilation = 1; + (self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1 + } + + pub(crate) fn out_w(&self) -> usize { + let dilation = 1; + (self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1 + } + + pub(crate) fn out_dims(&self) -> Vec { + vec![self.b_size, self.c_out, self.out_h(), self.out_w()] + } +} diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index a04ed9a0..c997d767 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1033,6 +1033,21 @@ impl<'a> Map2 for Conv1D<'a> { } } +struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); + +impl<'a> Map2 for Conv2D<'a> { + const OP: &'static str = "conv2d"; + fn f( + &self, + _inp: &[T], + _inp_l: &Layout, + _k: &[T], + _k_l: &Layout, + ) -> Result> { + todo!() + } +} + struct MatMul((usize, usize, usize, usize)); impl MatMul { @@ -1804,6 +1819,16 @@ impl BackendStorage for CpuStorage { Conv1D(params).map(self, l, kernel, kernel_l) } + fn conv2d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv2D, + ) -> Result { + Conv2D(params).map(self, l, kernel, kernel_l) + } + fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result { match ids { Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 3c37373a..727ea073 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1381,6 +1381,16 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } + fn conv2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConv2D, + ) -> Result { + todo!() + } + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { todo!() } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 99cb7c4e..ae4dd09f 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -75,6 +75,16 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + fn conv2d( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &crate::conv::ParamsConv2D, + ) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index b4fa02e8..3ed38e6a 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -266,6 +266,33 @@ impl Storage { } } + pub(crate) fn conv2d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv2D, + ) -> Result { + self.same_device(kernel, "conv2d")?; + self.same_dtype(kernel, "conv2d")?; + match (self, &kernel) { + (Storage::Cpu(inp), Storage::Cpu(kernel)) => { + let s = inp.conv2d(l, kernel, kernel_l, params)?; + Ok(Self::Cpu(s)) + } + (Storage::Cuda(inp), Storage::Cuda(kernel)) => { + let s = inp.conv2d(l, kernel, kernel_l, params)?; + Ok(Self::Cuda(s)) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "conv2d", + } + .bt()), + } + } + pub(crate) fn avg_pool2d( &self, layout: &Layout, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index ffa4bf8c..adba7376 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -817,8 +817,34 @@ impl Tensor { Ok(from_storage(storage, out_dims, op, false)) } - pub fn conv2d(&self, _kernel: &Self, _padding: usize, _stride: usize) -> Result { - todo!() + pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result { + let (b_size, c_in, i_h, i_w) = self.dims4()?; + let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?; + if c_in != c_in_k { + crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})") + } + let params = crate::conv::ParamsConv2D { + b_size, + i_h, + i_w, + k_h, + k_w, + c_out, + c_in, + padding, + stride, + }; + let storage = + self.storage() + .conv2d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?; + let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D { + arg, + kernel, + padding, + stride, + }); + let out_dims = params.out_dims(); + Ok(from_storage(storage, out_dims, op, false)) } pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result {