Add more conv2d support. (#340)

* Add more conv2d support.

* Conv2d cpu work.

* Conv2d output shape.
This commit is contained in:
Laurent Mazare
2023-08-08 07:04:32 +02:00
committed by GitHub
parent d0d7010682
commit b5bb5e056d
7 changed files with 137 additions and 2 deletions

View File

@ -817,8 +817,34 @@ impl Tensor {
Ok(from_storage(storage, out_dims, op, false))
}
pub fn conv2d(&self, _kernel: &Self, _padding: usize, _stride: usize) -> Result<Self> {
todo!()
pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = crate::conv::ParamsConv2D {
b_size,
i_h,
i_w,
k_h,
k_w,
c_out,
c_in,
padding,
stride,
};
let storage =
self.storage()
.conv2d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
arg,
kernel,
padding,
stride,
});
let out_dims = params.out_dims();
Ok(from_storage(storage, out_dims, op, false))
}
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {