mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Random initializers. (#128)
* Random initialization. * CPU rng generation.
This commit is contained in:
@ -5,7 +5,7 @@ use cudarc::driver::{
|
||||
CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
};
|
||||
use half::{bf16, f16};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
/// cudarc related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -19,12 +19,18 @@ pub enum CudaError {
|
||||
#[error(transparent)]
|
||||
Cublas(#[from] cudarc::cublas::result::CublasError),
|
||||
|
||||
#[error(transparent)]
|
||||
Curand(#[from] cudarc::curand::result::CurandError),
|
||||
|
||||
#[error("{op} only supports contiguous tensors")]
|
||||
RequiresContiguous { op: &'static str },
|
||||
|
||||
#[error("missing kernel '{module_name}'")]
|
||||
MissingKernel { module_name: String },
|
||||
|
||||
#[error("unsupported dtype {dtype:?} for {op}")]
|
||||
UnsupportedDtype { dtype: DType, op: &'static str },
|
||||
|
||||
#[error("internal error '{0}'")]
|
||||
InternalError(&'static str),
|
||||
|
||||
@ -67,12 +73,21 @@ impl DeviceId {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CudaRng(cudarc::curand::CudaRng);
|
||||
unsafe impl Send for CudaRng {}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CudaDevice {
|
||||
id: DeviceId,
|
||||
device: Arc<cudarc::driver::CudaDevice>,
|
||||
#[allow(dead_code)]
|
||||
blas: Arc<cudarc::cublas::CudaBlas>,
|
||||
curand: Arc<Mutex<CudaRng>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CudaDevice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "CudaDevice({:?})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CudaDevice {
|
||||
@ -87,10 +102,12 @@ impl CudaDevice {
|
||||
pub(crate) fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = cudarc::driver::CudaDevice::new(ordinal)?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(device.clone())?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone())?;
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
blas: Arc::new(blas),
|
||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||
})
|
||||
}
|
||||
|
||||
@ -136,6 +153,68 @@ impl CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn rand_uniform(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
let slice = match dtype {
|
||||
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_uniform",
|
||||
})?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }?;
|
||||
curand.0.fill_with_uniform(&mut data)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }?;
|
||||
curand.0.fill_with_uniform(&mut data)?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn rand_normal(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
let slice = match dtype {
|
||||
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_normal",
|
||||
})?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }?;
|
||||
curand
|
||||
.0
|
||||
.fill_with_normal(&mut data, mean as f32, std as f32)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }?;
|
||||
curand.0.fill_with_normal(&mut data, mean, std)?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
|
Reference in New Issue
Block a user