Handle zero dims in some simple operations. (#2064)

* Handle zero dims in some simple operations.

* Handle zero-dims in matmul.

* More testing.
This commit is contained in:
Laurent Mazare
2024-04-15 09:18:54 +02:00
committed by GitHub
parent f7d5bf5b97
commit e198bb0816
2 changed files with 43 additions and 0 deletions

View File

@ -1083,6 +1083,27 @@ fn randn(device: &Device) -> Result<()> {
Ok(())
}
fn zero_dim(device: &Device) -> Result<()> {
let t = Tensor::zeros((4, 0, 1), DType::F32, device)?;
assert_eq!(t.dims3()?, (4, 0, 1));
let t2 = Tensor::zeros((4, 3, 1), DType::F32, device)?;
let t_cat = Tensor::cat(&[&t, &t2], 1)?;
assert_eq!(t_cat.dims3()?, (4, 3, 1));
let t_cat = Tensor::cat(&[&t, &t], 1)?;
assert_eq!(t_cat.dims3()?, (4, 0, 1));
let t_unary = t.sqrt()?;
assert_eq!(t_unary.dims3()?, (4, 0, 1));
let t_plus = (&t + 1.)?;
assert_eq!(t_plus.dims3()?, (4, 0, 1));
let t_mm = t2.matmul(&t.t()?)?;
assert_eq!(t_mm.dims3()?, (4, 3, 0));
let t_mm = t.matmul(&t2.t()?)?;
assert_eq!(t_mm.dims3()?, (4, 0, 3));
let t_mm = t.t()?.matmul(&t)?;
assert_eq!(t_mm.dims3()?, (4, 1, 1));
Ok(())
}
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(full, full_cpu, full_gpu, full_metal);
@ -1131,6 +1152,7 @@ test_device!(
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381