Make the cuda rng seedable. (#1056)

This commit is contained in:
Laurent Mazare
2023-10-08 09:32:36 +01:00
committed by GitHub
parent 2e5fb0b251
commit 9abeddd750
4 changed files with 16 additions and 0 deletions

View File

@ -111,4 +111,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>; fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>; fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn set_seed(&self, _: u64) -> Result<()>;
} }

View File

@ -2603,6 +2603,10 @@ impl BackendDevice for CpuDevice {
Ok(Self) Ok(Self)
} }
fn set_seed(&self, _seed: u64) -> Result<()> {
crate::bail!("cannot seed the CPU rng with set_seed")
}
fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> { fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
use rand::prelude::*; use rand::prelude::*;

View File

@ -223,6 +223,12 @@ impl BackendDevice for CudaDevice {
}) })
} }
fn set_seed(&self, seed: u64) -> Result<()> {
let mut curand = self.curand.lock().unwrap();
curand.0.set_seed(seed).w()?;
Ok(())
}
fn location(&self) -> crate::DeviceLocation { fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda { crate::DeviceLocation::Cuda {
gpu_id: self.device.ordinal(), gpu_id: self.device.ordinal(),

View File

@ -167,6 +167,10 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
fn set_seed(&self, _: u64) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn location(&self) -> crate::DeviceLocation { fn location(&self) -> crate::DeviceLocation {
fail!() fail!()
} }