diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f032a896..d51a3db7 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -856,6 +856,20 @@ impl Tensor { self.sum_impl(mean_dims, false)? * scale } + /// Returns the unbiased variance over the selected dimension. + pub fn var_keepdim(&self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "var")?; + let mean = self.mean_keepdim(dim)?; + let squares = self.broadcast_sub(&mean)?.sqr()?; + squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64 + } + + /// Returns the unbiased variance over the selected dimension. + pub fn var(&self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "var")?; + self.var_keepdim(dim)?.squeeze(dim) + } + /// Gathers the maximum value across the selected dimension. The resulting shape has the same /// number of dimensions as the original tensor and the select dimension has a single element. pub fn max_keepdim(&self, dim: D) -> Result { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 734cb7e8..cc44ce94 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -180,6 +180,22 @@ fn transpose(device: &Device) -> Result<()> { Ok(()) } +fn var(device: &Device) -> Result<()> { + // Values taken from https://pytorch.org/docs/stable/generated/torch.var.html + let data = &[ + [0.2035f32, 1.2959, 1.8101, -0.4644], + [1.5027, -0.3270, 0.5905, 0.6538], + [-1.5745, 1.3330, -0.5596, -0.6548], + [0.1264, -0.5080, 1.6420, 0.1992], + ]; + let tensor = Tensor::new(data, device)?; + assert_eq!( + test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?, + &[[1.0631], [0.559], [1.4893], [0.8258]] + ); + Ok(()) +} + fn sum(device: &Device) -> Result<()> { let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; let tensor = Tensor::new(data, device)?; @@ -1082,6 +1098,7 @@ test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu); test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu); test_device!(randn, randn_cpu, randn_gpu); test_device!(clamp, clamp_cpu, clamp_gpu); +test_device!(var, var_cpu, var_gpu); // There was originally a bug on the CPU implementation for randn // https://github.com/huggingface/candle/issues/381