Fix cuda randn when generating an odd number of values. (#793)

This commit is contained in:
Laurent Mazare
2023-09-09 18:44:21 +01:00
committed by GitHub
parent 31936c08fe
commit 258ac32c38
2 changed files with 18 additions and 2 deletions

View File

@ -877,6 +877,14 @@ fn broadcasting(device: &Device) -> Result<()> {
Ok(())
}
fn randn(device: &Device) -> Result<()> {
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
assert_eq!(tensor.dims(), [5, 3]);
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
assert_eq!(tensor.dims(), [5, 3]);
Ok(())
}
test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
@ -899,6 +907,7 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
test_device!(randn, randn_cpu, randn_gpu);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381