Random initializers. (#128)

* Random initialization.

* CPU rng generation.
This commit is contained in:
Laurent Mazare
2023-07-10 18:26:21 +01:00
committed by GitHub
parent e2807c78a4
commit f29b77ec19
6 changed files with 235 additions and 3 deletions

View File

@ -25,6 +25,7 @@ libc = { version = "0.2.147", optional = true }
memmap2 = "0.7.1"
num-traits = "0.2.15"
num_cpus = "1.15.0"
rand = "0.8.5"
safetensors = "0.3.1"
thiserror = "1"
zip = { version = "0.6.6", default-features=false }

View File

@ -895,6 +895,66 @@ impl CpuStorage {
MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
}
pub(crate) fn rand_uniform(shape: &Shape, dtype: DType) -> Result<Self> {
use rand::prelude::*;
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
match dtype {
DType::U8 | DType::U32 | DType::BF16 | DType::F16 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal"))
}
DType::F32 => {
let mut data = Vec::new();
data.reserve(elem_count);
let uniform = rand::distributions::Uniform::new(0f32, 1f32);
for _i in 0..elem_count {
data.push(rng.sample::<f32, _>(uniform))
}
Ok(Self::F32(data))
}
DType::F64 => {
let mut data = Vec::new();
data.reserve(elem_count);
let uniform = rand::distributions::Uniform::new(0f64, 1f64);
for _i in 0..elem_count {
data.push(rng.sample::<f64, _>(uniform))
}
Ok(Self::F64(data))
}
}
}
pub(crate) fn rand_normal(shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<Self> {
use rand::prelude::*;
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
match dtype {
DType::U8 | DType::U32 | DType::BF16 | DType::F16 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal"))
}
DType::F32 => {
let mut data = Vec::new();
data.reserve(elem_count);
let std = std as f32;
let mean = mean as f32;
for _i in 0..elem_count {
data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean)
}
Ok(Self::F32(data))
}
DType::F64 => {
let mut data = Vec::new();
data.reserve(elem_count);
for _i in 0..elem_count {
data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
}
Ok(Self::F64(data))
}
}
}
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
let elem_count = shape.elem_count();
match dtype {

View File

@ -5,7 +5,7 @@ use cudarc::driver::{
CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
};
use half::{bf16, f16};
use std::sync::Arc;
use std::sync::{Arc, Mutex};
/// cudarc related errors
#[derive(thiserror::Error, Debug)]
@ -19,12 +19,18 @@ pub enum CudaError {
#[error(transparent)]
Cublas(#[from] cudarc::cublas::result::CublasError),
#[error(transparent)]
Curand(#[from] cudarc::curand::result::CurandError),
#[error("{op} only supports contiguous tensors")]
RequiresContiguous { op: &'static str },
#[error("missing kernel '{module_name}'")]
MissingKernel { module_name: String },
#[error("unsupported dtype {dtype:?} for {op}")]
UnsupportedDtype { dtype: DType, op: &'static str },
#[error("internal error '{0}'")]
InternalError(&'static str),
@ -67,12 +73,21 @@ impl DeviceId {
}
}
#[derive(Debug, Clone)]
struct CudaRng(cudarc::curand::CudaRng);
unsafe impl Send for CudaRng {}
#[derive(Clone)]
pub struct CudaDevice {
id: DeviceId,
device: Arc<cudarc::driver::CudaDevice>,
#[allow(dead_code)]
blas: Arc<cudarc::cublas::CudaBlas>,
curand: Arc<Mutex<CudaRng>>,
}
impl std::fmt::Debug for CudaDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "CudaDevice({:?})", self.id)
}
}
impl std::ops::Deref for CudaDevice {
@ -87,10 +102,12 @@ impl CudaDevice {
pub(crate) fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal)?;
let blas = cudarc::cublas::CudaBlas::new(device.clone())?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone())?;
Ok(Self {
id: DeviceId::new(),
device,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
})
}
@ -136,6 +153,68 @@ impl CudaDevice {
})
}
pub(crate) fn rand_uniform(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
})?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }?;
curand.0.fill_with_uniform(&mut data)?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count) }?;
curand.0.fill_with_uniform(&mut data)?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
pub(crate) fn rand_normal(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
std: f64,
) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }?;
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count) }?;
curand.0.fill_with_normal(&mut data, mean, std)?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);

View File

@ -109,6 +109,38 @@ impl Device {
}
}
pub(crate) fn rand_uniform(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuStorage::rand_uniform(shape, dtype)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_uniform(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
}
}
pub(crate) fn rand_normal(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
std: f64,
) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuStorage::rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
}
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {

View File

@ -38,6 +38,14 @@ impl CudaDevice {
pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn rand_uniform(&self, _: &Shape, _: DType) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
}
#[derive(Debug)]

View File

@ -222,6 +222,58 @@ impl Tensor {
Tensor::zeros(self.shape(), self.dtype(), &self.device())
}
fn rand_uniform_impl<S: Into<Shape>>(
s: S,
dtype: DType,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
let storage = device.rand_uniform(&s, dtype)?;
Ok(from_storage(storage, s, None, is_variable))
}
pub fn rand_uniform<S: Into<Shape>>(s: S, dtype: DType, device: &Device) -> Result<Self> {
Self::rand_uniform_impl(s, dtype, device, false)
}
pub fn rand_uniform_var<S: Into<Shape>>(s: S, dtype: DType, device: &Device) -> Result<Self> {
Self::rand_uniform_impl(s, dtype, device, true)
}
fn rand_normal_impl<S: Into<Shape>>(
s: S,
dtype: DType,
device: &Device,
mean: f64,
std: f64,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
let storage = device.rand_normal(&s, dtype, mean, std)?;
Ok(from_storage(storage, s, None, is_variable))
}
pub fn rand_normal<S: Into<Shape>>(
s: S,
dtype: DType,
device: &Device,
mean: f64,
std: f64,
) -> Result<Self> {
Self::rand_normal_impl(s, dtype, device, mean, std, false)
}
pub fn rand_normal_var<S: Into<Shape>>(
s: S,
dtype: DType,
device: &Device,
mean: f64,
std: f64,
) -> Result<Self> {
Self::rand_normal_impl(s, dtype, device, mean, std, true)
}
pub fn new_impl<A: crate::device::NdArray>(
array: A,
shape: Shape,