diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 90bb5229..041bb6fb 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -1,6 +1,9 @@ #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct ParamsConv1D { pub(crate) b_size: Option, + // Maybe we should have a version without l_in as this bit depends on the input and not only on + // the weights. + pub(crate) l_in: usize, pub(crate) c_out: usize, pub(crate) c_in: usize, pub(crate) k_size: usize, @@ -9,13 +12,13 @@ pub(crate) struct ParamsConv1D { } impl ParamsConv1D { - pub(crate) fn l_out(&self, l_in: usize) -> usize { + pub(crate) fn l_out(&self) -> usize { let dilation = 1; - (l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1 + (self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1 } - pub(crate) fn out_dims(&self, l_in: usize) -> Vec { - let l_out = self.l_out(l_in); + pub(crate) fn out_dims(&self) -> Vec { + let l_out = self.l_out(); match self.b_size { None => vec![self.c_out, l_out], Some(n) => vec![n, self.c_out, l_out], diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 54002184..718b071c 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -202,6 +202,26 @@ fn copy_strided_src_( } } +struct Conv1D<'a>(&'a crate::conv::ParamsConv1D); + +impl<'a> Map2 for Conv1D<'a> { + const OP: &'static str = "conv1d"; + fn f( + &self, + _inp: &[T], + _inp_l: &Layout, + _k: &[T], + _k_l: &Layout, + ) -> Result> { + let p = self.0; + let l_out = p.l_out(); + let out_elems = p.c_out * l_out * p.b_size.unwrap_or(1); + let dst = vec![T::zero(); out_elems]; + // TODO: actually implement the ops. + Ok(dst) + } +} + struct MatMul((usize, usize, usize, usize)); impl Map2 for MatMul { @@ -629,12 +649,12 @@ impl CpuStorage { pub(crate) fn conv1d( &self, - _l: &Layout, - _kernel: &Self, - _kernel_l: &Layout, - _params: &crate::conv::ParamsConv1D, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv1D, ) -> Result { - todo!() + Conv1D(params).map(self, l, kernel, kernel_l) } pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 590b81c4..25ab0a9b 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -444,6 +444,7 @@ impl Tensor { } let params = crate::conv::ParamsConv1D { b_size, + l_in, c_out, c_in, k_size, @@ -463,7 +464,7 @@ impl Tensor { } else { None }; - let out_dims = params.out_dims(l_in); + let out_dims = params.out_dims(); Ok(from_storage(storage, out_dims, op, false)) }