Add more conv2d support. (#340)

* Add more conv2d support.

* Conv2d cpu work.

* Conv2d output shape.
This commit is contained in:
Laurent Mazare
2023-08-08 07:04:32 +02:00
committed by GitHub
parent d0d7010682
commit b5bb5e056d
7 changed files with 137 additions and 2 deletions

View File

@ -266,6 +266,33 @@ impl Storage {
}
}
pub(crate) fn conv2d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result<Self> {
self.same_device(kernel, "conv2d")?;
self.same_dtype(kernel, "conv2d")?;
match (self, &kernel) {
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Cpu(s))
}
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "conv2d",
}
.bt()),
}
}
pub(crate) fn avg_pool2d(
&self,
layout: &Layout,