From 9abeddd750fe13632136a9807fcb0b6d1c999bd3 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 8 Oct 2023 09:32:36 +0100 Subject: [PATCH] Make the cuda rng seedable. (#1056) --- candle-core/src/backend.rs | 2 ++ candle-core/src/cpu_backend.rs | 4 ++++ candle-core/src/cuda_backend.rs | 6 ++++++ candle-core/src/dummy_cuda_backend.rs | 4 ++++ 4 files changed, 16 insertions(+) diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 03a07434..7f0e2fc7 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -111,4 +111,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone { fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result; fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result; + + fn set_seed(&self, _: u64) -> Result<()>; } diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 4e808b34..86cbeb78 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -2603,6 +2603,10 @@ impl BackendDevice for CpuDevice { Ok(Self) } + fn set_seed(&self, _seed: u64) -> Result<()> { + crate::bail!("cannot seed the CPU rng with set_seed") + } + fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result { use rand::prelude::*; diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index f7518067..f0f48327 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -223,6 +223,12 @@ impl BackendDevice for CudaDevice { }) } + fn set_seed(&self, seed: u64) -> Result<()> { + let mut curand = self.curand.lock().unwrap(); + curand.0.set_seed(seed).w()?; + Ok(()) + } + fn location(&self) -> crate::DeviceLocation { crate::DeviceLocation::Cuda { gpu_id: self.device.ordinal(), diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 5cc9c6d8..53574458 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -167,6 +167,10 @@ impl crate::backend::BackendDevice for CudaDevice { Err(Error::NotCompiledWithCudaSupport) } + fn set_seed(&self, _: u64) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + fn location(&self) -> crate::DeviceLocation { fail!() }