mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Gaussian normal distribution of PRNG via Box-Muller transform
This commit is contained in:
@ -1385,7 +1385,7 @@ impl BackendDevice for MetalDevice {
|
||||
compute_per_buffer,
|
||||
buffers,
|
||||
kernels,
|
||||
seed
|
||||
seed,
|
||||
})
|
||||
}
|
||||
|
||||
@ -1467,8 +1467,9 @@ impl BackendDevice for MetalDevice {
|
||||
min as f32,
|
||||
max as f32,
|
||||
shape.elem_count(),
|
||||
&buffer
|
||||
).map_err(MetalError::from)?;
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
@ -1480,9 +1481,28 @@ impl BackendDevice for MetalDevice {
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
let name = match dtype {
|
||||
DType::F32 => "rand_normal_f32",
|
||||
DType::F16 => "rand_normal_f16",
|
||||
DType::BF16 => "rand_normal_bf16",
|
||||
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
|
||||
};
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
candle_metal_kernels::call_random_normal(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
*self.seed.lock().unwrap(),
|
||||
mean as f32,
|
||||
stddev as f32,
|
||||
shape.elem_count(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user