mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add more conv2d support. (#340)
* Add more conv2d support. * Conv2d cpu work. * Conv2d output shape.
This commit is contained in:
@ -1033,6 +1033,21 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||
|
||||
impl<'a> Map2 for Conv2D<'a> {
|
||||
const OP: &'static str = "conv2d";
|
||||
fn f<T: 'static + num_traits::NumAssign + Copy>(
|
||||
&self,
|
||||
_inp: &[T],
|
||||
_inp_l: &Layout,
|
||||
_k: &[T],
|
||||
_k_l: &Layout,
|
||||
) -> Result<Vec<T>> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
struct MatMul((usize, usize, usize, usize));
|
||||
|
||||
impl MatMul {
|
||||
@ -1804,6 +1819,16 @@ impl BackendStorage for CpuStorage {
|
||||
Conv1D(params).map(self, l, kernel, kernel_l)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
Conv2D(params).map(self, l, kernel, kernel_l)
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
match ids {
|
||||
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
|
Reference in New Issue
Block a user