diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index cb00441f..7cc85489 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -312,6 +312,13 @@ impl BackendDevice for CudaDevice { // cudarc changes. let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); + // curand can only generate an odd number of values. + // https://github.com/huggingface/candle/issues/734 + let elem_count_round = if elem_count % 2 == 1 { + elem_count + 1 + } else { + elem_count + }; let slice = match dtype { DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { Err(CudaError::UnsupportedDtype { @@ -321,7 +328,7 @@ impl BackendDevice for CudaDevice { .w()? } DType::F32 => { - let mut data = unsafe { self.alloc::(elem_count) }.w()?; + let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; curand .0 .fill_with_normal(&mut data, mean as f32, std as f32) @@ -329,7 +336,7 @@ impl BackendDevice for CudaDevice { CudaStorageSlice::F32(data) } DType::F64 => { - let mut data = unsafe { self.alloc::(elem_count) }.w()?; + let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; curand.0.fill_with_normal(&mut data, mean, std).w()?; CudaStorageSlice::F64(data) } diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 6af43196..cd68908f 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -877,6 +877,14 @@ fn broadcasting(device: &Device) -> Result<()> { Ok(()) } +fn randn(device: &Device) -> Result<()> { + let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?; + assert_eq!(tensor.dims(), [5, 3]); + let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?; + assert_eq!(tensor.dims(), [5, 3]); + Ok(()) +} + test_device!(zeros, zeros_cpu, zeros_gpu); test_device!(add_mul, add_mul_cpu, add_mul_gpu); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu); @@ -899,6 +907,7 @@ test_device!(index_select, index_select_cpu, index_select_gpu); test_device!(index_add, index_add_cpu, index_add_gpu); test_device!(gather, gather_cpu, gather_gpu); test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu); +test_device!(randn, randn_cpu, randn_gpu); // There was originally a bug on the CPU implementation for randn // https://github.com/huggingface/candle/issues/381