mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Simplify Tensor::randn. (#255)
* Simplify Tensor::randn. * Also switch Tensor::rand to use a generic dtype. * Support sampling for f16. * Cleanup.
This commit is contained in:
@ -23,7 +23,7 @@ cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas
|
||||
# TODO: Switch back to the official gemm implementation if we manage to upstream the changes.
|
||||
gemm = { git = "https://github.com/LaurentMazare/gemm.git" }
|
||||
hf-hub = "0.1.3"
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
|
23
README.md
23
README.md
@ -2,10 +2,11 @@
|
||||
ML framework for Rust
|
||||
|
||||
```rust
|
||||
let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?;
|
||||
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
|
||||
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
println!("{c}");
|
||||
```
|
||||
|
||||
## Check out our examples
|
||||
@ -45,13 +46,15 @@ And then browse to
|
||||
|
||||
## Features
|
||||
|
||||
- Simple syntax (looks and like PyTorch)
|
||||
- CPU and Cuda backends, m1, f16, bf16 (and tentatively wasm)
|
||||
- Simple syntax, looks and like PyTorch.
|
||||
- CPU and Cuda backends, m1, f16, bf16.
|
||||
- Enable serverless (CPU), small and fast deployments
|
||||
- Model training
|
||||
- Distributed computing (NCCL).
|
||||
- Models out of the box (Llama, Whisper, Falcon, ...)
|
||||
- Emphasis on enabling users to use custom ops/kernels
|
||||
- WASM support, run your models in a browser.
|
||||
- Model training.
|
||||
- Distributed computing using NCCL.
|
||||
- Models out of the box: Llama, Whisper, Falcon, BERT...
|
||||
- Embed user-defined ops/kernels, such as [flash-attention
|
||||
v2](https://github.com/LaurentMazare/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
|
||||
|
||||
## How to use ?
|
||||
|
||||
@ -59,9 +62,7 @@ Cheatsheet:
|
||||
|
||||
| | Using PyTorch | Using Candle |
|
||||
|------------|------------------------------------------|------------------------------------------------------------------|
|
||||
| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(` |
|
||||
| | | ` &[[1f32, 2.]], [3., 4.]],` |
|
||||
| | | ` &Device::Cpu)?` |
|
||||
| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(&[[1f32, 2.]], [3., 4.]], &Device::Cpu)?` |
|
||||
| Indexing | `tensor[:, :4]` | `tensor.i((.., ..4))?` |
|
||||
| Operations | `tensor.view((2, 2))` | `tensor.reshape((2, 2))?` |
|
||||
| Operations | `a.matmul(b)` | `a.matmul(&b)?` |
|
||||
|
@ -5,6 +5,11 @@ use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
|
||||
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
|
||||
let c = a.matmul(&b)?;
|
||||
println!("{a} {b} {c}");
|
||||
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 7., 1., 8., 2.]];
|
||||
let t1 = Tensor::new(data, &Device::Cpu)?;
|
||||
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 7., 1., 8., 2.]];
|
||||
|
@ -369,8 +369,7 @@ pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
let mut result = vec![];
|
||||
result.reserve(layout.shape().elem_count());
|
||||
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
||||
// Specialize the case where block_len is one to avoid the second loop.
|
||||
if block_len == 1 {
|
||||
for index in block_start_index {
|
||||
@ -1843,12 +1842,27 @@ impl BackendDevice for CpuDevice {
|
||||
let elem_count = shape.elem_count();
|
||||
let mut rng = rand::thread_rng();
|
||||
match dtype {
|
||||
DType::U8 | DType::U32 | DType::BF16 | DType::F16 => {
|
||||
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
|
||||
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()),
|
||||
DType::BF16 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let uniform =
|
||||
rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max));
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<bf16, _>(uniform))
|
||||
}
|
||||
Ok(CpuStorage::BF16(data))
|
||||
}
|
||||
DType::F16 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let uniform =
|
||||
rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max));
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f16, _>(uniform))
|
||||
}
|
||||
Ok(CpuStorage::F16(data))
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = Vec::new();
|
||||
data.reserve(elem_count);
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let uniform = rand::distributions::Uniform::new(min as f32, max as f32);
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f32, _>(uniform))
|
||||
@ -1856,8 +1870,7 @@ impl BackendDevice for CpuDevice {
|
||||
Ok(CpuStorage::F32(data))
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = Vec::new();
|
||||
data.reserve(elem_count);
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let uniform = rand::distributions::Uniform::new(min, max);
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f64, _>(uniform))
|
||||
@ -1873,12 +1886,27 @@ impl BackendDevice for CpuDevice {
|
||||
let elem_count = shape.elem_count();
|
||||
let mut rng = rand::thread_rng();
|
||||
match dtype {
|
||||
DType::U8 | DType::U32 | DType::BF16 | DType::F16 => {
|
||||
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
|
||||
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
|
||||
DType::BF16 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let std = bf16::from_f64(std);
|
||||
let mean = bf16::from_f64(mean);
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean)
|
||||
}
|
||||
Ok(CpuStorage::BF16(data))
|
||||
}
|
||||
DType::F16 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let std = f16::from_f64(std);
|
||||
let mean = f16::from_f64(mean);
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean)
|
||||
}
|
||||
Ok(CpuStorage::F16(data))
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = Vec::new();
|
||||
data.reserve(elem_count);
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let std = std as f32;
|
||||
let mean = mean as f32;
|
||||
for _i in 0..elem_count {
|
||||
@ -1887,8 +1915,7 @@ impl BackendDevice for CpuDevice {
|
||||
Ok(CpuStorage::F32(data))
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = Vec::new();
|
||||
data.reserve(elem_count);
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
|
||||
}
|
||||
|
@ -255,6 +255,8 @@ impl BackendDevice for CudaDevice {
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
let slice = match dtype {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_uniform",
|
||||
@ -282,6 +284,8 @@ impl BackendDevice for CudaDevice {
|
||||
}
|
||||
|
||||
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
let slice = match dtype {
|
||||
|
@ -71,8 +71,7 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let mut vec = Vec::new();
|
||||
vec.reserve(N1 * N2 * N3);
|
||||
let mut vec = Vec::with_capacity(N1 * N2 * N3);
|
||||
for i1 in 0..N1 {
|
||||
for i2 in 0..N2 {
|
||||
vec.extend(self[i1][i2])
|
||||
@ -117,39 +116,41 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rand_uniform(
|
||||
pub(crate) fn rand_uniform<T: crate::FloatDType>(
|
||||
&self,
|
||||
lo: T,
|
||||
up: T,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
lo: f64,
|
||||
up: f64,
|
||||
) -> Result<Storage> {
|
||||
let lo = lo.to_f64();
|
||||
let up = up.to_f64();
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
|
||||
let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rand_normal(
|
||||
pub(crate) fn rand_normal<T: crate::FloatDType>(
|
||||
&self,
|
||||
mean: T,
|
||||
std: T,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
) -> Result<Storage> {
|
||||
let mean = mean.to_f64();
|
||||
let std = std.to_f64();
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
|
||||
let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
let storage = device.rand_normal(shape, T::DTYPE, mean, std)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
|
@ -120,7 +120,7 @@ with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
|
||||
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
|
||||
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
|
||||
|
||||
pub trait IntDType {
|
||||
pub trait IntDType: WithDType {
|
||||
fn is_true(&self) -> bool;
|
||||
fn as_usize(&self) -> usize;
|
||||
}
|
||||
@ -142,3 +142,10 @@ impl IntDType for u8 {
|
||||
*self as usize
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FloatDType: WithDType {}
|
||||
|
||||
impl FloatDType for f16 {}
|
||||
impl FloatDType for bf16 {}
|
||||
impl FloatDType for f32 {}
|
||||
impl FloatDType for f64 {}
|
||||
|
@ -61,7 +61,7 @@ mod variable;
|
||||
|
||||
pub use cpu_backend::CpuStorage;
|
||||
pub use device::{Device, DeviceLocation};
|
||||
pub use dtype::{DType, IntDType, WithDType};
|
||||
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use indexer::IndexOp;
|
||||
pub use layout::Layout;
|
||||
|
@ -232,55 +232,51 @@ impl Tensor {
|
||||
Tensor::zeros(self.shape(), self.dtype(), self.device())
|
||||
}
|
||||
|
||||
pub(crate) fn rand_impl<S: Into<Shape>>(
|
||||
pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>(
|
||||
lo: T,
|
||||
up: T,
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
lo: f64,
|
||||
up: f64,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let s = s.into();
|
||||
let storage = device.rand_uniform(&s, dtype, lo, up)?;
|
||||
let storage = device.rand_uniform(lo, up, &s)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, s, none, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
|
||||
pub fn rand<S: Into<Shape>>(
|
||||
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
|
||||
lo: T,
|
||||
up: T,
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
lo: f64,
|
||||
up: f64,
|
||||
) -> Result<Self> {
|
||||
Self::rand_impl(s, dtype, device, lo, up, false)
|
||||
Self::rand_impl(lo, up, s, device, false)
|
||||
}
|
||||
|
||||
pub(crate) fn randn_impl<S: Into<Shape>>(
|
||||
pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
|
||||
mean: T,
|
||||
std: T,
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let s = s.into();
|
||||
let storage = device.rand_normal(&s, dtype, mean, std)?;
|
||||
let storage = device.rand_normal(mean, std, &s)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, s, none, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor initialized with values sampled from a normal distribution with the
|
||||
/// specified `mean` and standard deviation `std`.
|
||||
pub fn randn<S: Into<Shape>>(
|
||||
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
|
||||
mean: T,
|
||||
std: T,
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
) -> Result<Self> {
|
||||
Self::randn_impl(s, dtype, device, mean, std, false)
|
||||
Self::randn_impl(mean, std, s, device, false)
|
||||
}
|
||||
|
||||
pub(crate) fn new_impl<A: crate::device::NdArray>(
|
||||
|
@ -34,25 +34,23 @@ impl Var {
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
pub fn rand<S: Into<Shape>>(
|
||||
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
|
||||
lo: T,
|
||||
up: T,
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
lo: f64,
|
||||
up: f64,
|
||||
) -> Result<Self> {
|
||||
let inner = Tensor::rand_impl(s, dtype, device, lo, up, true)?;
|
||||
let inner = Tensor::rand_impl(lo, up, s, device, true)?;
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
pub fn randn<S: Into<Shape>>(
|
||||
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
|
||||
mean: T,
|
||||
std: T,
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
) -> Result<Self> {
|
||||
let inner = Tensor::randn_impl(s, dtype, device, mean, std, true)?;
|
||||
let inner = Tensor::randn_impl(mean, std, s, device, true)?;
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user