mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Sketch the conv1d op.
This commit is contained in:
@ -432,6 +432,28 @@ impl Tensor {
|
||||
Ok(from_storage(storage, dims, op, false))
|
||||
}
|
||||
|
||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let storage = self.storage.conv1d(
|
||||
self.layout(),
|
||||
&kernel.storage,
|
||||
kernel.layout(),
|
||||
padding,
|
||||
stride,
|
||||
)?;
|
||||
let op = if self.track_op() || kernel.track_op() {
|
||||
Some(Op::Conv1D {
|
||||
arg: self.clone(),
|
||||
kernel: kernel.clone(),
|
||||
padding,
|
||||
stride,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let dims = self.dims();
|
||||
Ok(from_storage(storage, dims, op, false))
|
||||
}
|
||||
|
||||
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
||||
let a_dims = self.shape().dims();
|
||||
let b_dims = rhs.shape().dims();
|
||||
|
Reference in New Issue
Block a user