mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Skeleton for the avg-pool2d and upsample-nearest2d ops. (#337)
* Skeleton for the avg-pool2d and upsample-nearest2d ops. * Preliminary conv2d support.
This commit is contained in:
@ -817,6 +817,35 @@ impl Tensor {
|
||||
Ok(from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
pub fn conv2d(&self, _kernel: &Self, _padding: usize, _stride: usize) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||
let (n, c, _h, _w) = self.dims4()?;
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
|
||||
let storage = self
|
||||
.storage()
|
||||
.upsample_nearest2d(self.layout(), target_h, target_w)?;
|
||||
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
|
||||
}
|
||||
|
||||
pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||
let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
|
||||
arg,
|
||||
kernel_size,
|
||||
stride,
|
||||
});
|
||||
let storage = self
|
||||
.storage()
|
||||
.avg_pool2d(self.layout(), kernel_size, stride)?;
|
||||
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
||||
}
|
||||
|
||||
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
|
Reference in New Issue
Block a user