mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Random initializers. (#128)
* Random initialization. * CPU rng generation.
This commit is contained in:
@ -25,6 +25,7 @@ libc = { version = "0.2.147", optional = true }
|
|||||||
memmap2 = "0.7.1"
|
memmap2 = "0.7.1"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
num_cpus = "1.15.0"
|
num_cpus = "1.15.0"
|
||||||
|
rand = "0.8.5"
|
||||||
safetensors = "0.3.1"
|
safetensors = "0.3.1"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
zip = { version = "0.6.6", default-features=false }
|
zip = { version = "0.6.6", default-features=false }
|
||||||
|
@ -895,6 +895,66 @@ impl CpuStorage {
|
|||||||
MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
|
MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rand_uniform(shape: &Shape, dtype: DType) -> Result<Self> {
|
||||||
|
use rand::prelude::*;
|
||||||
|
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
match dtype {
|
||||||
|
DType::U8 | DType::U32 | DType::BF16 | DType::F16 => {
|
||||||
|
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal"))
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let mut data = Vec::new();
|
||||||
|
data.reserve(elem_count);
|
||||||
|
let uniform = rand::distributions::Uniform::new(0f32, 1f32);
|
||||||
|
for _i in 0..elem_count {
|
||||||
|
data.push(rng.sample::<f32, _>(uniform))
|
||||||
|
}
|
||||||
|
Ok(Self::F32(data))
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let mut data = Vec::new();
|
||||||
|
data.reserve(elem_count);
|
||||||
|
let uniform = rand::distributions::Uniform::new(0f64, 1f64);
|
||||||
|
for _i in 0..elem_count {
|
||||||
|
data.push(rng.sample::<f64, _>(uniform))
|
||||||
|
}
|
||||||
|
Ok(Self::F64(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rand_normal(shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<Self> {
|
||||||
|
use rand::prelude::*;
|
||||||
|
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
match dtype {
|
||||||
|
DType::U8 | DType::U32 | DType::BF16 | DType::F16 => {
|
||||||
|
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal"))
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let mut data = Vec::new();
|
||||||
|
data.reserve(elem_count);
|
||||||
|
let std = std as f32;
|
||||||
|
let mean = mean as f32;
|
||||||
|
for _i in 0..elem_count {
|
||||||
|
data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean)
|
||||||
|
}
|
||||||
|
Ok(Self::F32(data))
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let mut data = Vec::new();
|
||||||
|
data.reserve(elem_count);
|
||||||
|
for _i in 0..elem_count {
|
||||||
|
data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
|
||||||
|
}
|
||||||
|
Ok(Self::F64(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
match dtype {
|
match dtype {
|
||||||
|
@ -5,7 +5,7 @@ use cudarc::driver::{
|
|||||||
CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||||
};
|
};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
/// cudarc related errors
|
/// cudarc related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -19,12 +19,18 @@ pub enum CudaError {
|
|||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Cublas(#[from] cudarc::cublas::result::CublasError),
|
Cublas(#[from] cudarc::cublas::result::CublasError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Curand(#[from] cudarc::curand::result::CurandError),
|
||||||
|
|
||||||
#[error("{op} only supports contiguous tensors")]
|
#[error("{op} only supports contiguous tensors")]
|
||||||
RequiresContiguous { op: &'static str },
|
RequiresContiguous { op: &'static str },
|
||||||
|
|
||||||
#[error("missing kernel '{module_name}'")]
|
#[error("missing kernel '{module_name}'")]
|
||||||
MissingKernel { module_name: String },
|
MissingKernel { module_name: String },
|
||||||
|
|
||||||
|
#[error("unsupported dtype {dtype:?} for {op}")]
|
||||||
|
UnsupportedDtype { dtype: DType, op: &'static str },
|
||||||
|
|
||||||
#[error("internal error '{0}'")]
|
#[error("internal error '{0}'")]
|
||||||
InternalError(&'static str),
|
InternalError(&'static str),
|
||||||
|
|
||||||
@ -67,12 +73,21 @@ impl DeviceId {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
struct CudaRng(cudarc::curand::CudaRng);
|
||||||
|
unsafe impl Send for CudaRng {}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct CudaDevice {
|
pub struct CudaDevice {
|
||||||
id: DeviceId,
|
id: DeviceId,
|
||||||
device: Arc<cudarc::driver::CudaDevice>,
|
device: Arc<cudarc::driver::CudaDevice>,
|
||||||
#[allow(dead_code)]
|
|
||||||
blas: Arc<cudarc::cublas::CudaBlas>,
|
blas: Arc<cudarc::cublas::CudaBlas>,
|
||||||
|
curand: Arc<Mutex<CudaRng>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for CudaDevice {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "CudaDevice({:?})", self.id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::ops::Deref for CudaDevice {
|
impl std::ops::Deref for CudaDevice {
|
||||||
@ -87,10 +102,12 @@ impl CudaDevice {
|
|||||||
pub(crate) fn new(ordinal: usize) -> Result<Self> {
|
pub(crate) fn new(ordinal: usize) -> Result<Self> {
|
||||||
let device = cudarc::driver::CudaDevice::new(ordinal)?;
|
let device = cudarc::driver::CudaDevice::new(ordinal)?;
|
||||||
let blas = cudarc::cublas::CudaBlas::new(device.clone())?;
|
let blas = cudarc::cublas::CudaBlas::new(device.clone())?;
|
||||||
|
let curand = cudarc::curand::CudaRng::new(299792458, device.clone())?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
id: DeviceId::new(),
|
id: DeviceId::new(),
|
||||||
device,
|
device,
|
||||||
blas: Arc::new(blas),
|
blas: Arc::new(blas),
|
||||||
|
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -136,6 +153,68 @@ impl CudaDevice {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rand_uniform(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let curand = self.curand.lock().unwrap();
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => {
|
||||||
|
Err(CudaError::UnsupportedDtype {
|
||||||
|
dtype,
|
||||||
|
op: "rand_uniform",
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f32>(elem_count) }?;
|
||||||
|
curand.0.fill_with_uniform(&mut data)?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f64>(elem_count) }?;
|
||||||
|
curand.0.fill_with_uniform(&mut data)?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rand_normal(
|
||||||
|
&self,
|
||||||
|
shape: &Shape,
|
||||||
|
dtype: DType,
|
||||||
|
mean: f64,
|
||||||
|
std: f64,
|
||||||
|
) -> Result<CudaStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let curand = self.curand.lock().unwrap();
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => {
|
||||||
|
Err(CudaError::UnsupportedDtype {
|
||||||
|
dtype,
|
||||||
|
op: "rand_normal",
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f32>(elem_count) }?;
|
||||||
|
curand
|
||||||
|
.0
|
||||||
|
.fill_with_normal(&mut data, mean as f32, std as f32)?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f64>(elem_count) }?;
|
||||||
|
curand.0.fill_with_normal(&mut data, mean, std)?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||||
|
@ -109,6 +109,38 @@ impl Device {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rand_uniform(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||||
|
match self {
|
||||||
|
Device::Cpu => {
|
||||||
|
let storage = CpuStorage::rand_uniform(shape, dtype)?;
|
||||||
|
Ok(Storage::Cpu(storage))
|
||||||
|
}
|
||||||
|
Device::Cuda(device) => {
|
||||||
|
let storage = device.rand_uniform(shape, dtype)?;
|
||||||
|
Ok(Storage::Cuda(storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rand_normal(
|
||||||
|
&self,
|
||||||
|
shape: &Shape,
|
||||||
|
dtype: DType,
|
||||||
|
mean: f64,
|
||||||
|
std: f64,
|
||||||
|
) -> Result<Storage> {
|
||||||
|
match self {
|
||||||
|
Device::Cpu => {
|
||||||
|
let storage = CpuStorage::rand_normal(shape, dtype, mean, std)?;
|
||||||
|
Ok(Storage::Cpu(storage))
|
||||||
|
}
|
||||||
|
Device::Cuda(device) => {
|
||||||
|
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||||
|
Ok(Storage::Cuda(storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => {
|
Device::Cpu => {
|
||||||
|
@ -38,6 +38,14 @@ impl CudaDevice {
|
|||||||
pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> {
|
pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rand_uniform(&self, _: &Shape, _: DType) -> Result<CudaStorage> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<CudaStorage> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -222,6 +222,58 @@ impl Tensor {
|
|||||||
Tensor::zeros(self.shape(), self.dtype(), &self.device())
|
Tensor::zeros(self.shape(), self.dtype(), &self.device())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn rand_uniform_impl<S: Into<Shape>>(
|
||||||
|
s: S,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
is_variable: bool,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let s = s.into();
|
||||||
|
let storage = device.rand_uniform(&s, dtype)?;
|
||||||
|
Ok(from_storage(storage, s, None, is_variable))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rand_uniform<S: Into<Shape>>(s: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
Self::rand_uniform_impl(s, dtype, device, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rand_uniform_var<S: Into<Shape>>(s: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
Self::rand_uniform_impl(s, dtype, device, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_normal_impl<S: Into<Shape>>(
|
||||||
|
s: S,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
mean: f64,
|
||||||
|
std: f64,
|
||||||
|
is_variable: bool,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let s = s.into();
|
||||||
|
let storage = device.rand_normal(&s, dtype, mean, std)?;
|
||||||
|
Ok(from_storage(storage, s, None, is_variable))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rand_normal<S: Into<Shape>>(
|
||||||
|
s: S,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
mean: f64,
|
||||||
|
std: f64,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Self::rand_normal_impl(s, dtype, device, mean, std, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rand_normal_var<S: Into<Shape>>(
|
||||||
|
s: S,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
mean: f64,
|
||||||
|
std: f64,
|
||||||
|
) -> Result<Self> {
|
||||||
|
Self::rand_normal_impl(s, dtype, device, mean, std, true)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new_impl<A: crate::device::NdArray>(
|
pub fn new_impl<A: crate::device::NdArray>(
|
||||||
array: A,
|
array: A,
|
||||||
shape: Shape,
|
shape: Shape,
|
||||||
|
Reference in New Issue
Block a user