Random initializers. (#128)

* Random initialization.

* CPU rng generation.
This commit is contained in:
Laurent Mazare
2023-07-10 18:26:21 +01:00
committed by GitHub
parent e2807c78a4
commit f29b77ec19
6 changed files with 235 additions and 3 deletions

View File

@ -5,7 +5,7 @@ use cudarc::driver::{
CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
};
use half::{bf16, f16};
use std::sync::Arc;
use std::sync::{Arc, Mutex};
/// cudarc related errors
#[derive(thiserror::Error, Debug)]
@ -19,12 +19,18 @@ pub enum CudaError {
#[error(transparent)]
Cublas(#[from] cudarc::cublas::result::CublasError),
#[error(transparent)]
Curand(#[from] cudarc::curand::result::CurandError),
#[error("{op} only supports contiguous tensors")]
RequiresContiguous { op: &'static str },
#[error("missing kernel '{module_name}'")]
MissingKernel { module_name: String },
#[error("unsupported dtype {dtype:?} for {op}")]
UnsupportedDtype { dtype: DType, op: &'static str },
#[error("internal error '{0}'")]
InternalError(&'static str),
@ -67,12 +73,21 @@ impl DeviceId {
}
}
#[derive(Debug, Clone)]
struct CudaRng(cudarc::curand::CudaRng);
unsafe impl Send for CudaRng {}
#[derive(Clone)]
pub struct CudaDevice {
id: DeviceId,
device: Arc<cudarc::driver::CudaDevice>,
#[allow(dead_code)]
blas: Arc<cudarc::cublas::CudaBlas>,
curand: Arc<Mutex<CudaRng>>,
}
impl std::fmt::Debug for CudaDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "CudaDevice({:?})", self.id)
}
}
impl std::ops::Deref for CudaDevice {
@ -87,10 +102,12 @@ impl CudaDevice {
pub(crate) fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal)?;
let blas = cudarc::cublas::CudaBlas::new(device.clone())?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone())?;
Ok(Self {
id: DeviceId::new(),
device,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
})
}
@ -136,6 +153,68 @@ impl CudaDevice {
})
}
pub(crate) fn rand_uniform(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
})?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }?;
curand.0.fill_with_uniform(&mut data)?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count) }?;
curand.0.fill_with_uniform(&mut data)?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
pub(crate) fn rand_normal(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
std: f64,
) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }?;
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count) }?;
curand.0.fill_with_normal(&mut data, mean, std)?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);