Gaussian normal distribution of PRNG via Box-Muller transform

This commit is contained in:
Ivar Flakstad
2024-01-05 21:18:12 +01:00
parent 955e63c803
commit 6bf52b9fdf
5 changed files with 238 additions and 101 deletions

View File

@ -1385,7 +1385,7 @@ impl BackendDevice for MetalDevice {
compute_per_buffer,
buffers,
kernels,
seed
seed,
})
}
@ -1467,8 +1467,9 @@ impl BackendDevice for MetalDevice {
min as f32,
max as f32,
shape.elem_count(),
&buffer
).map_err(MetalError::from)?;
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::Storage::new(buffer, self.clone(), dtype))
}
@ -1480,9 +1481,28 @@ impl BackendDevice for MetalDevice {
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)
let name = match dtype {
DType::F32 => "rand_normal_f32",
DType::F16 => "rand_normal_f16",
DType::BF16 => "rand_normal_bf16",
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_random_normal(
&self.device,
&command_buffer,
&self.kernels,
name,
*self.seed.lock().unwrap(),
mean as f32,
stddev as f32,
shape.elem_count(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::Storage::new(buffer, self.clone(), dtype))
}
}