mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Make the cuda rng seedable. (#1056)
This commit is contained in:
@ -111,4 +111,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
|||||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||||
|
|
||||||
|
fn set_seed(&self, _: u64) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
@ -2603,6 +2603,10 @@ impl BackendDevice for CpuDevice {
|
|||||||
Ok(Self)
|
Ok(Self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||||
|
crate::bail!("cannot seed the CPU rng with set_seed")
|
||||||
|
}
|
||||||
|
|
||||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
|
fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
|
||||||
|
@ -223,6 +223,12 @@ impl BackendDevice for CudaDevice {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
|
let mut curand = self.curand.lock().unwrap();
|
||||||
|
curand.0.set_seed(seed).w()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn location(&self) -> crate::DeviceLocation {
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
crate::DeviceLocation::Cuda {
|
crate::DeviceLocation::Cuda {
|
||||||
gpu_id: self.device.ordinal(),
|
gpu_id: self.device.ordinal(),
|
||||||
|
@ -167,6 +167,10 @@ impl crate::backend::BackendDevice for CudaDevice {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, _: u64) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn location(&self) -> crate::DeviceLocation {
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
fail!()
|
fail!()
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user