Add more of the conv1d op.

This commit is contained in:
laurent
2023-07-04 11:15:45 +01:00
parent 3aac1047fe
commit a424d95473
8 changed files with 52 additions and 19 deletions

View File

@ -236,8 +236,7 @@ impl Conv1D {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let (bsize, _, _) = x.shape().r3()?;
let w = self.weight.broadcast_left(bsize)?.t()?;
// TODO: Add the conv1d operation
let x = x.matmul(&w)?;
let x = x.conv1d(&w, self.config.padding, self.config.stride)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),