mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fix cuda randn when generating an odd number of values. (#793)
This commit is contained in:
@ -312,6 +312,13 @@ impl BackendDevice for CudaDevice {
|
|||||||
// cudarc changes.
|
// cudarc changes.
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let curand = self.curand.lock().unwrap();
|
let curand = self.curand.lock().unwrap();
|
||||||
|
// curand can only generate an odd number of values.
|
||||||
|
// https://github.com/huggingface/candle/issues/734
|
||||||
|
let elem_count_round = if elem_count % 2 == 1 {
|
||||||
|
elem_count + 1
|
||||||
|
} else {
|
||||||
|
elem_count
|
||||||
|
};
|
||||||
let slice = match dtype {
|
let slice = match dtype {
|
||||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||||
Err(CudaError::UnsupportedDtype {
|
Err(CudaError::UnsupportedDtype {
|
||||||
@ -321,7 +328,7 @@ impl BackendDevice for CudaDevice {
|
|||||||
.w()?
|
.w()?
|
||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
||||||
curand
|
curand
|
||||||
.0
|
.0
|
||||||
.fill_with_normal(&mut data, mean as f32, std as f32)
|
.fill_with_normal(&mut data, mean as f32, std as f32)
|
||||||
@ -329,7 +336,7 @@ impl BackendDevice for CudaDevice {
|
|||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
||||||
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
|
@ -877,6 +877,14 @@ fn broadcasting(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn randn(device: &Device) -> Result<()> {
|
||||||
|
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
||||||
|
assert_eq!(tensor.dims(), [5, 3]);
|
||||||
|
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
||||||
|
assert_eq!(tensor.dims(), [5, 3]);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||||
@ -899,6 +907,7 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
|
|||||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||||
test_device!(gather, gather_cpu, gather_gpu);
|
test_device!(gather, gather_cpu, gather_gpu);
|
||||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||||
|
test_device!(randn, randn_cpu, randn_gpu);
|
||||||
|
|
||||||
// There was originally a bug on the CPU implementation for randn
|
// There was originally a bug on the CPU implementation for randn
|
||||||
// https://github.com/huggingface/candle/issues/381
|
// https://github.com/huggingface/candle/issues/381
|
||||||
|
Reference in New Issue
Block a user