Proper conv1d dispatch.

This commit is contained in:
laurent
2023-07-04 11:29:28 +01:00
parent a424d95473
commit 950b4af49e
3 changed files with 34 additions and 10 deletions

View File

@ -1,6 +1,9 @@
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct ParamsConv1D { pub(crate) struct ParamsConv1D {
pub(crate) b_size: Option<usize>, pub(crate) b_size: Option<usize>,
// Maybe we should have a version without l_in as this bit depends on the input and not only on
// the weights.
pub(crate) l_in: usize,
pub(crate) c_out: usize, pub(crate) c_out: usize,
pub(crate) c_in: usize, pub(crate) c_in: usize,
pub(crate) k_size: usize, pub(crate) k_size: usize,
@ -9,13 +12,13 @@ pub(crate) struct ParamsConv1D {
} }
impl ParamsConv1D { impl ParamsConv1D {
pub(crate) fn l_out(&self, l_in: usize) -> usize { pub(crate) fn l_out(&self) -> usize {
let dilation = 1; let dilation = 1;
(l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1 (self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
} }
pub(crate) fn out_dims(&self, l_in: usize) -> Vec<usize> { pub(crate) fn out_dims(&self) -> Vec<usize> {
let l_out = self.l_out(l_in); let l_out = self.l_out();
match self.b_size { match self.b_size {
None => vec![self.c_out, l_out], None => vec![self.c_out, l_out],
Some(n) => vec![n, self.c_out, l_out], Some(n) => vec![n, self.c_out, l_out],

View File

@ -202,6 +202,26 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
} }
} }
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
impl<'a> Map2 for Conv1D<'a> {
const OP: &'static str = "conv1d";
fn f<T: 'static + num_traits::Num + Copy>(
&self,
_inp: &[T],
_inp_l: &Layout,
_k: &[T],
_k_l: &Layout,
) -> Result<Vec<T>> {
let p = self.0;
let l_out = p.l_out();
let out_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
let dst = vec![T::zero(); out_elems];
// TODO: actually implement the ops.
Ok(dst)
}
}
struct MatMul((usize, usize, usize, usize)); struct MatMul((usize, usize, usize, usize));
impl Map2 for MatMul { impl Map2 for MatMul {
@ -629,12 +649,12 @@ impl CpuStorage {
pub(crate) fn conv1d( pub(crate) fn conv1d(
&self, &self,
_l: &Layout, l: &Layout,
_kernel: &Self, kernel: &Self,
_kernel_l: &Layout, kernel_l: &Layout,
_params: &crate::conv::ParamsConv1D, params: &crate::conv::ParamsConv1D,
) -> Result<Self> { ) -> Result<Self> {
todo!() Conv1D(params).map(self, l, kernel, kernel_l)
} }
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {

View File

@ -444,6 +444,7 @@ impl Tensor {
} }
let params = crate::conv::ParamsConv1D { let params = crate::conv::ParamsConv1D {
b_size, b_size,
l_in,
c_out, c_out,
c_in, c_in,
k_size, k_size,
@ -463,7 +464,7 @@ impl Tensor {
} else { } else {
None None
}; };
let out_dims = params.out_dims(l_in); let out_dims = params.out_dims();
Ok(from_storage(storage, out_dims, op, false)) Ok(from_storage(storage, out_dims, op, false))
} }