Sketch the conv1d op.

This commit is contained in:
laurent
2023-07-04 10:52:34 +01:00
parent e6b01d0c18
commit 3aac1047fe
7 changed files with 97 additions and 1 deletions

View File

@ -144,6 +144,33 @@ impl Storage {
}
}
pub(crate) fn conv1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
padding: usize,
stride: usize,
) -> Result<Self> {
self.same_device(kernel, "conv1d")?;
self.same_dtype(kernel, "conv1d")?;
match (self, &kernel) {
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
let s = inp.conv1d(l, kernel, kernel_l, padding, stride)?;
Ok(Self::Cpu(s))
}
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
let s = inp.conv1d(l, kernel, kernel_l, padding, stride)?;
Ok(Self::Cuda(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "conv1d",
}),
}
}
pub(crate) fn where_cond(
&self,
layout: &Layout,