mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
@ -856,6 +856,20 @@ impl Tensor {
|
|||||||
self.sum_impl(mean_dims, false)? * scale
|
self.sum_impl(mean_dims, false)? * scale
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the unbiased variance over the selected dimension.
|
||||||
|
pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "var")?;
|
||||||
|
let mean = self.mean_keepdim(dim)?;
|
||||||
|
let squares = self.broadcast_sub(&mean)?.sqr()?;
|
||||||
|
squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the unbiased variance over the selected dimension.
|
||||||
|
pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "var")?;
|
||||||
|
self.var_keepdim(dim)?.squeeze(dim)
|
||||||
|
}
|
||||||
|
|
||||||
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
|
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
|
||||||
/// number of dimensions as the original tensor and the select dimension has a single element.
|
/// number of dimensions as the original tensor and the select dimension has a single element.
|
||||||
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
|
@ -180,6 +180,22 @@ fn transpose(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn var(device: &Device) -> Result<()> {
|
||||||
|
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
|
||||||
|
let data = &[
|
||||||
|
[0.2035f32, 1.2959, 1.8101, -0.4644],
|
||||||
|
[1.5027, -0.3270, 0.5905, 0.6538],
|
||||||
|
[-1.5745, 1.3330, -0.5596, -0.6548],
|
||||||
|
[0.1264, -0.5080, 1.6420, 0.1992],
|
||||||
|
];
|
||||||
|
let tensor = Tensor::new(data, device)?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
|
||||||
|
&[[1.0631], [0.559], [1.4893], [0.8258]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn sum(device: &Device) -> Result<()> {
|
fn sum(device: &Device) -> Result<()> {
|
||||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||||
let tensor = Tensor::new(data, device)?;
|
let tensor = Tensor::new(data, device)?;
|
||||||
@ -1082,6 +1098,7 @@ test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
|||||||
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
||||||
test_device!(randn, randn_cpu, randn_gpu);
|
test_device!(randn, randn_cpu, randn_gpu);
|
||||||
test_device!(clamp, clamp_cpu, clamp_gpu);
|
test_device!(clamp, clamp_cpu, clamp_gpu);
|
||||||
|
test_device!(var, var_cpu, var_gpu);
|
||||||
|
|
||||||
// There was originally a bug on the CPU implementation for randn
|
// There was originally a bug on the CPU implementation for randn
|
||||||
// https://github.com/huggingface/candle/issues/381
|
// https://github.com/huggingface/candle/issues/381
|
||||||
|
Reference in New Issue
Block a user