Add the var method. (#1315)

* Add the var method.

* Add a test.
This commit is contained in:
Laurent Mazare
2023-11-10 22:47:57 +01:00
committed by GitHub
parent 1b12142a02
commit 9e666d4229
2 changed files with 31 additions and 0 deletions

View File

@ -856,6 +856,20 @@ impl Tensor {
self.sum_impl(mean_dims, false)? * scale self.sum_impl(mean_dims, false)? * scale
} }
/// Returns the unbiased variance over the selected dimension.
pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "var")?;
let mean = self.mean_keepdim(dim)?;
let squares = self.broadcast_sub(&mean)?.sqr()?;
squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
}
/// Returns the unbiased variance over the selected dimension.
pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "var")?;
self.var_keepdim(dim)?.squeeze(dim)
}
/// Gathers the maximum value across the selected dimension. The resulting shape has the same /// Gathers the maximum value across the selected dimension. The resulting shape has the same
/// number of dimensions as the original tensor and the select dimension has a single element. /// number of dimensions as the original tensor and the select dimension has a single element.
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> { pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {

View File

@ -180,6 +180,22 @@ fn transpose(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
fn var(device: &Device) -> Result<()> {
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
let data = &[
[0.2035f32, 1.2959, 1.8101, -0.4644],
[1.5027, -0.3270, 0.5905, 0.6538],
[-1.5745, 1.3330, -0.5596, -0.6548],
[0.1264, -0.5080, 1.6420, 0.1992],
];
let tensor = Tensor::new(data, device)?;
assert_eq!(
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
&[[1.0631], [0.559], [1.4893], [0.8258]]
);
Ok(())
}
fn sum(device: &Device) -> Result<()> { fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?; let tensor = Tensor::new(data, device)?;
@ -1082,6 +1098,7 @@ test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu); test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
test_device!(randn, randn_cpu, randn_gpu); test_device!(randn, randn_cpu, randn_gpu);
test_device!(clamp, clamp_cpu, clamp_gpu); test_device!(clamp, clamp_cpu, clamp_gpu);
test_device!(var, var_cpu, var_gpu);
// There was originally a bug on the CPU implementation for randn // There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381 // https://github.com/huggingface/candle/issues/381