Sketch the conv1d op.

This commit is contained in:
laurent
2023-07-04 10:52:34 +01:00
parent e6b01d0c18
commit 3aac1047fe
7 changed files with 97 additions and 1 deletions

View File

@ -432,6 +432,28 @@ impl Tensor {
Ok(from_storage(storage, dims, op, false))
}
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let storage = self.storage.conv1d(
self.layout(),
&kernel.storage,
kernel.layout(),
padding,
stride,
)?;
let op = if self.track_op() || kernel.track_op() {
Some(Op::Conv1D {
arg: self.clone(),
kernel: kernel.clone(),
padding,
stride,
})
} else {
None
};
let dims = self.dims();
Ok(from_storage(storage, dims, op, false))
}
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
let a_dims = self.shape().dims();
let b_dims = rhs.shape().dims();