From f29b77ec19444d4308f3aa20f1f9625e77145298 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 10 Jul 2023 18:26:21 +0100 Subject: [PATCH] Random initializers. (#128) * Random initialization. * CPU rng generation. --- candle-core/Cargo.toml | 1 + candle-core/src/cpu_backend.rs | 60 +++++++++++++++++++ candle-core/src/cuda_backend.rs | 85 ++++++++++++++++++++++++++- candle-core/src/device.rs | 32 ++++++++++ candle-core/src/dummy_cuda_backend.rs | 8 +++ candle-core/src/tensor.rs | 52 ++++++++++++++++ 6 files changed, 235 insertions(+), 3 deletions(-) diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 91ca0cff..6af01139 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -25,6 +25,7 @@ libc = { version = "0.2.147", optional = true } memmap2 = "0.7.1" num-traits = "0.2.15" num_cpus = "1.15.0" +rand = "0.8.5" safetensors = "0.3.1" thiserror = "1" zip = { version = "0.6.6", default-features=false } diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 6663021d..1af694d7 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -895,6 +895,66 @@ impl CpuStorage { MatMul(bmnk).map(self, lhs_l, rhs, rhs_l) } + pub(crate) fn rand_uniform(shape: &Shape, dtype: DType) -> Result { + use rand::prelude::*; + + let elem_count = shape.elem_count(); + let mut rng = rand::thread_rng(); + match dtype { + DType::U8 | DType::U32 | DType::BF16 | DType::F16 => { + Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal")) + } + DType::F32 => { + let mut data = Vec::new(); + data.reserve(elem_count); + let uniform = rand::distributions::Uniform::new(0f32, 1f32); + for _i in 0..elem_count { + data.push(rng.sample::(uniform)) + } + Ok(Self::F32(data)) + } + DType::F64 => { + let mut data = Vec::new(); + data.reserve(elem_count); + let uniform = rand::distributions::Uniform::new(0f64, 1f64); + for _i in 0..elem_count { + data.push(rng.sample::(uniform)) + } + Ok(Self::F64(data)) + } + } + } + + pub(crate) fn rand_normal(shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result { + use rand::prelude::*; + + let elem_count = shape.elem_count(); + let mut rng = rand::thread_rng(); + match dtype { + DType::U8 | DType::U32 | DType::BF16 | DType::F16 => { + Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal")) + } + DType::F32 => { + let mut data = Vec::new(); + data.reserve(elem_count); + let std = std as f32; + let mean = mean as f32; + for _i in 0..elem_count { + data.push(rng.sample::(rand::distributions::Standard) * std + mean) + } + Ok(Self::F32(data)) + } + DType::F64 => { + let mut data = Vec::new(); + data.reserve(elem_count); + for _i in 0..elem_count { + data.push(rng.sample::(rand::distributions::Standard) * std + mean) + } + Ok(Self::F64(data)) + } + } + } + pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self { let elem_count = shape.elem_count(); match dtype { diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 3df1b33c..7106d4d7 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -5,7 +5,7 @@ use cudarc::driver::{ CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, }; use half::{bf16, f16}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; /// cudarc related errors #[derive(thiserror::Error, Debug)] @@ -19,12 +19,18 @@ pub enum CudaError { #[error(transparent)] Cublas(#[from] cudarc::cublas::result::CublasError), + #[error(transparent)] + Curand(#[from] cudarc::curand::result::CurandError), + #[error("{op} only supports contiguous tensors")] RequiresContiguous { op: &'static str }, #[error("missing kernel '{module_name}'")] MissingKernel { module_name: String }, + #[error("unsupported dtype {dtype:?} for {op}")] + UnsupportedDtype { dtype: DType, op: &'static str }, + #[error("internal error '{0}'")] InternalError(&'static str), @@ -67,12 +73,21 @@ impl DeviceId { } } -#[derive(Debug, Clone)] +struct CudaRng(cudarc::curand::CudaRng); +unsafe impl Send for CudaRng {} + +#[derive(Clone)] pub struct CudaDevice { id: DeviceId, device: Arc, - #[allow(dead_code)] blas: Arc, + curand: Arc>, +} + +impl std::fmt::Debug for CudaDevice { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "CudaDevice({:?})", self.id) + } } impl std::ops::Deref for CudaDevice { @@ -87,10 +102,12 @@ impl CudaDevice { pub(crate) fn new(ordinal: usize) -> Result { let device = cudarc::driver::CudaDevice::new(ordinal)?; let blas = cudarc::cublas::CudaBlas::new(device.clone())?; + let curand = cudarc::curand::CudaRng::new(299792458, device.clone())?; Ok(Self { id: DeviceId::new(), device, blas: Arc::new(blas), + curand: Arc::new(Mutex::new(CudaRng(curand))), }) } @@ -136,6 +153,68 @@ impl CudaDevice { }) } + pub(crate) fn rand_uniform(&self, shape: &Shape, dtype: DType) -> Result { + let elem_count = shape.elem_count(); + let curand = self.curand.lock().unwrap(); + let slice = match dtype { + DType::U8 | DType::U32 | DType::F16 | DType::BF16 => { + Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_uniform", + })? + } + DType::F32 => { + let mut data = unsafe { self.alloc::(elem_count) }?; + curand.0.fill_with_uniform(&mut data)?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let mut data = unsafe { self.alloc::(elem_count) }?; + curand.0.fill_with_uniform(&mut data)?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + pub(crate) fn rand_normal( + &self, + shape: &Shape, + dtype: DType, + mean: f64, + std: f64, + ) -> Result { + let elem_count = shape.elem_count(); + let curand = self.curand.lock().unwrap(); + let slice = match dtype { + DType::U8 | DType::U32 | DType::F16 | DType::BF16 => { + Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_normal", + })? + } + DType::F32 => { + let mut data = unsafe { self.alloc::(elem_count) }?; + curand + .0 + .fill_with_normal(&mut data, mean as f32, std as f32)?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let mut data = unsafe { self.alloc::(elem_count) }?; + curand.0.fill_with_normal(&mut data, mean, std)?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result { let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 3562d374..0faf7fa2 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -109,6 +109,38 @@ impl Device { } } + pub(crate) fn rand_uniform(&self, shape: &Shape, dtype: DType) -> Result { + match self { + Device::Cpu => { + let storage = CpuStorage::rand_uniform(shape, dtype)?; + Ok(Storage::Cpu(storage)) + } + Device::Cuda(device) => { + let storage = device.rand_uniform(shape, dtype)?; + Ok(Storage::Cuda(storage)) + } + } + } + + pub(crate) fn rand_normal( + &self, + shape: &Shape, + dtype: DType, + mean: f64, + std: f64, + ) -> Result { + match self { + Device::Cpu => { + let storage = CpuStorage::rand_normal(shape, dtype, mean, std)?; + Ok(Storage::Cpu(storage)) + } + Device::Cuda(device) => { + let storage = device.rand_normal(shape, dtype, mean, std)?; + Ok(Storage::Cuda(storage)) + } + } + } + pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result { match self { Device::Cpu => { diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 9263adee..1fe6ba5d 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -38,6 +38,14 @@ impl CudaDevice { pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result { Err(Error::NotCompiledWithCudaSupport) } + + pub(crate) fn rand_uniform(&self, _: &Shape, _: DType) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + pub(crate) fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } } #[derive(Debug)] diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 9b0681e0..98632a9b 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -222,6 +222,58 @@ impl Tensor { Tensor::zeros(self.shape(), self.dtype(), &self.device()) } + fn rand_uniform_impl>( + s: S, + dtype: DType, + device: &Device, + is_variable: bool, + ) -> Result { + let s = s.into(); + let storage = device.rand_uniform(&s, dtype)?; + Ok(from_storage(storage, s, None, is_variable)) + } + + pub fn rand_uniform>(s: S, dtype: DType, device: &Device) -> Result { + Self::rand_uniform_impl(s, dtype, device, false) + } + + pub fn rand_uniform_var>(s: S, dtype: DType, device: &Device) -> Result { + Self::rand_uniform_impl(s, dtype, device, true) + } + + fn rand_normal_impl>( + s: S, + dtype: DType, + device: &Device, + mean: f64, + std: f64, + is_variable: bool, + ) -> Result { + let s = s.into(); + let storage = device.rand_normal(&s, dtype, mean, std)?; + Ok(from_storage(storage, s, None, is_variable)) + } + + pub fn rand_normal>( + s: S, + dtype: DType, + device: &Device, + mean: f64, + std: f64, + ) -> Result { + Self::rand_normal_impl(s, dtype, device, mean, std, false) + } + + pub fn rand_normal_var>( + s: S, + dtype: DType, + device: &Device, + mean: f64, + std: f64, + ) -> Result { + Self::rand_normal_impl(s, dtype, device, mean, std, true) + } + pub fn new_impl( array: A, shape: Shape,