mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Simplify usage of the pool functions. (#662)
* Simplify usage of the pool functions. * Small tweak. * Attempt at using apply to simplify the convnet definition.
This commit is contained in:
@ -797,7 +797,18 @@ impl Tensor {
|
||||
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
|
||||
}
|
||||
|
||||
pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
|
||||
let sz = sz.to_usize2();
|
||||
self.avg_pool2d_with_stride(sz, sz)
|
||||
}
|
||||
|
||||
pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
|
||||
&self,
|
||||
kernel_size: T,
|
||||
stride: T,
|
||||
) -> Result<Self> {
|
||||
let kernel_size = kernel_size.to_usize2();
|
||||
let stride = stride.to_usize2();
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
@ -813,7 +824,18 @@ impl Tensor {
|
||||
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
||||
}
|
||||
|
||||
pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
|
||||
let sz = sz.to_usize2();
|
||||
self.max_pool2d_with_stride(sz, sz)
|
||||
}
|
||||
|
||||
pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
|
||||
&self,
|
||||
kernel_size: T,
|
||||
stride: T,
|
||||
) -> Result<Self> {
|
||||
let kernel_size = kernel_size.to_usize2();
|
||||
let stride = stride.to_usize2();
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
@ -1855,6 +1877,10 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
|
||||
m.forward(self)
|
||||
}
|
||||
|
||||
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
self.storage.read().unwrap()
|
||||
}
|
||||
|
Reference in New Issue
Block a user