Skeleton for the avg-pool2d and upsample-nearest2d ops. (#337)

* Skeleton for the avg-pool2d and upsample-nearest2d ops.

* Preliminary conv2d support.
This commit is contained in:
Laurent Mazare
2023-08-07 17:15:38 +02:00
committed by GitHub
parent f53a333ea9
commit 2345b8ce3f
7 changed files with 88 additions and 17 deletions

View File

@ -85,8 +85,16 @@ impl Conv2d {
&self.config
}
pub fn forward(&self, _x: &Tensor) -> Result<Tensor> {
todo!()
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?;
match &self.bias {
None => Ok(x),
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1, 1))?;
Ok(x.broadcast_add(&bias)?)
}
}
}
}