Make the cuda rng seedable. (#1056)

This commit is contained in:
Laurent Mazare
2023-10-08 09:32:36 +01:00
committed by GitHub
parent 2e5fb0b251
commit 9abeddd750
4 changed files with 16 additions and 0 deletions

View File

@ -223,6 +223,12 @@ impl BackendDevice for CudaDevice {
})
}
fn set_seed(&self, seed: u64) -> Result<()> {
let mut curand = self.curand.lock().unwrap();
curand.0.set_seed(seed).w()?;
Ok(())
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.device.ordinal(),