Add more conv2d support. (#340)

* Add more conv2d support.

* Conv2d cpu work.

* Conv2d output shape.
This commit is contained in:
Laurent Mazare
2023-08-08 07:04:32 +02:00
committed by GitHub
parent d0d7010682
commit b5bb5e056d
7 changed files with 137 additions and 2 deletions

View File

@ -37,6 +37,14 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D, _params: &crate::conv::ParamsConv1D,
) -> Result<Self>; ) -> Result<Self>;
fn conv2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConv2D,
) -> Result<Self>;
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>; fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>; fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;

View File

@ -25,3 +25,32 @@ impl ParamsConv1D {
} }
} }
} }
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv2D {
pub(crate) b_size: usize,
pub(crate) i_h: usize,
pub(crate) i_w: usize,
pub(crate) k_h: usize,
pub(crate) k_w: usize,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
}
impl ParamsConv2D {
pub(crate) fn out_h(&self) -> usize {
let dilation = 1;
(self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1
}
pub(crate) fn out_w(&self) -> usize {
let dilation = 1;
(self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
}
}

View File

@ -1033,6 +1033,21 @@ impl<'a> Map2 for Conv1D<'a> {
} }
} }
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> {
const OP: &'static str = "conv2d";
fn f<T: 'static + num_traits::NumAssign + Copy>(
&self,
_inp: &[T],
_inp_l: &Layout,
_k: &[T],
_k_l: &Layout,
) -> Result<Vec<T>> {
todo!()
}
}
struct MatMul((usize, usize, usize, usize)); struct MatMul((usize, usize, usize, usize));
impl MatMul { impl MatMul {
@ -1804,6 +1819,16 @@ impl BackendStorage for CpuStorage {
Conv1D(params).map(self, l, kernel, kernel_l) Conv1D(params).map(self, l, kernel, kernel_l)
} }
fn conv2d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result<Self> {
Conv2D(params).map(self, l, kernel, kernel_l)
}
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> { fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
match ids { match ids {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),

View File

@ -1381,6 +1381,16 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device }) Ok(Self { slice, device })
} }
fn conv2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConv2D,
) -> Result<Self> {
todo!()
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> { fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
todo!() todo!()
} }

View File

@ -75,6 +75,16 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
fn conv2d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv2D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> { fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }

View File

@ -266,6 +266,33 @@ impl Storage {
} }
} }
pub(crate) fn conv2d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result<Self> {
self.same_device(kernel, "conv2d")?;
self.same_dtype(kernel, "conv2d")?;
match (self, &kernel) {
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Cpu(s))
}
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "conv2d",
}
.bt()),
}
}
pub(crate) fn avg_pool2d( pub(crate) fn avg_pool2d(
&self, &self,
layout: &Layout, layout: &Layout,

View File

@ -817,8 +817,34 @@ impl Tensor {
Ok(from_storage(storage, out_dims, op, false)) Ok(from_storage(storage, out_dims, op, false))
} }
pub fn conv2d(&self, _kernel: &Self, _padding: usize, _stride: usize) -> Result<Self> { pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
todo!() let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = crate::conv::ParamsConv2D {
b_size,
i_h,
i_w,
k_h,
k_w,
c_out,
c_in,
padding,
stride,
};
let storage =
self.storage()
.conv2d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
arg,
kernel,
padding,
stride,
});
let out_dims = params.out_dims();
Ok(from_storage(storage, out_dims, op, false))
} }
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> { pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {