Add more conv2d support. (#340)

* Add more conv2d support.

* Conv2d cpu work.

* Conv2d output shape.
This commit is contained in:
Laurent Mazare
2023-08-08 07:04:32 +02:00
committed by GitHub
parent d0d7010682
commit b5bb5e056d
7 changed files with 137 additions and 2 deletions

View File

@ -1033,6 +1033,21 @@ impl<'a> Map2 for Conv1D<'a> {
}
}
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> {
const OP: &'static str = "conv2d";
fn f<T: 'static + num_traits::NumAssign + Copy>(
&self,
_inp: &[T],
_inp_l: &Layout,
_k: &[T],
_k_l: &Layout,
) -> Result<Vec<T>> {
todo!()
}
}
struct MatMul((usize, usize, usize, usize));
impl MatMul {
@ -1804,6 +1819,16 @@ impl BackendStorage for CpuStorage {
Conv1D(params).map(self, l, kernel, kernel_l)
}
fn conv2d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result<Self> {
Conv2D(params).map(self, l, kernel, kernel_l)
}
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
match ids {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),