diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index f4f90373..c6d55e61 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -256,7 +256,7 @@ impl Tensor { // we scale the gradient for this case). let node_upsampled = node.upsample_nearest2d(h, w)?; let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?; - let avg = mask.avg_pool2d(*kernel_size, *stride)?; + let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?; let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad_arg)?; diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index fa85f6e0..a0347416 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -91,3 +91,36 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; + +pub trait ToUsize2 { + fn to_usize2(self) -> (usize, usize); +} + +impl ToUsize2 for usize { + fn to_usize2(self) -> (usize, usize) { + (self, self) + } +} + +impl ToUsize2 for (usize, usize) { + fn to_usize2(self) -> (usize, usize) { + self + } +} + +// A simple trait defining a module with forward method using a single argument. +pub trait Module: std::fmt::Debug { + fn forward(&self, xs: &Tensor) -> Result; + + /// Change the module to use training mode vs eval mode. + /// + /// The default implementation does nothing as this is only used for a couple modules such as + /// dropout or batch-normalization. + fn set_training(&mut self, _training: bool) {} +} + +impl Module for quantized::QMatMul { + fn forward(&self, xs: &Tensor) -> Result { + self.forward(xs) + } +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 75b3743d..f834e040 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -797,7 +797,18 @@ impl Tensor { Ok(from_storage(storage, (n, c, target_h, target_w), op, false)) } - pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result { + pub fn avg_pool2d(&self, sz: T) -> Result { + let sz = sz.to_usize2(); + self.avg_pool2d_with_stride(sz, sz) + } + + pub fn avg_pool2d_with_stride( + &self, + kernel_size: T, + stride: T, + ) -> Result { + let kernel_size = kernel_size.to_usize2(); + let stride = stride.to_usize2(); let (n, c, h, w) = self.dims4()?; // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d let h_out = (h - kernel_size.0) / stride.0 + 1; @@ -813,7 +824,18 @@ impl Tensor { Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) } - pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result { + pub fn max_pool2d(&self, sz: T) -> Result { + let sz = sz.to_usize2(); + self.max_pool2d_with_stride(sz, sz) + } + + pub fn max_pool2d_with_stride( + &self, + kernel_size: T, + stride: T, + ) -> Result { + let kernel_size = kernel_size.to_usize2(); + let stride = stride.to_usize2(); let (n, c, h, w) = self.dims4()?; // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d let h_out = (h - kernel_size.0) / stride.0 + 1; @@ -1855,6 +1877,10 @@ impl Tensor { } } + pub fn apply(&self, m: &M) -> Result { + m.forward(self) + } + pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { self.storage.read().unwrap() } diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs index b8c007b8..c6db194d 100644 --- a/candle-core/tests/pool_tests.rs +++ b/candle-core/tests/pool_tests.rs @@ -6,14 +6,14 @@ fn avg_pool2d(dev: &Device) -> Result<()> { 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., ]; let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?; - let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?; + let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?; assert_eq!(pool.to_vec2::()?, [[0.5f32, 1.], [1., 1.]]); let data: Vec = vec![ 1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1., ]; let t = Tensor::from_vec(data, (1, 1, 2, 8), dev)?; - let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?; + let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?; assert_eq!(pool.to_vec2::()?, [[5. / 4., 6. / 4., 6. / 4., 1.]]); Ok(()) } @@ -24,11 +24,11 @@ fn max_pool2d(dev: &Device) -> Result<()> { ]; let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?; - let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?; + let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?; assert_eq!(pool.to_vec2::()?, [[2f32, 3.], [5., 1.]]); let t = t.reshape((1, 1, 2, 8))?; - let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?; + let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?; assert_eq!(pool.to_vec2::()?, [[2.0, 3.0, 5.0, 1.0]]); Ok(()) } @@ -53,7 +53,7 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> { dev, )? .reshape((1, 2, 4, 4))?; - let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?; + let pool = t.avg_pool2d(2)?.squeeze(0)?; assert_eq!( test_utils::to_vec3_round(&pool, 4)?, [ @@ -61,14 +61,14 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> { [[0.1835, -0.1606], [0.6249, 0.3217]] ] ); - let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?; + let pool = t.avg_pool2d(3)?.squeeze(0)?; assert_eq!( test_utils::to_vec3_round(&pool, 4)?, [[[0.085]], [[0.0078]]] ); let t = t.reshape((1, 1, 4, 8))?; - let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?; + let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?; assert_eq!( test_utils::to_vec2_round(&pool, 4)?, [ diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs index 5bbce31b..a90904c4 100644 --- a/candle-examples/examples/mnist-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -83,13 +83,15 @@ impl Model for ConvNet { fn forward(&self, xs: &Tensor) -> Result { let (b_sz, _img_dim) = xs.dims2()?; - let xs = xs.reshape((b_sz, 1, 28, 28))?; - let xs = self.conv1.forward(&xs)?.max_pool2d((2, 2), (2, 2))?; - let xs = self.conv2.forward(&xs)?.max_pool2d((2, 2), (2, 2))?; - let xs = xs.flatten_from(1)?; - let xs = self.fc1.forward(&xs)?; - let xs = xs.relu()?; - self.fc2.forward(&xs) + xs.reshape((b_sz, 1, 28, 28))? + .apply(&self.conv1)? + .max_pool2d(2)? + .apply(&self.conv2)? + .max_pool2d(2)? + .flatten_from(1)? + .apply(&self.fc1)? + .relu()? + .apply(&self.fc2) } } diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs index 1db65222..26a1035b 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs @@ -47,7 +47,7 @@ impl Downsample2D { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); match &self.conv { - None => xs.avg_pool2d((2, 2), (2, 2)), + None => xs.avg_pool2d(2), Some(conv) => { if self.padding == 0 { let xs = xs diff --git a/candle-examples/examples/yolo-v8/model.rs b/candle-examples/examples/yolo-v8/model.rs index d7fe5c12..b834f967 100644 --- a/candle-examples/examples/yolo-v8/model.rs +++ b/candle-examples/examples/yolo-v8/model.rs @@ -198,15 +198,15 @@ impl Module for Sppf { let xs2 = xs .pad_with_zeros(2, self.k / 2, self.k / 2)? .pad_with_zeros(3, self.k / 2, self.k / 2)? - .max_pool2d((self.k, self.k), (1, 1))?; + .max_pool2d_with_stride(self.k, 1)?; let xs3 = xs2 .pad_with_zeros(2, self.k / 2, self.k / 2)? .pad_with_zeros(3, self.k / 2, self.k / 2)? - .max_pool2d((self.k, self.k), (1, 1))?; + .max_pool2d_with_stride(self.k, 1)?; let xs4 = xs3 .pad_with_zeros(2, self.k / 2, self.k / 2)? .pad_with_zeros(3, self.k / 2, self.k / 2)? - .max_pool2d((self.k, self.k), (1, 1))?; + .max_pool2d_with_stride(self.k, 1)?; self.cv2.forward(&Tensor::cat(&[&xs, &xs2, &xs3, &xs4], 1)?) } } diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 34e2dbed..2e2c2545 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -1,5 +1,3 @@ -use candle::{Result, Tensor}; - pub mod activation; pub mod batch_norm; pub mod conv; @@ -28,19 +26,4 @@ pub use optim::{AdamW, ParamsAdamW, SGD}; pub use var_builder::VarBuilder; pub use var_map::VarMap; -// A simple trait defining a module with forward method using a single argument. -pub trait Module: std::fmt::Debug { - fn forward(&self, xs: &Tensor) -> Result; - - /// Change the module to use training mode vs eval mode. - /// - /// The default implementation does nothing as this is only used for a couple modules such as - /// dropout or batch-normalization. - fn set_training(&mut self, _training: bool) {} -} - -impl Module for candle::quantized::QMatMul { - fn forward(&self, xs: &Tensor) -> Result { - self.forward(xs) - } -} +pub use candle::Module; diff --git a/candle-wasm-examples/yolo/src/model.rs b/candle-wasm-examples/yolo/src/model.rs index e0fa7ac4..d49cf55f 100644 --- a/candle-wasm-examples/yolo/src/model.rs +++ b/candle-wasm-examples/yolo/src/model.rs @@ -202,15 +202,15 @@ impl Module for Sppf { let xs2 = xs .pad_with_zeros(2, self.k / 2, self.k / 2)? .pad_with_zeros(3, self.k / 2, self.k / 2)? - .max_pool2d((self.k, self.k), (1, 1))?; + .max_pool2d_with_stride(self.k, 1)?; let xs3 = xs2 .pad_with_zeros(2, self.k / 2, self.k / 2)? .pad_with_zeros(3, self.k / 2, self.k / 2)? - .max_pool2d((self.k, self.k), (1, 1))?; + .max_pool2d_with_stride(self.k, 1)?; let xs4 = xs3 .pad_with_zeros(2, self.k / 2, self.k / 2)? .pad_with_zeros(3, self.k / 2, self.k / 2)? - .max_pool2d((self.k, self.k), (1, 1))?; + .max_pool2d_with_stride(self.k, 1)?; self.cv2.forward(&Tensor::cat(&[&xs, &xs2, &xs3, &xs4], 1)?) } }