mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Handle zero dims in some simple operations. (#2064)
* Handle zero dims in some simple operations. * Handle zero-dims in matmul. * More testing.
This commit is contained in:
@ -1083,6 +1083,27 @@ fn randn(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn zero_dim(device: &Device) -> Result<()> {
|
||||
let t = Tensor::zeros((4, 0, 1), DType::F32, device)?;
|
||||
assert_eq!(t.dims3()?, (4, 0, 1));
|
||||
let t2 = Tensor::zeros((4, 3, 1), DType::F32, device)?;
|
||||
let t_cat = Tensor::cat(&[&t, &t2], 1)?;
|
||||
assert_eq!(t_cat.dims3()?, (4, 3, 1));
|
||||
let t_cat = Tensor::cat(&[&t, &t], 1)?;
|
||||
assert_eq!(t_cat.dims3()?, (4, 0, 1));
|
||||
let t_unary = t.sqrt()?;
|
||||
assert_eq!(t_unary.dims3()?, (4, 0, 1));
|
||||
let t_plus = (&t + 1.)?;
|
||||
assert_eq!(t_plus.dims3()?, (4, 0, 1));
|
||||
let t_mm = t2.matmul(&t.t()?)?;
|
||||
assert_eq!(t_mm.dims3()?, (4, 3, 0));
|
||||
let t_mm = t.matmul(&t2.t()?)?;
|
||||
assert_eq!(t_mm.dims3()?, (4, 0, 3));
|
||||
let t_mm = t.t()?.matmul(&t)?;
|
||||
assert_eq!(t_mm.dims3()?, (4, 1, 1));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
||||
@ -1131,6 +1152,7 @@ test_device!(
|
||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
|
||||
|
||||
// There was originally a bug on the CPU implementation for randn
|
||||
// https://github.com/huggingface/candle/issues/381
|
||||
|
Reference in New Issue
Block a user