mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Random initializers. (#128)
* Random initialization. * CPU rng generation.
This commit is contained in:
@ -25,6 +25,7 @@ libc = { version = "0.2.147", optional = true }
|
||||
memmap2 = "0.7.1"
|
||||
num-traits = "0.2.15"
|
||||
num_cpus = "1.15.0"
|
||||
rand = "0.8.5"
|
||||
safetensors = "0.3.1"
|
||||
thiserror = "1"
|
||||
zip = { version = "0.6.6", default-features=false }
|
||||
|
@ -895,6 +895,66 @@ impl CpuStorage {
|
||||
MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
|
||||
}
|
||||
|
||||
pub(crate) fn rand_uniform(shape: &Shape, dtype: DType) -> Result<Self> {
|
||||
use rand::prelude::*;
|
||||
|
||||
let elem_count = shape.elem_count();
|
||||
let mut rng = rand::thread_rng();
|
||||
match dtype {
|
||||
DType::U8 | DType::U32 | DType::BF16 | DType::F16 => {
|
||||
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal"))
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = Vec::new();
|
||||
data.reserve(elem_count);
|
||||
let uniform = rand::distributions::Uniform::new(0f32, 1f32);
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f32, _>(uniform))
|
||||
}
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = Vec::new();
|
||||
data.reserve(elem_count);
|
||||
let uniform = rand::distributions::Uniform::new(0f64, 1f64);
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f64, _>(uniform))
|
||||
}
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rand_normal(shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<Self> {
|
||||
use rand::prelude::*;
|
||||
|
||||
let elem_count = shape.elem_count();
|
||||
let mut rng = rand::thread_rng();
|
||||
match dtype {
|
||||
DType::U8 | DType::U32 | DType::BF16 | DType::F16 => {
|
||||
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal"))
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = Vec::new();
|
||||
data.reserve(elem_count);
|
||||
let std = std as f32;
|
||||
let mean = mean as f32;
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean)
|
||||
}
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = Vec::new();
|
||||
data.reserve(elem_count);
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
|
||||
}
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
|
@ -5,7 +5,7 @@ use cudarc::driver::{
|
||||
CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
};
|
||||
use half::{bf16, f16};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
/// cudarc related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -19,12 +19,18 @@ pub enum CudaError {
|
||||
#[error(transparent)]
|
||||
Cublas(#[from] cudarc::cublas::result::CublasError),
|
||||
|
||||
#[error(transparent)]
|
||||
Curand(#[from] cudarc::curand::result::CurandError),
|
||||
|
||||
#[error("{op} only supports contiguous tensors")]
|
||||
RequiresContiguous { op: &'static str },
|
||||
|
||||
#[error("missing kernel '{module_name}'")]
|
||||
MissingKernel { module_name: String },
|
||||
|
||||
#[error("unsupported dtype {dtype:?} for {op}")]
|
||||
UnsupportedDtype { dtype: DType, op: &'static str },
|
||||
|
||||
#[error("internal error '{0}'")]
|
||||
InternalError(&'static str),
|
||||
|
||||
@ -67,12 +73,21 @@ impl DeviceId {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CudaRng(cudarc::curand::CudaRng);
|
||||
unsafe impl Send for CudaRng {}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CudaDevice {
|
||||
id: DeviceId,
|
||||
device: Arc<cudarc::driver::CudaDevice>,
|
||||
#[allow(dead_code)]
|
||||
blas: Arc<cudarc::cublas::CudaBlas>,
|
||||
curand: Arc<Mutex<CudaRng>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CudaDevice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "CudaDevice({:?})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CudaDevice {
|
||||
@ -87,10 +102,12 @@ impl CudaDevice {
|
||||
pub(crate) fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = cudarc::driver::CudaDevice::new(ordinal)?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(device.clone())?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone())?;
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
blas: Arc::new(blas),
|
||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||
})
|
||||
}
|
||||
|
||||
@ -136,6 +153,68 @@ impl CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn rand_uniform(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
let slice = match dtype {
|
||||
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_uniform",
|
||||
})?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }?;
|
||||
curand.0.fill_with_uniform(&mut data)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }?;
|
||||
curand.0.fill_with_uniform(&mut data)?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn rand_normal(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
let slice = match dtype {
|
||||
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_normal",
|
||||
})?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }?;
|
||||
curand
|
||||
.0
|
||||
.fill_with_normal(&mut data, mean as f32, std as f32)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }?;
|
||||
curand.0.fill_with_normal(&mut data, mean, std)?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
|
@ -109,6 +109,38 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rand_uniform(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuStorage::rand_uniform(shape, dtype)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.rand_uniform(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rand_normal(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuStorage::rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
|
@ -38,6 +38,14 @@ impl CudaDevice {
|
||||
pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn rand_uniform(&self, _: &Shape, _: DType) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -222,6 +222,58 @@ impl Tensor {
|
||||
Tensor::zeros(self.shape(), self.dtype(), &self.device())
|
||||
}
|
||||
|
||||
fn rand_uniform_impl<S: Into<Shape>>(
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let s = s.into();
|
||||
let storage = device.rand_uniform(&s, dtype)?;
|
||||
Ok(from_storage(storage, s, None, is_variable))
|
||||
}
|
||||
|
||||
pub fn rand_uniform<S: Into<Shape>>(s: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||
Self::rand_uniform_impl(s, dtype, device, false)
|
||||
}
|
||||
|
||||
pub fn rand_uniform_var<S: Into<Shape>>(s: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||
Self::rand_uniform_impl(s, dtype, device, true)
|
||||
}
|
||||
|
||||
fn rand_normal_impl<S: Into<Shape>>(
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let s = s.into();
|
||||
let storage = device.rand_normal(&s, dtype, mean, std)?;
|
||||
Ok(from_storage(storage, s, None, is_variable))
|
||||
}
|
||||
|
||||
pub fn rand_normal<S: Into<Shape>>(
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
) -> Result<Self> {
|
||||
Self::rand_normal_impl(s, dtype, device, mean, std, false)
|
||||
}
|
||||
|
||||
pub fn rand_normal_var<S: Into<Shape>>(
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
) -> Result<Self> {
|
||||
Self::rand_normal_impl(s, dtype, device, mean, std, true)
|
||||
}
|
||||
|
||||
pub fn new_impl<A: crate::device::NdArray>(
|
||||
array: A,
|
||||
shape: Shape,
|
||||
|
Reference in New Issue
Block a user