mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Create a new curand instead of reseeding. (#1089)
This commit is contained in:
@ -224,8 +224,10 @@ impl BackendDevice for CudaDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
|
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||||
|
// state will be identical and the same random numbers will be generated.
|
||||||
let mut curand = self.curand.lock().unwrap();
|
let mut curand = self.curand.lock().unwrap();
|
||||||
curand.0.set_seed(seed).w()?;
|
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -128,6 +128,13 @@ impl Device {
|
|||||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
|
match self {
|
||||||
|
Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed),
|
||||||
|
Self::Cuda(c) => c.set_seed(seed),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn same_device(&self, rhs: &Self) -> bool {
|
pub fn same_device(&self, rhs: &Self) -> bool {
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Self::Cpu, Self::Cpu) => true,
|
(Self::Cpu, Self::Cpu) => true,
|
||||||
|
Reference in New Issue
Block a user