mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Simplify usage of the pool functions. (#662)
* Simplify usage of the pool functions. * Small tweak. * Attempt at using apply to simplify the convnet definition.
This commit is contained in:
@ -256,7 +256,7 @@ impl Tensor {
|
||||
// we scale the gradient for this case).
|
||||
let node_upsampled = node.upsample_nearest2d(h, w)?;
|
||||
let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
|
||||
let avg = mask.avg_pool2d(*kernel_size, *stride)?;
|
||||
let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
|
||||
let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
|
@ -91,3 +91,36 @@ extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
pub trait ToUsize2 {
|
||||
fn to_usize2(self) -> (usize, usize);
|
||||
}
|
||||
|
||||
impl ToUsize2 for usize {
|
||||
fn to_usize2(self) -> (usize, usize) {
|
||||
(self, self)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToUsize2 for (usize, usize) {
|
||||
fn to_usize2(self) -> (usize, usize) {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
// A simple trait defining a module with forward method using a single argument.
|
||||
pub trait Module: std::fmt::Debug {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||
|
||||
/// Change the module to use training mode vs eval mode.
|
||||
///
|
||||
/// The default implementation does nothing as this is only used for a couple modules such as
|
||||
/// dropout or batch-normalization.
|
||||
fn set_training(&mut self, _training: bool) {}
|
||||
}
|
||||
|
||||
impl Module for quantized::QMatMul {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.forward(xs)
|
||||
}
|
||||
}
|
||||
|
@ -797,7 +797,18 @@ impl Tensor {
|
||||
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
|
||||
}
|
||||
|
||||
pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
|
||||
let sz = sz.to_usize2();
|
||||
self.avg_pool2d_with_stride(sz, sz)
|
||||
}
|
||||
|
||||
pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
|
||||
&self,
|
||||
kernel_size: T,
|
||||
stride: T,
|
||||
) -> Result<Self> {
|
||||
let kernel_size = kernel_size.to_usize2();
|
||||
let stride = stride.to_usize2();
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
@ -813,7 +824,18 @@ impl Tensor {
|
||||
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
||||
}
|
||||
|
||||
pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
||||
pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
|
||||
let sz = sz.to_usize2();
|
||||
self.max_pool2d_with_stride(sz, sz)
|
||||
}
|
||||
|
||||
pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
|
||||
&self,
|
||||
kernel_size: T,
|
||||
stride: T,
|
||||
) -> Result<Self> {
|
||||
let kernel_size = kernel_size.to_usize2();
|
||||
let stride = stride.to_usize2();
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
@ -1855,6 +1877,10 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
|
||||
m.forward(self)
|
||||
}
|
||||
|
||||
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
self.storage.read().unwrap()
|
||||
}
|
||||
|
@ -6,14 +6,14 @@ fn avg_pool2d(dev: &Device) -> Result<()> {
|
||||
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
|
||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
|
||||
|
||||
let data: Vec<f32> = vec![
|
||||
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 2, 8), dev)?;
|
||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[5. / 4., 6. / 4., 6. / 4., 1.]]);
|
||||
Ok(())
|
||||
}
|
||||
@ -24,11 +24,11 @@ fn max_pool2d(dev: &Device) -> Result<()> {
|
||||
];
|
||||
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
|
||||
|
||||
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
|
||||
|
||||
let t = t.reshape((1, 1, 2, 8))?;
|
||||
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(pool.to_vec2::<f32>()?, [[2.0, 3.0, 5.0, 1.0]]);
|
||||
Ok(())
|
||||
}
|
||||
@ -53,7 +53,7 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
||||
dev,
|
||||
)?
|
||||
.reshape((1, 2, 4, 4))?;
|
||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?;
|
||||
let pool = t.avg_pool2d(2)?.squeeze(0)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&pool, 4)?,
|
||||
[
|
||||
@ -61,14 +61,14 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
||||
[[0.1835, -0.1606], [0.6249, 0.3217]]
|
||||
]
|
||||
);
|
||||
let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?;
|
||||
let pool = t.avg_pool2d(3)?.squeeze(0)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&pool, 4)?,
|
||||
[[[0.085]], [[0.0078]]]
|
||||
);
|
||||
|
||||
let t = t.reshape((1, 1, 4, 8))?;
|
||||
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
|
||||
let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&pool, 4)?,
|
||||
[
|
||||
|
Reference in New Issue
Block a user