diff --git a/Cargo.toml b/Cargo.toml index 5f06fa80..0dec835b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas # TODO: Switch back to the official gemm implementation if we manage to upstream the changes. gemm = { git = "https://github.com/LaurentMazare/gemm.git" } hf-hub = "0.1.3" -half = { version = "2.3.1", features = ["num-traits"] } +half = { version = "2.3.1", features = ["num-traits", "rand_distr"] } intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] } libc = { version = "0.2.147" } log = "0.4" diff --git a/README.md b/README.md index dfe62ab2..3a98aa6a 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,11 @@ ML framework for Rust ```rust -let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; -let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?; +let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?; +let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?; let c = a.matmul(&b)?; +println!("{c}"); ``` ## Check out our examples @@ -45,13 +46,15 @@ And then browse to ## Features -- Simple syntax (looks and like PyTorch) -- CPU and Cuda backends, m1, f16, bf16 (and tentatively wasm) +- Simple syntax, looks and like PyTorch. +- CPU and Cuda backends, m1, f16, bf16. - Enable serverless (CPU), small and fast deployments -- Model training -- Distributed computing (NCCL). -- Models out of the box (Llama, Whisper, Falcon, ...) -- Emphasis on enabling users to use custom ops/kernels +- WASM support, run your models in a browser. +- Model training. +- Distributed computing using NCCL. +- Models out of the box: Llama, Whisper, Falcon, BERT... +- Embed user-defined ops/kernels, such as [flash-attention + v2](https://github.com/LaurentMazare/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152). ## How to use ? @@ -59,9 +62,7 @@ Cheatsheet: | | Using PyTorch | Using Candle | |------------|------------------------------------------|------------------------------------------------------------------| -| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(` | -| | | ` &[[1f32, 2.]], [3., 4.]],` | -| | | ` &Device::Cpu)?` | +| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(&[[1f32, 2.]], [3., 4.]], &Device::Cpu)?` | | Indexing | `tensor[:, :4]` | `tensor.i((.., ..4))?` | | Operations | `tensor.view((2, 2))` | `tensor.reshape((2, 2))?` | | Operations | `a.matmul(b)` | `a.matmul(&b)?` | diff --git a/candle-core/examples/basics.rs b/candle-core/examples/basics.rs index a5e2b24e..88b110e0 100644 --- a/candle-core/examples/basics.rs +++ b/candle-core/examples/basics.rs @@ -5,6 +5,11 @@ use anyhow::Result; use candle::{Device, Tensor}; fn main() -> Result<()> { + let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?; + let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?; + let c = a.matmul(&b)?; + println!("{a} {b} {c}"); + let data = &[[3f32, 1., 4., 1., 5.], [2., 7., 1., 8., 2.]]; let t1 = Tensor::new(data, &Device::Cpu)?; let data2 = &[[5f32, 5., 5., 5., 5.], [2., 7., 1., 8., 2.]]; diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 8d38b158..83c7080f 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -369,8 +369,7 @@ pub fn unary_map U>( block_start_index, block_len, } => { - let mut result = vec![]; - result.reserve(layout.shape().elem_count()); + let mut result = Vec::with_capacity(layout.shape().elem_count()); // Specialize the case where block_len is one to avoid the second loop. if block_len == 1 { for index in block_start_index { @@ -1843,12 +1842,27 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::BF16 | DType::F16 => { - Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) + DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()), + DType::BF16 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = + rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max)); + for _i in 0..elem_count { + data.push(rng.sample::(uniform)) + } + Ok(CpuStorage::BF16(data)) + } + DType::F16 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = + rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max)); + for _i in 0..elem_count { + data.push(rng.sample::(uniform)) + } + Ok(CpuStorage::F16(data)) } DType::F32 => { - let mut data = Vec::new(); - data.reserve(elem_count); + let mut data = Vec::with_capacity(elem_count); let uniform = rand::distributions::Uniform::new(min as f32, max as f32); for _i in 0..elem_count { data.push(rng.sample::(uniform)) @@ -1856,8 +1870,7 @@ impl BackendDevice for CpuDevice { Ok(CpuStorage::F32(data)) } DType::F64 => { - let mut data = Vec::new(); - data.reserve(elem_count); + let mut data = Vec::with_capacity(elem_count); let uniform = rand::distributions::Uniform::new(min, max); for _i in 0..elem_count { data.push(rng.sample::(uniform)) @@ -1873,12 +1886,27 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::BF16 | DType::F16 => { - Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) + DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()), + DType::BF16 => { + let mut data = Vec::with_capacity(elem_count); + let std = bf16::from_f64(std); + let mean = bf16::from_f64(mean); + for _i in 0..elem_count { + data.push(rng.sample::(rand::distributions::Standard) * std + mean) + } + Ok(CpuStorage::BF16(data)) + } + DType::F16 => { + let mut data = Vec::with_capacity(elem_count); + let std = f16::from_f64(std); + let mean = f16::from_f64(mean); + for _i in 0..elem_count { + data.push(rng.sample::(rand::distributions::Standard) * std + mean) + } + Ok(CpuStorage::F16(data)) } DType::F32 => { - let mut data = Vec::new(); - data.reserve(elem_count); + let mut data = Vec::with_capacity(elem_count); let std = std as f32; let mean = mean as f32; for _i in 0..elem_count { @@ -1887,8 +1915,7 @@ impl BackendDevice for CpuDevice { Ok(CpuStorage::F32(data)) } DType::F64 => { - let mut data = Vec::new(); - data.reserve(elem_count); + let mut data = Vec::with_capacity(elem_count); for _i in 0..elem_count { data.push(rng.sample::(rand::distributions::Standard) * std + mean) } diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 9cc454f1..b3d542b9 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -255,6 +255,8 @@ impl BackendDevice for CudaDevice { let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); let slice = match dtype { + // TODO: Add support for F16 and BF16 though this is likely to require some upstream + // cudarc changes. DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype { dtype, op: "rand_uniform", @@ -282,6 +284,8 @@ impl BackendDevice for CudaDevice { } fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result { + // TODO: Add support for F16 and BF16 though this is likely to require some upstream + // cudarc changes. let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); let slice = match dtype { diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 53e2de43..89df8f84 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -71,8 +71,7 @@ impl NdArray } fn to_cpu_storage(&self) -> CpuStorage { - let mut vec = Vec::new(); - vec.reserve(N1 * N2 * N3); + let mut vec = Vec::with_capacity(N1 * N2 * N3); for i1 in 0..N1 { for i2 in 0..N2 { vec.extend(self[i1][i2]) @@ -117,39 +116,41 @@ impl Device { } } - pub(crate) fn rand_uniform( + pub(crate) fn rand_uniform( &self, + lo: T, + up: T, shape: &Shape, - dtype: DType, - lo: f64, - up: f64, ) -> Result { + let lo = lo.to_f64(); + let up = up.to_f64(); match self { Device::Cpu => { - let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?; + let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_uniform(shape, dtype, lo, up)?; + let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?; Ok(Storage::Cuda(storage)) } } } - pub(crate) fn rand_normal( + pub(crate) fn rand_normal( &self, + mean: T, + std: T, shape: &Shape, - dtype: DType, - mean: f64, - std: f64, ) -> Result { + let mean = mean.to_f64(); + let std = std.to_f64(); match self { Device::Cpu => { - let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?; + let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_normal(shape, dtype, mean, std)?; + let storage = device.rand_normal(shape, T::DTYPE, mean, std)?; Ok(Storage::Cuda(storage)) } } diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index c6befbb8..0e906119 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -120,7 +120,7 @@ with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64); with_dtype!(f64, F64, |v: f64| v, |v: f64| v); -pub trait IntDType { +pub trait IntDType: WithDType { fn is_true(&self) -> bool; fn as_usize(&self) -> usize; } @@ -142,3 +142,10 @@ impl IntDType for u8 { *self as usize } } + +pub trait FloatDType: WithDType {} + +impl FloatDType for f16 {} +impl FloatDType for bf16 {} +impl FloatDType for f32 {} +impl FloatDType for f64 {} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 3dbae7fc..95cc189c 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -61,7 +61,7 @@ mod variable; pub use cpu_backend::CpuStorage; pub use device::{Device, DeviceLocation}; -pub use dtype::{DType, IntDType, WithDType}; +pub use dtype::{DType, FloatDType, IntDType, WithDType}; pub use error::{Error, Result}; pub use indexer::IndexOp; pub use layout::Layout; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 28ecc357..09f61340 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -232,55 +232,51 @@ impl Tensor { Tensor::zeros(self.shape(), self.dtype(), self.device()) } - pub(crate) fn rand_impl>( + pub(crate) fn rand_impl, T: crate::FloatDType>( + lo: T, + up: T, s: S, - dtype: DType, device: &Device, - lo: f64, - up: f64, is_variable: bool, ) -> Result { let s = s.into(); - let storage = device.rand_uniform(&s, dtype, lo, up)?; + let storage = device.rand_uniform(lo, up, &s)?; let none = BackpropOp::none(); Ok(from_storage(storage, s, none, is_variable)) } /// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`. - pub fn rand>( + pub fn rand, T: crate::FloatDType>( + lo: T, + up: T, s: S, - dtype: DType, device: &Device, - lo: f64, - up: f64, ) -> Result { - Self::rand_impl(s, dtype, device, lo, up, false) + Self::rand_impl(lo, up, s, device, false) } - pub(crate) fn randn_impl>( + pub(crate) fn randn_impl, T: crate::FloatDType>( + mean: T, + std: T, s: S, - dtype: DType, device: &Device, - mean: f64, - std: f64, is_variable: bool, ) -> Result { let s = s.into(); - let storage = device.rand_normal(&s, dtype, mean, std)?; + let storage = device.rand_normal(mean, std, &s)?; let none = BackpropOp::none(); Ok(from_storage(storage, s, none, is_variable)) } /// Creates a new tensor initialized with values sampled from a normal distribution with the /// specified `mean` and standard deviation `std`. - pub fn randn>( + pub fn randn, T: crate::FloatDType>( + mean: T, + std: T, s: S, - dtype: DType, device: &Device, - mean: f64, - std: f64, ) -> Result { - Self::randn_impl(s, dtype, device, mean, std, false) + Self::randn_impl(mean, std, s, device, false) } pub(crate) fn new_impl( diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs index e26f1420..0cefee11 100644 --- a/candle-core/src/variable.rs +++ b/candle-core/src/variable.rs @@ -34,25 +34,23 @@ impl Var { Ok(Self(inner)) } - pub fn rand>( + pub fn rand, T: crate::FloatDType>( + lo: T, + up: T, s: S, - dtype: DType, device: &Device, - lo: f64, - up: f64, ) -> Result { - let inner = Tensor::rand_impl(s, dtype, device, lo, up, true)?; + let inner = Tensor::rand_impl(lo, up, s, device, true)?; Ok(Self(inner)) } - pub fn randn>( + pub fn randn, T: crate::FloatDType>( + mean: T, + std: T, s: S, - dtype: DType, device: &Device, - mean: f64, - std: f64, ) -> Result { - let inner = Tensor::randn_impl(s, dtype, device, mean, std, true)?; + let inner = Tensor::randn_impl(mean, std, s, device, true)?; Ok(Self(inner)) }