From ae79c00e48089d889f900b4c05f90a1201e610c6 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 11 Jul 2023 08:52:29 +0100 Subject: [PATCH] Allow for uniform initialization in a single step. (#136) --- candle-core/src/cpu_backend.rs | 6 +++--- candle-core/src/cuda_backend.rs | 12 +++++++++++- candle-core/src/device.rs | 12 +++++++++--- candle-core/src/dummy_cuda_backend.rs | 2 +- candle-core/src/tensor.rs | 24 +++++++++++++++++++----- 5 files changed, 43 insertions(+), 13 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 1af694d7..de32b549 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -895,7 +895,7 @@ impl CpuStorage { MatMul(bmnk).map(self, lhs_l, rhs, rhs_l) } - pub(crate) fn rand_uniform(shape: &Shape, dtype: DType) -> Result { + pub(crate) fn rand_uniform(shape: &Shape, dtype: DType, min: f64, max: f64) -> Result { use rand::prelude::*; let elem_count = shape.elem_count(); @@ -907,7 +907,7 @@ impl CpuStorage { DType::F32 => { let mut data = Vec::new(); data.reserve(elem_count); - let uniform = rand::distributions::Uniform::new(0f32, 1f32); + let uniform = rand::distributions::Uniform::new(min as f32, max as f32); for _i in 0..elem_count { data.push(rng.sample::(uniform)) } @@ -916,7 +916,7 @@ impl CpuStorage { DType::F64 => { let mut data = Vec::new(); data.reserve(elem_count); - let uniform = rand::distributions::Uniform::new(0f64, 1f64); + let uniform = rand::distributions::Uniform::new(min, max); for _i in 0..elem_count { data.push(rng.sample::(uniform)) } diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 7106d4d7..9fc4ceca 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -153,7 +153,13 @@ impl CudaDevice { }) } - pub(crate) fn rand_uniform(&self, shape: &Shape, dtype: DType) -> Result { + pub(crate) fn rand_uniform( + &self, + shape: &Shape, + dtype: DType, + lo: f64, + up: f64, + ) -> Result { let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); let slice = match dtype { @@ -174,6 +180,10 @@ impl CudaDevice { CudaStorageSlice::F64(data) } }; + if lo != 0.0 || up != 1.0 { + let layout = Layout::contiguous(shape); + Affine(up - lo, lo).map(&slice, self, &layout)?; + } Ok(CudaStorage { slice, device: self.clone(), diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 0faf7fa2..1380cbc9 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -109,14 +109,20 @@ impl Device { } } - pub(crate) fn rand_uniform(&self, shape: &Shape, dtype: DType) -> Result { + pub(crate) fn rand_uniform( + &self, + shape: &Shape, + dtype: DType, + lo: f64, + up: f64, + ) -> Result { match self { Device::Cpu => { - let storage = CpuStorage::rand_uniform(shape, dtype)?; + let storage = CpuStorage::rand_uniform(shape, dtype, lo, up)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_uniform(shape, dtype)?; + let storage = device.rand_uniform(shape, dtype, lo, up)?; Ok(Storage::Cuda(storage)) } } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 1fe6ba5d..f5c80fcf 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -39,7 +39,7 @@ impl CudaDevice { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn rand_uniform(&self, _: &Shape, _: DType) -> Result { + pub(crate) fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index aba7b91a..ecc018f9 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -226,19 +226,33 @@ impl Tensor { s: S, dtype: DType, device: &Device, + lo: f64, + up: f64, is_variable: bool, ) -> Result { let s = s.into(); - let storage = device.rand_uniform(&s, dtype)?; + let storage = device.rand_uniform(&s, dtype, lo, up)?; Ok(from_storage(storage, s, None, is_variable)) } - pub fn rand_uniform>(s: S, dtype: DType, device: &Device) -> Result { - Self::rand_uniform_impl(s, dtype, device, false) + pub fn rand_uniform>( + s: S, + dtype: DType, + device: &Device, + lo: f64, + up: f64, + ) -> Result { + Self::rand_uniform_impl(s, dtype, device, lo, up, false) } - pub fn rand_uniform_var>(s: S, dtype: DType, device: &Device) -> Result { - Self::rand_uniform_impl(s, dtype, device, true) + pub fn rand_uniform_var>( + s: S, + dtype: DType, + device: &Device, + lo: f64, + up: f64, + ) -> Result { + Self::rand_uniform_impl(s, dtype, device, lo, up, true) } fn rand_normal_impl>(