diff --git a/Cargo.toml b/Cargo.toml index 40f51fea..fe50b356 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "candle-transformers", "candle-wasm-examples/*", "candle-wasm-tests", + "tensor-tools", ] exclude = [ "candle-flash-attn", @@ -19,7 +20,7 @@ exclude = [ resolver = "2" [workspace.package] -version = "0.4.1" +version = "0.5.0" edition = "2021" description = "Minimalist ML framework." repository = "https://github.com/huggingface/candle" @@ -28,17 +29,18 @@ categories = ["science"] license = "MIT OR Apache-2.0" [workspace.dependencies] +ab_glyph = "0.2.23" accelerate-src = { version = "0.3.2" } anyhow = { version = "1", features = ["backtrace"] } byteorder = "1.4.3" -candle = { path = "./candle-core", package = "candle-core", version = "0.4.1" } -candle-datasets = { path = "./candle-datasets", version = "0.4.1" } -candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.1" } -candle-kernels = { path = "./candle-kernels", version = "0.4.1" } -candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.1" } -candle-nn = { path = "./candle-nn", version = "0.4.1" } -candle-onnx = { path = "./candle-onnx", version = "0.4.1" } -candle-transformers = { path = "./candle-transformers", version = "0.4.1" } +candle = { path = "./candle-core", package = "candle-core", version = "0.5.0" } +candle-datasets = { path = "./candle-datasets", version = "0.5.0" } +candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.0" } +candle-kernels = { path = "./candle-kernels", version = "0.5.0" } +candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.0" } +candle-nn = { path = "./candle-nn", version = "0.5.0" } +candle-onnx = { path = "./candle-onnx", version = "0.5.0" } +candle-transformers = { path = "./candle-transformers", version = "0.5.0" } clap = { version = "4.2.4", features = ["derive"] } criterion = { version = "0.5.1", default-features=false } cudarc = { version = "0.10.0", features = ["f16"] } @@ -46,19 +48,18 @@ fancy-regex = "0.13.0" gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } hf-hub = "0.3.0" half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } -image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] } -imageproc = { version = "0.23.0", default-features = false } +image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] } +imageproc = { version = "0.24.0", default-features = false } intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] } libc = { version = "0.2.147" } log = "0.4" memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] } num_cpus = "1.15.0" num-traits = "0.2.15" -parquet = { version = "50.0.0" } +parquet = { version = "51.0.0" } rand = "0.8.5" rand_distr = "0.4.3" rayon = "1.7.0" -rusttype = { version = "0.9", default-features = false } safetensors = "0.4.1" serde = { version = "1.0.171", features = ["derive"] } serde_plain = "1.0.2" diff --git a/README.md b/README.md index fd80069e..b9e603b2 100644 --- a/README.md +++ b/README.md @@ -125,10 +125,14 @@ We also provide a some command line based examples using state of the art models [RepVGG](./candle-examples/examples/repvgg): computer vision models. - [BLIP](./candle-examples/examples/blip/): image to text model, can be used to generate captions for an image. +- [CLIP](./candle-examples/examples/clip/): multi-model vision and language + model. - [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with dedicated submodels for hand-writing and printed recognition. - [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation model, generates the translated text from the input text. +- [Moondream](./candle-examples/examples/moondream/): tiny computer-vision model + that can answer real-world questions about images. Run them using commands like: ``` @@ -172,9 +176,11 @@ And then head over to - [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and serving local LLMs including an OpenAI compatible API server. - [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle. +- [`candle-coursera-ml`](https://github.com/vishpat/candle-coursera-ml): Implementation of ML algorithms from Coursera's [Machine Learning Specialization](https://www.coursera.org/specializations/machine-learning-introduction) course. - [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more. - [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle. - [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem. +- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library. If you have an addition to this list, please submit a pull request. @@ -205,7 +211,7 @@ If you have an addition to this list, please submit a pull request. - Replit-code-v1.5-3B. - Bert. - Yi-6B and Yi-34B. - - Qwen1.5. + - Qwen1.5, Qwen1.5 MoE. - RWKV v5 and v6. - Quantized LLMs. - Llama 7b, 13b, 70b, as well as the chat and code variants. diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 661bdd2a..9f94b252 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -2,8 +2,9 @@ mod benchmarks; use criterion::criterion_main; criterion_main!( - //benchmarks::affine::benches, + benchmarks::affine::benches, benchmarks::matmul::benches, - //benchmarks::random::benches, - //benchmarks::where_cond::benches + benchmarks::random::benches, + benchmarks::where_cond::benches, + benchmarks::conv_transpose2d::benches, ); diff --git a/candle-core/benches/benchmarks/conv_transpose2d.rs b/candle-core/benches/benchmarks/conv_transpose2d.rs new file mode 100644 index 00000000..7b252ec6 --- /dev/null +++ b/candle-core/benches/benchmarks/conv_transpose2d.rs @@ -0,0 +1,59 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use std::time::Instant; + +fn run( + x: &Tensor, + k: &Tensor, + padding: usize, + output_padding: usize, + stride: usize, + dilation: usize, +) { + x.conv_transpose2d(k, padding, output_padding, stride, dilation) + .unwrap(); +} + +fn run_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let t = Tensor::arange(0.0f32, 10000.0, device) + .unwrap() + .reshape((1, 4, 50, 50)) + .unwrap() + .to_dtype(dtype) + .unwrap(); + + let kernel = Tensor::arange(0.0f32, 100.0, device) + .unwrap() + .reshape((4, 1, 5, 5)) + .unwrap() + .to_dtype(dtype) + .unwrap(); + + let flops = t.dims().iter().product::() * dtype.size_in_bytes(); + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&t), black_box(&kernel), 1, 0, 1, 2); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_benchmark(c, &device, DType::F32, "conv_transpose2d_f32"); + run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16"); + run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index c45effee..a0ffa3eb 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod affine; +pub(crate) mod conv_transpose2d; pub(crate) mod matmul; pub(crate) mod random; pub(crate) mod where_cond; diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 2125af69..27ffe934 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -98,6 +98,19 @@ pub trait BackendStorage: Sized { ) -> Result; fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>; + + #[allow(clippy::too_many_arguments)] + // Similar to cudaMemcpy2D, though values are in elements and not in bytes. + fn copy2d( + &self, + _: &mut Self, + _d1: usize, + _d2: usize, + _src_stride1: usize, + _dst_stride1: usize, + _src_offset: usize, + _dst_offset: usize, + ) -> Result<()>; } pub trait BackendDevice: Sized + std::fmt::Debug + Clone { @@ -114,8 +127,16 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone { fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result; + /// # Safety + /// This function is unsafe as it doesn't initialize the underlying data store. + /// The caller should ensure that the data is properly initialized as early as possible + /// after this call. + unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result; + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result; + fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result; + fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result; fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result; diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 2a1db58a..65d91849 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -1,3 +1,4 @@ +/// Methods for backpropagation of gradients. use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp}; use crate::{Error, Result, Tensor, TensorId}; use std::collections::HashMap; @@ -111,7 +112,8 @@ impl Tensor { } Op::Unary(_node, UnaryOp::Ceil) | Op::Unary(_node, UnaryOp::Floor) - | Op::Unary(_node, UnaryOp::Round) => nodes, + | Op::Unary(_node, UnaryOp::Round) + | Op::Unary(_node, UnaryOp::Sign) => nodes, Op::Reshape(node) | Op::UpsampleNearest1D { arg: node, .. } | Op::UpsampleNearest2D { arg: node, .. } @@ -310,9 +312,32 @@ impl Tensor { Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported { op: "conv-transpose1d", })?, - Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported { - op: "conv-transpose2d", - })?, + Op::ConvTranspose2D { + arg, + kernel, + padding, + stride, + dilation, + output_padding: _output_padding, + } => { + let grad_arg = grad.conv2d(kernel, *padding, *dilation, *stride, 1)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + + let grad_kernel = grad + .transpose(0, 1)? + .conv2d(&arg.transpose(0, 1)?, *padding, *stride, *dilation, 1)? + .transpose(0, 1)?; + let sum_grad = grads.or_insert(kernel)?; + let (_, _, k0, k1) = kernel.dims4()?; + let (_, _, g_k0, g_k1) = grad_kernel.dims4()?; + let grad_kernel = if g_k0 != k0 || g_k1 != k1 { + grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)? + } else { + grad_kernel + }; + *sum_grad = sum_grad.add(&grad_kernel)?; + } Op::AvgPool2D { arg, kernel_size, @@ -464,7 +489,6 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad)?; } - Op::Cmp(_args, _) => {} Op::Reduce(arg, ReduceOp::Max, reduced_dims) => { let node = broadcast_back(arg, node, reduced_dims)?; let grad = broadcast_back(arg, &grad, reduced_dims)?; @@ -554,20 +578,18 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } - Op::Reduce(_, ReduceOp::ArgMin, _) => {} - Op::Reduce(_, ReduceOp::ArgMax, _) => {} + Op::Unary(_, UnaryOp::Floor) + | Op::Unary(_, UnaryOp::Round) + | Op::Reduce(_, ReduceOp::ArgMin, _) + | Op::Reduce(_, ReduceOp::ArgMax, _) + | Op::Unary(_, UnaryOp::Sign) + | Op::Cmp(_, _) => {} Op::Reshape(arg) => { let arg_grad = grad.reshape(arg.dims())?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?, - Op::Unary(_, UnaryOp::Floor) => { - Err(Error::BackwardNotSupported { op: "floor" })? - } - Op::Unary(_, UnaryOp::Round) => { - Err(Error::BackwardNotSupported { op: "round" })? - } Op::Unary(arg, UnaryOp::Gelu) => { let sum_grad = grads.or_insert(arg)?; let cube = arg.powf(3.)?; @@ -690,30 +712,38 @@ impl Tensor { } } +/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation. #[derive(Debug)] pub struct GradStore(HashMap); impl GradStore { + /// Create a new gradient store fn new() -> Self { GradStore(HashMap::new()) } + /// Get the gradient tensor corresponding to the given tensor id pub fn get_id(&self, id: TensorId) -> Option<&Tensor> { self.0.get(&id) } + /// Get the gradient tensor associated with the given tensor pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> { self.0.get(&tensor.id()) } + /// Remove the gradient tensor associated with the given tensor, returning it if it exists pub fn remove(&mut self, tensor: &Tensor) -> Option { self.0.remove(&tensor.id()) } + /// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option { self.0.insert(tensor.id(), grad) } + /// Get the gradient tensor associated with the given tensor, or, if it does not exist, + /// insert a tensor of zeroes, with the same shape and type as the given tensors and return it fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> { use std::collections::hash_map::Entry; let grad = match self.0.entry(tensor.id()) { diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend/mod.rs similarity index 85% rename from candle-core/src/cpu_backend.rs rename to candle-core/src/cpu_backend/mod.rs index 8ae39020..6f8250f0 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -4,7 +4,13 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; use half::{bf16, f16}; use rayon::prelude::*; +mod utils; +pub use utils::{ + binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8, +}; + const USE_IM2COL_CONV1D: bool = true; +const USE_IM2COL_CONV1D_TR: bool = true; const USE_IM2COL_CONV2D: bool = true; // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + @@ -23,102 +29,6 @@ pub enum CpuStorage { #[derive(Debug, Clone)] pub struct CpuDevice; -pub trait Map1 { - fn f(&self, vs: &[T], layout: &Layout) -> Result>; - - fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result { - match vs { - CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)), - CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)), - CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)), - CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)), - CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)), - CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)), - CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)), - } - } -} - -pub trait Map1Any { - fn f) -> CpuStorage>( - &self, - vs: &[T], - layout: &Layout, - wrap: W, - ) -> Result; - - fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result { - match vs { - CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?), - CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?), - CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?), - CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?), - CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?), - CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?), - CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?), - } - } -} - -type C = CpuStorage; -pub trait Map2 { - const OP: &'static str; - fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; - - fn map( - &self, - v1: &CpuStorage, - l1: &Layout, - v2: &CpuStorage, - l2: &Layout, - ) -> Result { - match (v1, v2) { - (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), - (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)), - (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)), - (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)), - (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)), - (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)), - (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)), - _ => Err(Error::DTypeMismatchBinaryOp { - lhs: v1.dtype(), - rhs: v2.dtype(), - op: Self::OP, - } - .bt()), - } - } -} - -pub trait Map2U8 { - const OP: &'static str; - fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; - - fn map( - &self, - v1: &CpuStorage, - l1: &Layout, - v2: &CpuStorage, - l2: &Layout, - ) -> Result { - match (v1, v2) { - (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), - (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), - (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), - (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), - (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), - (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), - (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), - _ => Err(Error::DTypeMismatchBinaryOp { - lhs: v1.dtype(), - rhs: v2.dtype(), - op: Self::OP, - } - .bt()), - } - } -} - struct Cmp(CmpOp); impl Map2U8 for Cmp { const OP: &'static str = "cmp"; @@ -365,275 +275,6 @@ impl<'a> Map1 for ReduceSum<'a> { } } -pub fn unary_map U>( - vs: &[T], - layout: &Layout, - mut f: F, -) -> Vec { - match layout.strided_blocks() { - crate::StridedBlocks::SingleBlock { start_offset, len } => vs - [start_offset..start_offset + len] - .iter() - .map(|&v| f(v)) - .collect(), - crate::StridedBlocks::MultipleBlocks { - block_start_index, - block_len, - } => { - let mut result = Vec::with_capacity(layout.shape().elem_count()); - // Specialize the case where block_len is one to avoid the second loop. - if block_len == 1 { - for index in block_start_index { - let v = unsafe { vs.get_unchecked(index) }; - result.push(f(*v)) - } - } else { - for index in block_start_index { - for offset in 0..block_len { - let v = unsafe { vs.get_unchecked(index + offset) }; - result.push(f(*v)) - } - } - } - result - } - } -} - -pub fn unary_map_vec U, FV: FnMut(&[T], &mut [U])>( - vs: &[T], - layout: &Layout, - mut f: F, - mut f_vec: FV, -) -> Vec { - match layout.strided_blocks() { - crate::StridedBlocks::SingleBlock { start_offset, len } => { - let mut ys: Vec = Vec::with_capacity(len); - let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; - f_vec(&vs[start_offset..start_offset + len], ys_to_set); - // SAFETY: values are all set by f_vec. - unsafe { ys.set_len(len) }; - ys - } - crate::StridedBlocks::MultipleBlocks { - block_start_index, - block_len, - } => { - let el_count = layout.shape().elem_count(); - // Specialize the case where block_len is one to avoid the second loop. - if block_len == 1 { - let mut result = Vec::with_capacity(el_count); - for index in block_start_index { - let v = unsafe { vs.get_unchecked(index) }; - result.push(f(*v)) - } - result - } else { - let mut ys: Vec = Vec::with_capacity(el_count); - let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; - let mut dst_index = 0; - for src_index in block_start_index { - let vs = &vs[src_index..src_index + block_len]; - let ys = &mut ys_to_set[dst_index..dst_index + block_len]; - f_vec(vs, ys); - dst_index += block_len; - } - // SAFETY: values are all set by f_vec. - unsafe { ys.set_len(el_count) }; - ys - } - } - } -} - -// This function maps over two strided index sequences. -pub fn binary_map U>( - lhs_l: &Layout, - rhs_l: &Layout, - lhs: &[T], - rhs: &[T], - mut f: F, -) -> Vec { - match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { - (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2] - .iter() - .zip(rhs[o_r1..o_r2].iter()) - .map(|(&l, &r)| f(l, r)) - .collect(), - (Some((o_l1, o_l2)), None) => { - // TODO: Maybe we want to avoid going through the layout twice. - match rhs_l.offsets_b() { - Some(ob) => { - let mut i_in_block = 0; - let mut i_right_broadcast = 0; - lhs[o_l1..o_l2] - .iter() - .map(|&l| { - let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) }; - i_right_broadcast += 1; - if i_right_broadcast >= ob.right_broadcast { - i_in_block += 1; - i_right_broadcast = 0; - } - if i_in_block >= ob.len { - i_in_block = 0 - } - f(l, *r) - }) - .collect() - } - None => lhs_l - .strided_index() - .zip(rhs_l.strided_index()) - .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) - .collect(), - } - } - (None, Some((o_r1, o_r2))) => { - // TODO: Maybe we want to avoid going through the layout twice. - match lhs_l.offsets_b() { - Some(ob) => { - let mut i_in_block = 0; - let mut i_right_broadcast = 0; - rhs[o_r1..o_r2] - .iter() - .map(|&r| { - let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) }; - i_right_broadcast += 1; - if i_right_broadcast >= ob.right_broadcast { - i_in_block += 1; - i_right_broadcast = 0; - } - if i_in_block >= ob.len { - i_in_block = 0 - } - f(*l, r) - }) - .collect() - } - None => lhs_l - .strided_index() - .zip(rhs_l.strided_index()) - .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) - .collect(), - } - } - _ => lhs_l - .strided_index() - .zip(rhs_l.strided_index()) - .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) - .collect(), - } -} - -// Similar to binary_map but with vectorized variants. -pub fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])>( - lhs_l: &Layout, - rhs_l: &Layout, - lhs: &[T], - rhs: &[T], - mut f: F, - mut f_vec: FV, -) -> Vec { - let el_count = lhs_l.shape().elem_count(); - match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { - (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => { - let mut ys: Vec = Vec::with_capacity(el_count); - let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; - f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set); - // SAFETY: values are all set by f_vec. - unsafe { ys.set_len(el_count) }; - ys - } - (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() { - Some(ob) if ob.right_broadcast == 1 => { - let rhs = &rhs[ob.start..ob.start + ob.len]; - let mut ys: Vec = Vec::with_capacity(el_count); - let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; - let mut dst_i = 0; - for src_i in (o_l1..o_l2).step_by(ob.len) { - f_vec( - &lhs[src_i..src_i + ob.len], - rhs, - &mut ys_to_set[dst_i..dst_i + ob.len], - ); - dst_i += ob.len; - } - // SAFETY: values are all set by f_vec. - unsafe { ys.set_len(el_count) }; - ys - } - Some(ob) => { - let rhs = &rhs[ob.start..ob.start + ob.len]; - let mut ys = lhs[o_l1..o_l2].to_vec(); - for idx_l in 0..ob.left_broadcast { - let start = idx_l * ob.len * ob.right_broadcast; - for (i, &r) in rhs.iter().enumerate() { - let start = start + i * ob.right_broadcast; - for v in ys[start..start + ob.right_broadcast].iter_mut() { - *v = f(*v, r) - } - } - } - ys - } - None => lhs_l - .strided_index() - .zip(rhs_l.strided_index()) - .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) - .collect(), - }, - (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() { - Some(ob) if ob.right_broadcast == 1 => { - let lhs = &lhs[ob.start..ob.start + ob.len]; - let mut ys: Vec = Vec::with_capacity(el_count); - let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; - let mut dst_i = 0; - for src_i in (o_r1..o_r2).step_by(ob.len) { - f_vec( - lhs, - &rhs[src_i..src_i + ob.len], - &mut ys_to_set[dst_i..dst_i + ob.len], - ); - dst_i += ob.len; - } - // SAFETY: values are all set by f_vec. - unsafe { ys.set_len(el_count) }; - ys - } - Some(ob) => { - let lhs = &lhs[ob.start..ob.start + ob.len]; - let mut ys = rhs[o_r1..o_r2].to_vec(); - for idx_l in 0..ob.left_broadcast { - let start = idx_l * ob.len * ob.right_broadcast; - for (i, &l) in lhs.iter().enumerate() { - let start = start + i * ob.right_broadcast; - for v in ys[start..start + ob.right_broadcast].iter_mut() { - *v = f(l, *v) - } - } - } - ys - } - None => lhs_l - .strided_index() - .zip(rhs_l.strided_index()) - .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) - .collect(), - }, - _ => lhs_l - .strided_index() - .zip(rhs_l.strided_index()) - .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) - .collect(), - } -} - struct Affine(f64, f64); impl Map1 for Affine { @@ -1022,6 +663,26 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> { } } +#[allow(clippy::too_many_arguments)] +fn copy2d_( + src: &[T], + dst: &mut [T], + d1: usize, + d2: usize, + src_stride1: usize, + dst_stride1: usize, + src_offset: usize, + dst_offset: usize, +) { + for i1 in 0..d1 { + let dst_idx = i1 * dst_stride1 + dst_offset; + let src_idx = i1 * src_stride1 + src_offset; + let dst = &mut dst[dst_idx..dst_idx + d2]; + let src = &src[src_idx..src_idx + d2]; + dst.copy_from_slice(src) + } +} + fn copy_strided_src_(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) { match src_l.strided_blocks() { crate::StridedBlocks::SingleBlock { start_offset, len } => { @@ -1256,6 +917,34 @@ impl Map1 for Im2Col { } } +struct Col2Im1D { + stride: usize, +} + +impl Map1 for Col2Im1D { + fn f(&self, col: &[T], l: &Layout) -> Result> { + let (b_size, l_in, c_out, k_size) = l.shape().dims4()?; + let stride = self.stride; + let l_out = (l_in - 1) * stride + k_size; + let mut im = vec![T::zero(); b_size * c_out * l_out]; + let (dst_s0, dst_s1) = (c_out * l_out, l_out); + let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size); + for l_in_i in 0..l_in { + for k_i in 0..k_size { + let l_out_i = l_in_i * stride + k_i; + for b_i in 0..b_size { + for c_i in 0..c_out { + let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i; + let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i; + im[dst_idx] += col[src_idx] + } + } + } + } + Ok(im) + } +} + struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D); impl<'a> Map2 for ConvTranspose1D<'a> { @@ -1515,6 +1204,30 @@ impl MatMul { })) .bt() } + + fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> { + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + let rank = lhs_stride.len(); + let (_b, m, n, k) = self.0; + let a_skip: usize = match lhs_stride[..rank - 2] { + [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, + [_, stride] if lhs_l.dims()[0] == 1 => stride, + [stride, _] if lhs_l.dims()[1] == 1 => stride, + [stride] => stride, + [] => m * k, + _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?, + }; + let b_skip: usize = match rhs_stride[..rank - 2] { + [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, + [_, stride] if rhs_l.dims()[0] == 1 => stride, + [stride, _] if rhs_l.dims()[1] == 1 => stride, + [stride] => stride, + [] => n * k, + _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?, + }; + Ok((a_skip, b_skip)) + } } impl Map2 for MatMul { @@ -1548,18 +1261,7 @@ impl Map2 for MatMul { let rhs_cs = rhs_stride[rank - 1]; let rhs_rs = rhs_stride[rank - 2]; - let a_skip: usize = match lhs_stride[..rank - 2] { - [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, - [stride] => stride, - [] => m * k, - _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?, - }; - let b_skip: usize = match rhs_stride[..rank - 2] { - [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, - [stride] => stride, - [] => n * k, - _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?, - }; + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; let c_skip: usize = m * n; let dst_shape: Shape = (m, n).into(); @@ -1619,20 +1321,8 @@ impl Map2 for MatMul { let lhs_stride = lhs_l.stride(); let rhs_stride = rhs_l.stride(); - let rank = lhs_stride.len(); - let a_skip: usize = match lhs_stride[..rank - 2] { - [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, - [stride] => stride, - [] => m * k, - _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?, - }; - let b_skip: usize = match rhs_stride[..rank - 2] { - [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, - [stride] => stride, - [] => n * k, - _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?, - }; + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; let c_skip: usize = m * n; let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; @@ -1640,7 +1330,7 @@ impl Map2 for MatMul { let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; - let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n { + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { (n as i32, b'N') } else if rhs_m1 == k && rhs_m2 == 1 { (k as i32, b'T') @@ -1648,7 +1338,7 @@ impl Map2 for MatMul { Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? }; // The b tensor has dims batching, m, k (lhs) - let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k { + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { (k as i32, b'N') } else if lhs_m1 == m && lhs_m2 == 1 { (m as i32, b'T') @@ -1722,20 +1412,8 @@ impl Map2 for MatMul { let lhs_stride = lhs_l.stride(); let rhs_stride = rhs_l.stride(); - let rank = lhs_stride.len(); - let a_skip: usize = match lhs_stride[..rank - 2] { - [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, - [stride] => stride, - [] => m * k, - _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?, - }; - let b_skip: usize = match rhs_stride[..rank - 2] { - [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, - [stride] => stride, - [] => n * k, - _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?, - }; + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; let c_skip: usize = m * n; let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; @@ -1743,7 +1421,7 @@ impl Map2 for MatMul { let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; - let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n { + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { (n as i32, b'N') } else if rhs_m1 == k && rhs_m2 == 1 { (k as i32, b'T') @@ -1751,7 +1429,7 @@ impl Map2 for MatMul { Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? }; // The b tensor has dims batching, m, k (lhs) - let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k { + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { (k as i32, b'N') } else if lhs_m1 == m && lhs_m2 == 1 { (m as i32, b'T') @@ -2423,6 +2101,48 @@ impl BackendStorage for CpuStorage { } } + fn copy2d( + &self, + dst: &mut Self, + d1: usize, + d2: usize, + src_s: usize, + dst_s: usize, + src_o: usize, + dst_o: usize, + ) -> Result<()> { + match (self, dst) { + (Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o), + (Self::U32(src), Self::U32(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::I64(src), Self::I64(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::BF16(src), Self::BF16(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::F16(src), Self::F16(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::F32(src), Self::F32(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::F64(src), Self::F64(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (_, dst) => { + return Err(Error::DTypeMismatchBinaryOp { + lhs: self.dtype(), + rhs: dst.dtype(), + op: "copy2d", + } + .bt()); + } + } + Ok(()) + } + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { match (self, dst) { (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), @@ -2491,7 +2211,10 @@ impl BackendStorage for CpuStorage { col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. - let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + let mut kernel_c = unsafe { + self.device() + .alloc_uninit(kernel_l.shape(), kernel.dtype())? + }; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? @@ -2499,7 +2222,7 @@ impl BackendStorage for CpuStorage { col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?; - let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; res.copy_strided_src(&mut res_t, 0, &res_l)?; Ok(res_t) } @@ -2511,7 +2234,52 @@ impl BackendStorage for CpuStorage { kernel_l: &Layout, params: &crate::conv::ParamsConvTranspose1D, ) -> Result { - ConvTranspose1D(params).map(self, l, kernel, kernel_l) + let can_use_col2im = kernel_l.is_contiguous() + && params.dilation == 1 + && params.padding == 0 + && params.output_padding == 0; + if USE_IM2COL_CONV1D_TR && can_use_col2im { + let (b_size, c_in, l_in) = l.shape().dims3()?; + let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?; + if !kernel_l.is_contiguous() { + crate::bail!( + "convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}" + ) + } + if c_in != c_in2 { + crate::bail!( + "convtr1d: shape mismatch on c_in {:?} {:?}", + l.shape(), + kernel_l.shape() + ) + } + let col = { + // This merges the last two dimensions of the kernel together. + let kernel_l_mm = Layout::new( + (b_size, c_in, k_size * c_out).into(), + vec![0, k_size * c_out, 1], + kernel_l.start_offset(), + ); + self.matmul( + kernel, + ( + b_size, + /* m */ l_in, + /* n */ c_out * k_size, + /* k */ c_in, + ), + &l.transpose(1, 2)?, + &kernel_l_mm, + )? + }; + let col_l = Layout::contiguous((b_size, l_in, c_out, k_size)); + Col2Im1D { + stride: params.stride, + } + .map(&col, &col_l) + } else { + ConvTranspose1D(params).map(self, l, kernel, kernel_l) + } } fn conv2d( @@ -2545,7 +2313,10 @@ impl BackendStorage for CpuStorage { col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. - let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + let mut kernel_c = unsafe { + self.device() + .alloc_uninit(kernel_l.shape(), kernel.dtype())? + }; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? @@ -2555,7 +2326,7 @@ impl BackendStorage for CpuStorage { let res_l = Layout::contiguous((b, h_out, w_out, params.c_out)) .transpose(1, 2)? .transpose(1, 3)?; - let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; res.copy_strided_src(&mut res_t, 0, &res_l)?; Ok(res_t) } @@ -2678,6 +2449,10 @@ impl BackendDevice for CpuDevice { Ok(s.clone()) } + fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result { + Ok(s) + } + fn new(_: usize) -> Result { Ok(Self) } @@ -2779,6 +2554,53 @@ impl BackendDevice for CpuDevice { } } + #[allow(clippy::uninit_vec)] + unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result { + let elem_count = shape.elem_count(); + // The code below is highly unsafe but hopefully not directly unsound as we only consider + // types that are Copy, not Drop, and for which all bit patterns are proper values. + // It's still pretty risky, see the following for more details: + // https://github.com/rust-lang/rust-clippy/issues/4483 + let storage = match dtype { + DType::U8 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::U8(v) + } + DType::U32 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::U32(v) + } + DType::I64 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::I64(v) + } + DType::BF16 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::BF16(v) + } + DType::F16 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::F16(v) + } + DType::F32 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::F32(v) + } + DType::F64 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::F64(v) + } + }; + Ok(storage) + } + fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { let elem_count = shape.elem_count(); let storage = match dtype { diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs new file mode 100644 index 00000000..af25a2af --- /dev/null +++ b/candle-core/src/cpu_backend/utils.rs @@ -0,0 +1,350 @@ +/// Helper functions to write CPU kernels. +use crate::backend::BackendStorage; +use crate::{Error, Layout, Result, WithDType}; + +type C = super::CpuStorage; +pub trait Map1 { + fn f(&self, vs: &[T], layout: &Layout) -> Result>; + + fn map(&self, vs: &C, layout: &Layout) -> Result { + match vs { + C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)), + C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)), + C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)), + C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)), + C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)), + C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)), + C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)), + } + } +} + +pub trait Map1Any { + fn f) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result; + + fn map(&self, vs: &C, layout: &Layout) -> Result { + match vs { + C::U8(vs) => Ok(self.f(vs, layout, C::U8)?), + C::U32(vs) => Ok(self.f(vs, layout, C::U32)?), + C::I64(vs) => Ok(self.f(vs, layout, C::I64)?), + C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?), + C::F16(vs) => Ok(self.f(vs, layout, C::F16)?), + C::F32(vs) => Ok(self.f(vs, layout, C::F32)?), + C::F64(vs) => Ok(self.f(vs, layout, C::F64)?), + } + } +} + +pub trait Map2 { + const OP: &'static str; + fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; + + fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result { + match (v1, v2) { + (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)), + (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)), + (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)), + (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)), + (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)), + (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)), + _ => Err(Error::DTypeMismatchBinaryOp { + lhs: v1.dtype(), + rhs: v2.dtype(), + op: Self::OP, + } + .bt()), + } + } +} + +pub trait Map2U8 { + const OP: &'static str; + fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; + + fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result { + match (v1, v2) { + (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + _ => Err(Error::DTypeMismatchBinaryOp { + lhs: v1.dtype(), + rhs: v2.dtype(), + op: Self::OP, + } + .bt()), + } + } +} + +pub fn binary_map U>( + lhs_l: &Layout, + rhs_l: &Layout, + lhs: &[T], + rhs: &[T], + mut f: F, +) -> Vec { + match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { + (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2] + .iter() + .zip(rhs[o_r1..o_r2].iter()) + .map(|(&l, &r)| f(l, r)) + .collect(), + (Some((o_l1, o_l2)), None) => { + // TODO: Maybe we want to avoid going through the layout twice. + match rhs_l.offsets_b() { + Some(ob) => { + let mut i_in_block = 0; + let mut i_right_broadcast = 0; + lhs[o_l1..o_l2] + .iter() + .map(|&l| { + let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) }; + i_right_broadcast += 1; + if i_right_broadcast >= ob.right_broadcast { + i_in_block += 1; + i_right_broadcast = 0; + } + if i_in_block >= ob.len { + i_in_block = 0 + } + f(l, *r) + }) + .collect() + } + None => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + } + } + (None, Some((o_r1, o_r2))) => { + // TODO: Maybe we want to avoid going through the layout twice. + match lhs_l.offsets_b() { + Some(ob) => { + let mut i_in_block = 0; + let mut i_right_broadcast = 0; + rhs[o_r1..o_r2] + .iter() + .map(|&r| { + let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) }; + i_right_broadcast += 1; + if i_right_broadcast >= ob.right_broadcast { + i_in_block += 1; + i_right_broadcast = 0; + } + if i_in_block >= ob.len { + i_in_block = 0 + } + f(*l, r) + }) + .collect() + } + None => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + } + } + _ => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + } +} + +// Similar to binary_map but with vectorized variants. +pub fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])>( + lhs_l: &Layout, + rhs_l: &Layout, + lhs: &[T], + rhs: &[T], + mut f: F, + mut f_vec: FV, +) -> Vec { + let el_count = lhs_l.shape().elem_count(); + match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { + (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => { + let mut ys: Vec = Vec::with_capacity(el_count); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; + f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set); + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(el_count) }; + ys + } + (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() { + Some(ob) if ob.right_broadcast == 1 => { + let rhs = &rhs[ob.start..ob.start + ob.len]; + let mut ys: Vec = Vec::with_capacity(el_count); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; + let mut dst_i = 0; + for src_i in (o_l1..o_l2).step_by(ob.len) { + f_vec( + &lhs[src_i..src_i + ob.len], + rhs, + &mut ys_to_set[dst_i..dst_i + ob.len], + ); + dst_i += ob.len; + } + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(el_count) }; + ys + } + Some(ob) => { + let rhs = &rhs[ob.start..ob.start + ob.len]; + let mut ys = lhs[o_l1..o_l2].to_vec(); + for idx_l in 0..ob.left_broadcast { + let start = idx_l * ob.len * ob.right_broadcast; + for (i, &r) in rhs.iter().enumerate() { + let start = start + i * ob.right_broadcast; + for v in ys[start..start + ob.right_broadcast].iter_mut() { + *v = f(*v, r) + } + } + } + ys + } + None => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + }, + (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() { + Some(ob) if ob.right_broadcast == 1 => { + let lhs = &lhs[ob.start..ob.start + ob.len]; + let mut ys: Vec = Vec::with_capacity(el_count); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; + let mut dst_i = 0; + for src_i in (o_r1..o_r2).step_by(ob.len) { + f_vec( + lhs, + &rhs[src_i..src_i + ob.len], + &mut ys_to_set[dst_i..dst_i + ob.len], + ); + dst_i += ob.len; + } + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(el_count) }; + ys + } + Some(ob) => { + let lhs = &lhs[ob.start..ob.start + ob.len]; + let mut ys = rhs[o_r1..o_r2].to_vec(); + for idx_l in 0..ob.left_broadcast { + let start = idx_l * ob.len * ob.right_broadcast; + for (i, &l) in lhs.iter().enumerate() { + let start = start + i * ob.right_broadcast; + for v in ys[start..start + ob.right_broadcast].iter_mut() { + *v = f(l, *v) + } + } + } + ys + } + None => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + }, + _ => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + } +} + +pub fn unary_map U>( + vs: &[T], + layout: &Layout, + mut f: F, +) -> Vec { + match layout.strided_blocks() { + crate::StridedBlocks::SingleBlock { start_offset, len } => vs + [start_offset..start_offset + len] + .iter() + .map(|&v| f(v)) + .collect(), + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } => { + let mut result = Vec::with_capacity(layout.shape().elem_count()); + // Specialize the case where block_len is one to avoid the second loop. + if block_len == 1 { + for index in block_start_index { + let v = unsafe { vs.get_unchecked(index) }; + result.push(f(*v)) + } + } else { + for index in block_start_index { + for offset in 0..block_len { + let v = unsafe { vs.get_unchecked(index + offset) }; + result.push(f(*v)) + } + } + } + result + } + } +} + +pub fn unary_map_vec U, FV: FnMut(&[T], &mut [U])>( + vs: &[T], + layout: &Layout, + mut f: F, + mut f_vec: FV, +) -> Vec { + match layout.strided_blocks() { + crate::StridedBlocks::SingleBlock { start_offset, len } => { + let mut ys: Vec = Vec::with_capacity(len); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; + f_vec(&vs[start_offset..start_offset + len], ys_to_set); + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(len) }; + ys + } + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } => { + let el_count = layout.shape().elem_count(); + // Specialize the case where block_len is one to avoid the second loop. + if block_len == 1 { + let mut result = Vec::with_capacity(el_count); + for index in block_start_index { + let v = unsafe { vs.get_unchecked(index) }; + result.push(f(*v)) + } + result + } else { + let mut ys: Vec = Vec::with_capacity(el_count); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; + let mut dst_index = 0; + for src_index in block_start_index { + let vs = &vs[src_index..src_index + block_len]; + let ys = &mut ys_to_set[dst_index..dst_index + block_len]; + f_vec(vs, ys); + dst_index += block_len; + } + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(el_count) }; + ys + } + } + } +} diff --git a/candle-core/src/cudnn.rs b/candle-core/src/cuda_backend/cudnn.rs similarity index 100% rename from candle-core/src/cudnn.rs rename to candle-core/src/cuda_backend/cudnn.rs diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs new file mode 100644 index 00000000..0859d756 --- /dev/null +++ b/candle-core/src/cuda_backend/device.rs @@ -0,0 +1,410 @@ +use crate::backend::BackendDevice; +use crate::{CpuStorage, DType, Layout, Result, Shape}; +pub use candle_kernels as kernels; +pub use cudarc; +use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig}; +use half::{bf16, f16}; +use std::sync::{Arc, Mutex}; + +use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr}; + +/// Unique identifier for cuda devices. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct DeviceId(usize); + +impl DeviceId { + fn new() -> Self { + // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805 + use std::sync::atomic; + static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1); + Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed)) + } +} + +struct CudaRng(cudarc::curand::CudaRng); +unsafe impl Send for CudaRng {} + +#[derive(Clone)] +pub struct CudaDevice { + id: DeviceId, + device: Arc, + pub(crate) blas: Arc, + curand: Arc>, +} + +impl std::fmt::Debug for CudaDevice { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "CudaDevice({:?})", self.id) + } +} + +impl std::ops::Deref for CudaDevice { + type Target = Arc; + + fn deref(&self) -> &Self::Target { + &self.device + } +} + +impl CudaDevice { + pub fn cuda_device(&self) -> Arc { + self.device.clone() + } + + pub fn id(&self) -> DeviceId { + self.id + } + + fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result { + let elem_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + let slice = match dtype { + DType::U8 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_u8", kernels::FILL)?; + let params = (&data, v as u8, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::U8(data) + } + DType::U32 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_u32", kernels::FILL)?; + let params = (&data, v as u32, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::U32(data) + } + DType::I64 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_i64", kernels::FILL)?; + let params = (&data, v as i64, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I64(data) + } + DType::BF16 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_bf16", kernels::FILL)?; + let params = (&data, bf16::from_f64(v), elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::BF16(data) + } + DType::F16 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_f16", kernels::FILL)?; + let params = (&data, f16::from_f64(v), elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F16(data) + } + DType::F32 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_f32", kernels::FILL)?; + let params = (&data, v as f32, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_f64", kernels::FILL)?; + let params = (&data, v, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result { + if !self.has_func(module_name, module_name) { + // Leaking the string here is a bit sad but we need a &'static str and this is only + // done once per kernel name. + let static_module_name = Box::leak(module_name.to_string().into_boxed_str()); + self.load_ptx(ptx.into(), module_name, &[static_module_name]) + .map_err(|cuda| CudaError::Load { + cuda, + module_name: module_name.to_string(), + }) + .w()?; + } + self.get_func(module_name, module_name) + // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is + // able to only build the error value if needed. + .ok_or(CudaError::MissingKernel { + module_name: module_name.to_string(), + }) + .w() + } +} + +impl BackendDevice for CudaDevice { + type Storage = CudaStorage; + + fn new(ordinal: usize) -> Result { + let device = cudarc::driver::CudaDevice::new(ordinal).w()?; + let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?; + let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?; + Ok(Self { + id: DeviceId::new(), + device, + blas: Arc::new(blas), + curand: Arc::new(Mutex::new(CudaRng(curand))), + }) + } + + fn set_seed(&self, seed: u64) -> Result<()> { + // We do not call set_seed but instead create a new curand object. This ensures that the + // state will be identical and the same random numbers will be generated. + let mut curand = self.curand.lock().unwrap(); + curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?; + Ok(()) + } + + fn location(&self) -> crate::DeviceLocation { + crate::DeviceLocation::Cuda { + gpu_id: self.device.ordinal(), + } + } + + fn same_device(&self, rhs: &Self) -> bool { + self.id == rhs.id + } + + fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { + let elem_count = shape.elem_count(); + let slice = match dtype { + DType::U8 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::U8(data) + } + DType::U32 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::U32(data) + } + DType::I64 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::I64(data) + } + DType::BF16 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::BF16(data) + } + DType::F16 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F16(data) + } + DType::F32 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result { + let elem_count = shape.elem_count(); + let curand = self.curand.lock().unwrap(); + let slice = match dtype { + // TODO: Add support for F16 and BF16 though this is likely to require some upstream + // cudarc changes. + DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { + Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_uniform", + }) + .w()? + } + DType::F32 => { + let mut data = unsafe { self.alloc::(elem_count) }.w()?; + curand.0.fill_with_uniform(&mut data).w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let mut data = unsafe { self.alloc::(elem_count) }.w()?; + curand.0.fill_with_uniform(&mut data).w()?; + CudaStorageSlice::F64(data) + } + }; + let slice = if lo == 0. && up == 1.0 { + slice + } else { + use super::utils::Map1; + let layout = Layout::contiguous(shape); + super::Affine(up - lo, lo).map(&slice, self, &layout)? + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result { + // TODO: Add support for F16 and BF16 though this is likely to require some upstream + // cudarc changes. + let elem_count = shape.elem_count(); + let curand = self.curand.lock().unwrap(); + // curand can only generate an odd number of values. + // https://github.com/huggingface/candle/issues/734 + let elem_count_round = if elem_count % 2 == 1 { + elem_count + 1 + } else { + elem_count + }; + let slice = match dtype { + DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { + Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_normal", + }) + .w()? + } + DType::F32 => { + let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; + curand + .0 + .fill_with_normal(&mut data, mean as f32, std as f32) + .w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; + curand.0.fill_with_normal(&mut data, mean, std).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { + self.const_impl(1., shape, dtype) + } + + unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result { + let elem_count = shape.elem_count(); + let slice = match dtype { + DType::U8 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::U8(data) + } + DType::U32 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::U32(data) + } + DType::I64 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::I64(data) + } + DType::BF16 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::BF16(data) + } + DType::F16 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F16(data) + } + DType::F32 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { + let slice = match storage { + CpuStorage::U8(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::U8(data) + } + CpuStorage::U32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::U32(data) + } + CpuStorage::I64(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I64(data) + } + CpuStorage::BF16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::BF16(data) + } + CpuStorage::F16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F16(data) + } + CpuStorage::F32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F32(data) + } + CpuStorage::F64(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result { + let slice = match storage { + CpuStorage::U8(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::U8(data) + } + CpuStorage::U32(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::U32(data) + } + CpuStorage::I64(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::I64(data) + } + CpuStorage::BF16(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::BF16(data) + } + CpuStorage::F16(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F16(data) + } + CpuStorage::F32(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F32(data) + } + CpuStorage::F64(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } +} diff --git a/candle-core/src/cuda_backend/error.rs b/candle-core/src/cuda_backend/error.rs new file mode 100644 index 00000000..bd6f8ac6 --- /dev/null +++ b/candle-core/src/cuda_backend/error.rs @@ -0,0 +1,62 @@ +use crate::{DType, Layout}; + +/// cudarc related errors +#[derive(thiserror::Error, Debug)] +pub enum CudaError { + #[error(transparent)] + Cuda(#[from] cudarc::driver::DriverError), + + #[error(transparent)] + Compiler(#[from] cudarc::nvrtc::CompileError), + + #[error(transparent)] + Cublas(#[from] cudarc::cublas::result::CublasError), + + #[error(transparent)] + Curand(#[from] cudarc::curand::result::CurandError), + + #[error("missing kernel '{module_name}'")] + MissingKernel { module_name: String }, + + #[error("unsupported dtype {dtype:?} for {op}")] + UnsupportedDtype { dtype: DType, op: &'static str }, + + #[error("internal error '{0}'")] + InternalError(&'static str), + + #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] + MatMulNonContiguous { + lhs_stride: Layout, + rhs_stride: Layout, + mnk: (usize, usize, usize), + }, + + #[error("{msg}, expected: {expected:?}, got: {got:?}")] + UnexpectedDType { + msg: &'static str, + expected: DType, + got: DType, + }, + + #[error("{cuda} when loading {module_name}")] + Load { + cuda: cudarc::driver::DriverError, + module_name: String, + }, +} + +impl From for crate::Error { + fn from(val: CudaError) -> Self { + crate::Error::Cuda(Box::new(val)).bt() + } +} + +pub trait WrapErr { + fn w(self) -> std::result::Result; +} + +impl> WrapErr for std::result::Result { + fn w(self) -> std::result::Result { + self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt()) + } +} diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend/mod.rs similarity index 75% rename from candle-core/src/cuda_backend.rs rename to candle-core/src/cuda_backend/mod.rs index b7756fa6..6fecf7c7 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -5,395 +5,41 @@ pub use candle_kernels as kernels; pub use cudarc; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ - CudaFunction, CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, - ValidAsZeroBits, + CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, }; use half::{bf16, f16}; -use std::sync::{Arc, Mutex}; -/// cudarc related errors -#[derive(thiserror::Error, Debug)] -pub enum CudaError { - #[error(transparent)] - Cuda(#[from] cudarc::driver::DriverError), +#[cfg(feature = "cudnn")] +pub mod cudnn; +mod device; +mod error; +mod utils; +pub use device::{CudaDevice, DeviceId}; +pub use error::{CudaError, WrapErr}; +pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S}; - #[error(transparent)] - Compiler(#[from] cudarc::nvrtc::CompileError), - - #[error(transparent)] - Cublas(#[from] cudarc::cublas::result::CublasError), - - #[error(transparent)] - Curand(#[from] cudarc::curand::result::CurandError), - - #[error("missing kernel '{module_name}'")] - MissingKernel { module_name: String }, - - #[error("unsupported dtype {dtype:?} for {op}")] - UnsupportedDtype { dtype: DType, op: &'static str }, - - #[error("internal error '{0}'")] - InternalError(&'static str), - - #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] - MatMulNonContiguous { - lhs_stride: Vec, - rhs_stride: Vec, - mnk: (usize, usize, usize), - }, - - #[error("{msg}, expected: {expected:?}, got: {got:?}")] - UnexpectedDType { - msg: &'static str, - expected: DType, - got: DType, - }, - - #[error("{cuda} when loading {module_name}")] - Load { - cuda: cudarc::driver::DriverError, - module_name: String, - }, +enum SlicePtrOrNull { + Ptr(CudaSlice), + Null, } -impl From for crate::Error { - fn from(val: CudaError) -> Self { - crate::Error::Cuda(Box::new(val)).bt() - } -} - -/// Unique identifier for cuda devices. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub struct DeviceId(usize); - -impl DeviceId { - fn new() -> Self { - // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805 - use std::sync::atomic; - static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1); - Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed)) - } -} - -struct CudaRng(cudarc::curand::CudaRng); -unsafe impl Send for CudaRng {} - -#[derive(Clone)] -pub struct CudaDevice { - id: DeviceId, - device: Arc, - blas: Arc, - curand: Arc>, -} - -impl std::fmt::Debug for CudaDevice { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "CudaDevice({:?})", self.id) - } -} - -impl std::ops::Deref for CudaDevice { - type Target = Arc; - - fn deref(&self) -> &Self::Target { - &self.device - } -} - -pub trait WrapErr { - fn w(self) -> std::result::Result; -} - -impl> WrapErr for std::result::Result { - fn w(self) -> std::result::Result { - self.map_err(|e| crate::Error::Cuda(Box::new(e.into()))) - } -} - -impl CudaDevice { - pub fn cuda_device(&self) -> Arc { - self.device.clone() - } - - pub fn id(&self) -> DeviceId { - self.id - } - - fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result { - let elem_count = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(elem_count as u32); - let slice = match dtype { - DType::U8 => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_u8", kernels::FILL)?; - let params = (&data, v as u8, elem_count); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::U8(data) - } - DType::U32 => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_u32", kernels::FILL)?; - let params = (&data, v as u32, elem_count); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::U32(data) - } - DType::I64 => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_i64", kernels::FILL)?; - let params = (&data, v as i64, elem_count); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::I64(data) - } - DType::BF16 => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_bf16", kernels::FILL)?; - let params = (&data, bf16::from_f64(v), elem_count); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::BF16(data) - } - DType::F16 => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_f16", kernels::FILL)?; - let params = (&data, f16::from_f64(v), elem_count); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::F16(data) - } - DType::F32 => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_f32", kernels::FILL)?; - let params = (&data, v as f32, elem_count); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::F32(data) - } - DType::F64 => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_f64", kernels::FILL)?; - let params = (&data, v, elem_count); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::F64(data) - } - }; - Ok(CudaStorage { - slice, - device: self.clone(), - }) - } - - pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result { - if !self.has_func(module_name, module_name) { - // Leaking the string here is a bit sad but we need a &'static str and this is only - // done once per kernel name. - let static_module_name = Box::leak(module_name.to_string().into_boxed_str()); - self.load_ptx(ptx.into(), module_name, &[static_module_name]) - .map_err(|cuda| CudaError::Load { - cuda, - module_name: module_name.to_string(), - }) - .w()?; - } - self.get_func(module_name, module_name) - // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is - // able to only build the error value if needed. - .ok_or(CudaError::MissingKernel { - module_name: module_name.to_string(), - }) - .w() - } -} - -impl BackendDevice for CudaDevice { - type Storage = CudaStorage; - - fn new(ordinal: usize) -> Result { - let device = cudarc::driver::CudaDevice::new(ordinal).w()?; - let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?; - let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?; - Ok(Self { - id: DeviceId::new(), - device, - blas: Arc::new(blas), - curand: Arc::new(Mutex::new(CudaRng(curand))), - }) - } - - fn set_seed(&self, seed: u64) -> Result<()> { - // We do not call set_seed but instead create a new curand object. This ensures that the - // state will be identical and the same random numbers will be generated. - let mut curand = self.curand.lock().unwrap(); - curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?; - Ok(()) - } - - fn location(&self) -> crate::DeviceLocation { - crate::DeviceLocation::Cuda { - gpu_id: self.device.ordinal(), +unsafe impl DeviceRepr for &SlicePtrOrNull { + fn as_kernel_param(&self) -> *mut std::ffi::c_void { + match self { + SlicePtrOrNull::Ptr(slice) => slice.as_kernel_param(), + SlicePtrOrNull::Null => 0usize.as_kernel_param(), } } +} - fn same_device(&self, rhs: &Self) -> bool { - self.id == rhs.id - } - - fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { - let elem_count = shape.elem_count(); - let slice = match dtype { - DType::U8 => { - let data = self.alloc_zeros::(elem_count).w()?; - CudaStorageSlice::U8(data) - } - DType::U32 => { - let data = self.alloc_zeros::(elem_count).w()?; - CudaStorageSlice::U32(data) - } - DType::I64 => { - let data = self.alloc_zeros::(elem_count).w()?; - CudaStorageSlice::I64(data) - } - DType::BF16 => { - let data = self.alloc_zeros::(elem_count).w()?; - CudaStorageSlice::BF16(data) - } - DType::F16 => { - let data = self.alloc_zeros::(elem_count).w()?; - CudaStorageSlice::F16(data) - } - DType::F32 => { - let data = self.alloc_zeros::(elem_count).w()?; - CudaStorageSlice::F32(data) - } - DType::F64 => { - let data = self.alloc_zeros::(elem_count).w()?; - CudaStorageSlice::F64(data) - } - }; - Ok(CudaStorage { - slice, - device: self.clone(), - }) - } - - fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result { - let elem_count = shape.elem_count(); - let curand = self.curand.lock().unwrap(); - let slice = match dtype { - // TODO: Add support for F16 and BF16 though this is likely to require some upstream - // cudarc changes. - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_uniform", - }) - .w()? - } - DType::F32 => { - let mut data = unsafe { self.alloc::(elem_count) }.w()?; - curand.0.fill_with_uniform(&mut data).w()?; - CudaStorageSlice::F32(data) - } - DType::F64 => { - let mut data = unsafe { self.alloc::(elem_count) }.w()?; - curand.0.fill_with_uniform(&mut data).w()?; - CudaStorageSlice::F64(data) - } - }; - let slice = if lo == 0. && up == 1.0 { - slice +impl SlicePtrOrNull { + fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result { + let ds = if l.is_contiguous() { + SlicePtrOrNull::Null } else { - let layout = Layout::contiguous(shape); - Affine(up - lo, lo).map(&slice, self, &layout)? + SlicePtrOrNull::Ptr(dev.htod_copy([l.dims(), l.stride()].concat()).w()?) }; - Ok(CudaStorage { - slice, - device: self.clone(), - }) - } - - fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result { - // TODO: Add support for F16 and BF16 though this is likely to require some upstream - // cudarc changes. - let elem_count = shape.elem_count(); - let curand = self.curand.lock().unwrap(); - // curand can only generate an odd number of values. - // https://github.com/huggingface/candle/issues/734 - let elem_count_round = if elem_count % 2 == 1 { - elem_count + 1 - } else { - elem_count - }; - let slice = match dtype { - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_normal", - }) - .w()? - } - DType::F32 => { - let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; - curand - .0 - .fill_with_normal(&mut data, mean as f32, std as f32) - .w()?; - CudaStorageSlice::F32(data) - } - DType::F64 => { - let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; - curand.0.fill_with_normal(&mut data, mean, std).w()?; - CudaStorageSlice::F64(data) - } - }; - Ok(CudaStorage { - slice, - device: self.clone(), - }) - } - - fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { - self.const_impl(1., shape, dtype) - } - - fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { - let slice = match storage { - CpuStorage::U8(storage) => { - let data = self.htod_sync_copy(storage).w()?; - CudaStorageSlice::U8(data) - } - CpuStorage::U32(storage) => { - let data = self.htod_sync_copy(storage).w()?; - CudaStorageSlice::U32(data) - } - CpuStorage::I64(storage) => { - let data = self.htod_sync_copy(storage).w()?; - CudaStorageSlice::I64(data) - } - CpuStorage::BF16(storage) => { - let data = self.htod_sync_copy(storage).w()?; - CudaStorageSlice::BF16(data) - } - CpuStorage::F16(storage) => { - let data = self.htod_sync_copy(storage).w()?; - CudaStorageSlice::F16(data) - } - CpuStorage::F32(storage) => { - let data = self.htod_sync_copy(storage).w()?; - CudaStorageSlice::F32(data) - } - CpuStorage::F64(storage) => { - let data = self.htod_sync_copy(storage).w()?; - CudaStorageSlice::F64(data) - } - }; - Ok(CudaStorage { - slice, - device: self.clone(), - }) + Ok(ds) } } @@ -407,133 +53,6 @@ pub enum CudaStorageSlice { F32(CudaSlice), F64(CudaSlice), } -type S = CudaStorageSlice; - -pub trait Map1 { - fn f( - &self, - src: &CudaSlice, - dev: &CudaDevice, - layout: &Layout, - ) -> Result>; - - fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { - let out = match s { - S::U8(s) => S::U8(self.f(s, d, l)?), - S::U32(s) => S::U32(self.f(s, d, l)?), - S::I64(s) => S::I64(self.f(s, d, l)?), - S::BF16(s) => S::BF16(self.f(s, d, l)?), - S::F16(s) => S::F16(self.f(s, d, l)?), - S::F32(s) => S::F32(self.f(s, d, l)?), - S::F64(s) => S::F64(self.f(s, d, l)?), - }; - Ok(out) - } -} - -pub trait Map2 { - fn f( - &self, - src1: &CudaSlice, - layout1: &Layout, - src2: &CudaSlice, - layout2: &Layout, - dev: &CudaDevice, - ) -> Result>; - - fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { - let out = match (s1, s2) { - (S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?), - (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?), - (S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?), - (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?), - (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), - (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), - (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), - _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, - }; - Ok(out) - } -} - -pub trait Map2InPlace { - fn f( - &self, - dst: &mut CudaSlice, - dst_shape: &Shape, - src: &CudaSlice, - src_l: &Layout, - dev: &CudaDevice, - ) -> Result<()>; - - fn map( - &self, - dst: &mut S, - dst_s: &Shape, - src: &S, - src_l: &Layout, - d: &CudaDevice, - ) -> Result<()> { - match (dst, src) { - (S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d), - (S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d), - (S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d), - (S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d), - (S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d), - (S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d), - (S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d), - _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, - } - } -} - -pub trait Map1Any { - fn f) -> S>( - &self, - src: &CudaSlice, - dev: &CudaDevice, - layout: &Layout, - wrap: W, - ) -> Result; - - fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { - let out = match s { - S::U8(s) => self.f(s, d, l, S::U8)?, - S::U32(s) => self.f(s, d, l, S::U32)?, - S::I64(s) => self.f(s, d, l, S::I64)?, - S::BF16(s) => self.f(s, d, l, S::BF16)?, - S::F16(s) => self.f(s, d, l, S::F16)?, - S::F32(s) => self.f(s, d, l, S::F32)?, - S::F64(s) => self.f(s, d, l, S::F64)?, - }; - Ok(out) - } -} - -pub trait Map2Any { - fn f( - &self, - src1: &CudaSlice, - layout1: &Layout, - src2: &CudaSlice, - layout2: &Layout, - dev: &CudaDevice, - ) -> Result; - - fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { - let out = match (s1, s2) { - (S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?, - (S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?, - (S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?, - (S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?, - (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, - (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, - (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, - _ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?, - }; - Ok(out) - } -} struct Clone; impl Map1 for Clone { @@ -564,7 +83,7 @@ impl Map1 for Affine { let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); - let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("affine"), kernels::AFFINE)?; // SAFETY: Set later by running the kernel. @@ -596,7 +115,7 @@ impl Map1 for Elu { let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); - let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("uelu"), kernels::UNARY)?; // SAFETY: Set later by running the kernel. @@ -719,7 +238,7 @@ impl Map1 for Powf { let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); - let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("upowf"), kernels::UNARY)?; // SAFETY: Set later by running the kernel. @@ -852,7 +371,7 @@ impl Map1 for U { let dims = shape.dims(); let el_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); - let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::UNARY)?; // SAFETY: Set later by running the kernel. @@ -1402,9 +921,14 @@ impl Map2 for U { let dims = shape.dims(); let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); - let dims_and_strides = dev - .htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) - .w()?; + let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() { + SlicePtrOrNull::Null + } else { + SlicePtrOrNull::Ptr( + dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) + .w()?, + ) + }; let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::BINARY)?; @@ -1431,9 +955,14 @@ impl Map2Any for Cmp { let dims = shape.dims(); let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); - let dims_and_strides = dev - .htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) - .w()?; + let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() { + SlicePtrOrNull::Null + } else { + SlicePtrOrNull::Ptr( + dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) + .w()?, + ) + }; let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let name = match self.0 { @@ -1541,26 +1070,30 @@ fn gemm_config( let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; // The a tensor has dims batching, k, n (rhs) - let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n { + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { (n as i32, cublasOperation_t::CUBLAS_OP_N) - } else if rhs_m1 == k && rhs_m2 == 1 { + } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) { (k as i32, cublasOperation_t::CUBLAS_OP_T) } else { Err(CudaError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), + lhs_stride: lhs_l.clone(), + rhs_stride: rhs_l.clone(), mnk: (m, n, k), })? }; // The b tensor has dims batching, m, k (lhs) - let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k { + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { (k as i32, cublasOperation_t::CUBLAS_OP_N) - } else if lhs_m1 == m && lhs_m2 == 1 { + } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) { (m as i32, cublasOperation_t::CUBLAS_OP_T) } else { Err(CudaError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), + lhs_stride: lhs_l.clone(), + rhs_stride: rhs_l.clone(), mnk: (m, n, k), })? }; @@ -1581,21 +1114,25 @@ fn gemm_config( let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] { [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, + [_, stride] if lhs_l.dims()[0] == 1 => stride, + [stride, _] if lhs_l.dims()[1] == 1 => stride, [stride] => stride, [] => m * k, _ => Err(CudaError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), + lhs_stride: lhs_l.clone(), + rhs_stride: rhs_l.clone(), mnk: (m, n, k), })?, }; let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] { [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, + [_, stride] if rhs_l.dims()[0] == 1 => stride, + [stride, _] if rhs_l.dims()[1] == 1 => stride, [stride] => stride, [] => n * k, _ => Err(CudaError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), + lhs_stride: lhs_l.clone(), + rhs_stride: rhs_l.clone(), mnk: (m, n, k), })?, }; @@ -1640,7 +1177,7 @@ impl BackendStorage for CudaStorage { let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); - let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let start_o = layout.start_offset(); // This returns an i64 rather than a &i64, this is useful to get around some temporary // lifetime issue and is safe as long as self.slice does not go out of scope before inp @@ -1844,7 +1381,10 @@ impl BackendStorage for CudaStorage { col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. - let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + let mut kernel_c = unsafe { + self.device() + .alloc_uninit(kernel_l.shape(), kernel.dtype())? + }; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? @@ -1852,7 +1392,7 @@ impl BackendStorage for CudaStorage { col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; - let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; res.copy_strided_src(&mut res_t, 0, &res_l)?; Ok(res_t) } @@ -1909,7 +1449,10 @@ impl BackendStorage for CudaStorage { col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. - let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + let mut kernel_c = unsafe { + self.device() + .alloc_uninit(kernel_l.shape(), kernel.dtype())? + }; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? @@ -1919,7 +1462,7 @@ impl BackendStorage for CudaStorage { let res_l = Layout::contiguous((b, h_out, w_out, n)) .transpose(1, 2)? .transpose(1, 3)?; - let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; res.copy_strided_src(&mut res_t, 0, &res_l)?; Ok(res_t) } @@ -2056,7 +1599,7 @@ impl BackendStorage for CudaStorage { dim: usize, ) -> Result { let device = self.device().clone(); - let mut acc = device.zeros_impl(l.shape(), self.dtype())?; + let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? }; self.copy_strided_src(&mut acc, 0, l)?; ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; Ok(acc) @@ -2071,7 +1614,7 @@ impl BackendStorage for CudaStorage { dim: usize, ) -> Result { let device = self.device().clone(); - let mut acc = device.zeros_impl(l.shape(), self.dtype())?; + let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? }; self.copy_strided_src(&mut acc, 0, l)?; IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; Ok(acc) @@ -2145,6 +1688,72 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } + fn copy2d( + &self, + dst: &mut Self, + d1: usize, + d2: usize, + src_s: usize, + dst_s: usize, + src_o: usize, + dst_o: usize, + ) -> Result<()> { + let dev = &self.device; + let d1 = d1 as u32; + let d2 = d2 as u32; + // Nothing to copy so we exit early to avoid launching a kernel and some potential invalid + // argument with a null pointer. + if d1 == 0 || d2 == 0 { + return Ok(()); + } + let dst_s = dst_s as u32; + let src_s = src_s as u32; + let (src, dst, kname) = match (&self.slice, &mut dst.slice) { + (S::U8(s), S::U8(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_u8", + ), + (S::U32(s), S::U32(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_u32", + ), + (S::I64(s), S::I64(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_i64", + ), + (S::BF16(s), S::BF16(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_bf16", + ), + (S::F16(s), S::F16(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_f16", + ), + (S::F32(s), S::F32(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_f32", + ), + (S::F64(s), S::F64(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_f64", + ), + _ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?, + }; + let func = dev.get_or_load_func(kname, kernels::FILL)?; + let cfg = LaunchConfig::for_num_elems(d1 * d2); + let params = (src, dst, d1, d2, src_s, dst_s); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(()) + } + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let src_shape = src_l.shape(); let dims = src_shape.dims(); @@ -2154,7 +1763,7 @@ impl BackendStorage for CudaStorage { } let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = &self.device; - let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?; + let ds = SlicePtrOrNull::params_from_layout(dev, src_l)?; match (&self.slice, &mut dst.slice) { (CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs new file mode 100644 index 00000000..8dd5be77 --- /dev/null +++ b/candle-core/src/cuda_backend/utils.rs @@ -0,0 +1,134 @@ +/// Helper functions to plug cuda kernels in candle. +use crate::{Layout, Result, Shape, WithDType}; +pub use cudarc; +use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits}; + +use super::{CudaDevice, CudaError, WrapErr}; + +pub type S = super::CudaStorageSlice; + +pub trait Map1 { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result>; + + fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { + let out = match s { + S::U8(s) => S::U8(self.f(s, d, l)?), + S::U32(s) => S::U32(self.f(s, d, l)?), + S::I64(s) => S::I64(self.f(s, d, l)?), + S::BF16(s) => S::BF16(self.f(s, d, l)?), + S::F16(s) => S::F16(self.f(s, d, l)?), + S::F32(s) => S::F32(self.f(s, d, l)?), + S::F64(s) => S::F64(self.f(s, d, l)?), + }; + Ok(out) + } +} + +pub trait Map2 { + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + dev: &CudaDevice, + ) -> Result>; + + fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { + let out = match (s1, s2) { + (S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?), + (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?), + (S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?), + (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?), + (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), + (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), + (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), + _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, + }; + Ok(out) + } +} + +pub trait Map2InPlace { + fn f( + &self, + dst: &mut CudaSlice, + dst_shape: &Shape, + src: &CudaSlice, + src_l: &Layout, + dev: &CudaDevice, + ) -> Result<()>; + + fn map( + &self, + dst: &mut S, + dst_s: &Shape, + src: &S, + src_l: &Layout, + d: &CudaDevice, + ) -> Result<()> { + match (dst, src) { + (S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d), + (S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d), + (S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d), + (S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d), + _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, + } + } +} + +pub trait Map1Any { + fn f) -> S>( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + wrap: W, + ) -> Result; + + fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { + let out = match s { + S::U8(s) => self.f(s, d, l, S::U8)?, + S::U32(s) => self.f(s, d, l, S::U32)?, + S::I64(s) => self.f(s, d, l, S::I64)?, + S::BF16(s) => self.f(s, d, l, S::BF16)?, + S::F16(s) => self.f(s, d, l, S::F16)?, + S::F32(s) => self.f(s, d, l, S::F32)?, + S::F64(s) => self.f(s, d, l, S::F64)?, + }; + Ok(out) + } +} + +pub trait Map2Any { + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + dev: &CudaDevice, + ) -> Result; + + fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { + let out = match (s1, s2) { + (S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, + _ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?, + }; + Ok(out) + } +} diff --git a/candle-core/src/custom_op.rs b/candle-core/src/custom_op.rs new file mode 100644 index 00000000..3a85dba9 --- /dev/null +++ b/candle-core/src/custom_op.rs @@ -0,0 +1,377 @@ +use crate::op::{BackpropOp, Op}; +use crate::tensor::from_storage; +use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor}; +use std::sync::Arc; + +/// Unary ops that can be defined in user-land. +pub trait CustomOp1 { + // Box does not support const yet, so use a function to get the name. + fn name(&self) -> &'static str; + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>; + + /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> { + Err(crate::Error::Cuda( + format!("no cuda implementation for {}", self.name()).into(), + )) + } + + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd( + &self, + _storage: &MetalStorage, + _layout: &Layout, + ) -> Result<(MetalStorage, Shape)> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } + + /// This function takes as argument the argument `arg` used in the forward pass, the result + /// produced by the forward operation `res` and the gradient of the result `grad_res`. + /// The function should return the gradient of the argument. + fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result> { + Err(crate::Error::BackwardNotSupported { op: self.name() }) + } +} + +pub trait CustomOp2 { + fn name(&self) -> &'static str; + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + ) -> Result<(CpuStorage, Shape)>; + + /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cuda_fwd( + &self, + _: &CudaStorage, + _: &Layout, + _: &CudaStorage, + _: &Layout, + ) -> Result<(CudaStorage, Shape)> { + Err(crate::Error::Cuda( + format!("no cuda implementation for {}", self.name()).into(), + )) + } + + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd( + &self, + _: &MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + ) -> Result<(MetalStorage, Shape)> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } + + fn bwd( + &self, + _arg1: &Tensor, + _arg2: &Tensor, + _res: &Tensor, + _grad_res: &Tensor, + ) -> Result<(Option, Option)> { + Err(crate::Error::BackwardNotSupported { op: self.name() }) + } +} + +pub trait CustomOp3 { + fn name(&self) -> &'static str; + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + s3: &CpuStorage, + l3: &Layout, + ) -> Result<(CpuStorage, Shape)>; + + /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cuda_fwd( + &self, + _: &CudaStorage, + _: &Layout, + _: &CudaStorage, + _: &Layout, + _: &CudaStorage, + _: &Layout, + ) -> Result<(CudaStorage, Shape)> { + Err(crate::Error::Cuda( + format!("no cuda implementation for {}", self.name()).into(), + )) + } + + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd( + &self, + _: &MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + ) -> Result<(MetalStorage, Shape)> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } + + fn bwd( + &self, + _arg1: &Tensor, + _arg2: &Tensor, + _arg3: &Tensor, + _res: &Tensor, + _grad_res: &Tensor, + ) -> Result<(Option, Option, Option)> { + Err(crate::Error::BackwardNotSupported { op: self.name() }) + } +} + +impl Tensor { + /// Applies a unary custom op without backward support + pub fn apply_op1_no_bwd(&self, c: &C) -> Result { + let (storage, shape) = self.storage().apply_op1(self.layout(), c)?; + Ok(from_storage(storage, shape, BackpropOp::none(), false)) + } + + /// Applies a binary custom op without backward support + pub fn apply_op2_no_bwd(&self, rhs: &Self, c: &C) -> Result { + let (storage, shape) = + self.storage() + .apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?; + Ok(from_storage(storage, shape, BackpropOp::none(), false)) + } + + /// Applies a ternary custom op without backward support + pub fn apply_op3_no_bwd(&self, t2: &Self, t3: &Self, c: &C) -> Result { + let (storage, shape) = self.storage().apply_op3( + self.layout(), + &t2.storage(), + t2.layout(), + &t3.storage(), + t3.layout(), + c, + )?; + Ok(from_storage(storage, shape, BackpropOp::none(), false)) + } + + /// Applies a unary custom op. + pub fn apply_op1_arc(&self, c: Arc>) -> Result { + let (storage, shape) = self + .storage() + .apply_op1(self.layout(), c.as_ref().as_ref())?; + let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone())); + Ok(from_storage(storage, shape, op, false)) + } + + pub fn apply_op1(&self, c: C) -> Result { + self.apply_op1_arc(Arc::new(Box::new(c))) + } + + /// Applies a binary custom op. + pub fn apply_op2_arc( + &self, + rhs: &Self, + c: Arc>, + ) -> Result { + let (storage, shape) = self.storage().apply_op2( + self.layout(), + &rhs.storage(), + rhs.layout(), + c.as_ref().as_ref(), + )?; + let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone())); + Ok(from_storage(storage, shape, op, false)) + } + + pub fn apply_op2(&self, r: &Self, c: C) -> Result { + self.apply_op2_arc(r, Arc::new(Box::new(c))) + } + + /// Applies a ternary custom op. + pub fn apply_op3_arc( + &self, + t2: &Self, + t3: &Self, + c: Arc>, + ) -> Result { + let (storage, shape) = self.storage().apply_op3( + self.layout(), + &t2.storage(), + t2.layout(), + &t3.storage(), + t3.layout(), + c.as_ref().as_ref(), + )?; + let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| { + Op::CustomOp3(t1, t2, t3, c.clone()) + }); + Ok(from_storage(storage, shape, op, false)) + } + + pub fn apply_op3( + &self, + t2: &Self, + t3: &Self, + c: C, + ) -> Result { + self.apply_op3_arc(t2, t3, Arc::new(Box::new(c))) + } +} + +// In place ops. + +/// Unary ops that can be defined in user-land. +/// These ops work in place and as such back-prop is unsupported. +pub trait InplaceOp1 { + // Box does not support const yet, so use a function to get the name. + fn name(&self) -> &'static str; + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>; + + /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> { + Err(crate::Error::Cuda( + format!("no cuda implementation for {}", self.name()).into(), + )) + } + + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } +} + +pub trait InplaceOp2 { + fn name(&self) -> &'static str; + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout) + -> Result<()>; + + /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> { + Err(crate::Error::Cuda( + format!("no cuda implementation for {}", self.name()).into(), + )) + } + + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd( + &self, + _: &mut MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + ) -> Result<()> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } +} + +pub trait InplaceOp3 { + fn name(&self) -> &'static str; + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + s1: &mut CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + s3: &CpuStorage, + l3: &Layout, + ) -> Result<()>; + + /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cuda_fwd( + &self, + _: &mut CudaStorage, + _: &Layout, + _: &CudaStorage, + _: &Layout, + _: &CudaStorage, + _: &Layout, + ) -> Result<()> { + Err(crate::Error::Cuda( + format!("no cuda implementation for {}", self.name()).into(), + )) + } + + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd( + &self, + _: &mut MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + ) -> Result<()> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } +} + +impl Tensor { + /// Applies a unary custom op in place. + pub fn inplace_op1(&self, c: &C) -> Result<()> { + self.storage_mut().inplace_op1(self.layout(), c) + } + + /// Applies a unary custom op in place (for the first tensor). + pub fn inplace_op2(&self, rhs: &Self, c: &C) -> Result<()> { + self.storage_mut() + .inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c) + } + + /// Applies a ternary custom op in place (for the first tensor). + pub fn inplace_op3(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> { + self.storage_mut().inplace_op3( + self.layout(), + &t2.storage(), + t2.layout(), + &t3.storage(), + t3.layout(), + c, + ) + } +} diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 1e33021b..846c62ce 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -289,17 +289,34 @@ impl Device { } } + pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result { + match self { + Device::Cpu => { + let storage = CpuDevice.alloc_uninit(shape, dtype)?; + Ok(Storage::Cpu(storage)) + } + Device::Cuda(device) => { + let storage = device.alloc_uninit(shape, dtype)?; + Ok(Storage::Cuda(storage)) + } + Device::Metal(device) => { + let storage = device.alloc_uninit(shape, dtype)?; + Ok(Storage::Metal(storage)) + } + } + } + pub(crate) fn storage(&self, array: A) -> Result { match self { Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())), Device::Cuda(device) => { let storage = array.to_cpu_storage(); - let storage = device.storage_from_cpu_storage(&storage)?; + let storage = device.storage_from_cpu_storage_owned(storage)?; Ok(Storage::Cuda(storage)) } Device::Metal(device) => { let storage = array.to_cpu_storage(); - let storage = device.storage_from_cpu_storage(&storage)?; + let storage = device.storage_from_cpu_storage_owned(storage)?; Ok(Storage::Metal(storage)) } } @@ -310,12 +327,12 @@ impl Device { Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))), Device::Cuda(device) => { let storage = S::to_cpu_storage_owned(data); - let storage = device.storage_from_cpu_storage(&storage)?; + let storage = device.storage_from_cpu_storage_owned(storage)?; Ok(Storage::Cuda(storage)) } Device::Metal(device) => { let storage = S::to_cpu_storage_owned(data); - let storage = device.storage_from_cpu_storage(&storage)?; + let storage = device.storage_from_cpu_storage_owned(storage)?; Ok(Storage::Metal(storage)) } } diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 4f5a390e..7e6e3cf8 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -65,12 +65,13 @@ impl std::fmt::Debug for Tensor { } /// Options for Tensor pretty printing +#[derive(Debug, Clone)] pub struct PrinterOptions { - precision: usize, - threshold: usize, - edge_items: usize, - line_width: usize, - sci_mode: Option, + pub precision: usize, + pub threshold: usize, + pub edge_items: usize, + pub line_width: usize, + pub sci_mode: Option, } static PRINT_OPTS: std::sync::Mutex = @@ -89,6 +90,10 @@ impl PrinterOptions { } } +pub fn print_options() -> &'static std::sync::Mutex { + &PRINT_OPTS +} + pub fn set_print_options(options: PrinterOptions) { *PRINT_OPTS.lock().unwrap() = options } @@ -117,6 +122,26 @@ pub fn set_print_options_full() { } } +pub fn set_line_width(line_width: usize) { + PRINT_OPTS.lock().unwrap().line_width = line_width +} + +pub fn set_precision(precision: usize) { + PRINT_OPTS.lock().unwrap().precision = precision +} + +pub fn set_edge_items(edge_items: usize) { + PRINT_OPTS.lock().unwrap().edge_items = edge_items +} + +pub fn set_threshold(threshold: usize) { + PRINT_OPTS.lock().unwrap().threshold = threshold +} + +pub fn set_sci_mode(sci_mode: Option) { + PRINT_OPTS.lock().unwrap().sci_mode = sci_mode +} + struct FmtSize { current_size: usize, } diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 94ca57d8..1a698a35 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -23,7 +23,15 @@ pub enum DType { } #[derive(Debug, PartialEq, Eq)] -pub struct DTypeParseError; +pub struct DTypeParseError(String); + +impl std::fmt::Display for DTypeParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "cannot parse '{}' as a dtype", self.0) + } +} + +impl std::error::Error for DTypeParseError {} impl std::str::FromStr for DType { type Err = DTypeParseError; @@ -36,7 +44,7 @@ impl std::str::FromStr for DType { "f16" => Ok(Self::F16), "f32" => Ok(Self::F32), "f64" => Ok(Self::F64), - _ => Err(DTypeParseError), + _ => Err(DTypeParseError(s.to_string())), } } } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 34c5d97f..5348233c 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -154,6 +154,19 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + fn copy2d( + &self, + _: &mut Self, + _: usize, + _: usize, + _: usize, + _: usize, + _: usize, + _: usize, + ) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { Err(Error::NotCompiledWithCudaSupport) } @@ -197,10 +210,18 @@ impl crate::backend::BackendDevice for CudaDevice { Err(Error::NotCompiledWithCudaSupport) } + unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result { Err(Error::NotCompiledWithCudaSupport) } + fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index e9d92331..322f81d2 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -166,6 +166,19 @@ impl crate::backend::BackendStorage for MetalStorage { Err(Error::NotCompiledWithMetalSupport) } + fn copy2d( + &self, + _: &mut Self, + _: usize, + _: usize, + _: usize, + _: usize, + _: usize, + _: usize, + ) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { Err(Error::NotCompiledWithMetalSupport) } @@ -209,10 +222,18 @@ impl crate::backend::BackendDevice for MetalDevice { Err(Error::NotCompiledWithMetalSupport) } + unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result { Err(Error::NotCompiledWithMetalSupport) } + fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { Err(Error::NotCompiledWithMetalSupport) } diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index bf346cf2..e6824b29 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -70,7 +70,7 @@ impl Layout { self.shape.is_fortran_contiguous(&self.stride) } - pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result { + pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result { let dims = self.shape().dims(); if dim >= dims.len() { Err(Error::DimOutOfRange { @@ -99,7 +99,7 @@ impl Layout { }) } - pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result { + pub fn transpose(&self, dim1: usize, dim2: usize) -> Result { let rank = self.shape.rank(); if rank <= dim1 || rank <= dim2 { Err(Error::UnexpectedNumberOfDims { @@ -120,7 +120,7 @@ impl Layout { }) } - pub(crate) fn permute(&self, idxs: &[usize]) -> Result { + pub fn permute(&self, idxs: &[usize]) -> Result { let is_permutation = idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i)); if !is_permutation { diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index fcc17afc..1f57ca9b 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -14,7 +14,7 @@ //! //! ## Features //! -//! - Simple syntax (looks and like PyTorch) +//! - Simple syntax (looks and feels like PyTorch) //! - CPU and Cuda backends (and M1 support) //! - Enable serverless (CPU) small and fast deployments //! - Model training @@ -37,14 +37,13 @@ mod accelerate; pub mod backend; pub mod backprop; -mod conv; +pub mod conv; mod convert; pub mod cpu; pub mod cpu_backend; #[cfg(feature = "cuda")] pub mod cuda_backend; -#[cfg(feature = "cudnn")] -pub mod cudnn; +mod custom_op; mod device; pub mod display; mod dtype; @@ -58,7 +57,7 @@ pub mod metal_backend; #[cfg(feature = "mkl")] mod mkl; pub mod npy; -mod op; +pub mod op; pub mod pickle; pub mod quantized; pub mod safetensors; @@ -67,17 +66,21 @@ pub mod shape; mod storage; mod strided_index; mod tensor; +mod tensor_cat; pub mod test_utils; pub mod utils; mod variable; +#[cfg(feature = "cudnn")] +pub use cuda_backend::cudnn; + pub use cpu_backend::CpuStorage; +pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3}; pub use device::{Device, DeviceLocation, NdArray}; -pub use dtype::{DType, FloatDType, IntDType, WithDType}; +pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType}; pub use error::{Error, Result}; pub use indexer::IndexOp; pub use layout::Layout; -pub use op::{CustomOp1, CustomOp2, CustomOp3}; pub use shape::{Shape, D}; pub use storage::Storage; pub use strided_index::{StridedBlocks, StridedIndex}; diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs new file mode 100644 index 00000000..fdeca13f --- /dev/null +++ b/candle-core/src/metal_backend/device.rs @@ -0,0 +1,287 @@ +use crate::{DType, Result}; +use candle_metal_kernels::Kernels; +use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; +use std::collections::HashMap; +use std::ffi::c_void; +use std::path::Path; +use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard}; + +use super::MetalError; + +/// Unique identifier for cuda devices. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct DeviceId(usize); + +impl DeviceId { + pub(crate) fn new() -> Self { + // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805 + use std::sync::atomic; + static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1); + Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed)) + } +} + +type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec>>; +type AllocatedBuffers = Arc>; + +#[derive(Clone)] +pub struct MetalDevice { + /// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than + /// the device itself. + pub(crate) id: DeviceId, + + /// Raw metal device: + pub(crate) device: metal::Device, + + /// Single command queue for the entire device. + pub(crate) command_queue: CommandQueue, + /// One command buffer at a time. + /// The scheduler works by allowing multiple + /// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) + /// on a single command buffer. Using a single command buffer would be fastest on the GPU but + /// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed + /// to start to work). + /// Despite what the documentation says, command buffers are NOT ordered. They are ordered + /// for their START time, but there's no guarantee that command buffer1 will finish before + /// command buffer2 starts (or there are metal bugs there) + pub(crate) command_buffer: Arc>, + /// Keeps track of the current amount of compute command encoders on the current + /// command buffer + /// Arc, RwLock because of the interior mutability. + pub(crate) command_buffer_index: Arc>, + /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) + pub(crate) compute_per_buffer: usize, + /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. + /// Heavily used by [`candle_metal_kernels`] + pub(crate) kernels: Arc, + /// Simple allocator struct. + /// The buffers are stored in size buckets since ML tends to use similar shapes over and over. + /// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting + /// (could be linked to FFI communication overhead). + /// + /// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the + /// graph calculation, and only we the allocator kept a reference to it, therefore it's free + /// to be reused. However, in order for this to work, we need to guarantee the order of + /// operation, so that this buffer is not being used by another kernel at the same time. + /// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things. + /// + /// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers + /// (strong_count = 1). + pub(crate) buffers: AllocatedBuffers, + /// Seed for random number generation. + pub(crate) seed: Arc>, +} + +impl std::fmt::Debug for MetalDevice { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MetalDevice({:?})", self.id) + } +} + +impl std::ops::Deref for MetalDevice { + type Target = metal::DeviceRef; + + fn deref(&self) -> &Self::Target { + &self.device + } +} + +impl MetalDevice { + pub fn id(&self) -> DeviceId { + self.id + } + + pub fn metal_device(&self) -> &metal::Device { + &self.device + } + + pub fn command_queue(&self) -> &CommandQueue { + &self.command_queue + } + + pub fn command_buffer(&self) -> Result { + let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?; + let mut command_buffer = command_buffer_lock.to_owned(); + let mut index = self + .command_buffer_index + .try_write() + .map_err(MetalError::from)?; + if *index > self.compute_per_buffer { + command_buffer.commit(); + command_buffer = self.command_queue.new_command_buffer().to_owned(); + *command_buffer_lock = command_buffer.clone(); + *index = 0; + + self.drop_unused_buffers()?; + } + *index += 1; + Ok(command_buffer) + } + + pub fn wait_until_completed(&self) -> Result<()> { + let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?; + match command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled + | metal::MTLCommandBufferStatus::Completed => { + panic!("Already committed"); + } + _ => {} + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + *command_buffer = self.command_queue.new_command_buffer().to_owned(); + + Ok(()) + } + + pub fn kernels(&self) -> &Kernels { + &self.kernels + } + + pub fn device(&self) -> &metal::Device { + &self.device + } + + /// Creates a new buffer (not necessarily zeroed). + /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// This means the buffer data cannot be read on the CPU directly. + /// + /// [`name`] is only used to keep track of the resource origin in case of bugs + pub fn new_buffer( + &self, + element_count: usize, + dtype: DType, + name: &str, + ) -> Result> { + let size = (element_count * dtype.size_in_bytes()) as NSUInteger; + self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name) + } + + /// Creates a new buffer (not necessarily zeroed). + /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// This means the buffer can be read on the CPU but will require manual + /// synchronization when the CPU memory is modified + /// Used as a bridge to gather data back from the GPU + pub fn new_buffer_managed(&self, size: NSUInteger) -> Result> { + self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") + } + + /// Creates a new buffer from data. + /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// + /// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes) + /// allocates the buffer and copies over the existing data before returning the MTLBuffer. + pub fn new_buffer_with_data(&self, data: &[T]) -> Result> { + let size = core::mem::size_of_val(data) as NSUInteger; + let new_buffer = self.device.new_buffer_with_data( + data.as_ptr() as *const c_void, + size, + MTLResourceOptions::StorageModeManaged, + ); + let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + let subbuffers = buffers + .entry((size, MTLResourceOptions::StorageModeManaged)) + .or_insert(vec![]); + + let new_buffer = Arc::new(new_buffer); + subbuffers.push(new_buffer.clone()); + Ok(new_buffer) + } + + pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result> { + let buffer = self.allocate_buffer( + size_in_bytes as NSUInteger, + MTLResourceOptions::StorageModePrivate, + "allocate_zeros", + )?; + let command_buffer = self.command_buffer()?; + command_buffer.set_label("zeros"); + let blit = command_buffer.new_blit_command_encoder(); + blit.fill_buffer( + &buffer, + metal::NSRange { + location: 0, + length: buffer.length(), + }, + 0, + ); + blit.end_encoding(); + Ok(buffer) + } + + fn find_available_buffer( + &self, + size: NSUInteger, + option: MTLResourceOptions, + buffers: &RwLockWriteGuard, + ) -> Option> { + let mut best_buffer: Option<&Arc> = None; + let mut best_buffer_size: NSUInteger = NSUInteger::MAX; + for ((buffer_size, buffer_option), subbuffers) in buffers.iter() { + if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option { + for sub in subbuffers { + if Arc::strong_count(sub) == 1 { + best_buffer = Some(sub); + best_buffer_size = *buffer_size; + } + } + } + } + best_buffer.cloned() + } + + fn drop_unused_buffers(&self) -> Result<()> { + let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + for subbuffers in buffers.values_mut() { + let newbuffers = subbuffers + .iter() + .filter(|s| Arc::strong_count(*s) > 1) + .map(Arc::clone) + .collect(); + *subbuffers = newbuffers; + } + Ok(()) + } + + /// The critical allocator algorithm + fn allocate_buffer( + &self, + size: NSUInteger, + option: MTLResourceOptions, + _name: &str, + ) -> Result> { + let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + if let Some(b) = self.find_available_buffer(size, option, &buffers) { + // Cloning also ensures we increment the strong count + return Ok(b.clone()); + } + + let size = buf_size(size); + let subbuffers = buffers.entry((size, option)).or_insert(vec![]); + + let new_buffer = self.device.new_buffer(size as NSUInteger, option); + let new_buffer = Arc::new(new_buffer); + subbuffers.push(new_buffer.clone()); + + Ok(new_buffer) + } + + /// Create a metal GPU capture trace on [`path`]. + pub fn capture>(&self, path: P) -> Result<()> { + let capture = metal::CaptureManager::shared(); + let descriptor = metal::CaptureDescriptor::new(); + descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); + descriptor.set_capture_device(self); + descriptor.set_output_url(path); + + capture + .start_capture(&descriptor) + .map_err(MetalError::from)?; + Ok(()) + } +} + +fn buf_size(size: NSUInteger) -> NSUInteger { + (size - 1).next_power_of_two() as NSUInteger +} diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend/mod.rs similarity index 76% rename from candle-core/src/metal_backend.rs rename to candle-core/src/metal_backend/mod.rs index abd647af..a25f7b7c 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -2,15 +2,21 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; -use candle_metal_kernels; -use candle_metal_kernels::Kernels; -use metal; -use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; +use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels}; +use metal::{Buffer, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; use std::ffi::c_void; -use std::path::Path; -use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard, TryLockError}; +use std::sync::{Arc, Mutex, RwLock, TryLockError}; +mod device; +pub use device::{DeviceId, MetalDevice}; + +fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> { + BufferOffset { + buffer, + offset_in_bytes: l.start_offset() * dtype.size_in_bytes(), + } +} /// Simple way to catch lock error without /// depending on T #[derive(thiserror::Error, Debug)] @@ -37,13 +43,6 @@ pub enum MetalError { Message(String), #[error(transparent)] KernelError(#[from] candle_metal_kernels::MetalKernelError), - - #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] - MatMulNonContiguous { - lhs_stride: Vec, - rhs_stride: Vec, - mnk: (usize, usize, usize), - }, #[error("{0:?}")] LockError(LockError), #[error("{msg}, expected: {expected:?}, got: {got:?}")] @@ -60,263 +59,6 @@ impl From for MetalError { } } -type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec>>; -type AllocatedBuffers = Arc>; - -#[derive(Clone)] -pub struct MetalDevice { - /// Raw metal device: - device: metal::Device, - - /// Single command queue for the entire device. - command_queue: CommandQueue, - /// One command buffer at a time. - /// The scheduler works by allowing multiple - /// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) - /// on a single command buffer. Using a single command buffer would be fastest on the GPU but - /// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed - /// to start to work). - /// Despite what the documentation says, command buffers are NOT ordered. They are ordered - /// for their START time, but there's no guarantee that command buffer1 will finish before - /// command buffer2 starts (or there are metal bugs there) - command_buffer: Arc>, - /// Keeps track of the current amount of compute command encoders on the current - /// command buffer - /// Arc, RwLock because of the interior mutability. - command_buffer_index: Arc>, - /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) - compute_per_buffer: usize, - /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. - /// Heavily used by [`candle_metal_kernels`] - kernels: Arc, - /// Simple allocator struct. - /// The buffers are stored in size buckets since ML tends to use similar shapes over and over. - /// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting - /// (could be linked to FFI communication overhead). - /// - /// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the - /// graph calculation, and only we the allocator kept a reference to it, therefore it's free - /// to be reused. However, in order for this to work, we need to guarantee the order of - /// operation, so that this buffer is not being used by another kernel at the same time. - /// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things. - /// - /// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers - /// (strong_count = 1). - buffers: AllocatedBuffers, - /// Seed for random number generation. - seed: Arc>, -} - -impl std::fmt::Debug for MetalDevice { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "MetalDevice({:?})", self.device.registry_id()) - } -} - -impl std::ops::Deref for MetalDevice { - type Target = metal::DeviceRef; - - fn deref(&self) -> &Self::Target { - &self.device - } -} - -impl MetalDevice { - pub fn id(&self) -> NSUInteger { - self.registry_id() - } - - pub fn metal_device(&self) -> &metal::Device { - &self.device - } - - pub fn command_queue(&self) -> &CommandQueue { - &self.command_queue - } - - pub fn command_buffer(&self) -> Result { - let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?; - let mut command_buffer = command_buffer_lock.to_owned(); - let mut index = self - .command_buffer_index - .try_write() - .map_err(MetalError::from)?; - if *index > self.compute_per_buffer { - command_buffer.commit(); - command_buffer = self.command_queue.new_command_buffer().to_owned(); - *command_buffer_lock = command_buffer.clone(); - *index = 0; - - self.drop_unused_buffers()?; - } - *index += 1; - Ok(command_buffer) - } - - pub fn wait_until_completed(&self) -> Result<()> { - let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?; - match command_buffer.status() { - metal::MTLCommandBufferStatus::Committed - | metal::MTLCommandBufferStatus::Scheduled - | metal::MTLCommandBufferStatus::Completed => { - panic!("Already committed"); - } - _ => {} - } - command_buffer.commit(); - command_buffer.wait_until_completed(); - *command_buffer = self.command_queue.new_command_buffer().to_owned(); - - Ok(()) - } - - pub fn kernels(&self) -> &Kernels { - &self.kernels - } - - pub fn device(&self) -> &metal::Device { - &self.device - } - - /// Creates a new buffer (not necessarily zeroed). - /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode) - /// This means the buffer data cannot be read on the CPU directly. - /// - /// [`name`] is only used to keep track of the resource origin in case of bugs - pub fn new_buffer( - &self, - element_count: usize, - dtype: DType, - name: &str, - ) -> Result> { - let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name) - } - - /// Creates a new buffer (not necessarily zeroed). - /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) - /// This means the buffer can be read on the CPU but will require manual - /// synchronization when the CPU memory is modified - /// Used as a bridge to gather data back from the GPU - pub fn new_buffer_managed(&self, size: NSUInteger) -> Result> { - self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") - } - - /// Creates a new buffer from data. - /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) - /// - /// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes) - /// allocates the buffer and copies over the existing data before returning the MTLBuffer. - pub fn new_buffer_with_data(&self, data: &[T]) -> Result> { - let size = core::mem::size_of_val(data) as NSUInteger; - let new_buffer = self.device.new_buffer_with_data( - data.as_ptr() as *const c_void, - size, - MTLResourceOptions::StorageModeManaged, - ); - let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; - let subbuffers = buffers - .entry((size, MTLResourceOptions::StorageModeManaged)) - .or_insert(vec![]); - - let new_buffer = Arc::new(new_buffer); - subbuffers.push(new_buffer.clone()); - Ok(new_buffer) - } - - pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result> { - let buffer = self.allocate_buffer( - size_in_bytes as NSUInteger, - MTLResourceOptions::StorageModePrivate, - "allocate_zeros", - )?; - let command_buffer = self.command_buffer()?; - command_buffer.set_label("zeros"); - let blit = command_buffer.new_blit_command_encoder(); - blit.fill_buffer( - &buffer, - metal::NSRange { - location: 0, - length: buffer.length(), - }, - 0, - ); - blit.end_encoding(); - Ok(buffer) - } - - fn find_available_buffer( - &self, - size: NSUInteger, - option: MTLResourceOptions, - buffers: &RwLockWriteGuard, - ) -> Option> { - let mut best_buffer: Option<&Arc> = None; - let mut best_buffer_size: NSUInteger = NSUInteger::MAX; - for ((buffer_size, buffer_option), subbuffers) in buffers.iter() { - if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option { - for sub in subbuffers { - if Arc::strong_count(sub) == 1 { - best_buffer = Some(sub); - best_buffer_size = *buffer_size; - } - } - } - } - return best_buffer.map(|b| b.clone()); - } - - fn drop_unused_buffers(&self) -> Result<()> { - let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; - for subbuffers in buffers.values_mut() { - let newbuffers = subbuffers - .iter() - .filter(|s| Arc::strong_count(*s) > 1) - .map(Arc::clone) - .collect(); - *subbuffers = newbuffers; - } - Ok(()) - } - - /// The critical allocator algorithm - fn allocate_buffer( - &self, - size: NSUInteger, - option: MTLResourceOptions, - _name: &str, - ) -> Result> { - let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; - if let Some(b) = self.find_available_buffer(size, option, &buffers) { - // Cloning also ensures we increment the strong count - return Ok(b.clone()); - } - - let size = buf_size(size); - let subbuffers = buffers.entry((size, option)).or_insert(vec![]); - - let new_buffer = self.device.new_buffer(size as NSUInteger, option); - let new_buffer = Arc::new(new_buffer); - subbuffers.push(new_buffer.clone()); - - Ok(new_buffer) - } - - /// Create a metal GPU capture trace on [`path`]. - pub fn capture>(&self, path: P) -> Result<()> { - let capture = metal::CaptureManager::shared(); - let descriptor = metal::CaptureDescriptor::new(); - descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); - descriptor.set_capture_device(self); - descriptor.set_output_url(path); - - capture - .start_capture(&descriptor) - .map_err(MetalError::from)?; - Ok(()) - } -} - #[derive(Debug, Clone)] pub struct MetalStorage { /// The actual buffer containing the data. @@ -365,7 +107,8 @@ impl BackendStorage for MetalStorage { let buffer = device.new_buffer(el, self.dtype, "affine")?; let command_buffer = self.device.command_buffer()?; - if layout.is_contiguous() && layout.start_offset() == 0 { + let src = buffer_o(&self.buffer, layout, dtype); + if layout.is_contiguous() { let name = match self.dtype { DType::F32 => "affine_f32", DType::F16 => "affine_f16", @@ -378,7 +121,7 @@ impl BackendStorage for MetalStorage { &device.kernels, name, el, - &self.buffer, + src, &buffer, mul as f32, add as f32, @@ -397,9 +140,8 @@ impl BackendStorage for MetalStorage { &device.kernels, name, layout.dims(), - &self.buffer, + src, layout.stride(), - layout.start_offset() * dtype.size_in_bytes(), &buffer, mul as f32, add as f32, @@ -418,10 +160,12 @@ impl BackendStorage for MetalStorage { let buffer = device.new_buffer(el, self.dtype, "powf")?; let command_buffer = self.device.command_buffer()?; - if layout.is_contiguous() && layout.start_offset() == 0 { + let src = buffer_o(&self.buffer, layout, dtype); + if layout.is_contiguous() { let name = match self.dtype { DType::F32 => "powf_f32", DType::F16 => "powf_f16", + DType::BF16 => "powf_bf16", dtype => crate::bail!("Metal contiguous powf {dtype:?} not implemented"), }; candle_metal_kernels::call_powf( @@ -430,7 +174,7 @@ impl BackendStorage for MetalStorage { &device.kernels, name, el, - &self.buffer, + src, &buffer, pow as f32, ) @@ -439,6 +183,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "powf_f32_strided", DType::F16 => "powf_f16_strided", + DType::BF16 => "powf_bf16_strided", dtype => crate::bail!("Metal strided powf {dtype:?} not implemented"), }; candle_metal_kernels::call_powf_strided( @@ -447,9 +192,8 @@ impl BackendStorage for MetalStorage { &device.kernels, name, layout.dims(), - &self.buffer, + src, layout.stride(), - layout.start_offset() * dtype.size_in_bytes(), &buffer, pow as f32, ) @@ -467,10 +211,12 @@ impl BackendStorage for MetalStorage { let buffer = device.new_buffer(el, self.dtype, "elu")?; let command_buffer = self.device.command_buffer()?; - if layout.is_contiguous() && layout.start_offset() == 0 { + let src = buffer_o(&self.buffer, layout, self.dtype); + if layout.is_contiguous() { let name = match self.dtype { DType::F32 => "elu_f32", DType::F16 => "elu_f16", + DType::BF16 => "elu_bf16", dtype => crate::bail!("Metal contiguous elu {dtype:?} not implemented"), }; candle_metal_kernels::call_elu( @@ -479,7 +225,7 @@ impl BackendStorage for MetalStorage { &device.kernels, name, el, - &self.buffer, + src, &buffer, alpha as f32, ) @@ -488,6 +234,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "elu_f32_strided", DType::F16 => "elu_f16_strided", + DType::BF16 => "elu_bf16_strided", dtype => crate::bail!("Metal strided elu {dtype:?} not implemented"), }; candle_metal_kernels::call_elu_strided( @@ -496,9 +243,8 @@ impl BackendStorage for MetalStorage { &device.kernels, name, layout.dims(), - &self.buffer, + src, layout.stride(), - layout.start_offset() * dtype.size_in_bytes(), &buffer, alpha as f32, ) @@ -568,6 +314,7 @@ impl BackendStorage for MetalStorage { let dtype = if return_index { DType::U32 } else { self.dtype }; let buffer = device.new_buffer(dst_el, dtype, "reduce")?; let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&self.buffer, layout, self.dtype); candle_metal_kernels::call_reduce_strided( &device.device, &command_buffer, @@ -576,8 +323,7 @@ impl BackendStorage for MetalStorage { &dims, &stride, dst_el, - &self.buffer, - layout.start_offset() * self.dtype.size_in_bytes(), + src, &buffer, ) .map_err(MetalError::from)?; @@ -603,30 +349,44 @@ impl BackendStorage for MetalStorage { let el_count = shape.elem_count(); let buffer = device.new_buffer(el_count, dtype, "todtype")?; let command_buffer = device.command_buffer()?; - if layout.is_contiguous() && layout.start_offset() == 0 { + let src = buffer_o(&self.buffer, layout, self.dtype); + if layout.is_contiguous() { let kernel_name = match (self.dtype, dtype) { - (DType::U32, DType::F32) => "cast_u32_f32", - (DType::U32, DType::U8) => "cast_u32_u8", - (DType::U32, DType::I64) => "cast_u32_i64", (DType::U32, DType::BF16) => "cast_u32_bf16", + (DType::U32, DType::F16) => "cast_u32_f16", + (DType::U32, DType::F32) => "cast_u32_f32", + (DType::U32, DType::I64) => "cast_u32_i64", + (DType::U32, DType::U8) => "cast_u32_u8", - (DType::U8, DType::U32) => "cast_u8_u32", + (DType::U8, DType::BF16) => "cast_u8_bf16", + (DType::U8, DType::F16) => "cast_u8_f16", (DType::U8, DType::F32) => "cast_u8_f32", (DType::U8, DType::I64) => "cast_u8_i64", - (DType::U8, DType::BF16) => "cast_u8_bf16", + (DType::U8, DType::U32) => "cast_u8_u32", - (DType::F32, DType::F16) => "cast_f32_f16", (DType::F32, DType::BF16) => "cast_f32_bf16", + (DType::F32, DType::F16) => "cast_f32_f16", + (DType::F32, DType::I64) => "cast_f32_i64", + (DType::F32, DType::U32) => "cast_f32_u32", + (DType::F32, DType::U8) => "cast_f32_u8", + (DType::I64, DType::BF16) => "cast_i64_bf16", + (DType::I64, DType::F16) => "cast_i64_f16", (DType::I64, DType::F32) => "cast_i64_f32", + (DType::I64, DType::U32) => "cast_i64_u32", + (DType::I64, DType::U8) => "cast_i64_u8", (DType::F16, DType::BF16) => "cast_f16_bf16", (DType::F16, DType::F32) => "cast_f16_f32", + (DType::F16, DType::I64) => "cast_f16_i64", + (DType::F16, DType::U32) => "cast_f16_u32", + (DType::F16, DType::U8) => "cast_f16_u8", - (DType::BF16, DType::U8) => "cast_bf16_u8", - (DType::BF16, DType::U32) => "cast_bf16_u32", (DType::BF16, DType::F16) => "cast_bf16_f16", (DType::BF16, DType::F32) => "cast_bf16_f32", + (DType::BF16, DType::I64) => "cast_bf16_i64", + (DType::BF16, DType::U32) => "cast_bf16_u32", + (DType::BF16, DType::U8) => "cast_bf16_u8", (left, right) => { crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented") @@ -638,8 +398,7 @@ impl BackendStorage for MetalStorage { &device.kernels, kernel_name, el_count, - &self.buffer, - layout.start_offset() * self.dtype.size_in_bytes(), + src, &buffer, ) .map_err(MetalError::from)?; @@ -666,9 +425,8 @@ impl BackendStorage for MetalStorage { &device.kernels, kernel_name, layout.dims(), - &self.buffer, + src, layout.stride(), - layout.start_offset() * self.dtype.size_in_bytes(), &buffer, ) .map_err(MetalError::from)?; @@ -685,46 +443,69 @@ impl BackendStorage for MetalStorage { let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?; let command_buffer = device.command_buffer()?; command_buffer.set_label(B::KERNEL); - if layout.is_contiguous() && layout.start_offset() == 0 { + let src = buffer_o(&self.buffer, layout, self.dtype); + if layout.is_contiguous() { use candle_metal_kernels::unary::contiguous; let kernel_name = match (B::KERNEL, dtype) { - ("ucos", DType::F32) => contiguous::cos::FLOAT, - ("usin", DType::F32) => contiguous::sin::FLOAT, - ("usqr", DType::F32) => contiguous::sqr::FLOAT, - ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, - ("uneg", DType::F32) => contiguous::neg::FLOAT, - ("uexp", DType::F32) => contiguous::exp::FLOAT, - ("ulog", DType::F32) => contiguous::log::FLOAT, - ("ugelu", DType::F32) => contiguous::gelu::FLOAT, - ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, - ("uerf", DType::F32) => contiguous::erf::FLOAT, - ("usilu", DType::F32) => contiguous::silu::FLOAT, - ("uabs", DType::F32) => contiguous::abs::FLOAT, - ("uceil", DType::F32) => contiguous::ceil::FLOAT, - ("ufloor", DType::F32) => contiguous::floor::FLOAT, - ("uround", DType::F32) => contiguous::round::FLOAT, - ("urecip", DType::F32) => contiguous::recip::FLOAT, - ("utanh", DType::F32) => contiguous::tanh::FLOAT, - ("urelu", DType::F32) => contiguous::relu::FLOAT, - ("ucos", DType::F16) => contiguous::cos::HALF, - ("usin", DType::F16) => contiguous::sin::HALF, - ("usqr", DType::F16) => contiguous::sqr::HALF, - ("usqrt", DType::F16) => contiguous::sqrt::HALF, - ("uneg", DType::F16) => contiguous::neg::HALF, - ("uexp", DType::F16) => contiguous::exp::HALF, - ("ulog", DType::F16) => contiguous::log::HALF, - ("ugelu", DType::F16) => contiguous::gelu::HALF, - ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, - ("uerf", DType::F16) => contiguous::erf::HALF, - ("usilu", DType::F16) => contiguous::silu::HALF, ("uabs", DType::F16) => contiguous::abs::HALF, + ("uabs", DType::F32) => contiguous::abs::FLOAT, + ("uabs", DType::BF16) => contiguous::abs::BFLOAT, ("uceil", DType::F16) => contiguous::ceil::HALF, + ("uceil", DType::F32) => contiguous::ceil::FLOAT, + ("uceil", DType::BF16) => contiguous::ceil::BFLOAT, + ("ucos", DType::F16) => contiguous::cos::HALF, + ("ucos", DType::F32) => contiguous::cos::FLOAT, + ("ucos", DType::BF16) => contiguous::cos::BFLOAT, + ("uerf", DType::F16) => contiguous::erf::HALF, + ("uerf", DType::F32) => contiguous::erf::FLOAT, + ("uerf", DType::BF16) => contiguous::erf::BFLOAT, + ("uexp", DType::F16) => contiguous::exp::HALF, + ("uexp", DType::F32) => contiguous::exp::FLOAT, + ("uexp", DType::BF16) => contiguous::exp::BFLOAT, ("ufloor", DType::F16) => contiguous::floor::HALF, - ("uround", DType::F16) => contiguous::round::HALF, + ("ufloor", DType::F32) => contiguous::floor::FLOAT, + ("ufloor", DType::BF16) => contiguous::floor::BFLOAT, + ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, + ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, + ("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT, + ("ugelu", DType::F16) => contiguous::gelu::HALF, + ("ugelu", DType::F32) => contiguous::gelu::FLOAT, + ("ugelu", DType::BF16) => contiguous::gelu::BFLOAT, + ("ulog", DType::F16) => contiguous::log::HALF, + ("ulog", DType::F32) => contiguous::log::FLOAT, + ("ulog", DType::BF16) => contiguous::log::BFLOAT, + ("uneg", DType::F16) => contiguous::neg::HALF, + ("uneg", DType::F32) => contiguous::neg::FLOAT, + ("uneg", DType::BF16) => contiguous::neg::BFLOAT, ("urecip", DType::F16) => contiguous::recip::HALF, - ("utanh", DType::F16) => contiguous::tanh::HALF, + ("urecip", DType::F32) => contiguous::recip::FLOAT, + ("urecip", DType::BF16) => contiguous::recip::BFLOAT, ("urelu", DType::F16) => contiguous::relu::HALF, + ("urelu", DType::F32) => contiguous::relu::FLOAT, + ("urelu", DType::BF16) => contiguous::relu::BFLOAT, + ("uround", DType::F16) => contiguous::round::HALF, + ("uround", DType::F32) => contiguous::round::FLOAT, + ("uround", DType::BF16) => contiguous::round::BFLOAT, + ("usilu", DType::F16) => contiguous::silu::HALF, + ("usilu", DType::F32) => contiguous::silu::FLOAT, + ("usilu", DType::BF16) => contiguous::silu::BFLOAT, + ("usin", DType::F16) => contiguous::sin::HALF, + ("usin", DType::F32) => contiguous::sin::FLOAT, + ("usin", DType::BF16) => contiguous::sin::BFLOAT, + ("usqr", DType::F16) => contiguous::sqr::HALF, + ("usqr", DType::F32) => contiguous::sqr::FLOAT, + ("usqr", DType::BF16) => contiguous::sqr::BFLOAT, + ("usqrt", DType::F16) => contiguous::sqrt::HALF, + ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, + ("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT, + ("utanh", DType::F16) => contiguous::tanh::HALF, + ("utanh", DType::F32) => contiguous::tanh::FLOAT, + ("utanh", DType::BF16) => contiguous::tanh::BFLOAT, + ("usign", DType::F16) => contiguous::sign::HALF, + ("usign", DType::F32) => contiguous::sign::FLOAT, + ("usign", DType::BF16) => contiguous::sign::BFLOAT, + ("usign", DType::I64) => contiguous::sign::I64, (name, dtype) => { crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") } @@ -735,7 +516,7 @@ impl BackendStorage for MetalStorage { &device.kernels, kernel_name, el_count, - &self.buffer, + src, &buffer, ) .map_err(MetalError::from)?; @@ -780,17 +561,16 @@ impl BackendStorage for MetalStorage { crate::bail!("Metal strided unary {name} {dtype:?} not implemented") } }; + let dst = BufferOffset::zero_offset(&buffer); candle_metal_kernels::call_unary_strided( &device.device, &command_buffer, &device.kernels, kernel_name, layout.dims(), - &self.buffer, + src, layout.stride(), - layout.start_offset() * self.dtype.size_in_bytes(), - &buffer, - 0, + dst, ) .map_err(MetalError::from)?; } @@ -837,21 +617,21 @@ impl BackendStorage for MetalStorage { (DType::U8, DType::U8) => "where_u8_u8", (left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"), }; + let src = buffer_o(&self.buffer, layout, self.dtype); + let t = buffer_o(&t.buffer, t_l, t.dtype); + let f = buffer_o(&f.buffer, f_l, f.dtype); candle_metal_kernels::call_where_cond_strided( &device.device, &command_buffer, &device.kernels, name, dims, - &self.buffer, - ( - layout.stride(), - layout.start_offset() * self.dtype.size_in_bytes(), - ), - &t.buffer, - (t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), - &f.buffer, - (f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), + src, + layout.stride(), + t, + t_l.stride(), + f, + f_l.stride(), &buffer, ) .map_err(MetalError::from)?; @@ -884,6 +664,7 @@ impl BackendStorage for MetalStorage { DType::F32 => "im2col1d_f32", dtype => crate::bail!("Metal conv1d {dtype:?} not implemented"), }; + let src = buffer_o(&self.buffer, layout, self.dtype); candle_metal_kernels::call_im2col1d_strided( &self.device.device, &command_buffer, @@ -892,8 +673,7 @@ impl BackendStorage for MetalStorage { layout.shape().dims(), strides, (k_size, stride, padding, dilation), - &self.buffer, - layout.start_offset() * self.dtype.size_in_bytes(), + src, &dst, ) .map_err(MetalError::from)?; @@ -931,12 +711,50 @@ impl BackendStorage for MetalStorage { fn conv_transpose1d( &self, - _l: &Layout, - _kernel: &Self, - _kernel_l: &Layout, - _params: &ParamsConvTranspose1D, + layout: &Layout, + k: &Self, + k_layout: &Layout, + params: &ParamsConvTranspose1D, ) -> Result { - crate::bail!("Metal conv_transpose1d not implemented") + let l_out = params.l_out(); + let dst_el = params.c_out * l_out * params.b_size; + let buffer = self + .device + .new_buffer(dst_el, self.dtype, "conv_transpose1d")?; + + let command_buffer = self.device.command_buffer()?; + let name = match self.dtype { + DType::F32 => "conv_transpose1d_f32", + DType::F16 => "conv_transpose1d_f16", + DType::BF16 => "conv_transpose1d_bf16", + DType::U32 => "conv_transpose1d_u32", + DType::U8 => "conv_transpose1d_u8", + dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"), + }; + candle_metal_kernels::call_conv_transpose1d( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + params.dilation, + params.stride, + params.padding, + params.output_padding, + params.c_out, + l_out, + params.b_size, + layout.dims(), + layout.stride(), + k_layout.dims(), + k_layout.stride(), + &self.buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &k.buffer, + k_layout.start_offset() * k.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) } fn conv2d( @@ -967,8 +785,13 @@ impl BackendStorage for MetalStorage { let command_buffer = self.device.command_buffer()?; let name = match self.dtype { DType::F32 => "im2col_f32", + DType::F16 => "im2col_f16", + DType::BF16 => "im2col_bf16", + DType::U8 => "im2col_u8", + DType::U32 => "im2col_u32", dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"), }; + let src = buffer_o(&self.buffer, layout, self.dtype); candle_metal_kernels::call_im2col_strided( &self.device.device, &command_buffer, @@ -977,8 +800,7 @@ impl BackendStorage for MetalStorage { layout.shape().dims(), layout.stride(), (h_k, w_k, stride, padding, dilation), - &self.buffer, - layout.start_offset() * self.dtype.size_in_bytes(), + src, &dst, ) .map_err(MetalError::from)?; @@ -1019,20 +841,150 @@ impl BackendStorage for MetalStorage { fn conv_transpose2d( &self, - _l: &Layout, - _kernel: &Self, - _kernel_l: &Layout, - _params: &ParamsConvTranspose2D, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &ParamsConvTranspose2D, ) -> Result { - crate::bail!("Metal conv_tranpose2d not implemented") + // Kernel shape: (c_in_k, c_out, h_k, w_k) + // Input shape: (b_size, c_in, h_in, w_in) + let (out_w, out_h) = (params.out_w(), params.out_h()); + let dst_el = params.c_out * out_w * out_h * params.b_size; + + let dims = l.dims(); + if dims.len() != 4 { + crate::bail!("unexpected input shape for conv_transpose2d {dims:?}, expected 4") + } + + let k_dims = kernel_l.dims(); + if k_dims.len() != 4 { + crate::bail!("unexpected kernel shape for conv_transpose2d {k_dims:?}, expected 4") + } + + let buffer = self + .device + .new_buffer(dst_el, self.dtype, "conv_transpose2d")?; + + let command_buffer = self.device.command_buffer()?; + + let name = match self.dtype { + DType::F32 => "conv_transpose2d_f32", + DType::F16 => "conv_transpose2d_f16", + DType::BF16 => "conv_transpose2d_bf16", + dtype => crate::bail!("Metal conv_transpose2d {dtype:?} not implemented"), + }; + + candle_metal_kernels::call_conv_transpose2d( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + CallConvTranspose2dCfg { + dilation: params.dilation, + stride: params.stride, + padding: params.padding, + output_padding: params.output_padding, + c_out: params.c_out, + out_h, + out_w, + b_size: params.b_size, + input_dims: l.dims(), + input_stride: l.stride(), + kernel_dims: kernel_l.dims(), + kernel_stride: kernel_l.stride(), + input_offset: l.start_offset() * self.dtype.size_in_bytes(), + kernel_offset: kernel_l.start_offset() * kernel.dtype.size_in_bytes(), + }, + &self.buffer, + &kernel.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) } - fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { - crate::bail!("Metal avg_pool2d not implemented") + fn avg_pool2d( + &self, + inp_l: &Layout, + (w_k, h_k): (usize, usize), + (w_stride, h_stride): (usize, usize), + ) -> Result { + let shape = inp_l.shape(); + let (b_size, channels, width, height) = shape.dims4()?; + let strides = inp_l.stride(); + let name = match self.dtype { + DType::F32 => "avg_pool2d_f32", + DType::F16 => "avg_pool2d_f16", + DType::BF16 => "avg_pool2d_bf16", + DType::U8 => "avg_pool2d_u8", + DType::U32 => "avg_pool2d_u32", + dtype => crate::bail!("Metal avg_pool2d {dtype:?} not implemented"), + }; + let out_w = (width - w_k) / w_stride + 1; + let out_h = (height - h_k) / h_stride + 1; + let dst_el = out_w * out_h * b_size * channels; + let buffer = self.device.new_buffer(dst_el, self.dtype, "avg_pool2d")?; + let command_buffers = self.device.command_buffer()?; + candle_metal_kernels::call_pool2d( + &self.device.device, + &command_buffers, + &self.device.kernels, + name, + inp_l.dims(), + strides, + out_w, + out_h, + w_k, + h_k, + w_stride, + h_stride, + &self.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) } - fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { - crate::bail!("Metal max_pool2d not implemented") + fn max_pool2d( + &self, + inp_l: &Layout, + (w_k, h_k): (usize, usize), + (w_stride, h_stride): (usize, usize), + ) -> Result { + let shape = inp_l.shape(); + let (b_size, channels, width, height) = shape.dims4()?; + let strides = inp_l.stride(); + let name = match self.dtype { + DType::F32 => "max_pool2d_f32", + DType::F16 => "max_pool2d_f16", + DType::BF16 => "max_pool2d_bf16", + DType::U8 => "max_pool2d_u8", + DType::U32 => "max_pool2d_u32", + dtype => crate::bail!("Metal max_pool2d {dtype:?} not implemented"), + }; + let out_w = (width - w_k) / w_stride + 1; + let out_h = (height - h_k) / h_stride + 1; + let dst_el = out_w * out_h * b_size * channels; + let buffer = self.device.new_buffer(dst_el, self.dtype, "max_pool2d")?; + let command_buffers = self.device.command_buffer()?; + candle_metal_kernels::call_pool2d( + &self.device.device, + &command_buffers, + &self.device.kernels, + name, + inp_l.dims(), + strides, + out_w, + out_h, + w_k, + h_k, + w_stride, + h_stride, + &self.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) } fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result { @@ -1049,6 +1001,10 @@ impl BackendStorage for MetalStorage { } let name = match self.dtype { DType::F32 => "upsample_nearest2d_f32", + DType::F16 => "upsample_nearest2d_f16", + DType::BF16 => "upsample_nearest2d_bf16", + DType::U8 => "upsample_nearest2d_u8", + DType::U32 => "upsample_nearest2d_u32", dtype => crate::bail!("Metal upsample_nearest2d {dtype:?} not implemented"), }; @@ -1057,6 +1013,7 @@ impl BackendStorage for MetalStorage { .device .new_buffer(dst_el, self.dtype, "upsample_nearest2d")?; let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&self.buffer, inp_l, self.dtype); candle_metal_kernels::call_upsample_nearest_2d( &self.device.device, &command_buffer, @@ -1066,8 +1023,7 @@ impl BackendStorage for MetalStorage { strides, out_w, out_h, - &self.buffer, - inp_l.start_offset() * self.dtype.size_in_bytes(), + src, &buffer, ) .map_err(MetalError::from)?; @@ -1075,9 +1031,8 @@ impl BackendStorage for MetalStorage { } fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result { - let (ids_o1, _) = match ids_l.contiguous_offsets() { - Some(o12) => o12, - None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?, + if !ids_l.is_contiguous() { + return Err(crate::Error::RequiresContiguous { op: "gather" }.bt()); }; let ids_el = ids_l.dims()[dim]; let dst_el = ids_l.shape().elem_count(); @@ -1087,9 +1042,12 @@ impl BackendStorage for MetalStorage { let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "gather_u32_f32", (DType::U32, DType::F16) => "gather_u32_f16", + (DType::U32, DType::BF16) => "gather_u32_bf16", (left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"), }; let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&self.buffer, src_l, dtype); + let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); candle_metal_kernels::call_gather( &device.device, &command_buffer, @@ -1098,10 +1056,8 @@ impl BackendStorage for MetalStorage { src_l.dims(), ids_el, dim, - &self.buffer, - src_l.start_offset() * dtype.size_in_bytes(), - &ids.buffer, - ids_o1 * ids.dtype.size_in_bytes(), + src, + ids, &buffer, ) .map_err(MetalError::from)?; @@ -1119,16 +1075,19 @@ impl BackendStorage for MetalStorage { ) -> Result { let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; self.copy_strided_src(&mut acc, 0, l)?; - let (ids_offset, _) = match ids_l.contiguous_offsets() { - Some(o12) => o12, - None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, - }; - let src_offset = match src_l.contiguous_offsets() { - Some((o1, _)) => o1, - None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, + if !ids_l.is_contiguous() || !src_l.is_contiguous() { + return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt()); }; let name = match (ids.dtype, self.dtype) { + (DType::U8, DType::F32) => "sa_u8_f32", + (DType::U8, DType::F16) => "sa_u8_f16", + (DType::U8, DType::BF16) => "sa_u8_bf16", (DType::U32, DType::F32) => "sa_u32_f32", + (DType::U32, DType::F16) => "sa_u32_f16", + (DType::U32, DType::BF16) => "sa_u32_bf16", + (DType::I64, DType::F32) => "sa_i64_f32", + (DType::I64, DType::F16) => "sa_i64_f16", + (DType::I64, DType::BF16) => "sa_i64_bf16", _ => Err(MetalError::UnexpectedDType { msg: "scatter-add ids should be u8/u32/i64", expected: DType::U32, @@ -1136,6 +1095,8 @@ impl BackendStorage for MetalStorage { })?, }; let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&src.buffer, src_l, src.dtype); + let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); candle_metal_kernels::call_scatter_add( &self.device.device, &command_buffer, @@ -1144,10 +1105,8 @@ impl BackendStorage for MetalStorage { src_l.dims(), l.dims(), dim, - &src.buffer, - src_offset * src.dtype.size_in_bytes(), - &ids.buffer, - ids_offset * ids.dtype.size_in_bytes(), + src, + ids, &acc.buffer, ) .map_err(MetalError::from)?; @@ -1155,12 +1114,8 @@ impl BackendStorage for MetalStorage { } fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { - if !(src_l.is_contiguous() - && src_l.start_offset() == 0 - && ids_l.is_contiguous() - && ids_l.start_offset() == 0) - { - crate::bail!("Metal strided index_select not implemented"); + if !ids_l.is_contiguous() { + crate::bail!("Metal index_select requires contiguous ids") } let left_size: usize = src_l.dims()[..dim].iter().product(); let right_size: usize = src_l.dims()[dim + 1..].iter().product(); @@ -1171,16 +1126,24 @@ impl BackendStorage for MetalStorage { let buffer = device.new_buffer(dst_el, dtype, "index_select")?; let name = match (ids.dtype, self.dtype) { (DType::U8, DType::BF16) => "is_u8_bf16", + (DType::U8, DType::F32) => "is_u8_f32", + (DType::U8, DType::F16) => "is_u8_f16", (DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F16) => "is_u32_f16", (DType::U32, DType::BF16) => "is_u32_bf16", + (DType::I64, DType::F32) => "is_i64_f32", + (DType::I64, DType::F16) => "is_i64_f16", + (DType::I64, DType::BF16) => "is_i64_bf16", + (left, right) => { crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented") } }; let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&self.buffer, src_l, dtype); + let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); candle_metal_kernels::call_index_select( &device.device, &command_buffer, @@ -1189,8 +1152,11 @@ impl BackendStorage for MetalStorage { src_l.dims(), ids_el, dim, - &self.buffer, - &ids.buffer, + src_l.is_contiguous(), + src_l.dims(), + src_l.stride(), + src, + ids, &buffer, ) .map_err(MetalError::from)?; @@ -1208,23 +1174,40 @@ impl BackendStorage for MetalStorage { ) -> Result { let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; self.copy_strided_src(&mut acc, 0, l)?; - let (ids_offset, _) = match ids_l.contiguous_offsets() { - Some(o12) => o12, - None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, - }; - let src_offset = match src_l.contiguous_offsets() { - Some((o1, _)) => o1, - None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, + if !ids_l.is_contiguous() || !src_l.is_contiguous() { + return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt()); }; let name = match (ids.dtype, self.dtype) { + (DType::I64, DType::BF16) => "ia_i64_bf16", + (DType::I64, DType::F16) => "ia_i64_f16", + (DType::I64, DType::F32) => "ia_i64_f32", + (DType::I64, DType::I64) => "ia_i64_i64", + (DType::I64, DType::U32) => "ia_i64_u32", + (DType::I64, DType::U8) => "ia_i64_u8", + + (DType::U32, DType::BF16) => "ia_u32_bf16", + (DType::U32, DType::F16) => "ia_u32_f16", (DType::U32, DType::F32) => "ia_u32_f32", + (DType::U32, DType::I64) => "ia_u32_i64", + (DType::U32, DType::U32) => "ia_u32_u32", + (DType::U32, DType::U8) => "ia_u32_u8", + + (DType::U8, DType::BF16) => "ia_u8_bf16", + (DType::U8, DType::F16) => "ia_u8_f16", + (DType::U8, DType::F32) => "ia_u8_f32", + (DType::U8, DType::I64) => "ia_u8_i64", + (DType::U8, DType::U32) => "ia_u8_u32", + (DType::U8, DType::U8) => "ia_u8_u8", + _ => Err(MetalError::UnexpectedDType { - msg: "index-add ids should be u32", + msg: "index-add ids should be u8/u32/i64", expected: DType::U32, got: ids.dtype(), })?, }; let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&src.buffer, src_l, src.dtype); + let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); candle_metal_kernels::call_index_add( &self.device.device, &command_buffer, @@ -1234,10 +1217,8 @@ impl BackendStorage for MetalStorage { l.dims(), ids_l.dims(), dim, - &src.buffer, - src_offset * src.dtype.size_in_bytes(), - &ids.buffer, - ids_offset * ids.dtype.size_in_bytes(), + src, + ids, &acc.buffer, ) .map_err(MetalError::from)?; @@ -1285,6 +1266,67 @@ impl BackendStorage for MetalStorage { )) } + fn copy2d( + &self, + dst: &mut Self, + d1: usize, + d2: usize, + src_s: usize, + dst_s: usize, + src_o: usize, + dst_o: usize, + ) -> Result<()> { + if self.dtype() != dst.dtype() { + crate::bail!( + "copy2d with inconsistent dtypes {:?} {:?}", + self.dtype(), + dst.dtype() + ) + } + let command_buffer = self.device.command_buffer()?; + if src_s == d2 && dst_s == d2 { + command_buffer.set_label("copy2d_contiguous"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("copy2d_contiguous"); + let src_offset = (src_o * self.dtype.size_in_bytes()) as NSUInteger; + let length = (d1 * d2 * self.dtype.size_in_bytes()) as NSUInteger; + let dst_offset = (dst_o * dst.dtype().size_in_bytes()) as NSUInteger; + blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length); + blit.end_encoding(); + } else { + let el_count = d1 * d2; + if el_count == 0 { + return Ok(()); + } + let kernel_name = match self.dtype { + DType::F32 => candle_metal_kernels::copy2d::FLOAT, + DType::F16 => candle_metal_kernels::copy2d::HALF, + DType::BF16 => candle_metal_kernels::copy2d::BFLOAT, + DType::I64 => candle_metal_kernels::copy2d::I64, + DType::U32 => candle_metal_kernels::copy2d::U32, + DType::U8 => candle_metal_kernels::copy2d::U8, + dtype => crate::bail!("Metal copy2d {dtype:?} not implemented"), + }; + candle_metal_kernels::call_copy2d( + &self.device.device, + &command_buffer, + &self.device.kernels, + kernel_name, + &self.buffer, + &dst.buffer, + d1, + d2, + src_s, + dst_s, + src_o * self.dtype.size_in_bytes(), + dst_o * self.dtype.size_in_bytes(), + ) + .map_err(MetalError::from)?; + command_buffer.set_label("copy2d"); + } + Ok(()) + } + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let command_buffer = self.device.command_buffer()?; if src_l.is_contiguous() && self.dtype == dst.dtype() { @@ -1311,17 +1353,20 @@ impl BackendStorage for MetalStorage { DType::U8 => candle_metal_kernels::unary::strided::copy::U8, dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"), }; + let src = buffer_o(&self.buffer, src_l, self.dtype); + let dst = BufferOffset { + buffer: &dst.buffer, + offset_in_bytes: dst_offset * dst.dtype.size_in_bytes(), + }; candle_metal_kernels::call_unary_strided( &self.device.device, &command_buffer, &self.device.kernels, kernel_name, src_l.dims(), - &self.buffer, + src, src_l.stride(), - src_l.start_offset() * self.dtype.size_in_bytes(), - &dst.buffer, - dst_offset * dst.dtype.size_in_bytes(), + dst, ) .map_err(MetalError::from)?; command_buffer.set_label("copy_strided"); @@ -1355,10 +1400,9 @@ impl MetalStorage { let shape = lhs_l.shape(); let el_count = shape.elem_count(); let command_buffer = device.command_buffer()?; - let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) - && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) - && &op[..1] != "b" - { + let lhs = buffer_o(&self.buffer, lhs_l, self.dtype); + let rhs = buffer_o(&rhs.buffer, rhs_l, rhs.dtype); + let (buffer, dtype) = if lhs_l.is_contiguous() && rhs_l.is_contiguous() && &op[..1] != "b" { use candle_metal_kernels::binary::contiguous; let (kernel_name, dtype) = match (op, self.dtype) { @@ -1439,8 +1483,8 @@ impl MetalStorage { &device.kernels, kernel_name, el_count, - &self.buffer, - &rhs.buffer, + lhs, + rhs, &buffer, ) .map_err(MetalError::from)?; @@ -1538,12 +1582,10 @@ impl MetalStorage { &device.kernels, kernel_name, lhs_l.dims(), - &self.buffer, + lhs, lhs_l.stride(), - lhs_l.start_offset() * self.dtype.size_in_bytes(), - &rhs.buffer, + rhs, rhs_l.stride(), - rhs_l.start_offset() * rhs.dtype.size_in_bytes(), &buffer, ) .map_err(MetalError::from)?; @@ -1592,6 +1634,7 @@ impl BackendDevice for MetalDevice { MTLResourceOptions::StorageModeManaged, ))); Ok(Self { + id: DeviceId::new(), device, command_queue, command_buffer, @@ -1610,7 +1653,17 @@ impl BackendDevice for MetalDevice { } fn same_device(&self, rhs: &Self) -> bool { - self.device.registry_id() == rhs.device.registry_id() + self.id == rhs.id + } + + unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result { + let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-uninit")?; + Ok(MetalStorage::new( + buffer, + self.clone(), + shape.elem_count(), + dtype, + )) } fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { @@ -1648,6 +1701,10 @@ impl BackendDevice for MetalDevice { )) } + fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result { + self.storage_from_cpu_storage(&storage) + } + fn rand_uniform( &self, shape: &Shape, @@ -1728,7 +1785,7 @@ impl BackendDevice for MetalDevice { let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?; let contents = seed_buffer.contents(); unsafe { - std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4); + std::ptr::copy([seed].as_ptr(), contents as *mut u32, 1); } seed_buffer.did_modify_range(metal::NSRange::new(0, 4)); @@ -1736,10 +1793,6 @@ impl BackendDevice for MetalDevice { } } -fn buf_size(size: NSUInteger) -> NSUInteger { - (size - 1).next_power_of_two() as NSUInteger -} - fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { let ptr = buffer.contents() as *const T; assert!(!ptr.is_null()); diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 022b4fc3..49ba44be 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -1,5 +1,5 @@ #![allow(clippy::redundant_closure_call)] -use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor}; +use crate::Tensor; use half::{bf16, f16}; use num_traits::float::Float; @@ -66,6 +66,7 @@ pub enum UnaryOp { Floor, Ceil, Round, + Sign, } #[derive(Clone)] @@ -161,168 +162,23 @@ pub enum Op { Permute(Tensor, Vec), Elu(Tensor, f64), Powf(Tensor, f64), - CustomOp1(Tensor, std::sync::Arc>), + CustomOp1( + Tensor, + std::sync::Arc>, + ), CustomOp2( Tensor, Tensor, - std::sync::Arc>, + std::sync::Arc>, ), CustomOp3( Tensor, Tensor, Tensor, - std::sync::Arc>, + std::sync::Arc>, ), } -/// Unary ops that can be defined in user-land. -pub trait CustomOp1 { - // Box does not support const yet, so use a function to get the name. - fn name(&self) -> &'static str; - - /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, - /// offsets etc so the associated layout should be used to access it. - fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>; - - /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, - /// offsets etc so the associated layout should be used to access it. - fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> { - Err(crate::Error::Cuda( - format!("no cuda implementation for {}", self.name()).into(), - )) - } - - /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, - /// offsets etc so the associated layout should be used to access it. - fn metal_fwd( - &self, - _storage: &MetalStorage, - _layout: &Layout, - ) -> Result<(MetalStorage, Shape)> { - Err(crate::Error::Metal( - format!("no metal implementation for {}", self.name()).into(), - )) - } - - /// This function takes as argument the argument `arg` used in the forward pass, the result - /// produced by the forward operation `res` and the gradient of the result `grad_res`. - /// The function should return the gradient of the argument. - fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result> { - Err(crate::Error::BackwardNotSupported { op: self.name() }) - } -} - -pub trait CustomOp2 { - fn name(&self) -> &'static str; - - /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, - /// offsets etc so the associated layout should be used to access it. - fn cpu_fwd( - &self, - s1: &CpuStorage, - l1: &Layout, - s2: &CpuStorage, - l2: &Layout, - ) -> Result<(CpuStorage, Shape)>; - - /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, - /// offsets etc so the associated layout should be used to access it. - fn cuda_fwd( - &self, - _: &CudaStorage, - _: &Layout, - _: &CudaStorage, - _: &Layout, - ) -> Result<(CudaStorage, Shape)> { - Err(crate::Error::Cuda( - format!("no cuda implementation for {}", self.name()).into(), - )) - } - - /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, - /// offsets etc so the associated layout should be used to access it. - fn metal_fwd( - &self, - _: &MetalStorage, - _: &Layout, - _: &MetalStorage, - _: &Layout, - ) -> Result<(MetalStorage, Shape)> { - Err(crate::Error::Metal( - format!("no metal implementation for {}", self.name()).into(), - )) - } - - fn bwd( - &self, - _arg1: &Tensor, - _arg2: &Tensor, - _res: &Tensor, - _grad_res: &Tensor, - ) -> Result<(Option, Option)> { - Err(crate::Error::BackwardNotSupported { op: self.name() }) - } -} - -pub trait CustomOp3 { - fn name(&self) -> &'static str; - - /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, - /// offsets etc so the associated layout should be used to access it. - fn cpu_fwd( - &self, - s1: &CpuStorage, - l1: &Layout, - s2: &CpuStorage, - l2: &Layout, - s3: &CpuStorage, - l3: &Layout, - ) -> Result<(CpuStorage, Shape)>; - - /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, - /// offsets etc so the associated layout should be used to access it. - fn cuda_fwd( - &self, - _: &CudaStorage, - _: &Layout, - _: &CudaStorage, - _: &Layout, - _: &CudaStorage, - _: &Layout, - ) -> Result<(CudaStorage, Shape)> { - Err(crate::Error::Cuda( - format!("no cuda implementation for {}", self.name()).into(), - )) - } - - /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, - /// offsets etc so the associated layout should be used to access it. - fn metal_fwd( - &self, - _: &MetalStorage, - _: &Layout, - _: &MetalStorage, - _: &Layout, - _: &MetalStorage, - _: &Layout, - ) -> Result<(MetalStorage, Shape)> { - Err(crate::Error::Metal( - format!("no metal implementation for {}", self.name()).into(), - )) - } - - fn bwd( - &self, - _arg1: &Tensor, - _arg2: &Tensor, - _arg3: &Tensor, - _res: &Tensor, - _grad_res: &Tensor, - ) -> Result<(Option, Option, Option)> { - Err(crate::Error::BackwardNotSupported { op: self.name() }) - } -} - pub trait UnaryOpT { const NAME: &'static str; const KERNEL: &'static str; @@ -399,6 +255,7 @@ pub(crate) struct Tanh; pub(crate) struct Floor; pub(crate) struct Ceil; pub(crate) struct Round; +pub(crate) struct Sign; macro_rules! bin_op { ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => { @@ -602,6 +459,13 @@ unary_op!(Recip, "recip", v, v.recip()); unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr); unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt); +// Hardcode the value for sqrt(2/pi) +// https://github.com/huggingface/candle/issues/1982 +#[allow(clippy::excessive_precision)] +const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373; +#[allow(clippy::excessive_precision)] +const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373; + /// Tanh based approximation of the `gelu` operation /// GeluErf is the more precise one. /// @@ -614,7 +478,7 @@ impl UnaryOpT for Gelu { * v * (bf16::ONE + bf16::tanh( - (bf16::from_f32_const(2.0) / bf16::PI).sqrt() + bf16::from_f32_const(SQRT_TWO_OVER_PI_F32) * v * (bf16::ONE + bf16::from_f32_const(0.044715) * v * v), )) @@ -625,22 +489,18 @@ impl UnaryOpT for Gelu { * v * (f16::ONE + f16::tanh( - (f16::from_f32_const(2.0) / f16::PI).sqrt() + f16::from_f32_const(SQRT_TWO_OVER_PI_F32) * v * (f16::ONE + f16::from_f32_const(0.044715) * v * v), )) } #[inline(always)] fn f32(v: f32) -> f32 { - 0.5 * v - * (1.0 - + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) + 0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v))) } #[inline(always)] fn f64(v: f64) -> f64 { - 0.5 * v - * (1.0 - + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) + 0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v))) } #[inline(always)] fn u8(_: u8) -> u8 { @@ -1067,3 +927,37 @@ impl std::ops::Deref for BackpropOp { &self.0 } } + +impl UnaryOpT for Sign { + const NAME: &'static str = "sign"; + const KERNEL: &'static str = "usign"; + const V: Self = Sign; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + f32::from(v > 0.) - f32::from(v < 0.) + } + #[inline(always)] + fn f64(v: f64) -> f64 { + f64::from(v > 0.) - f64::from(v < 0.) + } + #[inline(always)] + fn u8(v: u8) -> u8 { + u8::min(1, v) + } + #[inline(always)] + fn u32(v: u32) -> u32 { + u32::min(1, v) + } + #[inline(always)] + fn i64(v: i64) -> i64 { + (v > 0) as i64 - (v < 0) as i64 + } +} diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 5b684573..07f8c13e 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -1,22 +1,62 @@ use super::{GgmlDType, QStorage}; +use crate::quantized::k_quants::GgmlType; use crate::{backend::BackendDevice, cuda_backend::WrapErr}; use crate::{CudaDevice, CudaStorage, Result}; -use cudarc::driver::{CudaSlice, DeviceSlice}; +use cudarc::driver::{CudaSlice, CudaView, DeviceSlice}; +#[derive(Clone, Debug)] pub struct QCudaStorage { data: CudaSlice, dtype: GgmlDType, device: CudaDevice, } +static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false); + +pub fn set_force_dmmv(f: bool) { + FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed) +} + pub const WARP_SIZE: usize = 32; pub const MMQ_X_Q4_0_AMPERE: usize = 4; pub const MMQ_Y_Q4_0_AMPERE: usize = 32; pub const NWARPS_Q4_0_AMPERE: usize = 4; pub const GGML_CUDA_MMV_X: usize = 32; pub const GGML_CUDA_MMV_Y: usize = 1; +pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256; pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256; +pub const MATRIX_ROW_PADDING: usize = 512; + +fn ceil_div(p: usize, q: usize) -> usize { + (p + q - 1) / q +} + +fn pad(p: usize, q: usize) -> usize { + ceil_div(p, q) * q +} + +fn quantize_q8_1( + src: &CudaView, + dst: &mut CudaSlice, + elem_count: usize, + dev: &CudaDevice, +) -> Result<()> { + use cudarc::driver::LaunchAsync; + + let kx = elem_count; + let kx_padded = pad(kx, MATRIX_ROW_PADDING); + let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE); + let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (num_blocks as u32, 1, 1), + block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1), + shared_mem_bytes: 0, + }; + let params = (src, dst, kx as i32, kx_padded as i32); + unsafe { func.launch(cfg, params) }.w()?; + Ok(()) +} fn dequantize( data: &CudaSlice, @@ -30,26 +70,18 @@ fn dequantize( let (kernel_name, is_k, block_dim, num_blocks) = match dtype { GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb), GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb), - GgmlDType::Q5_0 => { - let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1) - / (2 * CUDA_DEQUANTIZE_BLOCK_SIZE); - ( - "dequantize_block_q5_0", - false, - CUDA_DEQUANTIZE_BLOCK_SIZE, - nb, - ) - } - GgmlDType::Q5_1 => { - let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1) - / (2 * CUDA_DEQUANTIZE_BLOCK_SIZE); - ( - "dequantize_block_q5_1", - false, - CUDA_DEQUANTIZE_BLOCK_SIZE, - nb, - ) - } + GgmlDType::Q5_0 => ( + "dequantize_block_q5_0", + false, + CUDA_DEQUANTIZE_BLOCK_SIZE, + ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE), + ), + GgmlDType::Q5_1 => ( + "dequantize_block_q5_1", + false, + CUDA_DEQUANTIZE_BLOCK_SIZE, + ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE), + ), GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb), GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb), GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb), @@ -60,7 +92,7 @@ fn dequantize( _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"), }; let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; - let dst = dev.alloc_zeros::(elem_count).w()?; + let dst = unsafe { dev.alloc::(elem_count).w()? }; // See e.g. // https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270 let cfg = cudarc::driver::LaunchConfig { @@ -83,9 +115,9 @@ fn dequantize( Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) } -fn dequantize_mut_mal_vec( +fn dequantize_mul_mat_vec( data: &CudaSlice, - y: &cudarc::driver::CudaView, + y: &CudaView, dtype: GgmlDType, ncols: usize, nrows: usize, @@ -93,6 +125,13 @@ fn dequantize_mut_mal_vec( ) -> Result { use cudarc::driver::LaunchAsync; + let data_elems = data.len() / dtype.type_size() * dtype.block_size(); + if data_elems < ncols * nrows { + crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems) + } + if y.len() != ncols { + crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len()) + } let kernel_name = match dtype { GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda", GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda", @@ -107,8 +146,8 @@ fn dequantize_mut_mal_vec( _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), }; let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; - let dst = dev.alloc_zeros::(nrows).w()?; - let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + let dst = unsafe { dev.alloc::(nrows).w()? }; + let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y); let cfg = cudarc::driver::LaunchConfig { grid_dim: (block_num_y as u32, 1, 1), block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1), @@ -120,9 +159,66 @@ fn dequantize_mut_mal_vec( Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) } +fn mul_mat_vec_via_q8_1( + data: &CudaSlice, + y: &CudaView, + dtype: GgmlDType, + ncols: usize, + nrows: usize, + dev: &CudaDevice, +) -> Result { + use cudarc::driver::LaunchAsync; + + let data_elems = data.len() / dtype.type_size() * dtype.block_size(); + if data_elems < ncols * nrows { + crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems) + } + if y.len() != ncols { + crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len()) + } + // Start by quantizing y + let ncols_padded = pad(ncols, MATRIX_ROW_PADDING); + let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); + let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes).w()? }; + quantize_q8_1(y, &mut y_q8_1, ncols, dev)?; + + let kernel_name = match dtype { + GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda", + GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda", + GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda", + GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda", + GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda", + GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda", + GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda", + GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda", + GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda", + GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda", + _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), + }; + let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; + let dst = unsafe { dev.alloc::(nrows).w()? }; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (nrows as u32, 1, 1), + block_dim: (WARP_SIZE as u32, 4, 1), + shared_mem_bytes: 0, + }; + + let params = ( + data, + &y_q8_1, + &dst, + /* ncols_x */ ncols as i32, + /* nrows_x */ nrows as i32, + /* nrows_y */ ncols as i32, + /* nrows_dst */ nrows as i32, + ); + unsafe { func.launch(cfg, params) }.w()?; + Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) +} + impl QCudaStorage { pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result { - let size_in_bytes = el_count * dtype.type_size() / dtype.block_size(); + let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size(); let data = device.alloc_zeros::(size_in_bytes).w()?; Ok(QCudaStorage { data, @@ -140,6 +236,12 @@ impl QCudaStorage { } pub fn dequantize(&self, elem_count: usize) -> Result { + fn deq(buffer: &[u8], n: usize, dst: &mut [f32]) -> Result<()> { + let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) }; + let vec = slice.to_vec(); + T::to_float(&vec, dst) + } + let fast_kernel = matches!( self.dtype, GgmlDType::Q4_0 @@ -158,69 +260,25 @@ impl QCudaStorage { return dequantize(&self.data, self.dtype, elem_count, self.device()); } // Run the dequantization on cpu. - use crate::quantized::k_quants::GgmlType; let buffer = self.device.dtoh_sync_copy(&self.data).w()?; let mut out = vec![0.0; elem_count]; let block_len = elem_count / self.dtype.block_size(); match self.dtype { - GgmlDType::F32 => { - let slice = - unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const f32, block_len) }; - out.copy_from_slice(slice) - } - GgmlDType::F16 => { - let vec: Vec = read_to_vec(&buffer, block_len); - half::f16::to_float(&vec, &mut out)?; - } - GgmlDType::Q4_0 => { - let vec: Vec = read_to_vec(&buffer, block_len); - crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?; - } - GgmlDType::Q4_1 => { - let vec: Vec = read_to_vec(&buffer, block_len); - crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?; - } - GgmlDType::Q5_0 => { - let vec: Vec = read_to_vec(&buffer, block_len); - crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?; - } - GgmlDType::Q5_1 => { - let vec: Vec = read_to_vec(&buffer, block_len); - crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?; - } - GgmlDType::Q8_0 => { - let vec: Vec = read_to_vec(&buffer, block_len); - crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?; - } - GgmlDType::Q8_1 => { - let vec: Vec = read_to_vec(&buffer, block_len); - crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?; - } - GgmlDType::Q2K => { - let vec: Vec = read_to_vec(&buffer, block_len); - crate::quantized::BlockQ2K::to_float(&vec, &mut out)?; - } - GgmlDType::Q3K => { - let vec: Vec = read_to_vec(&buffer, block_len); - crate::quantized::BlockQ3K::to_float(&vec, &mut out)?; - } - GgmlDType::Q4K => { - let vec: Vec = read_to_vec(&buffer, block_len); - crate::quantized::BlockQ4K::to_float(&vec, &mut out)?; - } - GgmlDType::Q5K => { - let vec: Vec = read_to_vec(&buffer, block_len); - crate::quantized::BlockQ5K::to_float(&vec, &mut out)?; - } - GgmlDType::Q6K => { - let vec: Vec = read_to_vec(&buffer, block_len); - crate::quantized::BlockQ6K::to_float(&vec, &mut out)?; - } - GgmlDType::Q8K => { - let vec: Vec = read_to_vec(&buffer, block_len); - crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; - } + GgmlDType::F32 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::F16 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q4_0 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q4_1 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q5_0 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q5_1 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q8_0 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q8_1 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q2K => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q3K => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q4K => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q5K => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q6K => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q8K => deq::(&buffer, block_len, &mut out)?, } self.device @@ -285,8 +343,11 @@ impl QCudaStorage { crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape()) } - let out = - dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?; + let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) { + dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())? + } else { + mul_mat_vec_via_q8_1(&self.data, &rhs, self.dtype, ncols, nrows, self.device())? + }; let out_shape = if with_batch { vec![1, 1, nrows] } else { @@ -313,7 +374,7 @@ impl QCudaStorage { } let data_f32 = self.dequantize(n * k)?; - let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0); + let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?; let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?; let mut out_shape = layout.shape().dims().to_vec(); out_shape.pop(); @@ -322,11 +383,6 @@ impl QCudaStorage { } } -fn read_to_vec(buffer: &[u8], n: usize) -> Vec { - let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) }; - slice.to_vec() -} - pub fn load_quantized( device: &CudaDevice, data: &[T], @@ -341,3 +397,60 @@ pub fn load_quantized( dtype: T::DTYPE, })) } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn cuda_quantize_q8_1() -> Result<()> { + let dev = CudaDevice::new(0)?; + let el = 256; + let el_padded = pad(el, MATRIX_ROW_PADDING); + let y_size_in_bytes = + el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); + let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes).w()? }; + let vs: Vec = (0..el).map(|v| v as f32).collect(); + let y = dev.htod_sync_copy(&vs).w()?; + quantize_q8_1(&y.slice(..), &mut y_q8_1, el, &dev)?; + Ok(()) + } + + #[test] + fn cuda_mmv_q8_1() -> Result<()> { + let dev = CudaDevice::new(0)?; + let ncols = 256; + let vs: Vec = (0..ncols).map(|v| v as f32).collect(); + let y = dev.htod_sync_copy(&vs).w()?; + let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?; + xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?; + let cuda_storage = mul_mat_vec_via_q8_1( + &xs.data, + &y.slice(..), + /* dtype */ GgmlDType::Q4_0, + /* ncols */ ncols, + /* nrows */ 1, + &dev, + )?; + let vs = cuda_storage.as_cuda_slice::()?; + let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap(); + assert_eq!(vs.len(), 1); + // for n = 255, n.(n+1).(2n+1) / 6 = 5559680 + // Q8 means 1/256 precision. + assert_eq!(vs[0], 5561664.5); + + let cuda_storage = dequantize_mul_mat_vec( + &xs.data, + &y.slice(..), + /* dtype */ GgmlDType::Q4_0, + /* ncols */ ncols, + /* nrows */ 1, + &dev, + )?; + let vs = cuda_storage.as_cuda_slice::()?; + let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap(); + assert_eq!(vs.len(), 1); + assert_eq!(vs[0], 5561851.0); + Ok(()) + } +} diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index 7be0f74e..c310d766 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -149,8 +149,11 @@ impl QMetalStorage { let (n, k) = self_shape.dims2()?; let mut dst_shape = src_shape.dims().to_vec(); + // We always use a single batch dimension and stack all the tensors in the batch on the + // second dimension as the implementation in candle-metal-kernels doesn't handle batch + // properly. let (b, m) = match dst_shape.len() { - 3 => (dst_shape[0], dst_shape[1]), + 3 => (1, dst_shape[0] * dst_shape[1]), 2 => (1, dst_shape[0]), n => crate::bail!("Invalid rank {n} for quantized matmul metal"), }; diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index f7abcd93..47307f2e 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -398,7 +398,7 @@ impl QMatMul { _ => DEQUANTIZE_ALL.with(|b| *b), }; let t = if dequantize { - let tensor = qtensor.dequantize(&Device::Cpu)?; + let tensor = qtensor.dequantize(&qtensor.device())?; Self::Tensor(tensor) } else { Self::QTensor(qtensor) diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 32ebb23f..567a711b 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -171,7 +171,7 @@ impl Shape { } let mut acc = 1; for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() { - if stride != acc { + if dim > 1 && stride != acc { return false; } acc *= dim; @@ -186,7 +186,7 @@ impl Shape { } let mut acc = 1; for (&stride, &dim) in stride.iter().zip(self.0.iter()) { - if stride != acc { + if dim > 1 && stride != acc { return false; } acc *= dim; diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 65bcc6aa..8a0637e3 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -1,6 +1,7 @@ use crate::backend::BackendStorage; -use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp}; +use crate::op::{self, CmpOp, ReduceOp}; use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape}; +use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3}; // We do not want to implement Clone on Storage as cloning may fail because of // out of memory. Instead try_clone should be used. @@ -43,9 +44,19 @@ impl Storage { } pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> { - let lhs = self.device().location(); - let rhs = rhs.device().location(); - if lhs != rhs { + let lhs_device = self.device(); + let rhs_device = rhs.device(); + let lhs = lhs_device.location(); + let rhs = rhs_device.location(); + let same_device = if self.device().is_metal() { + // On metal, we require the device to be exactly the same rather than + // having the same location. In cuda this is not necessary as all CudaDevice on the + // same GPU will use the same cuda stream. + lhs_device.same_device(&rhs_device) + } else { + lhs == rhs + }; + if !same_device { Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt()) } else { Ok(()) @@ -252,6 +263,51 @@ impl Storage { } } + pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> { + match self { + Self::Cpu(storage) => c.cpu_fwd(storage, l), + Self::Cuda(storage) => c.cuda_fwd(storage, l), + Self::Metal(storage) => c.metal_fwd(storage, l), + } + } + + pub(crate) fn inplace_op2( + &mut self, + l1: &Layout, + t2: &Self, + l2: &Layout, + c: &dyn InplaceOp2, + ) -> Result<()> { + self.same_device(t2, c.name())?; + match (self, t2) { + (Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2), + (Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2), + (Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2), + _ => unreachable!(), + } + } + + pub(crate) fn inplace_op3( + &mut self, + l1: &Layout, + t2: &Self, + l2: &Layout, + t3: &Self, + l3: &Layout, + c: &dyn InplaceOp3, + ) -> Result<()> { + self.same_device(t2, c.name())?; + self.same_device(t3, c.name())?; + match (self, t2, t3) { + (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3), + (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3), + (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => { + c.metal_fwd(s1, l1, s2, l2, s3, l3) + } + _ => unreachable!(), + } + } + pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { match self { Storage::Cpu(storage) => { @@ -701,4 +757,32 @@ impl Storage { .bt()), } } + + #[allow(clippy::too_many_arguments)] + pub(crate) fn copy2d( + &self, + dst: &mut Self, + d1: usize, + d2: usize, + src_s: usize, + dst_s: usize, + src_o: usize, + dst_o: usize, + ) -> Result<()> { + match (self, dst) { + (Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o), + (Self::Cuda(src), Self::Cuda(dst)) => { + Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?) + } + (Self::Metal(src), Self::Metal(dst)) => { + Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "copy2d", + } + .bt()), + } + } } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 0e2c3e8f..a5a9dbb1 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,9 +1,7 @@ //! Tensors are N-dimensional matrixes of elements using a single data type. #![allow(clippy::redundant_closure_call)] use crate::backend::{BackendDevice, BackendStorage}; -use crate::op::{ - BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp, -}; +use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp}; use crate::scalar::TensorOrScalar; use crate::shape::{Dim, Dims}; use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape}; @@ -512,6 +510,7 @@ impl Tensor { unary_op!(ceil, Ceil); unary_op!(floor, Floor); unary_op!(round, Round); + unary_op!(sign, Sign); /// Round element of the input tensor to the nearest integer. /// @@ -666,7 +665,7 @@ impl Tensor { Ok(from_storage(storage, self.shape(), op, false)) } - fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> { + pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> { if dim >= self.dims().len() { Err(Error::DimOutOfRange { shape: self.shape().clone(), @@ -1351,7 +1350,7 @@ impl Tensor { } .bt())? } - let mut storage = self.device().zeros(self.shape(), self.dtype())?; + let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? }; self.storage() .copy_strided_src(&mut storage, 0, self.layout())?; let offset = start * src.dims()[1..].iter().product::(); @@ -2001,7 +2000,7 @@ impl Tensor { Ok(self.clone()) } else { let shape = self.shape(); - let mut storage = self.device().zeros(shape, self.dtype())?; + let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? }; self.storage() .copy_strided_src(&mut storage, 0, self.layout())?; let op = BackpropOp::new1(self, Op::Copy); @@ -2009,11 +2008,21 @@ impl Tensor { } } + /// Returns a tensor that is in row major order. This always makes a copy. + pub fn force_contiguous(&self) -> Result { + let shape = self.shape(); + let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? }; + self.storage() + .copy_strided_src(&mut storage, 0, self.layout())?; + let op = BackpropOp::new1(self, Op::Copy); + Ok(from_storage(storage, shape.clone(), op, false)) + } + /// Create a variable based on the values currently stored in a tensor. The storage is always /// copied. pub(crate) fn make_var(&self) -> Result { let shape = self.shape().clone(); - let mut storage = self.device().zeros(&shape, self.dtype())?; + let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? }; self.storage() .copy_strided_src(&mut storage, 0, self.layout())?; Ok(from_storage(storage, shape, BackpropOp::none(), true)) @@ -2066,7 +2075,7 @@ impl Tensor { }; Ok(Tensor(Arc::new(tensor_))) } else { - let mut storage = self.device().zeros(&shape, self.dtype())?; + let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? }; self.storage() .copy_strided_src(&mut storage, 0, self.layout())?; Ok(from_storage(storage, shape, op, false)) @@ -2093,8 +2102,19 @@ impl Tensor { let dim = dim.to_index(self.shape(), "squeeze")?; if dims[dim] == 1 { let mut dims = dims.to_vec(); + let mut strides = self.stride().to_vec(); dims.remove(dim); - self.reshape(dims) + strides.remove(dim); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: Layout::new(dims.into(), strides, self.layout.start_offset()), + op: BackpropOp::new1(self, Op::Reshape), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) } else { Ok(self.clone()) } @@ -2115,10 +2135,24 @@ impl Tensor { /// ``` pub fn unsqueeze(&self, dim: D) -> Result { let mut dims = self.dims().to_vec(); + let mut strides = self.stride().to_vec(); let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?; // Cannot panic because to_index_plus_one already checks dimensions dims.insert(dim, 1); - self.reshape(dims) + // Any stride would work here, but we pick one so as to maximize the probability to remain + // C contiguous. + let stride = if dim < strides.len() { strides[dim] } else { 1 }; + strides.insert(dim, stride); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: Layout::new(dims.into(), strides, self.layout.start_offset()), + op: BackpropOp::new1(self, Op::Reshape), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) } /// Stacks two or more tensors along a particular dimension. @@ -2149,152 +2183,6 @@ impl Tensor { Self::cat(&args, dim) } - /// Concatenates two or more tensors along a particular dimension. - /// - /// All tensors must of the same rank, and the output will have - /// the same rank - /// - /// ```rust - /// # use candle_core::{Tensor, DType, Device}; - /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; - /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; - /// - /// let c = Tensor::cat(&[&a, &b], 0)?; - /// assert_eq!(c.shape().dims(), &[4, 3]); - /// - /// let c = Tensor::cat(&[&a, &b], 1)?; - /// assert_eq!(c.shape().dims(), &[2, 6]); - /// # Ok::<(), candle_core::Error>(()) - /// ``` - pub fn cat, D: Dim>(args: &[A], dim: D) -> Result { - if args.is_empty() { - Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())? - } - let arg0 = args[0].as_ref(); - if args.len() == 1 { - return Ok(arg0.clone()); - } - let dim = dim.to_index(arg0.shape(), "cat")?; - for arg in args { - arg.as_ref().check_dim(dim, "cat")?; - } - for (arg_idx, arg) in args.iter().enumerate() { - let arg = arg.as_ref(); - if arg0.rank() != arg.rank() { - Err(Error::UnexpectedNumberOfDims { - expected: arg0.rank(), - got: arg.rank(), - shape: arg.shape().clone(), - } - .bt())? - } - for (dim_idx, (v1, v2)) in arg0 - .shape() - .dims() - .iter() - .zip(arg.shape().dims().iter()) - .enumerate() - { - if dim_idx != dim && v1 != v2 { - Err(Error::ShapeMismatchCat { - dim: dim_idx, - first_shape: arg0.shape().clone(), - n: arg_idx + 1, - nth_shape: arg.shape().clone(), - } - .bt())? - } - } - } - if dim == 0 { - Self::cat0(args) - } else { - // TODO: Avoid these transpositions and have an implementation that works - // for dim != 0... - let args: Vec = args - .iter() - .map(|a| a.as_ref().transpose(0, dim)) - .collect::>>()?; - let cat = Self::cat0(&args)?; - cat.transpose(0, dim) - } - } - - fn cat0>(args: &[A]) -> Result { - if args.is_empty() { - Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())? - } - let arg0 = args[0].as_ref(); - if args.len() == 1 { - return Ok(arg0.clone()); - } - let rank = arg0.rank(); - let device = arg0.device(); - let dtype = arg0.dtype(); - let first_dims = arg0.shape().dims(); - let mut cat_dims = first_dims.to_vec(); - cat_dims[0] = 0; - let mut offsets = vec![0usize]; - for (arg_idx, arg) in args.iter().enumerate() { - let arg = arg.as_ref(); - if arg.dtype() != dtype { - Err(Error::DTypeMismatchBinaryOp { - lhs: dtype, - rhs: arg.dtype(), - op: "cat", - } - .bt())? - } - if arg.device().location() != device.location() { - Err(Error::DeviceMismatchBinaryOp { - lhs: device.location(), - rhs: arg.device().location(), - op: "cat", - } - .bt())? - } - if rank != arg.rank() { - Err(Error::UnexpectedNumberOfDims { - expected: rank, - got: arg.rank(), - shape: arg.shape().clone(), - } - .bt())? - } - for (dim_idx, (v1, v2)) in arg0 - .shape() - .dims() - .iter() - .zip(arg.shape().dims().iter()) - .enumerate() - { - if dim_idx == 0 { - cat_dims[0] += v2; - } - if dim_idx != 0 && v1 != v2 { - Err(Error::ShapeMismatchCat { - dim: dim_idx, - first_shape: arg0.shape().clone(), - n: arg_idx + 1, - nth_shape: arg.shape().clone(), - } - .bt())? - } - } - let next_offset = offsets.last().unwrap() + arg.elem_count(); - offsets.push(next_offset); - } - let shape = Shape::from(cat_dims); - let op = BackpropOp::new(args, |args| Op::Cat(args, 0)); - let mut storage = device.zeros(&shape, dtype)?; - for (arg, &offset) in args.iter().zip(offsets.iter()) { - let arg = arg.as_ref(); - arg.storage() - .copy_strided_src(&mut storage, offset, arg.layout())?; - } - Ok(from_storage(storage, shape, op, false)) - } - /// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the /// input tensor values and `right` elements after. pub fn pad_with_zeros(&self, dim: D, left: usize, right: usize) -> Result { @@ -2377,6 +2265,10 @@ impl Tensor { self.storage.read().unwrap() } + pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> { + self.storage.write().unwrap() + } + // If we extend the visibility of this function to be usable outside of this crate, we should // make it unsafe. pub(crate) fn storage_mut_and_layout( @@ -2398,96 +2290,6 @@ impl Tensor { std::ptr::eq(lhs, rhs) } - /// Applies a unary custom op without backward support - pub fn apply_op1_no_bwd(&self, c: &C) -> Result { - let (storage, shape) = self.storage().apply_op1(self.layout(), c)?; - Ok(from_storage(storage, shape, BackpropOp::none(), false)) - } - - /// Applies a binary custom op without backward support - pub fn apply_op2_no_bwd(&self, rhs: &Self, c: &C) -> Result { - let (storage, shape) = - self.storage() - .apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?; - Ok(from_storage(storage, shape, BackpropOp::none(), false)) - } - - /// Applies a ternary custom op without backward support - pub fn apply_op3_no_bwd(&self, t2: &Self, t3: &Self, c: &C) -> Result { - let (storage, shape) = self.storage().apply_op3( - self.layout(), - &t2.storage(), - t2.layout(), - &t3.storage(), - t3.layout(), - c, - )?; - Ok(from_storage(storage, shape, BackpropOp::none(), false)) - } - - /// Applies a unary custom op. - pub fn apply_op1_arc(&self, c: Arc>) -> Result { - let (storage, shape) = self - .storage() - .apply_op1(self.layout(), c.as_ref().as_ref())?; - let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone())); - Ok(from_storage(storage, shape, op, false)) - } - - pub fn apply_op1(&self, c: C) -> Result { - self.apply_op1_arc(Arc::new(Box::new(c))) - } - - /// Applies a binary custom op. - pub fn apply_op2_arc( - &self, - rhs: &Self, - c: Arc>, - ) -> Result { - let (storage, shape) = self.storage().apply_op2( - self.layout(), - &rhs.storage(), - rhs.layout(), - c.as_ref().as_ref(), - )?; - let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone())); - Ok(from_storage(storage, shape, op, false)) - } - - pub fn apply_op2(&self, r: &Self, c: C) -> Result { - self.apply_op2_arc(r, Arc::new(Box::new(c))) - } - - /// Applies a ternary custom op. - pub fn apply_op3_arc( - &self, - t2: &Self, - t3: &Self, - c: Arc>, - ) -> Result { - let (storage, shape) = self.storage().apply_op3( - self.layout(), - &t2.storage(), - t2.layout(), - &t3.storage(), - t3.layout(), - c.as_ref().as_ref(), - )?; - let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| { - Op::CustomOp3(t1, t2, t3, c.clone()) - }); - Ok(from_storage(storage, shape, op, false)) - } - - pub fn apply_op3( - &self, - t2: &Self, - t3: &Self, - c: C, - ) -> Result { - self.apply_op3_arc(t2, t3, Arc::new(Box::new(c))) - } - /// Normalize a 'relative' axis value: positive values are kept, negative /// values means counting the dimensions from the back. pub fn normalize_axis(&self, axis: i64) -> Result { diff --git a/candle-core/src/tensor_cat.rs b/candle-core/src/tensor_cat.rs new file mode 100644 index 00000000..27ff7851 --- /dev/null +++ b/candle-core/src/tensor_cat.rs @@ -0,0 +1,238 @@ +use crate::{shape::Dim, Error, Result, Shape, Tensor}; + +impl Tensor { + /// Concatenates two or more tensors along a particular dimension. + /// + /// All tensors must of the same rank, and the output will have + /// the same rank + /// + /// ```rust + /// # use candle_core::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = Tensor::cat(&[&a, &b], 0)?; + /// assert_eq!(c.shape().dims(), &[4, 3]); + /// + /// let c = Tensor::cat(&[&a, &b], 1)?; + /// assert_eq!(c.shape().dims(), &[2, 6]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn cat, D: Dim>(args: &[A], dim: D) -> Result { + if args.is_empty() { + Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())? + } + let arg0 = args[0].as_ref(); + if args.len() == 1 { + return Ok(arg0.clone()); + } + let dim = dim.to_index(arg0.shape(), "cat")?; + for arg in args { + arg.as_ref().check_dim(dim, "cat")?; + } + for (arg_idx, arg) in args.iter().enumerate() { + let arg = arg.as_ref(); + if arg0.rank() != arg.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: arg0.rank(), + got: arg.rank(), + shape: arg.shape().clone(), + } + .bt())? + } + for (dim_idx, (v1, v2)) in arg0 + .shape() + .dims() + .iter() + .zip(arg.shape().dims().iter()) + .enumerate() + { + if dim_idx != dim && v1 != v2 { + Err(Error::ShapeMismatchCat { + dim: dim_idx, + first_shape: arg0.shape().clone(), + n: arg_idx + 1, + nth_shape: arg.shape().clone(), + } + .bt())? + } + } + } + let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous()); + if all_contiguous { + Self::cat_contiguous(args, dim) + } else if dim == 0 { + Self::cat0(args) + } else { + let args: Vec = args + .iter() + .map(|a| a.as_ref().transpose(0, dim)) + .collect::>>()?; + let cat = Self::cat0(&args)?; + cat.transpose(0, dim) + } + } + + fn cat0>(args: &[A]) -> Result { + if args.is_empty() { + Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())? + } + let arg0 = args[0].as_ref(); + if args.len() == 1 { + return Ok(arg0.clone()); + } + let rank = arg0.rank(); + let device = arg0.device(); + let dtype = arg0.dtype(); + let first_dims = arg0.shape().dims(); + let mut cat_dims = first_dims.to_vec(); + cat_dims[0] = 0; + let mut offsets = vec![0usize]; + for (arg_idx, arg) in args.iter().enumerate() { + let arg = arg.as_ref(); + if arg.dtype() != dtype { + Err(Error::DTypeMismatchBinaryOp { + lhs: dtype, + rhs: arg.dtype(), + op: "cat", + } + .bt())? + } + if arg.device().location() != device.location() { + Err(Error::DeviceMismatchBinaryOp { + lhs: device.location(), + rhs: arg.device().location(), + op: "cat", + } + .bt())? + } + if rank != arg.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: rank, + got: arg.rank(), + shape: arg.shape().clone(), + } + .bt())? + } + for (dim_idx, (v1, v2)) in arg0 + .shape() + .dims() + .iter() + .zip(arg.shape().dims().iter()) + .enumerate() + { + if dim_idx == 0 { + cat_dims[0] += v2; + } + if dim_idx != 0 && v1 != v2 { + Err(Error::ShapeMismatchCat { + dim: dim_idx, + first_shape: arg0.shape().clone(), + n: arg_idx + 1, + nth_shape: arg.shape().clone(), + } + .bt())? + } + } + let next_offset = offsets.last().unwrap() + arg.elem_count(); + offsets.push(next_offset); + } + let shape = Shape::from(cat_dims); + let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0)); + let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? }; + for (arg, &offset) in args.iter().zip(offsets.iter()) { + let arg = arg.as_ref(); + arg.storage() + .copy_strided_src(&mut storage, offset, arg.layout())?; + } + Ok(crate::tensor::from_storage(storage, shape, op, false)) + } + + fn cat_contiguous>(args: &[A], dim: usize) -> Result { + if args.is_empty() { + Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())? + } + let arg0 = args[0].as_ref(); + if args.len() == 1 { + return Ok(arg0.clone()); + } + let rank = arg0.rank(); + let device = arg0.device(); + let dtype = arg0.dtype(); + let first_dims = arg0.shape().dims(); + let mut cat_dims = first_dims.to_vec(); + cat_dims[dim] = 0; + for (arg_idx, arg) in args.iter().enumerate() { + let arg = arg.as_ref(); + if arg.dtype() != dtype { + Err(Error::DTypeMismatchBinaryOp { + lhs: dtype, + rhs: arg.dtype(), + op: "cat", + } + .bt())? + } + if arg.device().location() != device.location() { + Err(Error::DeviceMismatchBinaryOp { + lhs: device.location(), + rhs: arg.device().location(), + op: "cat", + } + .bt())? + } + if rank != arg.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: rank, + got: arg.rank(), + shape: arg.shape().clone(), + } + .bt())? + } + for (dim_idx, (v1, v2)) in arg0 + .shape() + .dims() + .iter() + .zip(arg.shape().dims().iter()) + .enumerate() + { + if dim_idx == dim { + cat_dims[dim] += v2; + } + if dim_idx != dim && v1 != v2 { + Err(Error::ShapeMismatchCat { + dim: dim_idx, + first_shape: arg0.shape().clone(), + n: arg_idx + 1, + nth_shape: arg.shape().clone(), + } + .bt())? + } + } + } + let cat_target_dim_len = cat_dims[dim]; + let block_size: usize = cat_dims.iter().skip(1 + dim).product(); + let shape = Shape::from(cat_dims); + let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim)); + let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? }; + let mut dst_o = 0; + for arg in args.iter() { + let arg = arg.as_ref(); + let arg_dims = arg.shape().dims(); + let d1: usize = arg_dims.iter().take(dim).product(); + let d2 = block_size * arg_dims[dim]; + let dst_s = block_size * cat_target_dim_len; + let src_o = arg.layout().start_offset(); + arg.storage().copy2d( + &mut storage, + d1, + d2, + /* src_s */ d2, + dst_s, + src_o, + dst_o, + )?; + dst_o += d2; + } + Ok(crate::tensor::from_storage(storage, shape, op, false)) + } +} diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index b967515d..3762e02f 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -53,26 +53,31 @@ fn conv1d(dev: &Device) -> Result<()> { test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352] ); - let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?; - assert_eq!(res.dims(), [1, 2, 7]); - assert_eq!( - test_utils::to_vec1_round(&res.flatten_all()?, 4)?, - [ - 0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538, - 4.7076, -5.9745, -0.8276, 1.621 - ], - ); - let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 2)?; - assert_eq!(res.dims(), [1, 4, 7]); - assert_eq!( - test_utils::to_vec2_round(&res.squeeze(0)?, 4)?, - [ - [-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819], - [0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721], - [1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113], - [1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488] - ] - ); + + let w = w.transpose(0, 1)?; + // The CPU kernels applied in the contiguous and non contiguous cases are different. + for w in [w.clone(), w.contiguous()?] { + let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 1)?; + assert_eq!(res.dims(), [1, 2, 7]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [ + 0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538, + 4.7076, -5.9745, -0.8276, 1.621 + ], + ); + let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 2)?; + assert_eq!(res.dims(), [1, 4, 7]); + assert_eq!( + test_utils::to_vec2_round(&res.squeeze(0)?, 4)?, + [ + [-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819], + [0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721], + [1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113], + [1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488] + ] + ); + } Ok(()) } @@ -130,7 +135,7 @@ fn conv2d(dev: &Device) -> Result<()> { 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790, -0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006, - -0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085, + -0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085, ], dev, )?; @@ -158,7 +163,9 @@ fn conv2d(dev: &Device) -> Result<()> { 10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075 ] ); + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; + assert_eq!(res.dims(), [1, 2, 7, 7]); assert_eq!( test_utils::to_vec3_round(&res.i(0)?, 4)?, @@ -183,6 +190,7 @@ fn conv2d(dev: &Device) -> Result<()> { ] ] ); + // Dilations. let res = t.conv2d(&w, 0, 1, 2, 1)?; assert_eq!(res.dims(), [1, 2, 1, 1]); @@ -221,6 +229,7 @@ fn conv2d(dev: &Device) -> Result<()> { ] ] ); + Ok(()) } @@ -267,13 +276,13 @@ fn conv2d_small(dev: &Device) -> Result<()> { assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [ - 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, - 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000, - 0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855, - -1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, - 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000 + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1640, + -0.0111, -0.1742, 0.0, 0.0, 0.0, 0.0, 2.6437, -2.0268, 1.1823, 0.0, 0.0, 0.0, 0.0, + 3.2855, -1.0324, 0.2539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0 ] ); + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; assert_eq!(res.dims(), [1, 1, 3, 3]); assert_eq!( @@ -375,6 +384,7 @@ print(w.grad.shape) print(w.grad[0]) */ fn conv2d_grad(dev: &Device) -> Result<()> { + // conv-transposes are not implemented for metal use candle_core::Var; let t = Var::from_slice( &[ @@ -387,7 +397,7 @@ fn conv2d_grad(dev: &Device) -> Result<()> { 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790, -0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006, - -0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085, + -0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085, ], (1, 4, 5, 5), dev, @@ -572,6 +582,154 @@ fn conv2d_grad(dev: &Device) -> Result<()> { ] ); + // Conv Transpose 2d Test + //tested against following python + + // import torch + // torch.manual_seed(4242) + // padding = 4 + // outpadding = 2 + // dilation = 3 + // stride = 3 + // input = torch.randn((1, 4, 7, 5), requires_grad=True) + // kernel = torch.randn((4, 2, 3, 5), requires_grad=True) + // print("input", input.flatten()) + // print("kernel", kernel.flatten()) + // res = torch.nn.functional.conv_transpose2d( + // input, + // kernel, + // stride=stride, + // padding=padding, + // dilation=dilation, + // output_padding=outpadding, + // ) + // res.retain_grad() + // print(res.shape) + // loss = (res**2).sum() + // print(loss) + // loss.backward() + // print(input.grad.shape) + // print("input grad", torch.round(input.grad, decimals=1)) + // print(kernel.grad.shape) + // print("kernel grad", torch.round(kernel.grad.flatten(), decimals=1)) + + let padding = 4; + let outpadding = 2; + let dilation = 3; + let stride = 3; + + let t = Var::from_slice( + &[ + 0.4056_f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, + 3.0616, 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, + 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, + -0.6836, 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, + 1.3123, 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, + 1.1586, 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, + 0.2049, 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, + -1.3712, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, + 0.3790, -0.4431, -0.4720, -0.7890, 0.2620, 0.5411, -1.1715, -2.4997, 2.3249, -0.8912, + -0.4733, -0.5701, -2.8888, -1.4112, -0.5471, -0.9234, -1.1660, 0.4189, -0.7465, + -0.6473, 0.1402, 0.7875, 0.5377, -0.6779, -0.8088, -0.4864, -0.2312, 0.9279, 0.1264, + 1.5480, 0.8265, -0.1025, 0.5138, -0.2512, 0.1576, 1.2705, 0.3641, -0.9325, 0.6451, + -0.8537, 0.2378, 0.1794, 0.2752, -0.3687, -1.1149, -0.1410, -0.5829, -0.0892, 1.4258, + -2.2789, 0.5270, 0.1825, 1.7007, -0.5263, -0.2954, 0.4440, 0.5537, 0.3492, 0.6186, + 1.6475, 0.2219, + ], + (1, 4, 7, 5), + dev, + )?; + + #[rustfmt::skip] + let w = Var::from_slice( + &[ + -1.1744_f32, 0.3266, 2.5893, 1.0142, 0.1763, 0.7752, 0.6604, 0.2029, -0.2145, 0.7234, + -0.3441, -1.5400, -0.6333, 0.6613, 0.2083, 0.6230, -1.7002, 0.3393, 0.4049, 1.0762, + 0.2723, 1.4181, 0.0029, -0.2122, 1.7668, 1.4168, 0.3320, -0.2719, 0.7932, -0.7204, + 0.4447, 0.1211, 0.5908, 1.0089, -0.1646, 1.8033, -0.6286, 0.2016, -0.3370, 1.2555, + 0.8009, -0.6488, -0.4652, -1.5685, 1.5860, 0.5583, 0.4623, 0.6026, 0.8828, 2.4990, + 0.6811, -0.3369, 1.3320, 1.7669, -1.1067, 1.2958, -0.9415, -0.9655, -0.4462, 0.7181, + 0.5181, -1.1658, -1.8467, -0.7763, 1.2769, 0.8651, 0.9890, 1.5092, 0.7207, -0.8481, + 0.7417, 0.3375, -1.2685, 1.4572, 1.0915, 0.1093, -0.8550, -0.5831, -0.6309, -0.2509, + 0.5220, -0.0914, 0.7900, 0.1096, 0.3258, 0.2723, -1.0942, -0.3393, -0.1653, 0.5732, + -0.8014, 1.8194, -1.9023, 0.2127, 1.8636, -0.8979, 0.1927, -0.2778, 0.3105, 0.0071, + -1.1823, 0.2476, -0.7178, -1.3821, 1.0769, -0.4376, -0.9967, -0.1227, 1.6197, -1.0604, + 0.1372, 0.8141, -0.6163, 0.7304, -0.8285, 2.0636, -0.7176, 0.2495, -0.2581, -0.4478, + ], + (4, 2, 3, 5), + dev, + )?; + let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?; + let loss = res.sqr()?.sum_all()?; + assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 2904.0); + let grads = loss.backward()?; + + let grad_t = grads.get(&t).unwrap(); + let grad_w = grads.get(&w).unwrap(); + assert_eq!(grad_t.dims(), [1, 4, 7, 5]); + assert_eq!(grad_w.dims(), [4, 2, 3, 5]); + + assert_eq!( + test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?, + [ + // torch gets 89.1 + -89.0, -135.3, 136.7, 102.0, -53.4, 117.9, 118.6, -43.9, -218.0, -58.5, -114.3, -150.0, + -15.6, 172.1, 66.3, -64.3, -27.9, -19.8, 31.7, 62.1, 5.5, 92.6, 28.2, -29.6, 55.9, + 52.7, -72.7, -119.8, 53.8, -25.5, 128.8, 19.3, 68.0, 190.9, -64.1, -86.2, -111.2, + 106.6, -67.7, 37.8, 115.9, 50.4, -77.7, -54.9, 22.3, -4.6, 89.8, 61.7, 122.4, 192.6, + -27.8, -104.6, 57.0, 166.4, 27.1, 6.1, 18.7, -93.2, 31.5, 168.2, -3.7, -99.5, -55.5, + -10.8, 17.5, 20.8, 16.9, 43.8, 42.0, -89.2, 18.8, -9.6, -84.1, 212.6, 19.7, -50.0, + -52.0, -40.0, -166.6, -73.2, -10.8, -73.3, 31.5, -23.4, -79.3, -27.0, -84.4, -42.9, + -20.3, 51.8, -16.7, 76.3, -120.5, -65.8, 96.5, -10.7, -45.9, -88.1, 65.4, -7.0, -1.5, + 92.8, -25.1, -114.2, -5.8, -14.8, -51.2, -20.7, 54.2, -79.8, 47.7, -29.2, -8.8, 53.5, + -28.4, 85.0, -18.3, 107.0, 28.3, -71.8 + ] + ); + + assert_eq!( + test_utils::to_vec3_round(&grad_t.i(0)?, 1)?, + [ + [ + [32.3, -41.6, -24.0, 14.1, 17.6], + [-11.8, 72.5, 87.6, 46.4, 61.5], + [115.0, 108.5, -48.6, -63.4, -50.0], + [51.3, 5.4, 31.3, 91.1, -30.9], + [52.7, 92.8, -68.0, -47.0, 83.0], + // pytorch gets -107.1 + [-10.2, -107.0, -5.4, 213.1, -31.4], + [-2.4, 65.1, 9.2, -146.2, -24.2] + ], + [ + [-72.6, -63.9, -61.9, 45.3, 33.0], + [79.3, -0.5, -26.2, 78.2, 42.7], + [90.9, 141.6, 40.1, -62.7, 37.0], + [32.8, 198.2, -0.8, -31.1, 27.3], + // torch gets 48.0 + [34.5, 34.9, -47.9, 127.6, -12.3], + [-61.4, -3.2, -2.9, -10.9, -16.6], + [74.6, 60.1, -68.9, 34.5, -50.4] + ], + [ + [37.5, -56.9, -43.6, -13.5, -9.9], + [40.0, 97.3, 28.6, 14.2, -30.1], + [-22.3, -126.3, -68.8, -8.2, 26.1], + [-32.9, 37.3, 108.5, -54.8, 29.6], + [34.9, -176.9, -125.0, -28.3, -13.9], + [-54.9, 142.6, 62.1, -80.4, -65.6], + [7.4, -91.1, -67.6, 35.0, 39.7] + ], + [ + [-57.2, -40.9, -10.1, 32.6, 29.4], + [18.7, -18.0, 29.5, -1.2, 59.2], + [-14.0, -74.4, 19.8, -117.0, 58.2], + [-21.8, 163.5, -71.1, -99.0, 80.9], + [-58.9, -10.9, 93.8, -139.6, 98.0], + // torch gets 54.5 + [-54.4, 135.3, 6.0, -79.1, 134.6], + [27.5, -76.0, 43.4, -2.8, -7.8] + ] + ] + ); Ok(()) } diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs index cff0aebe..be59e0c0 100644 --- a/candle-core/tests/custom_op_tests.rs +++ b/candle-core/tests/custom_op_tests.rs @@ -112,3 +112,34 @@ fn custom_op1_with_backward() -> Result<()> { Ok(()) } + +impl candle_core::InplaceOp1 for Elu { + fn name(&self) -> &'static str { + "elu" + } + + fn cpu_fwd(&self, s: &mut CpuStorage, _l: &Layout) -> Result<()> { + let alpha = self.alpha; + match s { + CpuStorage::BF16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)), + CpuStorage::F16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)), + CpuStorage::F32(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)), + CpuStorage::F64(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)), + _ => candle_core::bail!("unsupported dtype for inplace elu"), + } + Ok(()) + } +} + +#[test] +fn inplace_op1() -> Result<()> { + let cpu = &Device::Cpu; + let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?; + let t = (t - 5.)?; + t.inplace_op1(&Elu { alpha: 1. })?; + assert_eq!( + to_vec1_round(&t, 4)?, + &[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + ); + Ok(()) +} diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index a4d81618..b8b6be8d 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -1,3 +1,4 @@ +#![allow(clippy::approx_constant)] use anyhow::{Context, Result}; use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var}; @@ -96,24 +97,24 @@ fn unary_grad(device: &Device) -> Result<()> { let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; assert_eq!( - y.to_vec1::()?, - [20.085537, 2.7182817, 54.59815, 1.1618342] + test_utils::to_vec1_round(&y, 4)?, + [20.0855, 2.7183, 54.5982, 1.1618] ); assert_eq!( - grad_x.to_vec1::()?, - [20.085537, 2.7182817, 54.59815, 1.1618342] + test_utils::to_vec1_round(grad_x, 4)?, + [20.0855, 2.7183, 54.5982, 1.1618] ); let y = x.exp()?.sqr()?; let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; assert_eq!( - y.to_vec1::()?, - [403.4288, 7.3890557, 2980.9578, 1.3498588] + test_utils::to_vec1_round(&y, 3)?, + [403.429, 7.389, 2980.958, 1.35] ); // exp(x)^2 = exp(2*x) assert_eq!( - grad_x.to_vec1::()?, - [806.8576, 14.778111, 5961.9155, 2.6997175] + test_utils::to_vec1_round(grad_x, 2)?, + [806.86, 14.78, 5961.92, 2.7] ); let y = x.sin()?; let grads = y.backward()?; @@ -261,6 +262,7 @@ fn unary_grad(device: &Device) -> Result<()> { let y = elu_x.elu(2.)?; let grads = y.backward()?; let grad_x = grads.get(&elu_x).context("no grad for x")?; + assert_eq!( test_utils::to_vec1_round(&y, 4)?, [-1.2642, 0.0000, -1.7293, 3.0000] diff --git a/candle-core/tests/layout_tests.rs b/candle-core/tests/layout_tests.rs index e0618850..bc67f7de 100644 --- a/candle-core/tests/layout_tests.rs +++ b/candle-core/tests/layout_tests.rs @@ -88,7 +88,7 @@ fn strided_blocks() -> Result<()> { } }; let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; - let tensor = tensor.i((.., 1))?; + let tensor = tensor.i((.., 1))?.contiguous()?; match tensor.strided_blocks() { candle::StridedBlocks::SingleBlock { start_offset, len } => { assert_eq!(start_offset, 0); @@ -100,6 +100,20 @@ fn strided_blocks() -> Result<()> { } }; let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; + let tensor = tensor.i((.., 1))?; + match tensor.strided_blocks() { + candle::StridedBlocks::SingleBlock { .. } => { + panic!("unexpected block structure") + } + candle::StridedBlocks::MultipleBlocks { + block_len, + block_start_index, + } => { + assert_eq!(block_len, 4); + assert_eq!(block_start_index.collect::>(), &[4, 16]) + } + }; + let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; match tensor.t()?.strided_blocks() { candle::StridedBlocks::SingleBlock { .. } => { panic!("unexpected block structure") diff --git a/candle-core/tests/matmul_tests.rs b/candle-core/tests/matmul_tests.rs new file mode 100644 index 00000000..e3e18107 --- /dev/null +++ b/candle-core/tests/matmul_tests.rs @@ -0,0 +1,106 @@ +use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor}; + +fn matmul(device: &Device) -> Result<()> { + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), device)?; + + let c = a.matmul(&b)?; + assert_eq!(c.to_vec2::()?, &[[7.0f32, 10.0], [15.0, 22.0]]); + + let data = vec![1.0f32, 2.0]; + let a = Tensor::from_slice(&data, (2, 1), device)?; + let data = vec![3.0f32, 4.0]; + let b = Tensor::from_slice(&data, (1, 2), device)?; + let c = a.matmul(&b)?; + assert_eq!(c.to_vec2::()?, &[&[3.0, 4.0], &[6.0, 8.0]]); + + let data: Vec<_> = (0..6).map(|i| i as f32).collect(); + let a = Tensor::from_slice(&data, (2, 3), device)?; + let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect(); + let b = Tensor::from_slice(&data, (3, 2), device)?; + let c = a.matmul(&b)?; + assert_eq!(c.to_vec2::()?, &[&[16., 19.], &[52., 64.]]); + + let data: Vec<_> = (0..12).map(|i| i as f32).collect(); + let a = Tensor::from_slice(&data, (2, 2, 3), device)?; + let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect(); + let b = Tensor::from_slice(&data, (2, 3, 2), device)?; + let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]]; + + let c = a.matmul(&b)?; + assert_eq!(c.to_vec3::()?, &expected); + + // Also perform the matmul on contiguous transposed versions. + let a_tt = a.t()?.contiguous()?.t()?; + assert!(!a_tt.is_contiguous()); + assert_eq!(a.dims(), a_tt.dims()); + assert_eq!(a_tt.stride(), &[6, 1, 2]); + + let b_tt = b.t()?.contiguous()?.t()?; + assert!(!b_tt.is_contiguous()); + assert_eq!(b.dims(), b_tt.dims()); + assert_eq!(b_tt.stride(), &[6, 1, 3]); + + assert_eq!(a_tt.matmul(&b)?.to_vec3::()?, &expected); + assert_eq!(a.matmul(&b_tt)?.to_vec3::()?, &expected); + assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::()?, &expected); + Ok(()) +} + +fn broadcast_matmul(device: &Device) -> Result<()> { + let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?; + let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?; + let out = lhs.broadcast_matmul(&rhs)?; + assert_eq!(out.dims(), &[3, 6, 4, 2]); + for idx1 in 0..3 { + for idx2 in 0..6 { + let out = out.i((idx1, idx2))?; + let lhs = lhs.i((idx1, 0))?; + let rhs = rhs.i(idx2)?; + let out2 = lhs.matmul(&rhs); + let sum_diff2 = (out - out2)?.sqr()?.sum_all()?; + // With cuda, we see errors of up to ~1e-12. + assert!(sum_diff2.to_vec0::()? < 1e-6) + } + } + Ok(()) +} + +// https://github.com/huggingface/candle/issues/1948 +fn squeeze_mm(device: &Device) -> Result<()> { + let seq_len = 8_usize; + let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?; + let x = a.i((.., seq_len - 1, ..))?; + let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?; + let x = x.matmul(&w)?; + assert_eq!(x.dims(), &[1, 32]); + Ok(()) +} + +// https://github.com/huggingface/candle/issues/1992 +fn mm_layout(device: &Device) -> Result<()> { + let a = Tensor::arange(0f32, 16f32, device)?.reshape((1, 1, 4, 4))?; + let b = Tensor::arange(0f32, 8f32, device)?.reshape((1, 1, 4, 2))?; + let mm1 = a.matmul(&b)?; + // Forces the layout to be: + // shape: [1, 1, 4, 2], stride: [8, 2, 2, 1], start_offset: 0 + // This is still a contiguous matrix but matmul checks are only the two last dimensions have + // non 1 sizes but matmul check may be reluctant to handle it. + let b = b.transpose(1, 2)?.force_contiguous()?.transpose(1, 2)?; + let mm2 = a.matmul(&b)?; + let diff = (mm1 - mm2)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff, 0.); + Ok(()) +} + +test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal); +test_device!( + broadcast_matmul, + broadcast_matmul_cpu, + broadcast_matmul_gpu, + broadcast_matmul_metal +); +test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal); +test_device!(mm_layout, mm_layout_cpu, mm_layout_gpu, mm_layout_metal); diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs index a3708ec4..1edb7d35 100644 --- a/candle-core/tests/pool_tests.rs +++ b/candle-core/tests/pool_tests.rs @@ -43,6 +43,9 @@ res = torch.nn.functional.avg_pool2d(t, 2) print(res) */ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> { + if dev.is_metal() { + return Ok(()); + } let t = Tensor::new( &[ 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616, diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 40737e7b..78841779 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -106,6 +106,9 @@ fn unary_op(device: &Device) -> Result<()> { [2.6911, -0.0647, -0.1091, 1.7353, 2.7933] ] ); + let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?; + let max_diff = (tensor.gelu()? - t_f16)?.flatten_all()?.max(0)?; + assert!(max_diff.to_vec0::()? < 5e-3); assert_eq!( test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?, [ @@ -148,6 +151,14 @@ fn unary_op(device: &Device) -> Result<()> { test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?, [3000.0, 300.] ); + let tensor = Tensor::new( + &[-1.01f32, -0.9, -0.1, 0.0, -0.0, 0.1, 0.9, 1.0, 1.1], + device, + )?; + assert_eq!( + tensor.sign()?.to_vec1::()?, + [-1., -1., -1., 0., 0., 1., 1., 1., 1.] + ); Ok(()) } @@ -672,6 +683,31 @@ fn cat(device: &Device) -> Result<()> { [2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0] ] ); + + // 3D + let t1 = Tensor::arange(0, 48i64, device)?.reshape((2, 6, 4))?; + let t2 = Tensor::arange(100, 124i64, device)?.reshape((2, 3, 4))?; + let t3 = Tensor::arange(10000, 10032i64, device)?.reshape((2, 4, 4))?; + + let t_cat = Tensor::cat(&[&t1, &t2, &t3], 1)?; + + let t1 = t1.t()?.contiguous()?.t()?; + let t2 = t2.t()?.contiguous()?.t()?; + let t3 = t3.t()?.contiguous()?.t()?; + let t_cat2 = Tensor::cat(&[&t1, &t2, &t3], 1)?; + + let diff = t_cat.eq(&t_cat2)?.to_dtype(DType::F32)?.sum_all()?; + assert_eq!(diff.to_vec0::()?, 104.0); + assert_eq!(t_cat.i((0, 0, 0))?.to_vec0::()?, 0); + assert_eq!(t_cat.i((0, 4, 0))?.to_vec0::()?, 16); + assert_eq!(t_cat.i((0, 5, 0))?.to_vec0::()?, 20); + assert_eq!(t_cat.i((1, 5, 0))?.to_vec0::()?, 44); + assert_eq!(t_cat.i((0, 6, 0))?.to_vec0::()?, 100); + assert_eq!(t_cat.i((1, 6, 0))?.to_vec0::()?, 112); + assert_eq!(t_cat.i((0, 6, 1))?.to_vec0::()?, 101); + assert_eq!(t_cat.i((0, 7, 1))?.to_vec0::()?, 105); + assert_eq!(t_cat.i((0, 12, 1))?.to_vec0::()?, 10013); + assert_eq!(t_cat.i((1, 12, 3))?.to_vec0::()?, 10031); Ok(()) } @@ -682,6 +718,8 @@ fn embeddings(device: &Device) -> Result<()> { assert_eq!(hs.to_vec2::()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); let hs = t.index_select(&ids, 0)?; assert_eq!(hs.to_vec2::()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); + let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?; + assert_eq!(hs.to_vec2::()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); Ok(()) } @@ -709,44 +747,47 @@ fn index_select(device: &Device) -> Result<()> { [9.0, 10.0, 11.0] ] ); - let hs = t.index_select(&ids, 1)?; - assert_eq!( - hs.to_vec2::()?, - &[ - [0.0, 2.0, 1.0], - [3.0, 5.0, 4.0], - [6.0, 8.0, 7.0], - [9.0, 11.0, 10.0] - ] - ); - let hs = t.index_select(&ids, 0)?; - assert_eq!( - hs.to_vec2::()?, - &[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]] - ); - // Prior to https://github.com/huggingface/candle/pull/1022 - // There would be a bug where the last values in the result tensor would be set to 0. - let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?; - let hs = t.index_select(&ids, 0)?; - assert_eq!( - hs.to_vec2::()?, - &[ - [0.0, 1.0, 2.0], - [6.0, 7.0, 8.0], - [3.0, 4.0, 5.0], - [0.0, 1.0, 2.0], - [6.0, 7.0, 8.0], - [3.0, 4.0, 5.0], - ] - ); + for dtype in [DType::U8, DType::U32, DType::I64] { + let ids = ids.to_dtype(dtype)?; + let hs = t.index_select(&ids, 1)?; + assert_eq!( + hs.to_vec2::()?, + &[ + [0.0, 2.0, 1.0], + [3.0, 5.0, 4.0], + [6.0, 8.0, 7.0], + [9.0, 11.0, 10.0] + ] + ); + let hs = t.index_select(&ids, 0)?; + assert_eq!( + hs.to_vec2::()?, + &[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]] + ); + // Prior to https://github.com/huggingface/candle/pull/1022 + // There would be a bug where the last values in the result tensor would be set to 0. + let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?; + let hs = t.index_select(&ids, 0)?; + assert_eq!( + hs.to_vec2::()?, + &[ + [0.0, 1.0, 2.0], + [6.0, 7.0, 8.0], + [3.0, 4.0, 5.0], + [0.0, 1.0, 2.0], + [6.0, 7.0, 8.0], + [3.0, 4.0, 5.0], + ] + ); - // Test when selecting dim > 0 with ids size different from elem count of - // target dim in source/input. - let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?; - let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?; - assert_eq!(t.to_vec2::()?, &[[1.0, 2.0], [3.0, 4.0]]); - let hs = t.index_select(&ids, 1)?; - assert_eq!(hs.to_vec2::()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]); + // Test when selecting dim > 0 with ids size different from elem count of + // target dim in source/input. + let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?; + let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?; + assert_eq!(t.to_vec2::()?, &[[1.0, 2.0], [3.0, 4.0]]); + let hs = t.index_select(&ids, 1)?; + assert_eq!(hs.to_vec2::()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]); + } Ok(()) } @@ -908,74 +949,6 @@ fn gather(device: &Device) -> Result<()> { Ok(()) } -fn matmul(device: &Device) -> Result<()> { - let data = vec![1.0f32, 2.0, 3.0, 4.0]; - let a = Tensor::from_slice(&data, (2, 2), device)?; - let data = vec![1.0f32, 2.0, 3.0, 4.0]; - let b = Tensor::from_slice(&data, (2, 2), device)?; - - let c = a.matmul(&b)?; - assert_eq!(c.to_vec2::()?, &[[7.0f32, 10.0], [15.0, 22.0]]); - - let data = vec![1.0f32, 2.0]; - let a = Tensor::from_slice(&data, (2, 1), device)?; - let data = vec![3.0f32, 4.0]; - let b = Tensor::from_slice(&data, (1, 2), device)?; - let c = a.matmul(&b)?; - assert_eq!(c.to_vec2::()?, &[&[3.0, 4.0], &[6.0, 8.0]]); - - let data: Vec<_> = (0..6).map(|i| i as f32).collect(); - let a = Tensor::from_slice(&data, (2, 3), device)?; - let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect(); - let b = Tensor::from_slice(&data, (3, 2), device)?; - let c = a.matmul(&b)?; - assert_eq!(c.to_vec2::()?, &[&[16., 19.], &[52., 64.]]); - - let data: Vec<_> = (0..12).map(|i| i as f32).collect(); - let a = Tensor::from_slice(&data, (2, 2, 3), device)?; - let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect(); - let b = Tensor::from_slice(&data, (2, 3, 2), device)?; - let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]]; - - let c = a.matmul(&b)?; - assert_eq!(c.to_vec3::()?, &expected); - - // Also perform the matmul on contiguous transposed versions. - let a_tt = a.t()?.contiguous()?.t()?; - assert!(!a_tt.is_contiguous()); - assert_eq!(a.dims(), a_tt.dims()); - assert_eq!(a_tt.stride(), &[6, 1, 2]); - - let b_tt = b.t()?.contiguous()?.t()?; - assert!(!b_tt.is_contiguous()); - assert_eq!(b.dims(), b_tt.dims()); - assert_eq!(b_tt.stride(), &[6, 1, 3]); - - assert_eq!(a_tt.matmul(&b)?.to_vec3::()?, &expected); - assert_eq!(a.matmul(&b_tt)?.to_vec3::()?, &expected); - assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::()?, &expected); - Ok(()) -} - -fn broadcast_matmul(device: &Device) -> Result<()> { - let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?; - let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?; - let out = lhs.broadcast_matmul(&rhs)?; - assert_eq!(out.dims(), &[3, 6, 4, 2]); - for idx1 in 0..3 { - for idx2 in 0..6 { - let out = out.i((idx1, idx2))?; - let lhs = lhs.i((idx1, 0))?; - let rhs = rhs.i(idx2)?; - let out2 = lhs.matmul(&rhs); - let sum_diff2 = (out - out2)?.sqr()?.sum_all()?; - // With cuda, we see errors of up to ~1e-12. - assert!(sum_diff2.to_vec0::()? < 1e-6) - } - } - Ok(()) -} - fn broadcasting(device: &Device) -> Result<()> { let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?; let t2 = Tensor::new(&[100f32, 200f32], device)?; @@ -1080,8 +1053,33 @@ fn broadcasting(device: &Device) -> Result<()> { fn randn(device: &Device) -> Result<()> { let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?; assert_eq!(tensor.dims(), [5, 3]); + // Check that the seed gets updated by checking that + // a new series of numbers is generated each time + let tensor2 = Tensor::randn(0f32, 1f32, (5, 3), device)?; + assert_ne!(tensor.to_vec2::()?, tensor2.to_vec2::()?); let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?; assert_eq!(tensor.dims(), [5, 3]); + // Check that the seed gets updated by checking that + // a new series of numbers is generated each time + let tensor2 = Tensor::rand(0f32, 1f32, (5, 3), device)?; + assert_ne!(tensor.to_vec2::()?, tensor2.to_vec2::()?); + // We do not expect deterministic elements at any index. + // There once was a bug that had a deterministic zero element in evenly sized tensors. + const N: usize = 2; + let v = (0..100) + .map(|_| Tensor::randn(0f32, 1f32, N, device).and_then(|t| t.to_vec1::())) + .collect::>>()?; + assert!( + (0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])), + "There are deterministic values in the randn tensors" + ); + let v = (0..100) + .map(|_| Tensor::rand(0f32, 1f32, N, device).and_then(|t| t.to_vec1::())) + .collect::>>()?; + assert!( + (0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])), + "There are deterministic values in the rand tensors" + ); Ok(()) } @@ -1104,13 +1102,6 @@ test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal); test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal); test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal); test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal); -test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal); -test_device!( - broadcast_matmul, - broadcast_matmul_cpu, - broadcast_matmul_gpu, - broadcast_matmul_metal -); test_device!( broadcasting, broadcasting_cpu, @@ -1267,8 +1258,8 @@ fn pow() -> Result<()> { let rhs = (&lhs - 2.)?; let res = lhs.pow(&rhs)?; assert_eq!( - test_utils::to_vec2_round(&res, 4)?, - [[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]] + test_utils::to_vec2_round(&res, 3)?, + [[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0]] ); Ok(()) } diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index cb704f0c..5b90f140 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -25,8 +25,9 @@ hf-hub = { workspace = true, features = ["tokio"] } image = { workspace = true } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } -pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true } +pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true } rayon = { workspace = true } +rubato = { version = "0.15.0", optional = true } safetensors = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } @@ -41,7 +42,7 @@ clap = { workspace = true } imageproc = { workspace = true } memmap2 = { workspace = true } rand = { workspace = true } -rusttype = { workspace = true } +ab_glyph = { workspace = true } tracing = { workspace = true } tracing-chrome = { workspace = true } tracing-subscriber = { workspace = true } @@ -63,6 +64,7 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"] onnx = ["candle-onnx"] metal = ["candle/metal", "candle-nn/metal"] microphone = ["cpal"] +encodec = ["cpal", "symphonia", "rubato"] [[example]] name = "llama_multiprocess" @@ -98,6 +100,4 @@ required-features = ["candle-datasets"] [[example]] name = "encodec" -required-features = ["symphonia"] - - +required-features = ["encodec"] diff --git a/candle-examples/examples/clip/README.md b/candle-examples/examples/clip/README.md new file mode 100644 index 00000000..f0ee3b2c --- /dev/null +++ b/candle-examples/examples/clip/README.md @@ -0,0 +1,46 @@ +Contrastive Language-Image Pre-Training + +Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +pairs of images with related texts. + +https://github.com/openai/CLIP + +https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip + +## Running on an example on cpu + +``` +$ cargo run --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle" + + +Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg + +INFO clip: Probability: 0.0000% Text: a cycling race +INFO clip: Probability: 0.0000% Text: a photo of two cats +INFO clip: Probability: 100.0000% Text: a robot holding a candle + +Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg + +INFO clip: Probability: 99.9999% Text: a cycling race +INFO clip: Probability: 0.0001% Text: a photo of two cats +INFO clip: Probability: 0.0000% Text: a robot holding a candle +``` + +## Running on an example with metal feature (mac) + +``` +$ cargo run --features metal --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle" + + +Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg + +INFO clip: Probability: 0.0000% Text: a cycling race +INFO clip: Probability: 0.0000% Text: a photo of two cats +INFO clip: Probability: 100.0000% Text: a robot holding a candle + +Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg + +INFO clip: Probability: 99.9999% Text: a cycling race +INFO clip: Probability: 0.0001% Text: a photo of two cats +INFO clip: Probability: 0.0000% Text: a robot holding a candle +``` diff --git a/candle-examples/examples/clip/main.rs b/candle-examples/examples/clip/main.rs new file mode 100644 index 00000000..f301d211 --- /dev/null +++ b/candle-examples/examples/clip/main.rs @@ -0,0 +1,202 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Error as E; +use clap::Parser; + +use candle::{DType, Device, Tensor}; +use candle_nn::{ops::softmax, VarBuilder}; +use candle_transformers::models::clip; + +use tokenizers::Tokenizer; +use tracing::info; + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + tokenizer: Option, + + #[arg(long, use_value_delimiter = true)] + images: Option>, + + #[arg(long)] + cpu: bool, + + #[arg(long, use_value_delimiter = true)] + sequences: Option>, +} + +fn load_image>(path: T, image_size: usize) -> anyhow::Result { + let img = image::io::Reader::open(path)?.decode()?; + let (height, width) = (image_size, image_size); + let img = img.resize_to_fill( + width as u32, + height as u32, + image::imageops::FilterType::Triangle, + ); + + let img = img.to_rgb8(); + + let img = img.into_raw(); + let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? + .permute((2, 0, 1))? + .to_dtype(DType::F32)? + .affine(2. / 255., -1.)?; + // .unsqueeze(0)?; + Ok(img) +} + +fn load_images>( + paths: &Vec, + image_size: usize, +) -> anyhow::Result { + let mut images = vec![]; + + for path in paths { + let tensor = load_image(path, image_size)?; + images.push(tensor); + } + + let images = Tensor::stack(&images, 0)?; + + Ok(images) +} + +pub fn main() -> anyhow::Result<()> { + // std::env::set_var("RUST_BACKTRACE", "full"); + + let args = Args::parse(); + + tracing_subscriber::fmt::init(); + + let model_file = match args.model { + None => { + let api = hf_hub::api::sync::Api::new()?; + + let api = api.repo(hf_hub::Repo::with_revision( + "openai/clip-vit-base-patch32".to_string(), + hf_hub::RepoType::Model, + "refs/pr/15".to_string(), + )); + + api.get("model.safetensors")? + } + Some(model) => model.into(), + }; + + let tokenizer = get_tokenizer(args.tokenizer)?; + + let config = clip::ClipConfig::vit_base_patch32(); + + let device = candle_examples::device(args.cpu)?; + + let vec_imgs = match args.images { + Some(imgs) => imgs, + None => vec![ + "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(), + "candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(), + ], + }; + + // let image = load_image(args.image, config.image_size)?.to_device(&device)?; + let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?; + + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? }; + + let model = clip::ClipModel::new(vb, &config)?; + + let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?; + + let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?; + + let softmax_image = softmax(&logits_per_image, 1)?; + + let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::()?; + + info!("softmax_image_vec: {:?}", softmax_image_vec); + + let probability_vec = softmax_image_vec + .iter() + .map(|v| v * 100.0) + .collect::>(); + + let probability_per_image = probability_vec.len() / vec_imgs.len(); + + for (i, img) in vec_imgs.iter().enumerate() { + let start = i * probability_per_image; + let end = start + probability_per_image; + let prob = &probability_vec[start..end]; + info!("\n\nResults for image: {}\n", img); + + for (i, p) in prob.iter().enumerate() { + info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]); + } + } + + Ok(()) +} + +pub fn get_tokenizer(tokenizer: Option) -> anyhow::Result { + let tokenizer = match tokenizer { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.repo(hf_hub::Repo::with_revision( + "openai/clip-vit-base-patch32".to_string(), + hf_hub::RepoType::Model, + "refs/pr/15".to_string(), + )); + api.get("tokenizer.json")? + } + Some(file) => file.into(), + }; + + Tokenizer::from_file(tokenizer).map_err(E::msg) +} + +pub fn tokenize_sequences( + sequences: Option>, + tokenizer: &Tokenizer, + device: &Device, +) -> anyhow::Result<(Tensor, Vec)> { + let pad_id = *tokenizer + .get_vocab(true) + .get("<|endoftext|>") + .ok_or(E::msg("No pad token"))?; + + let vec_seq = match sequences { + Some(seq) => seq, + None => vec![ + "a cycling race".to_string(), + "a photo of two cats".to_string(), + "a robot holding a candle".to_string(), + ], + }; + + let mut tokens = vec![]; + + for seq in vec_seq.clone() { + let encoding = tokenizer.encode(seq, true).map_err(E::msg)?; + tokens.push(encoding.get_ids().to_vec()); + } + + let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0); + + // Pad the sequences to have the same length + for token_vec in tokens.iter_mut() { + let len_diff = max_len - token_vec.len(); + if len_diff > 0 { + token_vec.extend(vec![pad_id; len_diff]); + } + } + + let input_ids = Tensor::new(tokens, device)?; + + Ok((input_ids, vec_seq)) +} diff --git a/candle-examples/examples/convmixer/main.rs b/candle-examples/examples/convmixer/main.rs index feae536f..d8c2e619 100644 --- a/candle-examples/examples/convmixer/main.rs +++ b/candle-examples/examples/convmixer/main.rs @@ -28,7 +28,7 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image224(args.image)?; + let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); let model_file = match args.model { diff --git a/candle-examples/examples/convnext/main.rs b/candle-examples/examples/convnext/main.rs index 8fc72e16..e5b235fa 100644 --- a/candle-examples/examples/convnext/main.rs +++ b/candle-examples/examples/convnext/main.rs @@ -93,7 +93,7 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image224(args.image)?; + let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); let model_file = match args.model { diff --git a/candle-examples/examples/dinov2/main.rs b/candle-examples/examples/dinov2/main.rs index 6b3edeb4..d718ee6f 100644 --- a/candle-examples/examples/dinov2/main.rs +++ b/candle-examples/examples/dinov2/main.rs @@ -31,7 +31,7 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image224(args.image)?; + let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); let model_file = match args.model { diff --git a/candle-examples/examples/efficientnet/main.rs b/candle-examples/examples/efficientnet/main.rs index 0e4a2864..a8f17cca 100644 --- a/candle-examples/examples/efficientnet/main.rs +++ b/candle-examples/examples/efficientnet/main.rs @@ -47,7 +47,7 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image224(args.image)?; + let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); let model_file = match args.model { diff --git a/candle-examples/examples/efficientvit/main.rs b/candle-examples/examples/efficientvit/main.rs index 1eb80a2d..efbf813c 100644 --- a/candle-examples/examples/efficientvit/main.rs +++ b/candle-examples/examples/efficientvit/main.rs @@ -66,7 +66,7 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image224(args.image)?; + let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); let model_file = match args.model { diff --git a/candle-examples/examples/encodec/README.md b/candle-examples/examples/encodec/README.md index 3028fb80..9de0d4ad 100644 --- a/candle-examples/examples/encodec/README.md +++ b/candle-examples/examples/encodec/README.md @@ -13,8 +13,13 @@ cargo run --example encodec --features symphonia --release -- code-to-audio \ ``` This decodes the EnCodec tokens stored in `jfk-codes.safetensors` and generates -an output wav file containing the audio data. Instead of `code-to-audio` one -can use: +an output wav file containing the audio data. + +Instead of `code-to-audio` one can use: - `audio-to-audio in.mp3 out.wav`: encodes the input audio file then decodes it to a wav file. - `audio-to-code in.mp3 out.safetensors`: generates a safetensors file containing EnCodec tokens for the input audio file. + +If the audio output file name is set to `-`, the audio content directly gets +played on default audio output device. If the audio input file is set to `-`, the audio +gets recorded from the default audio input. diff --git a/candle-examples/examples/encodec/audio_io.rs b/candle-examples/examples/encodec/audio_io.rs new file mode 100644 index 00000000..2103dd4a --- /dev/null +++ b/candle-examples/examples/encodec/audio_io.rs @@ -0,0 +1,275 @@ +#![allow(unused)] +use anyhow::{Context, Result}; +use std::sync::{Arc, Mutex}; + +pub const SAMPLE_RATE: usize = 24_000; + +pub(crate) struct AudioOutputData_ { + resampled_data: std::collections::VecDeque, + resampler: rubato::FastFixedIn, + output_buffer: Vec, + input_buffer: Vec, + input_len: usize, +} + +impl AudioOutputData_ { + pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result { + use rubato::Resampler; + + let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10); + let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64; + let resampler = rubato::FastFixedIn::new( + resample_ratio, + f64::max(resample_ratio, 1.0), + rubato::PolynomialDegree::Septic, + 1024, + 1, + )?; + let input_buffer = resampler.input_buffer_allocate(true).remove(0); + let output_buffer = resampler.output_buffer_allocate(true).remove(0); + Ok(Self { + resampled_data, + resampler, + input_buffer, + output_buffer, + input_len: 0, + }) + } + + pub fn reset(&mut self) { + use rubato::Resampler; + self.output_buffer.fill(0.); + self.input_buffer.fill(0.); + self.resampler.reset(); + self.resampled_data.clear(); + } + + pub(crate) fn take_all(&mut self) -> Vec { + let mut data = Vec::with_capacity(self.resampled_data.len()); + while let Some(elem) = self.resampled_data.pop_back() { + data.push(elem); + } + data + } + + pub(crate) fn is_empty(&self) -> bool { + self.resampled_data.is_empty() + } + + // Assumes that the input buffer is large enough. + fn push_input_buffer(&mut self, samples: &[f32]) { + self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples); + self.input_len += samples.len() + } + + pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> { + use rubato::Resampler; + + let mut pos_in = 0; + loop { + let rem = self.input_buffer.len() - self.input_len; + let pos_end = usize::min(pos_in + rem, samples.len()); + self.push_input_buffer(&samples[pos_in..pos_end]); + pos_in = pos_end; + if self.input_len < self.input_buffer.len() { + break; + } + let (_, out_len) = self.resampler.process_into_buffer( + &[&self.input_buffer], + &mut [&mut self.output_buffer], + None, + )?; + for &elem in self.output_buffer[..out_len].iter() { + self.resampled_data.push_front(elem) + } + self.input_len = 0; + } + Ok(()) + } +} + +type AudioOutputData = Arc>; + +pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> { + use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; + + println!("Setup audio output stream!"); + let host = cpal::default_host(); + let device = host + .default_output_device() + .context("no output device available")?; + let mut supported_configs_range = device.supported_output_configs()?; + let config_range = match supported_configs_range.find(|c| c.channels() == 1) { + // On macOS, it's commonly the case that there are only stereo outputs. + None => device + .supported_output_configs()? + .next() + .context("no audio output available")?, + Some(config_range) => config_range, + }; + let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp( + config_range.min_sample_rate(), + config_range.max_sample_rate(), + ); + let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into(); + let channels = config.channels as usize; + println!( + "cpal device: {} {} {config:?}", + device.name().unwrap_or_else(|_| "unk".to_string()), + config.sample_rate.0 + ); + let audio_data = Arc::new(Mutex::new(AudioOutputData_::new( + SAMPLE_RATE, + config.sample_rate.0 as usize, + )?)); + let ad = audio_data.clone(); + let stream = device.build_output_stream( + &config, + move |data: &mut [f32], _: &cpal::OutputCallbackInfo| { + data.fill(0.); + let mut ad = ad.lock().unwrap(); + let mut last_elem = 0f32; + for (idx, elem) in data.iter_mut().enumerate() { + if idx % channels == 0 { + match ad.resampled_data.pop_back() { + None => break, + Some(v) => { + last_elem = v; + *elem = v + } + } + } else { + *elem = last_elem + } + } + }, + move |err| eprintln!("cpal error: {err}"), + None, // None=blocking, Some(Duration)=timeout + )?; + stream.play()?; + Ok((stream, audio_data)) +} + +pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> { + use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; + + println!("Setup audio input stream!"); + let host = cpal::default_host(); + let device = host + .default_input_device() + .context("no input device available")?; + let mut supported_configs_range = device.supported_input_configs()?; + let config_range = supported_configs_range + .find(|c| c.channels() == 1) + .context("no audio input available")?; + let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp( + config_range.min_sample_rate(), + config_range.max_sample_rate(), + ); + let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into(); + println!( + "cpal device: {} {} {config:?}", + device.name().unwrap_or_else(|_| "unk".to_string()), + config.sample_rate.0 + ); + let audio_data = Arc::new(Mutex::new(AudioOutputData_::new( + config.sample_rate.0 as usize, + SAMPLE_RATE, + )?)); + let ad = audio_data.clone(); + let stream = device.build_input_stream( + &config, + move |data: &[f32], _: &cpal::InputCallbackInfo| { + let mut ad = ad.lock().unwrap(); + if let Err(err) = ad.push_samples(data) { + eprintln!("error processing audio input {err:?}") + } + }, + move |err| eprintln!("cpal error: {err}"), + None, // None=blocking, Some(Duration)=timeout + )?; + stream.play()?; + Ok((stream, audio_data)) +} + +fn conv(samples: &mut Vec, data: std::borrow::Cow>) +where + T: symphonia::core::sample::Sample, + f32: symphonia::core::conv::FromSample, +{ + use symphonia::core::audio::Signal; + use symphonia::core::conv::FromSample; + samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v))) +} + +pub(crate) fn pcm_decode>(path: P) -> Result<(Vec, u32)> { + use symphonia::core::audio::{AudioBufferRef, Signal}; + + let src = std::fs::File::open(path)?; + let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default()); + let hint = symphonia::core::probe::Hint::new(); + let meta_opts: symphonia::core::meta::MetadataOptions = Default::default(); + let fmt_opts: symphonia::core::formats::FormatOptions = Default::default(); + let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?; + let mut format = probed.format; + let track = format + .tracks() + .iter() + .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL) + .expect("no supported audio tracks"); + let mut decoder = symphonia::default::get_codecs() + .make(&track.codec_params, &Default::default()) + .expect("unsupported codec"); + let track_id = track.id; + let sample_rate = track.codec_params.sample_rate.unwrap_or(0); + let mut pcm_data = Vec::new(); + while let Ok(packet) = format.next_packet() { + while !format.metadata().is_latest() { + format.metadata().pop(); + } + if packet.track_id() != track_id { + continue; + } + match decoder.decode(&packet)? { + AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)), + AudioBufferRef::U8(data) => conv(&mut pcm_data, data), + AudioBufferRef::U16(data) => conv(&mut pcm_data, data), + AudioBufferRef::U24(data) => conv(&mut pcm_data, data), + AudioBufferRef::U32(data) => conv(&mut pcm_data, data), + AudioBufferRef::S8(data) => conv(&mut pcm_data, data), + AudioBufferRef::S16(data) => conv(&mut pcm_data, data), + AudioBufferRef::S24(data) => conv(&mut pcm_data, data), + AudioBufferRef::S32(data) => conv(&mut pcm_data, data), + AudioBufferRef::F64(data) => conv(&mut pcm_data, data), + } + } + Ok((pcm_data, sample_rate)) +} + +pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result> { + use rubato::Resampler; + + let mut pcm_out = + Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024); + + let mut resampler = rubato::FftFixedInOut::::new(sr_in, sr_out, 1024, 1)?; + let mut output_buffer = resampler.output_buffer_allocate(true); + let mut pos_in = 0; + while pos_in + resampler.input_frames_next() < pcm_in.len() { + let (in_len, out_len) = + resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?; + pos_in += in_len; + pcm_out.extend_from_slice(&output_buffer[0][..out_len]); + } + + if pos_in < pcm_in.len() { + let (_in_len, out_len) = resampler.process_partial_into_buffer( + Some(&[&pcm_in[pos_in..]]), + &mut output_buffer, + None, + )?; + pcm_out.extend_from_slice(&output_buffer[0][..out_len]); + } + + Ok(pcm_out) +} diff --git a/candle-examples/examples/encodec/main.rs b/candle-examples/examples/encodec/main.rs index f1c4a0ee..e77f98e7 100644 --- a/candle-examples/examples/encodec/main.rs +++ b/candle-examples/examples/encodec/main.rs @@ -11,59 +11,7 @@ use candle_transformers::models::encodec::{Config, Model}; use clap::{Parser, ValueEnum}; use hf_hub::api::sync::Api; -fn conv(samples: &mut Vec, data: std::borrow::Cow>) -where - T: symphonia::core::sample::Sample, - f32: symphonia::core::conv::FromSample, -{ - use symphonia::core::audio::Signal; - use symphonia::core::conv::FromSample; - samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v))) -} - -fn pcm_decode>(path: P) -> anyhow::Result<(Vec, u32)> { - use symphonia::core::audio::{AudioBufferRef, Signal}; - - let src = std::fs::File::open(path)?; - let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default()); - let hint = symphonia::core::probe::Hint::new(); - let meta_opts: symphonia::core::meta::MetadataOptions = Default::default(); - let fmt_opts: symphonia::core::formats::FormatOptions = Default::default(); - let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?; - let mut format = probed.format; - let track = format - .tracks() - .iter() - .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL) - .expect("no supported audio tracks"); - let mut decoder = symphonia::default::get_codecs() - .make(&track.codec_params, &Default::default()) - .expect("unsupported codec"); - let track_id = track.id; - let sample_rate = track.codec_params.sample_rate.unwrap_or(0); - let mut pcm_data = Vec::new(); - while let Ok(packet) = format.next_packet() { - while !format.metadata().is_latest() { - format.metadata().pop(); - } - if packet.track_id() != track_id { - continue; - } - match decoder.decode(&packet)? { - AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)), - AudioBufferRef::U8(data) => conv(&mut pcm_data, data), - AudioBufferRef::U16(data) => conv(&mut pcm_data, data), - AudioBufferRef::U24(data) => conv(&mut pcm_data, data), - AudioBufferRef::U32(data) => conv(&mut pcm_data, data), - AudioBufferRef::S8(data) => conv(&mut pcm_data, data), - AudioBufferRef::S16(data) => conv(&mut pcm_data, data), - AudioBufferRef::S24(data) => conv(&mut pcm_data, data), - AudioBufferRef::S32(data) => conv(&mut pcm_data, data), - AudioBufferRef::F64(data) => conv(&mut pcm_data, data), - } - } - Ok((pcm_data, sample_rate)) -} +mod audio_io; #[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] enum Action { @@ -109,14 +57,36 @@ fn main() -> Result<()> { let codes = match args.action { Action::CodeToAudio => { let codes = candle::safetensors::load(args.in_file, &device)?; - let codes = codes.get("codes").expect("no codes in input file").i(0)?; - codes + codes.get("codes").expect("no codes in input file").clone() } Action::AudioToCode | Action::AudioToAudio => { - let (pcm, sample_rate) = pcm_decode(args.in_file)?; - if sample_rate != 24_000 { - println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}") - } + let pcm = if args.in_file == "-" { + println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<"); + let (stream, input_audio) = audio_io::setup_input_stream()?; + let mut pcms = vec![]; + let stdin = std::thread::spawn(|| { + let mut s = String::new(); + std::io::stdin().read_line(&mut s) + }); + while !stdin.is_finished() { + let input = input_audio.lock().unwrap().take_all(); + if input.is_empty() { + std::thread::sleep(std::time::Duration::from_millis(100)); + continue; + } + pcms.push(input) + } + drop(stream); + pcms.concat() + } else { + let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?; + if sample_rate != 24_000 { + println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}, resampling..."); + audio_io::resample(&pcm, sample_rate as usize, 24_000)? + } else { + pcm + } + }; let pcm_len = pcm.len(); let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?; println!("input pcm shape: {:?}", pcm.shape()); @@ -135,8 +105,26 @@ fn main() -> Result<()> { let pcm = pcm.i(0)?.i(0)?; let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?; let pcm = pcm.to_vec1::()?; - let mut output = std::fs::File::create(&args.out_file)?; - candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?; + if args.out_file == "-" { + let (stream, ad) = audio_io::setup_output_stream()?; + { + let mut ad = ad.lock().unwrap(); + ad.push_samples(&pcm)?; + } + loop { + let ad = ad.lock().unwrap(); + if ad.is_empty() { + break; + } + // That's very weird, calling thread::sleep here triggers the stream to stop + // playing (the callback doesn't seem to be called anymore). + // std::thread::sleep(std::time::Duration::from_millis(100)); + } + drop(stream) + } else { + let mut output = std::fs::File::create(&args.out_file)?; + candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?; + } } } Ok(()) diff --git a/candle-examples/examples/gemma/README.md b/candle-examples/examples/gemma/README.md index 8319cf44..5d77c7a4 100644 --- a/candle-examples/examples/gemma/README.md +++ b/candle-examples/examples/gemma/README.md @@ -1,4 +1,4 @@ -# candle-mistral: 2b and 7b LLMs from Google DeepMind +# candle-gemma: 2b and 7b LLMs from Google DeepMind [Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open models published by Google Deepmind with a 2b and a 7b variant. diff --git a/candle-examples/examples/gemma/main.rs b/candle-examples/examples/gemma/main.rs index e1df8790..a5f7d591 100644 --- a/candle-examples/examples/gemma/main.rs +++ b/candle-examples/examples/gemma/main.rs @@ -16,6 +16,30 @@ use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum Which { + #[value(name = "2b")] + Base2B, + #[value(name = "7b")] + Base7B, + #[value(name = "2b-it")] + Instruct2B, + #[value(name = "7b-it")] + Instruct7B, + #[value(name = "1.1-2b-it")] + InstructV1_1_2B, + #[value(name = "1.1-7b-it")] + InstructV1_1_7B, + #[value(name = "code-2b")] + CodeBase2B, + #[value(name = "code-7b")] + CodeBase7B, + #[value(name = "code-2b-it")] + CodeInstruct2B, + #[value(name = "code-7b-it")] + CodeInstruct7B, +} + struct TextGeneration { model: Model, device: Device, @@ -165,6 +189,10 @@ struct Args { /// The context size to consider for the repeat penalty. #[arg(long, default_value_t = 64)] repeat_last_n: usize, + + /// The model to use. + #[arg(long, default_value = "2b")] + which: Which, } fn main() -> Result<()> { @@ -196,14 +224,19 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let api = Api::new()?; let model_id = match &args.model_id { - Some(model_id) => match model_id.as_str() { - "7b-it" => "google/gemma-7b-it".to_string(), - "7b" => "google/gemma-7b".to_string(), - "2b-it" => "google/gemma-2b-it".to_string(), - "2b" => "google/gemma-2b".to_string(), - _ => model_id.to_string(), + Some(model_id) => model_id.to_string(), + None => match args.which { + Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(), + Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(), + Which::Base2B => "google/gemma-2b".to_string(), + Which::Base7B => "google/gemma-7b".to_string(), + Which::Instruct2B => "google/gemma-2b-it".to_string(), + Which::Instruct7B => "google/gemma-7b-it".to_string(), + Which::CodeBase2B => "google/codegemma-2b".to_string(), + Which::CodeBase7B => "google/codegemma-7b".to_string(), + Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(), + Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(), }, - None => "google/gemma-2b".to_string(), }; let repo = api.repo(Repo::with_revision( model_id, diff --git a/candle-examples/examples/mamba/main.rs b/candle-examples/examples/mamba/main.rs index 4802f960..b8c8bb70 100644 --- a/candle-examples/examples/mamba/main.rs +++ b/candle-examples/examples/mamba/main.rs @@ -54,6 +54,7 @@ impl TextGeneration { fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { use std::io::Write; self.tokenizer.clear(); + let dtype = self.model.dtype(); let mut tokens = self .tokenizer .tokenizer() @@ -66,7 +67,7 @@ impl TextGeneration { Some(token) => token, None => anyhow::bail!("cannot find the token"), }; - let mut state = State::new(1, &self.config, &self.device)?; + let mut state = State::new(1, &self.config, dtype, &self.device)?; let mut next_logits = None; for &t in tokens.iter() { let input = Tensor::new(&[t], &self.device)?; @@ -84,7 +85,7 @@ impl TextGeneration { Some(logits) => logits, None => anyhow::bail!("cannot work on an empty prompt"), }; - let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + let logits = logits.squeeze(0)?.to_dtype(dtype)?; let logits = if self.repeat_penalty == 1. { logits } else { @@ -210,6 +211,9 @@ struct Args { #[arg(long)] config_file: Option, + #[arg(long, default_value = "f32")] + dtype: String, + /// Penalty to be applied for repeating tokens, 1. means no penalty. #[arg(long, default_value_t = 1.1)] repeat_penalty: f32, @@ -220,6 +224,7 @@ struct Args { } fn main() -> Result<()> { + use std::str::FromStr; use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; @@ -279,7 +284,8 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; let device = candle_examples::device(args.cpu)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; + let dtype = DType::from_str(&args.dtype)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let model = Model::new(&config, vb.pp("backbone"))?; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-examples/examples/metavoice/main.rs b/candle-examples/examples/metavoice/main.rs index ae571929..7a7ec3e4 100644 --- a/candle-examples/examples/metavoice/main.rs +++ b/candle-examples/examples/metavoice/main.rs @@ -11,6 +11,7 @@ use std::io::Write; use candle_transformers::generation::LogitsProcessor; use candle_transformers::models::encodec; use candle_transformers::models::metavoice::{adapters, gpt, tokenizers, transformer}; +use candle_transformers::models::quantized_metavoice::transformer as qtransformer; use candle::{DType, IndexOp, Tensor}; use candle_nn::VarBuilder; @@ -26,6 +27,11 @@ enum ArgDType { Bf16, } +enum Transformer { + Normal(transformer::Model), + Quantized(qtransformer::Model), +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -40,6 +46,10 @@ struct Args { #[arg(long)] prompt: String, + /// Use the quantized version of the model. + #[arg(long)] + quantized: bool, + /// The guidance scale. #[arg(long, default_value_t = 3.0)] guidance_scale: f64, @@ -116,11 +126,7 @@ fn main() -> Result<()> { }; let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?; - let first_stage_weights = match &args.first_stage_weights { - Some(w) => std::path::PathBuf::from(w), - None => repo.get("first_stage.safetensors")?, - }; - let second_stage_weights = match &args.first_stage_weights { + let second_stage_weights = match &args.second_stage_weights { Some(w) => std::path::PathBuf::from(w), None => repo.get("second_stage.safetensors")?, }; @@ -135,10 +141,27 @@ fn main() -> Result<()> { ArgDType::F16 => DType::F16, ArgDType::Bf16 => DType::BF16, }; - let first_stage_vb = - unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? }; + let first_stage_config = transformer::Config::cfg1b_v0_1(); - let mut first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?; + let mut first_stage_model = if args.quantized { + let filename = match &args.first_stage_weights { + Some(w) => std::path::PathBuf::from(w), + None => repo.get("first_stage_q4k.gguf")?, + }; + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; + let first_stage_model = qtransformer::Model::new(&first_stage_config, vb)?; + Transformer::Quantized(first_stage_model) + } else { + let first_stage_weights = match &args.first_stage_weights { + Some(w) => std::path::PathBuf::from(w), + None => repo.get("first_stage.safetensors")?, + }; + let first_stage_vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? }; + let first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?; + Transformer::Normal(first_stage_model) + }; let second_stage_vb = unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? }; @@ -178,7 +201,12 @@ fn main() -> Result<()> { let ctxt = &tokens[start_pos..]; let input = Tensor::new(ctxt, &device)?; let input = Tensor::stack(&[&input, &input], 0)?; - let logits = first_stage_model.forward(&input, &spk_emb, tokens.len() - context_size)?; + let logits = match &mut first_stage_model { + Transformer::Normal(m) => m.forward(&input, &spk_emb, tokens.len() - context_size)?, + Transformer::Quantized(m) => { + m.forward(&input, &spk_emb, tokens.len() - context_size)? + } + }; let logits0 = logits.i((0, 0))?; let logits1 = logits.i((1, 0))?; let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?; diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index 1cf4107c..6aa3f51e 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -13,7 +13,7 @@ use candle_transformers::models::quantized_mistral::Model as QMistral; use candle::{DType, Device, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; use candle_nn::VarBuilder; -use candle_transformers::generation::LogitsProcessor; +use candle_transformers::generation::{LogitsProcessor, Sampling}; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; @@ -39,11 +39,26 @@ impl TextGeneration { seed: u64, temp: Option, top_p: Option, + top_k: Option, repeat_penalty: f32, repeat_last_n: usize, device: &Device, ) -> Self { - let logits_processor = LogitsProcessor::new(seed, temp, top_p); + let logits_processor = { + let temperature = temp.unwrap_or(0.); + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (top_k, top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(seed, sampling) + }; + Self { model, tokenizer: TokenOutputStream::new(tokenizer), @@ -122,6 +137,18 @@ impl TextGeneration { } } +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum Which { + #[value(name = "7b-v0.1")] + Mistral7bV01, + #[value(name = "7b-v0.2")] + Mistral7bV02, + #[value(name = "7b-instruct-v0.1")] + Mistral7bInstructV01, + #[value(name = "7b-instruct-v0.2")] + Mistral7bInstructV02, +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -147,6 +174,10 @@ struct Args { #[arg(long)] top_p: Option, + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -155,6 +186,10 @@ struct Args { #[arg(long, short = 'n', default_value_t = 10000)] sample_len: usize, + /// The model size to use. + #[arg(long, default_value = "7b-v0.1")] + which: Which, + #[arg(long)] model_id: Option, @@ -164,6 +199,9 @@ struct Args { #[arg(long)] tokenizer_file: Option, + #[arg(long)] + config_file: Option, + #[arg(long)] weight_files: Option, @@ -177,6 +215,10 @@ struct Args { /// The context size to consider for the repeat penalty. #[arg(long, default_value_t = 64)] repeat_last_n: usize, + + /// Use the slower dmmv cuda kernel. + #[arg(long)] + force_dmmv: bool, } fn main() -> Result<()> { @@ -184,6 +226,9 @@ fn main() -> Result<()> { use tracing_subscriber::prelude::*; let args = Args::parse(); + #[cfg(feature = "cuda")] + candle::quantized::cuda::set_force_dmmv(args.force_dmmv); + let _guard = if args.tracing { let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); @@ -211,9 +256,17 @@ fn main() -> Result<()> { Some(model_id) => model_id, None => { if args.quantized { + if args.which != Which::Mistral7bV01 { + anyhow::bail!("only 7b-v0.1 is available as a quantized model for now") + } "lmz/candle-mistral".to_string() } else { - "mistralai/Mistral-7B-v0.1".to_string() + match args.which { + Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1".to_string(), + Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2".to_string(), + Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1".to_string(), + Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2".to_string(), + } } } }; @@ -243,7 +296,17 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); - let config = Config::config_7b_v0_1(args.use_flash_attn); + let config = match args.config_file { + Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?, + None => { + if args.quantized { + Config::config_7b_v0_1(args.use_flash_attn) + } else { + let config_file = repo.get("config.json")?; + serde_json::from_slice(&std::fs::read(config_file)?)? + } + } + }; let device = candle_examples::device(args.cpu)?; let (model, device) = if args.quantized { let filename = &filenames[0]; @@ -270,6 +333,7 @@ fn main() -> Result<()> { args.seed, args.temperature, args.top_p, + args.top_k, args.repeat_penalty, args.repeat_last_n, &device, diff --git a/candle-examples/examples/mobileone/main.rs b/candle-examples/examples/mobileone/main.rs index 4cd55001..76533fe3 100644 --- a/candle-examples/examples/mobileone/main.rs +++ b/candle-examples/examples/mobileone/main.rs @@ -63,7 +63,7 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image224(args.image)?; + let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); let model_file = match args.model { diff --git a/candle-examples/examples/moondream/README.md b/candle-examples/examples/moondream/README.md new file mode 100644 index 00000000..e202de7c --- /dev/null +++ b/candle-examples/examples/moondream/README.md @@ -0,0 +1,26 @@ +# candle-moondream + +[Moondream](https://github.com/vikhyat/moondream) is a computer-vision model can answer real-world questions about images. It's tiny by today's models, with only 1.6B parameters. That enables it to run on a variety of devices, including mobile phones and edge devices. + +## Running some examples +First download an example image +```bash +$ wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg +``` + + + +Now you can run Moondream from the `candle-examples` crate: +```bash +$ cargo run --example moondream --release -- --prompt "What is the girl eating?" --image "./demo-1.jpg" + +avavx: false, neon: true, simd128: false, f16c: false +temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64 +retrieved the files in 3.395583ms +Running on CPU, to run on GPU(metal), build this example with `--features metal` +loaded the model in 5.485493792s +loaded and encoded the image Tensor[dims 3, 378, 378; f32] in 4.801396417s +starting the inference loop + The girl is eating a hamburger.< +9 tokens generated (0.68 token/s) +``` \ No newline at end of file diff --git a/candle-examples/examples/moondream/main.rs b/candle-examples/examples/moondream/main.rs new file mode 100644 index 00000000..646ef258 --- /dev/null +++ b/candle-examples/examples/moondream/main.rs @@ -0,0 +1,343 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::Parser; + +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::{ + generation::LogitsProcessor, + models::{moondream, quantized_moondream}, +}; +use tokenizers::Tokenizer; + +enum Model { + Moondream(moondream::Model), + Quantized(quantized_moondream::Model), +} + +struct TextGeneration { + model: Model, + device: Device, + tokenizer: Tokenizer, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, + verbose_prompt: bool, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + tokenizer: Tokenizer, + seed: u64, + temp: Option, + top_p: Option, + repeat_penalty: f32, + repeat_last_n: usize, + verbose_prompt: bool, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + tokenizer, + logits_processor, + repeat_penalty, + repeat_last_n, + verbose_prompt, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, image_embeds: &Tensor, sample_len: usize) -> Result<()> { + use std::io::Write; + println!("starting the inference loop"); + let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?; + if tokens.is_empty() { + anyhow::bail!("Empty prompts are not supported in the Moondream model.") + } + if self.verbose_prompt { + for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) { + let token = token.replace('▁', " ").replace("<0x0A>", "\n"); + println!("{id:7} -> '{token}'"); + } + } + + let mut tokens = tokens.get_ids().to_vec(); + let mut generated_tokens = 0usize; + + // Moondream tokenizer bos_token and eos_token is "<|endoftext|>" + // https://huggingface.co/vikhyatk/moondream2/blob/main/special_tokens_map.json + let special_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { + Some(token) => *token, + None => anyhow::bail!("cannot find the special token"), + }; + let (bos_token, eos_token) = (special_token, special_token); + + let start_gen = std::time::Instant::now(); + let mut load_t = std::time::Duration::from_secs_f64(0f64); + for index in 0..sample_len { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = if index > 0 { + match self.model { + Model::Moondream(ref mut model) => model.text_model.forward(&input)?, + Model::Quantized(ref mut model) => model.text_model.forward(&input)?, + } + } else { + let bos_token = Tensor::new(&[bos_token], &self.device)?.unsqueeze(0)?; + let logits = match self.model { + Model::Moondream(ref mut model) => { + model + .text_model + .forward_with_img(&bos_token, &input, image_embeds)? + } + Model::Quantized(ref mut model) => { + model + .text_model + .forward_with_img(&bos_token, &input, image_embeds)? + } + }; + load_t = start_gen.elapsed(); + println!("load_t: {:?}", load_t); + logits + }; + let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token || tokens.ends_with(&[27, 10619, 29] /* */) { + break; + } + let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?; + print!("{token}"); + std::io::stdout().flush()?; + } + + let dt = start_gen.elapsed() - load_t; + println!( + "\ngenerated in {} seconds\n{generated_tokens} tokens generated ({:.2} token/s)", + dt.as_secs_f64(), + (generated_tokens - 1) as f64 / dt.as_secs_f64() + ); + + Ok(()) + } +} + +#[derive(Parser)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Display the token for the specified prompt. + #[arg(long)] + verbose_prompt: bool, + + #[arg(long)] + prompt: String, + + #[arg(long)] + image: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 0)] + seed: u64, + + #[arg(long, default_value_t = 5000)] + sample_len: usize, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.0)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + #[arg(long)] + model_id: Option, + + #[arg(long, default_value = "main")] + revision: String, + + #[arg(long)] + quantized: bool, + + /// Use f16 precision for all the computations rather than f32. + #[arg(long)] + f16: bool, + + #[arg(long)] + model_file: Option, + + #[arg(long)] + tokenizer_file: Option, +} + +/// Loads an image from disk using the image crate, this returns a tensor with shape +/// (3, 378, 378). +pub fn load_image>(p: P) -> candle::Result { + let img = image::io::Reader::open(p)? + .decode() + .map_err(candle::Error::wrap)? + .resize_to_fill(378, 378, image::imageops::FilterType::Triangle); // Adjusted to 378x378 + let img = img.to_rgb8(); + let data = img.into_raw(); + let data = Tensor::from_vec(data, (378, 378, 3), &Device::Cpu)?.permute((2, 0, 1))?; + let mean = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?; + let std = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?; + (data.to_dtype(candle::DType::F32)? / 255.)? + .broadcast_sub(&mean)? + .broadcast_div(&std) +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.), + args.repeat_penalty, + args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let api = hf_hub::api::tokio::Api::new()?; + let model_id = match args.model_id { + Some(model_id) => model_id.to_string(), + None => { + if args.quantized { + "santiagomed/candle-moondream".to_string() + } else { + "vikhyatk/moondream2".to_string() + } + } + }; + let repo = api.repo(hf_hub::Repo::with_revision( + model_id, + hf_hub::RepoType::Model, + args.revision, + )); + let model_file = match args.model_file { + Some(m) => m.into(), + None => { + if args.quantized { + repo.get("model-q4_0.gguf").await? + } else { + repo.get("model.safetensors").await? + } + } + }; + let tokenizer = match args.tokenizer_file { + Some(m) => m.into(), + None => repo.get("tokenizer.json").await?, + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; + + let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; + let config = moondream::Config::v2(); + let dtype = if args.quantized { + if args.f16 { + anyhow::bail!("Quantized model does not support f16"); + } + DType::F32 + } else if device.is_cuda() || args.f16 { + DType::F16 + } else { + DType::F32 + }; + let model = if args.quantized { + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + &model_file, + &device, + )?; + let model = quantized_moondream::Model::new(&config, vb)?; + Model::Quantized(model) + } else { + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; + let model = moondream::Model::new(&config, vb)?; + Model::Moondream(model) + }; + println!("loaded the model in {:?}", start.elapsed()); + + let start = std::time::Instant::now(); + let image = load_image(args.image)? + .to_device(&device)? + .to_dtype(dtype)?; + let image_embeds = image.unsqueeze(0)?; + let image_embeds = match model { + Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?, + Model::Quantized(ref m) => image_embeds.apply(m.vision_encoder())?, + }; + println!( + "loaded and encoded the image {image:?} in {:?}", + start.elapsed() + ); + + let prompt = format!("\n\nQuestion: {0}\n\nAnswer:", args.prompt); + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + args.verbose_prompt, + &device, + ); + pipeline.run(&prompt, &image_embeds, args.sample_len)?; + + Ok(()) +} diff --git a/candle-examples/examples/quantized-t5/README.md b/candle-examples/examples/quantized-t5/README.md index 8b8179eb..c86e746d 100644 --- a/candle-examples/examples/quantized-t5/README.md +++ b/candle-examples/examples/quantized-t5/README.md @@ -17,7 +17,7 @@ generate quantized weight files from the original safetensors file by using the `tensor-tools` command line utility via: ```bash -$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf +$ cargo run --bin tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf ``` ## Using custom models diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 96344a49..ea7f70eb 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -10,7 +10,7 @@ use tokenizers::Tokenizer; use candle::quantized::{ggml_file, gguf_file}; use candle::Tensor; -use candle_transformers::generation::LogitsProcessor; +use candle_transformers::generation::{LogitsProcessor, Sampling}; use candle_examples::token_output_stream::TokenOutputStream; use candle_transformers::models::quantized_llama as model; @@ -200,6 +200,10 @@ struct Args { #[arg(long)] top_p: Option, + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -235,6 +239,10 @@ struct Args { /// Group-Query Attention, use 8 for the 70B version of LLaMAv2. #[arg(long)] gqa: Option, + + /// Use the slower dmmv cuda kernel. + #[arg(long)] + force_dmmv: bool, } impl Args { @@ -341,11 +349,10 @@ fn main() -> anyhow::Result<()> { use tracing_subscriber::prelude::*; let args = Args::parse(); - let temperature = if args.temperature == 0. { - None - } else { - Some(args.temperature) - }; + + #[cfg(feature = "cuda")] + candle::quantized::cuda::set_force_dmmv(args.force_dmmv); + let _guard = if args.tracing { let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); @@ -492,7 +499,20 @@ fn main() -> anyhow::Result<()> { prompt_tokens }; let mut all_tokens = vec![]; - let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p); + let mut logits_processor = { + let temperature = args.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; let start_prompt_processing = std::time::Instant::now(); let mut next_token = if !args.split_prompt { diff --git a/candle-examples/examples/qwen/README.md b/candle-examples/examples/qwen/README.md new file mode 100644 index 00000000..cb785f21 --- /dev/null +++ b/candle-examples/examples/qwen/README.md @@ -0,0 +1,27 @@ +# candle-qwen: large language model series from Alibaba Cloud + +Qwen 1.5 is a series of large language models that provide strong performances +on English and Chinese. + +- [Blog post](https://qwenlm.github.io/blog/qwen1.5/) introducing Qwen1.5. +- [Model card](https://huggingface.co/Qwen/Qwen1.5-0.5B) on the HuggingFace Hub. +- [Blog post](https://qwenlm.github.io/blog/qwen-moe/) for the + mixture-of-experts (MoE) variant. + +## Running the example + +```bash +$ cargo run --example qwen --release -- --prompt "Hello there " +``` + +Various model sizes are available via the `--model` argument, including the MoE +variant. + +```bash +$ cargo run --example qwen --release -- --model moe-a2.7b --prompt 'def print_prime(n: int): ' +def print_prime(n: int): # n is the number of primes to be printed + for i in range(2, n + 1): + if all(i % j != 0 for j in range(2, i)): + print(i) +``` + diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs index d040d4b0..a203ad8e 100644 --- a/candle-examples/examples/qwen/main.rs +++ b/candle-examples/examples/qwen/main.rs @@ -7,7 +7,8 @@ extern crate accelerate_src; use anyhow::{Error as E, Result}; use clap::Parser; -use candle_transformers::models::qwen2::{Config, Model}; +use candle_transformers::models::qwen2::{Config as ConfigBase, Model as ModelBase}; +use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe}; use candle::{DType, Device, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; @@ -16,6 +17,20 @@ use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; +enum Model { + Base(ModelBase), + Moe(ModelMoe), +} + +impl Model { + fn forward(&mut self, xs: &Tensor, s: usize) -> candle::Result { + match self { + Self::Moe(ref mut m) => m.forward(xs, s), + Self::Base(ref mut m) => m.forward(xs, s), + } + } +} + struct TextGeneration { model: Model, device: Device, @@ -127,6 +142,8 @@ enum WhichModel { W14b, #[value(name = "72b")] W72b, + #[value(name = "moe-a2.7b")] + MoeA27b, } #[derive(Parser, Debug)] @@ -224,6 +241,7 @@ fn main() -> Result<()> { WhichModel::W7b => "7B", WhichModel::W14b => "14B", WhichModel::W72b => "72B", + WhichModel::MoeA27b => "MoE-A2.7B", }; format!("Qwen/Qwen1.5-{size}") } @@ -244,7 +262,11 @@ fn main() -> Result<()> { .collect::>(), None => match args.model { WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?], - WhichModel::W4b | WhichModel::W7b | WhichModel::W14b | WhichModel::W72b => { + WhichModel::W4b + | WhichModel::W7b + | WhichModel::W14b + | WhichModel::W72b + | WhichModel::MoeA27b => { candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? } }, @@ -254,7 +276,6 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let config_file = repo.get("config.json")?; - let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?; let device = candle_examples::device(args.cpu)?; let dtype = if device.is_cuda() { DType::BF16 @@ -262,7 +283,16 @@ fn main() -> Result<()> { DType::F32 }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; - let model = Model::new(&config, vb)?; + let model = match args.model { + WhichModel::MoeA27b => { + let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?; + Model::Moe(ModelMoe::new(&config, vb)?) + } + _ => { + let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?; + Model::Base(ModelBase::new(&config, vb)?) + } + }; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-examples/examples/reinforcement-learning/dqn.rs b/candle-examples/examples/reinforcement-learning/dqn.rs new file mode 100644 index 00000000..83457810 --- /dev/null +++ b/candle-examples/examples/reinforcement-learning/dqn.rs @@ -0,0 +1,118 @@ +use std::collections::VecDeque; + +use rand::distributions::Uniform; +use rand::{thread_rng, Rng}; + +use candle::{DType, Device, Module, Result, Tensor}; +use candle_nn::loss::mse; +use candle_nn::{linear, seq, Activation, AdamW, Optimizer, VarBuilder, VarMap}; + +use crate::gym_env::GymEnv; + +const DEVICE: Device = Device::Cpu; +const EPISODES: usize = 200; +const BATCH_SIZE: usize = 64; +const GAMMA: f64 = 0.99; +const LEARNING_RATE: f64 = 0.01; + +pub fn run() -> Result<()> { + let env = GymEnv::new("CartPole-v1")?; + + // Build the model that predicts the estimated rewards given a specific state. + let var_map = VarMap::new(); + let vb = VarBuilder::from_varmap(&var_map, DType::F32, &DEVICE); + let observation_space = *env.observation_space().first().unwrap(); + + let model = seq() + .add(linear(observation_space, 64, vb.pp("linear_in"))?) + .add(Activation::Relu) + .add(linear(64, env.action_space(), vb.pp("linear_out"))?); + + let mut optimizer = AdamW::new_lr(var_map.all_vars(), LEARNING_RATE)?; + + // Initialize the model's memory. + let mut memory = VecDeque::with_capacity(10000); + + // Start the training loop. + let mut state = env.reset(0)?; + let mut episode = 0; + let mut accumulate_rewards = 0.0; + while episode < EPISODES { + // Given the current state, predict the estimated rewards, and take the + // action that is expected to return the most rewards. + let estimated_rewards = model.forward(&state.unsqueeze(0)?)?; + let action: u32 = estimated_rewards.squeeze(0)?.argmax(0)?.to_scalar()?; + + // Take that action in the environment, and memorize the outcome: + // - the state for which the action was taken + // - the action taken + // - the new state resulting of taking that action + // - the actual rewards of taking that action + // - whether the environment reached a terminal state or not (e.g. game over) + let step = env.step(action)?; + accumulate_rewards += step.reward; + memory.push_back(( + state, + action, + step.state.clone(), + step.reward, + step.terminated || step.truncated, + )); + state = step.state; + + // If there's enough entries in the memory, perform a learning step, where + // BATCH_SIZE transitions will be sampled from the memory and will be + // fed to the model so that it performs a backward pass. + if memory.len() > BATCH_SIZE { + // Sample randomly from the memory. + let batch = thread_rng() + .sample_iter(Uniform::from(0..memory.len())) + .take(BATCH_SIZE) + .map(|i| memory.get(i).unwrap().clone()) + .collect::>(); + + // Group all the samples together into tensors with the appropriate shape. + let states: Vec<_> = batch.iter().map(|e| e.0.clone()).collect(); + let states = Tensor::stack(&states, 0)?; + + let actions = batch.iter().map(|e| e.1); + let actions = Tensor::from_iter(actions, &DEVICE)?.unsqueeze(1)?; + + let next_states: Vec<_> = batch.iter().map(|e| e.2.clone()).collect(); + let next_states = Tensor::stack(&next_states, 0)?; + + let rewards = batch.iter().map(|e| e.3 as f32); + let rewards = Tensor::from_iter(rewards, &DEVICE)?.unsqueeze(1)?; + + let non_final_mask = batch.iter().map(|e| !e.4 as u8 as f32); + let non_final_mask = Tensor::from_iter(non_final_mask, &DEVICE)?.unsqueeze(1)?; + + // Get the estimated rewards for the actions that where taken at each step. + let estimated_rewards = model.forward(&states)?; + let x = estimated_rewards.gather(&actions, 1)?; + + // Get the maximum expected rewards for the next state, apply them a discount rate + // GAMMA and add them to the rewards that were actually gathered on the current state. + // If the next state is a terminal state, just omit maximum estimated + // rewards for that state. + let expected_rewards = model.forward(&next_states)?.detach(); + let y = expected_rewards.max_keepdim(1)?; + let y = (y * GAMMA * non_final_mask + rewards)?; + + // Compare the estimated rewards with the maximum expected rewards and + // perform the backward step. + let loss = mse(&x, &y)?; + optimizer.backward_step(&loss)?; + } + + // If we are on a terminal state, reset the environment and log how it went. + if step.terminated || step.truncated { + episode += 1; + println!("Episode {episode} | Rewards {}", accumulate_rewards as i64); + state = env.reset(0)?; + accumulate_rewards = 0.0; + } + } + + Ok(()) +} diff --git a/candle-examples/examples/reinforcement-learning/gym_env.rs b/candle-examples/examples/reinforcement-learning/gym_env.rs index 8868c188..a2b6652f 100644 --- a/candle-examples/examples/reinforcement-learning/gym_env.rs +++ b/candle-examples/examples/reinforcement-learning/gym_env.rs @@ -42,7 +42,7 @@ impl GymEnv { /// Creates a new session of the specified OpenAI Gym environment. pub fn new(name: &str) -> Result { Python::with_gil(|py| { - let gym = py.import("gymnasium")?; + let gym = py.import_bound("gymnasium")?; let make = gym.getattr("make")?; let env = make.call1((name,))?; let action_space = env.getattr("action_space")?; @@ -66,10 +66,10 @@ impl GymEnv { /// Resets the environment, returning the observation tensor. pub fn reset(&self, seed: u64) -> Result { let state: Vec = Python::with_gil(|py| { - let kwargs = PyDict::new(py); + let kwargs = PyDict::new_bound(py); kwargs.set_item("seed", seed)?; - let state = self.env.call_method(py, "reset", (), Some(kwargs))?; - state.as_ref(py).get_item(0)?.extract() + let state = self.env.call_method_bound(py, "reset", (), Some(&kwargs))?; + state.bind(py).get_item(0)?.extract() }) .map_err(w)?; Tensor::new(state, &Device::Cpu) @@ -81,8 +81,10 @@ impl GymEnv { action: A, ) -> Result> { let (state, reward, terminated, truncated) = Python::with_gil(|py| { - let step = self.env.call_method(py, "step", (action.clone(),), None)?; - let step = step.as_ref(py); + let step = self + .env + .call_method_bound(py, "step", (action.clone(),), None)?; + let step = step.bind(py); let state: Vec = step.get_item(0)?.extract()?; let reward: f64 = step.get_item(1)?.extract()?; let terminated: bool = step.get_item(2)?.extract()?; diff --git a/candle-examples/examples/reinforcement-learning/main.rs b/candle-examples/examples/reinforcement-learning/main.rs index e87afae2..1a25cd93 100644 --- a/candle-examples/examples/reinforcement-learning/main.rs +++ b/candle-examples/examples/reinforcement-learning/main.rs @@ -13,6 +13,7 @@ mod gym_env; mod vec_gym_env; mod ddpg; +mod dqn; mod policy_gradient; #[derive(Parser)] @@ -25,6 +26,7 @@ struct Args { enum Command { Pg, Ddpg, + Dqn, } fn main() -> Result<()> { @@ -32,6 +34,7 @@ fn main() -> Result<()> { match args.command { Command::Pg => policy_gradient::run()?, Command::Ddpg => ddpg::run()?, + Command::Dqn => dqn::run()?, } Ok(()) } diff --git a/candle-examples/examples/reinforcement-learning/vec_gym_env.rs b/candle-examples/examples/reinforcement-learning/vec_gym_env.rs index 8f8f30bd..e382ad76 100644 --- a/candle-examples/examples/reinforcement-learning/vec_gym_env.rs +++ b/candle-examples/examples/reinforcement-learning/vec_gym_env.rs @@ -24,13 +24,13 @@ fn w(res: PyErr) -> candle::Error { impl VecGymEnv { pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result { Python::with_gil(|py| { - let sys = py.import("sys")?; + let sys = py.import_bound("sys")?; let path = sys.getattr("path")?; let _ = path.call_method1( "append", ("candle-examples/examples/reinforcement-learning",), )?; - let gym = py.import("atari_wrappers")?; + let gym = py.import_bound("atari_wrappers")?; let make = gym.getattr("make")?; let env = make.call1((name, img_dir, nprocesses))?; let action_space = env.getattr("action_space")?; @@ -60,10 +60,10 @@ impl VecGymEnv { pub fn step(&self, action: Vec) -> Result { let (obs, reward, is_done) = Python::with_gil(|py| { - let step = self.env.call_method(py, "step", (action,), None)?; - let step = step.as_ref(py); + let step = self.env.call_method_bound(py, "step", (action,), None)?; + let step = step.bind(py); let obs = step.get_item(0)?.call_method("flatten", (), None)?; - let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?; + let obs_buffer = pyo3::buffer::PyBuffer::get_bound(&obs)?; let obs: Vec = obs_buffer.to_vec(py)?; let reward: Vec = step.get_item(1)?.extract()?; let is_done: Vec = step.get_item(2)?.extract()?; diff --git a/candle-examples/examples/repvgg/main.rs b/candle-examples/examples/repvgg/main.rs index 0864c559..7cc90ba1 100644 --- a/candle-examples/examples/repvgg/main.rs +++ b/candle-examples/examples/repvgg/main.rs @@ -78,7 +78,7 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image224(args.image)?; + let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); let model_file = match args.model { diff --git a/candle-examples/examples/resnet/main.rs b/candle-examples/examples/resnet/main.rs index 4a4592ad..bdf02fb1 100644 --- a/candle-examples/examples/resnet/main.rs +++ b/candle-examples/examples/resnet/main.rs @@ -45,7 +45,7 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image224(args.image)?; + let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); let model_file = match args.model { diff --git a/candle-examples/examples/rwkv/main.rs b/candle-examples/examples/rwkv/main.rs index e971a1cc..8fb2c0d4 100644 --- a/candle-examples/examples/rwkv/main.rs +++ b/candle-examples/examples/rwkv/main.rs @@ -141,7 +141,7 @@ impl std::fmt::Display for Which { impl Which { fn model_id(&self) -> &'static str { match self { - Self::Eagle7b => "RWKV/HF_v5-Eagle-7B", + Self::Eagle7b => "RWKV/v5-Eagle-7B-HF", Self::World1b5 => "RWKV/rwkv-5-world-1b5", Self::World3b => "RWKV/rwkv-5-world-3b", Self::World6_1b6 => "paperfun/rwkv", diff --git a/candle-examples/examples/segformer/main.rs b/candle-examples/examples/segformer/main.rs index 76c9f30e..16db62fc 100644 --- a/candle-examples/examples/segformer/main.rs +++ b/candle-examples/examples/segformer/main.rs @@ -5,7 +5,7 @@ use candle_transformers::models::segformer::{ Config, ImageClassificationModel, SemanticSegmentationModel, }; use clap::{Args, Parser, Subcommand}; -use image::Rgb; +use imageproc::image::Rgb; use imageproc::integral_image::ArrayData; use std::collections::HashMap; use std::path::PathBuf; diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 8c3ca2ee..0e39902b 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -96,6 +96,10 @@ struct Args { /// information. #[arg(long, default_value_t = 0.8)] img2img_strength: f64, + + /// The seed to use when generating random samples. + #[arg(long)] + seed: Option, } #[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)] @@ -288,6 +292,13 @@ fn text_embeddings( .map_err(E::msg)? .get_ids() .to_vec(); + if tokens.len() > sd_config.clip.max_position_embeddings { + anyhow::bail!( + "the prompt is too long, {} > max-tokens ({})", + tokens.len(), + sd_config.clip.max_position_embeddings + ) + } while tokens.len() < sd_config.clip.max_position_embeddings { tokens.push(pad_id) } @@ -315,6 +326,13 @@ fn text_embeddings( .map_err(E::msg)? .get_ids() .to_vec(); + if uncond_tokens.len() > sd_config.clip.max_position_embeddings { + anyhow::bail!( + "the negative prompt is too long, {} > max-tokens ({})", + uncond_tokens.len(), + sd_config.clip.max_position_embeddings + ) + } while uncond_tokens.len() < sd_config.clip.max_position_embeddings { uncond_tokens.push(pad_id) } @@ -374,6 +392,7 @@ fn run(args: Args) -> Result<()> { use_flash_attn, img2img, img2img_strength, + seed, .. } = args; @@ -427,6 +446,9 @@ fn run(args: Args) -> Result<()> { let scheduler = sd_config.build_scheduler(n_steps)?; let device = candle_examples::device(cpu)?; + if let Some(seed) = seed { + device.set_seed(seed)?; + } let use_guide_scale = guidance_scale > 1.0; let which = match sd_version { diff --git a/candle-examples/examples/stable-lm/README.md b/candle-examples/examples/stable-lm/README.md index 546124a2..6f5e7597 100644 --- a/candle-examples/examples/stable-lm/README.md +++ b/candle-examples/examples/stable-lm/README.md @@ -10,11 +10,6 @@ order to be able to use it. Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants. -StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by -Candle, so to run it you can download a somewhat compatible -[tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true) -and pass it via the --tokenizer-file argument. - ## Running some example ```bash diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs index abe7020c..f0707010 100644 --- a/candle-examples/examples/stable-lm/main.rs +++ b/candle-examples/examples/stable-lm/main.rs @@ -239,14 +239,7 @@ fn main() -> Result<()> { )); let tokenizer_filename = match args.tokenizer_file { Some(file) => std::path::PathBuf::from(file), - None => match args.which { - Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::Code => { - repo.get("tokenizer.json")? - } - Which::V2 | Which::V2Zephyr => api - .model("lmz/candle-stablelm".to_string()) - .get("tokenizer-gpt4.json")?, - }, + None => repo.get("tokenizer.json")?, }; let filenames = match args.weight_files { Some(files) => files @@ -295,12 +288,12 @@ fn main() -> Result<()> { }; let device = candle_examples::device(args.cpu)?; - let (model, device) = if args.quantized { + let model = if args.quantized { let filename = &filenames[0]; let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; let model = QStableLM::new(&config, vb)?; - (Model::Quantized(model), Device::Cpu) + Model::Quantized(model) } else { let dtype = if device.is_cuda() { DType::BF16 @@ -309,7 +302,7 @@ fn main() -> Result<()> { }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let model = StableLM::new(&config, vb)?; - (Model::StableLM(model), device) + Model::StableLM(model) }; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 8ef108b6..902282c1 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -12,12 +12,23 @@ use anyhow::{Error as E, Result}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; -use clap::Parser; +use clap::{Parser, ValueEnum}; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; const DTYPE: DType = DType::F32; +#[derive(Clone, Debug, Copy, ValueEnum)] +enum Which { + T5Base, + T5Small, + T5Large, + T5_3B, + Mt5Base, + Mt5Small, + Mt5Large, +} + #[derive(Parser, Debug, Clone)] #[command(author, version, about, long_about = None)] struct Args { @@ -36,6 +47,15 @@ struct Args { #[arg(long)] revision: Option, + #[arg(long)] + model_file: Option, + + #[arg(long)] + tokenizer_file: Option, + + #[arg(long)] + config_file: Option, + /// Enable decoding. #[arg(long)] decode: bool, @@ -71,6 +91,10 @@ struct Args { /// The context size to consider for the repeat penalty. #[arg(long, default_value_t = 64)] repeat_last_n: usize, + + /// The model to be used. + #[arg(long, default_value = "t5-small")] + which: Which, } struct T5ModelBuilder { @@ -82,8 +106,17 @@ struct T5ModelBuilder { impl T5ModelBuilder { pub fn load(args: &Args) -> Result<(Self, Tokenizer)> { let device = candle_examples::device(args.cpu)?; - let default_model = "t5-small".to_string(); - let default_revision = "refs/pr/15".to_string(); + let (default_model, default_revision) = match args.which { + Which::T5Base => ("t5-base", "main"), + Which::T5Small => ("t5-small", "refs/pr/15"), + Which::T5Large => ("t5-large", "main"), + Which::T5_3B => ("t5-3b", "main"), + Which::Mt5Base => ("google/mt5-base", "refs/pr/5"), + Which::Mt5Small => ("google/mt5-small", "refs/pr/6"), + Which::Mt5Large => ("google/mt5-large", "refs/pr/2"), + }; + let default_model = default_model.to_string(); + let default_revision = default_revision.to_string(); let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) { (Some(model_id), Some(revision)) => (model_id, revision), (Some(model_id), None) => (model_id, "main".to_string()), @@ -93,14 +126,35 @@ impl T5ModelBuilder { let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision); let api = Api::new()?; - let api = api.repo(repo); - let config_filename = api.get("config.json")?; - let tokenizer_filename = api.get("tokenizer.json")?; - let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2" - { - candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? - } else { - vec![api.get("model.safetensors")?] + let repo = api.repo(repo); + let config_filename = match &args.config_file { + None => repo.get("config.json")?, + Some(f) => f.into(), + }; + let tokenizer_filename = match &args.tokenizer_file { + None => match args.which { + Which::Mt5Base => api + .model("lmz/mt5-tokenizers".into()) + .get("mt5-base.tokenizer.json")?, + Which::Mt5Small => api + .model("lmz/mt5-tokenizers".into()) + .get("mt5-small.tokenizer.json")?, + Which::Mt5Large => api + .model("lmz/mt5-tokenizers".into()) + .get("mt5-large.tokenizer.json")?, + _ => repo.get("tokenizer.json")?, + }, + Some(f) => f.into(), + }; + let weights_filename = match &args.model_file { + Some(f) => f.split(',').map(|v| v.into()).collect::>(), + None => { + if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2" { + candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? + } else { + vec![repo.get("model.safetensors")?] + } + } }; let config = std::fs::read_to_string(config_filename)?; let mut config: t5::Config = serde_json::from_str(&config)?; diff --git a/candle-examples/examples/vgg/main.rs b/candle-examples/examples/vgg/main.rs index 27e141cb..e7bfe7d2 100644 --- a/candle-examples/examples/vgg/main.rs +++ b/candle-examples/examples/vgg/main.rs @@ -33,7 +33,7 @@ struct Args { pub fn main() -> anyhow::Result<()> { let args = Args::parse(); let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image224(args.image)?; + let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); diff --git a/candle-examples/examples/vit/main.rs b/candle-examples/examples/vit/main.rs index 168caf9e..b38bae15 100644 --- a/candle-examples/examples/vit/main.rs +++ b/candle-examples/examples/vit/main.rs @@ -28,7 +28,7 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image224(args.image)?; + let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); let model_file = match args.model { diff --git a/candle-examples/examples/whisper/README.md b/candle-examples/examples/whisper/README.md index 124cd182..a7dd4081 100644 --- a/candle-examples/examples/whisper/README.md +++ b/candle-examples/examples/whisper/README.md @@ -34,6 +34,7 @@ from the hub. - `--timestamps`: enable the timestamp mode where some timestamps are reported for each recognized audio extracts. - `--model`: the model to be used. Models that do not end with `-en` are - multilingual models, other ones are English only models. The supported models - are `tiny`, `tiny.en`, `base`, `base.en`, `small`, `small.en`, `medium`, - `medium.en`, `large`, and `large-v2`. + multilingual models, other ones are English only models. The supported OpenAI + Whisper models are `tiny`, `tiny.en`, `base`, `base.en`, `small`, `small.en`, + `medium`, `medium.en`, `large`, `large-v2` and `large-v3`. The supported + Distil-Whisper models are `distil-medium.en`, `distil-large-v2` and `distil-large-v3`. diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index da8c73ae..ecd5ff84 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -374,6 +374,8 @@ enum WhichModel { DistilMediumEn, #[value(name = "distil-large-v2")] DistilLargeV2, + #[value(name = "distil-large-v3")] + DistilLargeV3, } impl WhichModel { @@ -386,7 +388,8 @@ impl WhichModel { | Self::Large | Self::LargeV2 | Self::LargeV3 - | Self::DistilLargeV2 => true, + | Self::DistilLargeV2 + | Self::DistilLargeV3 => true, Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => { false } @@ -408,6 +411,7 @@ impl WhichModel { Self::LargeV3 => ("openai/whisper-large-v3", "main"), Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"), Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"), + Self::DistilLargeV3 => ("distil-whisper/distil-large-v3", "main"), } } } diff --git a/candle-examples/examples/yolo-v8/assets/bike.pp.jpg b/candle-examples/examples/yolo-v8/assets/bike.pp.jpg new file mode 100644 index 00000000..a46b8e84 Binary files /dev/null and b/candle-examples/examples/yolo-v8/assets/bike.pp.jpg differ diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index c65a5ca1..eb338647 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -99,7 +99,7 @@ pub fn report_detect( let h_ratio = initial_h as f32 / h as f32; let mut img = img.to_rgb8(); let font = Vec::from(include_bytes!("roboto-mono-stripped.ttf") as &[u8]); - let font = rusttype::Font::try_from_vec(font); + let font = ab_glyph::FontRef::try_from_slice(&font).map_err(candle::Error::wrap)?; for (class_index, bboxes_for_class) in bboxes.iter().enumerate() { for b in bboxes_for_class.iter() { println!( @@ -119,27 +119,28 @@ pub fn report_detect( ); } if legend_size > 0 { - if let Some(font) = font.as_ref() { - imageproc::drawing::draw_filled_rect_mut( - &mut img, - imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, legend_size), - image::Rgb([170, 0, 0]), - ); - let legend = format!( - "{} {:.0}%", - candle_examples::coco_classes::NAMES[class_index], - 100. * b.confidence - ); - imageproc::drawing::draw_text_mut( - &mut img, - image::Rgb([255, 255, 255]), - xmin, - ymin, - rusttype::Scale::uniform(legend_size as f32 - 1.), - font, - &legend, - ) - } + imageproc::drawing::draw_filled_rect_mut( + &mut img, + imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, legend_size), + image::Rgb([170, 0, 0]), + ); + let legend = format!( + "{} {:.0}%", + candle_examples::coco_classes::NAMES[class_index], + 100. * b.confidence + ); + imageproc::drawing::draw_text_mut( + &mut img, + image::Rgb([255, 255, 255]), + xmin, + ymin, + ab_glyph::PxScale { + x: legend_size as f32 - 1., + y: legend_size as f32 - 1., + }, + &font, + &legend, + ) } } } diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index 29d72cd7..827cf970 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-flash-attn" -version = "0.4.1" +version = "0.5.0" edition = "2021" description = "Flash attention layer for the candle ML framework." @@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0" readme = "README.md" [dependencies] -candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.1" } +candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.5.0" } half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml index ab059f89..5cedb7d3 100644 --- a/candle-kernels/Cargo.toml +++ b/candle-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-kernels" -version = "0.4.1" +version = "0.5.0" edition = "2021" description = "CUDA kernels for Candle" diff --git a/candle-kernels/build.rs b/candle-kernels/build.rs index 63d744ca..c28abd97 100644 --- a/candle-kernels/build.rs +++ b/candle-kernels/build.rs @@ -1,5 +1,8 @@ fn main() { println!("cargo:rerun-if-changed=build.rs"); + println!("cargo:rerun-if-changed=src/compatibility.cuh"); + println!("cargo:rerun-if-changed=src/cuda_utils.cuh"); + println!("cargo:rerun-if-changed=src/binary_op_macros.cuh"); let builder = bindgen_cuda::Builder::default(); println!("cargo:info={builder:?}"); diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu index 152b9463..540d0819 100644 --- a/candle-kernels/src/affine.cu +++ b/candle-kernels/src/affine.cu @@ -13,7 +13,7 @@ extern "C" __global__ void FN_NAME( \ ) { \ const size_t *dims = info; \ const size_t *strides = info + num_dims; \ - if (is_contiguous(num_dims, dims, strides)) { \ + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ TYPENAME x = inp ? inp[i] : out[i]; \ out[i] = x * mul + add; \ diff --git a/candle-kernels/src/binary_op_macros.cuh b/candle-kernels/src/binary_op_macros.cuh index 05d0c3df..9cb00874 100644 --- a/candle-kernels/src/binary_op_macros.cuh +++ b/candle-kernels/src/binary_op_macros.cuh @@ -12,8 +12,8 @@ extern "C" __global__ void FN_NAME( \ const size_t *dims = dims_and_strides; \ const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \ const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \ - bool lhs_cont = is_contiguous(num_dims, dims, lhs_strides); \ - bool rhs_cont = is_contiguous(num_dims, dims, rhs_strides); \ + bool lhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, lhs_strides); \ + bool rhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, rhs_strides); \ if (lhs_cont && rhs_cont) { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ TYPENAME x = lhs[i]; \ diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index 024642c6..90f5e7ba 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -11,7 +11,7 @@ __device__ void cast_( ) { const size_t *dims = info; const size_t *strides = info + num_dims; - if (is_contiguous(num_dims, dims, strides)) { + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { out[i] = inp[i]; } @@ -34,7 +34,7 @@ __device__ void cast_through( ) { const size_t *dims = info; const size_t *strides = info + num_dims; - if (is_contiguous(num_dims, dims, strides)) { + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { out[i] = static_cast(static_cast(inp[i])); } @@ -83,6 +83,18 @@ CAST_OP(double, __nv_bfloat16, cast_f64_bf16) CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8) CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16) CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) +#else +#include +#if CUDA_VERSION >= 11000 +CAST_OP(__nv_bfloat16, float, cast_bf16_f32) +CAST_OP(float, __nv_bfloat16, cast_f32_bf16) +CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8) +CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16) +CAST_THROUGH_OP(__nv_bfloat16, double, float, cast_bf16_f64) +CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) +CAST_THROUGH_OP(double, __nv_bfloat16, float, cast_f64_bf16) +CAST_THROUGH_OP(uint8_t, __nv_bfloat16, float, cast_u8_bf16) +#endif #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index b0a85249..2673b8aa 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -14,7 +14,7 @@ __device__ bool is_contiguous( size_t acc = 1; for (unsigned int d = 0; d < num_dims; d++) { unsigned int dim_idx = num_dims - 1 - d; - if (acc != strides[dim_idx]) { + if (dims[dim_idx] > 1 && acc != strides[dim_idx]) { return false; } acc *= dims[dim_idx]; diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index 883ca072..ca448d98 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -10,11 +10,39 @@ __device__ void fill_with(T *buf, T value, const size_t numel) { extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); } -extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); } +template +__device__ void copy2d(const T *src, T *dst, uint32_t d1, uint32_t d2, uint32_t src_s, uint32_t dst_s) { + uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= d1 * d2) { + return; + } + uint32_t idx1 = idx / d2; + uint32_t idx2 = idx - d2 * idx1; + dst[idx1 * dst_s + idx2] = src[idx1 * src_s + idx2]; +} + +#define COPY2D_OP(TYPENAME, FNNAME) \ +extern "C" __global__ \ +void FNNAME(const TYPENAME *src, TYPENAME *dst, uint32_t d1, uint32_t d2, uint32_t src_s, uint32_t dst_s) { \ + copy2d(src, dst, d1, d2, src_s, dst_s); \ +} \ + +COPY2D_OP(float, copy2d_f32) +COPY2D_OP(double, copy2d_f64) +COPY2D_OP(uint8_t, copy2d_u8) +COPY2D_OP(uint32_t, copy2d_u32) +COPY2D_OP(int64_t, copy2d_i64) + +#if __CUDA_ARCH__ >= 530 +extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); } +COPY2D_OP(__half, copy2d_f16) +#endif + #if __CUDA_ARCH__ >= 800 #include extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); } +COPY2D_OP(__nv_bfloat16, copy2d_bf16) #endif diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 8fc69363..8af2954d 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -168,8 +168,10 @@ IS_OP(__half, uint8_t, is_u8_f16) GATHER_OP(__half, int64_t, gather_i64_f16) GATHER_OP(__half, uint32_t, gather_u32_f16) GATHER_OP(__half, uint8_t, gather_u8_f16) +IA_OP(__half, int64_t, ia_i64_f16) IA_OP(__half, uint32_t, ia_u32_f16) IA_OP(__half, uint8_t, ia_u8_f16) +SA_OP(__half, int64_t, sa_i64_f16) SA_OP(__half, uint32_t, sa_u32_f16) SA_OP(__half, uint8_t, sa_u8_f16) #endif diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu index f8becbbc..f91dbb32 100644 --- a/candle-kernels/src/quantized.cu +++ b/candle-kernels/src/quantized.cu @@ -23,6 +23,22 @@ typedef float dfloat; // dequantize float typedef float2 dfloat2; typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); +static __device__ __forceinline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, mask, 32); + } + return x; +} + +static __device__ __forceinline__ float warp_reduce_max(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +} + static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment @@ -233,57 +249,6 @@ typedef struct { static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); -// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called -// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q - -#define VDR_Q4_0_Q8_1_MMVQ 2 -#define VDR_Q4_0_Q8_1_MMQ 4 - -template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( - const int * v, const int * u, const float & d4, const half2 & ds8) { - - int sumi = 0; - -#pragma unroll - for (int i = 0; i < vdr; ++i) { - const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; - const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; - - // SIMD dot product of quantized values - sumi = __dp4a(vi0, u[2*i+0], sumi); - sumi = __dp4a(vi1, u[2*i+1], sumi); - } - - const float2 ds8f = __half22float2(ds8); - - // second part effectively subtracts 8 from each quant value - const float res = d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y); - printf("%f %f %f %f %f %f\n", res, d4, sumi, ds8f.x, vdr/QI4_0, ds8f.y); - return res; -} - - -static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - (void)x_qh; (void)x_sc; - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const float * x_dmf = (const float *) x_dm; - - int u[2*VDR_Q4_0_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; - } - - return vec_dot_q4_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); -} - template static __device__ __forceinline__ void mul_mat_q( @@ -447,30 +412,6 @@ template static __device__ __forceinline__ void allocate_tiles_q4_0( *x_dm = (half2 *) tile_x_d; } -extern "C" __global__ void mul_mat_q4_0_check( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - const int mmq_x = MMQ_X_Q4_0_AMPERE; - const int mmq_y = MMQ_Y_Q4_0_AMPERE; - const int nwarps = NWARPS_Q4_0_AMPERE; - - mul_mat_q, - load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -} - -extern "C" __global__ void mul_mat_q4_0_no_check( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - const int mmq_x = MMQ_X_Q4_0_AMPERE; - const int mmq_y = MMQ_Y_Q4_0_AMPERE; - const int nwarps = NWARPS_Q4_0_AMPERE; - - mul_mat_q, - load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -} - static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; @@ -1595,3 +1536,1001 @@ extern "C" __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ dst[row] = tmp; } } + +// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called +// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q + +#define VDR_Q4_0_Q8_1_MMVQ 2 +#define VDR_Q4_0_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( + const int * v, const int * u, const float & d4, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = __dp4a(vi0, u[2*i+0], sumi); + sumi = __dp4a(vi1, u[2*i+1], sumi); + } + + const float2 ds8f = __half22float2(ds8); + + // second part effectively subtracts 8 from each quant value + return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y); +} + +#define VDR_Q4_1_Q8_1_MMVQ 2 +#define VDR_Q4_1_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl( + const int * v, const int * u, const half2 & dm4, const half2 & ds8) { + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = __dp4a(vi0, u[2*i+0], sumi); + sumi = __dp4a(vi1, u[2*i+1], sumi); + } + +#ifdef GGML_CUDA_F16 + const float2 tmp = __half22float2(__hmul2(dm4, ds8)); + const float d4d8 = tmp.x; + const float m4s8 = tmp.y; +#else + const float2 dm4f = __half22float2(dm4); + const float2 ds8f = __half22float2(ds8); + const float d4d8 = dm4f.x * ds8f.x; + const float m4s8 = dm4f.y * ds8f.y; +#endif // GGML_CUDA_F16 + + // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it + return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1)); +} + +#define VDR_Q5_0_Q8_1_MMVQ 2 +#define VDR_Q5_0_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl( + const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + + int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + } + + const float2 ds8f = __half22float2(ds8); + + // second part effectively subtracts 16 from each quant value + return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y); +} + +#define VDR_Q5_1_Q8_1_MMVQ 2 +#define VDR_Q5_1_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl( + const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + + int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + } + +#ifdef GGML_CUDA_F16 + const float2 tmp = __half22float2(__hmul2(dm5, ds8)); + const float d5d8 = tmp.x; + const float m5s8 = tmp.y; +#else + const float2 dm5f = __half22float2(dm5); + const float2 ds8f = __half22float2(ds8); + const float d5d8 = dm5f.x * ds8f.x; + const float m5s8 = dm5f.y * ds8f.y; +#endif // GGML_CUDA_F16 + + // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it + return sumi*d5d8 + m5s8 / (QI5_1 / vdr); +} + +#define VDR_Q8_0_Q8_1_MMVQ 2 +#define VDR_Q8_0_Q8_1_MMQ 8 + +template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( + const int * v, const int * u, const float & d8_0, const float & d8_1) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = __dp4a(v[i], u[i], sumi); + } + + return d8_0*d8_1 * sumi; +} + +template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl( + const int * v, const int * u, const half2 & dm8, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = __dp4a(v[i], u[i], sumi); + } + +#ifdef GGML_CUDA_F16 + const float2 tmp = __half22float2(__hmul2(dm8, ds8)); + const float d8d8 = tmp.x; + const float m8s8 = tmp.y; +#else + const float2 dm8f = __half22float2(dm8); + const float2 ds8f = __half22float2(ds8); + const float d8d8 = dm8f.x * ds8f.x; + const float m8s8 = dm8f.y * ds8f.y; +#endif // GGML_CUDA_F16 + + // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it + return sumi*d8d8 + m8s8 / (QI8_1 / vdr); +} + +#define VDR_Q2_K_Q8_1_MMVQ 1 +#define VDR_Q2_K_Q8_1_MMQ 2 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( + const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const half2 & dm2, const float * __restrict__ d8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR2_K; ++i) { + const int sc = scales[2*i]; + + const int vi = (v >> (2*i)) & 0x03030303; + + sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values + } + + const float2 dm2f = __half22float2(dm2); + + return dm2f.x*sumf_d - dm2f.y*sumf_m; +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const half2 & dm2, const float & d8) { + + int sumi_d = 0; + int sumi_m = 0; + +#pragma unroll + for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { + int sumi_d_sc = 0; + + const int sc = scales[i0 / (QI8_1/2)]; + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + +#pragma unroll + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product + sumi_m = __dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m + } + + sumi_d += sumi_d_sc * (sc & 0xF); + } + + const float2 dm2f = __half22float2(dm2); + + return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m); +} + +#define VDR_Q3_K_Q8_1_MMVQ 1 +#define VDR_Q3_K_Q8_1_MMQ 2 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( + const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const int & scale_offset, const float & d3, const float * __restrict__ d8) { + + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + const int isc = scale_offset + 2*i; + + const int isc_low = isc % (QK_K/32); + const int sc_shift_low = 4 * (isc / (QK_K/32)); + const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF; + + const int isc_high = isc % (QK_K/64); + const int sc_shift_high = 2 * (isc / (QK_K/64)); + const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; + + const int sc = (sc_low | sc_high) - 32; + + const int vil = (vl >> (2*i)) & 0x03030303; + + const int vih = ((vh >> i) << 2) & 0x04040404; + + const int vi = __vsubss4(vil, vih); + + sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d3 * sumf; +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d3, const float & d8) { + + int sumi = 0; + +#pragma unroll + for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) { + int sumi_sc = 0; + + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product + } + + sumi += sumi_sc * scales[i0 / (QI8_1/2)]; + } + + return d3*d8 * sumi; +} + +#define VDR_Q4_K_Q8_1_MMVQ 2 +#define VDR_Q4_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K; ++i) { + const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F; + const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F; + + const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) { + int sumi_d = 0; + +#pragma unroll + for (int j = 0; j < QI8_1; ++j) { + sumi_d = __dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product + } + + const float2 ds8f = __half22float2(ds8[i]); + + sumf_d += ds8f.x * (sc[i] * sumi_d); + sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; +} + +#define VDR_Q5_K_Q8_1_MMVQ 2 +#define VDR_Q5_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( + const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F; + const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F; + + const int vh0i = ((vh[0] >> i) << 4) & 0x10101010; + const int vh1i = ((vh[1] >> i) << 4) & 0x10101010; + + const int v0i = vl0i | vh0i; + const int v1i = vl1i | vh1i; + + const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); + + } + + const float2 dm5f = __half22float2(dm5); + + return dm5f.x*sumf_d - dm5f.y*sumf_m; +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) { + int sumi_d = 0; + +#pragma unroll + for (int j = 0; j < QI8_1; ++j) { + sumi_d = __dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product + } + + const float2 ds8f = __half22float2(ds8[i]); + + sumf_d += ds8f.x * (sc[i] * sumi_d); + sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; +} + +#define VDR_Q6_K_Q8_1_MMVQ 1 +#define VDR_Q6_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( + const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d, const float * __restrict__ d8) { + + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + const int sc = scales[4*i]; + + const int vil = (vl >> (4*i)) & 0x0F0F0F0F; + + const int vih = ((vh >> (4*i)) << 4) & 0x30303030; + + const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32 + + sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d*sumf; +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, + const float & d6, const float * __restrict__ d8) { + + float sumf_d = 0.0f; + +#pragma unroll + for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { + int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale + +#pragma unroll + for (int i = i0; i < i0 + 2; ++i) { + sumi_d.x = __dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product + sumi_d.x = __dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product + + sumi_d.y = __dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product + sumi_d.y = __dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product + } + + sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y); + } + + return d6 * sumf_d; +} + +static __device__ __forceinline__ float vec_dot_q4_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; + + int v[VDR_Q4_0_Q8_1_MMVQ]; + int u[2*VDR_Q4_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); + } + + return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds); +} + + +static __device__ __forceinline__ float vec_dot_q4_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; + + int v[VDR_Q4_1_Q8_1_MMVQ]; + int u[2*VDR_Q4_1_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1); + } + + return vec_dot_q4_1_q8_1_impl(v, u, bq4_1->dm, bq8_1->ds); +} + +static __device__ __forceinline__ float vec_dot_q5_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; + + int vl[VDR_Q5_0_Q8_1_MMVQ]; + int vh[VDR_Q5_0_Q8_1_MMVQ]; + int u[2*VDR_Q5_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) { + vl[i] = get_int_from_uint8(bq5_0->qs, iqs + i); + vh[i] = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0); + } + + return vec_dot_q5_0_q8_1_impl(vl, vh, u, bq5_0->d, bq8_1->ds); +} + +static __device__ __forceinline__ float vec_dot_q5_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; + + int vl[VDR_Q5_1_Q8_1_MMVQ]; + int vh[VDR_Q5_1_Q8_1_MMVQ]; + int u[2*VDR_Q5_1_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) { + vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i); + vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1); + } + + return vec_dot_q5_1_q8_1_impl(vl, vh, u, bq5_1->dm, bq8_1->ds); +} + +static __device__ __forceinline__ float vec_dot_q8_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; + + int v[VDR_Q8_0_Q8_1_MMVQ]; + int u[VDR_Q8_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_int8(bq8_0->qs, iqs + i); + u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + } + + return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds)); +} + +static __device__ __forceinline__ float vec_dot_q2_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q2_K * bq2_K = (const block_q2_K *) vbq; + + const int bq8_offset = QR2_K * (iqs / QI8_1); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const uint8_t * scales = bq2_K->scales + scale_offset; + + const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs); + int u[QR2_K]; + float d8[QR2_K]; + +#pragma unroll + for (int i = 0; i < QR2_K; ++ i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = __low2float(bq8_1[bq8_offset + i].ds); + } + + return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8); +} + +static __device__ __forceinline__ float vec_dot_q3_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q3_K * bq3_K = (const block_q3_K *) vbq; + + const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const float d = bq3_K->d; + + const int vl = get_int_from_uint8(bq3_K->qs, iqs); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset; + + int u[QR3_K]; + float d8[QR3_K]; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = __low2float(bq8_1[bq8_offset + i].ds); + } + + return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); +} + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + +#ifndef GGML_QKK_64 + const block_q4_K * bq4_K = (const block_q4_K *) vbq; + + int v[2]; + int u[2*QR4_K]; + float d8[QR4_K]; + + // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6 + const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2)); + + // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 + // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 + // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 + // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 + + const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); + v[0] = q4[0]; + v[1] = q4[4]; + + const uint16_t * scales = (const uint16_t *)bq4_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + + for (int i = 0; i < QR4_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = __low2float(bq8i->ds); + + const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; + } + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); + +#else + + const block_q4_K * bq4_K = (const block_q4_K *) vbq; + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + const uint16_t * a = (const uint16_t *)bq4_K->scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const float dall = bq4_K->dm[0]; + const float dmin = bq4_K->dm[1]; + + const float d8_1 = __low2float(bq8_1[0].ds); + const float d8_2 = __low2float(bq8_1[1].ds); + + const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); + const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); + const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); + const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); + + const int * q4 = (const int *)bq4_K->qs + (iqs/2); + const int v1 = q4[0]; + const int v2 = q4[4]; + + const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0)); + const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); + const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); + const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0)); + + sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); + sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); + + return dall * sumf_d - dmin * sumf_m; +#endif +} + +static __device__ __forceinline__ float vec_dot_q5_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + +#ifndef GGML_QKK_64 + const block_q5_K * bq5_K = (const block_q5_K *) vbq; + + int vl[2]; + int vh[2]; + int u[2*QR5_K]; + float d8[QR5_K]; + + const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2)); + const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); + const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4)); + + vl[0] = ql[0]; + vl[1] = ql[4]; + + vh[0] = qh[0] >> bq8_offset; + vh[1] = qh[4] >> bq8_offset; + + const uint16_t * scales = (const uint16_t *)bq5_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = __low2float(bq8i->ds); + + const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; + } + + return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8); + +#else + + const block_q5_K * bq5_K = (const block_q5_K *) vbq; + + const int8_t * s = bq5_K->scales; + + const float d = bq5_K->d; + + const float d8_1 = __low2half(bq8_1[0].ds); + const float d8_2 = __low2half(bq8_1[1].ds); + + const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); + const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); + const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); + const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); + + const int * ql = (const int *)bq5_K->qs + (iqs/2); + const int vl1 = ql[0]; + const int vl2 = ql[4]; + + const int step = 4 * (iqs/2); // 0, 4, 8, 12 + const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6 + const int in = step%8; // 0, 4, 0, 4 + const int vh = (*((const int *)(bq5_K->qh + in))) >> im; + + const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f); + const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f); + const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f); + const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f); + + const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1]) + + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]); + + return d * sumf_d; +#endif +} + +static __device__ __forceinline__ float vec_dot_q6_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q6_K * bq6_K = (const block_q6_K *) vbq; + + const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4); + const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8); + const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4)); + + const int vl = get_int_from_uint8(bq6_K->ql, iqs); + const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift; + + const int8_t * scales = bq6_K->scales + scale_offset; + + int u[QR6_K]; + float d8[QR6_K]; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); + d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds); + } + + return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); +} + +// https://github.com/ggerganov/llama.cpp/blob/c50a82ce0f71558cbb8e555146ba124251504b38/ggml-cuda/mmvq.cu#L4 +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); + +template +static __device__ void mul_mat_vec_q( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) + constexpr int nwarps = 1; + constexpr int rows_per_cuda_block = 1; +#else + constexpr int nwarps = ncols_y <= 4 ? 4 : 2; + constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) + + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + const int row0 = rows_per_cuda_block*blockIdx.x; + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_col_y = nrows_y / QK8_1; + constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi; + +// partial sum for each thread + float tmp[ncols_y][rows_per_cuda_block] = {0.0f}; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { + const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx + + // x block quant index when casting the quants to int + const int kqs = vdr * (tid % (qi/vdr)); + +#pragma unroll + for (int j = 0; j < ncols_y; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp[j][i] += vec_dot_q_cuda( + &x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs); + } + } + } + + __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE]; + if (threadIdx.y > 0) { +#pragma unroll + for (int j = 0; j < ncols_y; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; + } + } + } + __syncthreads(); + if (threadIdx.y > 0) { + return; + } + + // sum up partial sums and write back result +#pragma unroll + for (int j = 0; j < ncols_y; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { +#pragma unroll + for (int l = 0; l < nwarps-1; ++l) { + tmp[j][i] += tmp_shared[l][j][i][threadIdx.x]; + } + tmp[j][i] = warp_reduce_sum(tmp[j][i]); + } + + if (threadIdx.x < rows_per_cuda_block) { + dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x]; + } + } +} + +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { + const int ix = blockDim.x*blockIdx.x + threadIdx.x; + + if (ix >= kx_padded) { + return; + } + + const int iy = blockDim.y*blockIdx.y + threadIdx.y; + + const int i_padded = iy*kx_padded + ix; + + block_q8_1 * y = (block_q8_1 *) vy; + + const int ib = i_padded / QK8_1; // block index + const int iqs = i_padded % QK8_1; // quant index + + const float xi = ix < kx ? x[iy*kx + ix] : 0.0f; + float amax = fabsf(xi); + float sum = xi; + + amax = warp_reduce_max(amax); + sum = warp_reduce_sum(sum); + + const float d = amax / 127; + const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + y[ib].qs[iqs] = q; + + if (iqs > 0) { + return; + } + + reinterpret_cast(y[ib].ds.x) = d; + reinterpret_cast(y[ib].ds.y) = sum; +} diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index fca6865e..4dbd8dcc 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -2,6 +2,7 @@ #include #include +#define WARP_SIZE 32 const int BLOCK_SIZE = 1024; // TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32 @@ -49,6 +50,59 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block, dst[dst_id] = shr[0]; } +static __device__ __forceinline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, mask, 32); + } + return x; +} + +// RmsNorm implementation adapted from ggml, accumulation is made using f32. +// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523 +template +__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + const int block_size = blockDim.x; + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = static_cast(x[row*ncols + col]); + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + const float mean = tmp / ncols; + const float scale = rsqrtf(mean + eps); + + if (alpha == nullptr) { + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = static_cast(scale * static_cast(x[row*ncols + col])); + } + } + else { + for (int col = tid; col < ncols; col += block_size) { + float a = static_cast(alpha[col]); + dst[row*ncols + col] = static_cast(scale * static_cast(x[row*ncols + col]) * a); + } + } +} + // Softmax implementation adapted from ggml. // https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L4159 template @@ -93,6 +147,65 @@ __device__ void softmax(const T * x, T * dst, const int ncols) { } } +template +__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (2 * idx >= bh * td) return; + + uint32_t rope_idx = idx % (td / 2); + T c = cos[rope_idx]; + T s = sin[rope_idx]; + + dst[2 * idx] = src[2 * idx] * c - src[2 * idx + 1] * s; + dst[2 * idx + 1] = src[2 * idx] * s + src[2 * idx + 1] * c; +} + +template +__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (2 * idx >= bh * td) return; + + uint32_t i_bh = idx / (td / 2); + uint32_t i_td = idx - (td / 2) * i_bh; + uint32_t i_t = i_td / (d / 2); + uint32_t i_d = i_td - (d / 2) * i_t; + uint32_t i1 = i_bh * td + i_t * d + i_d; + uint32_t i2 = i1 + d / 2; + uint32_t i_cs = i_t * (d / 2) + i_d; + T c = cos[i_cs]; + T s = sin[i_cs]; + + dst[i1] = src[i1] * c - src[i2] * s; + dst[i2] = src[i1] * s + src[i2] * c; +} + +template +__device__ void rope_thd( + const T * src, + const T * cos, + const T * sin, + T * dst, + const uint32_t b, + const uint32_t t, + const uint32_t h, + const uint32_t d +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (2 * idx >= b * t * h * d) return; + + uint32_t i_bth = idx / (d / 2); + uint32_t i_d = idx - (d / 2) * i_bth; + uint32_t i_t = (i_bth / h) % t; + uint32_t i1 = i_bth * d + i_d; + uint32_t i2 = i1 + d / 2; + uint32_t i_cs = i_t * (d / 2) + i_d; + T c = cos[i_cs]; + T s = sin[i_cs]; + + dst[i1] = src[i1] * c - src[i2] * s; + dst[i2] = src[i1] * s + src[i2] * c; +} + template __device__ void fast_max(const size_t src_numel, const size_t el_to_sum_per_block, @@ -341,14 +454,57 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, softmax(src, dst, n_cols); \ } \ +#define RMSNORM_OP(TYPENAME, FN_NAME) \ + extern "C" __global__ void FN_NAME( \ + const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \ + const int n_cols, const float eps) { \ + rmsnorm(src, dst, alpha, n_cols, eps); \ + } \ + +#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \ + extern "C" __global__ void FN_NAME_I( \ + const TYPENAME *src, \ + const TYPENAME *cos, \ + const TYPENAME *sin, \ + TYPENAME *dst, \ + const uint32_t bh, \ + const uint32_t td) { \ + ropei(src, cos, sin, dst, bh, td); \ + } \ + extern "C" __global__ void FN_NAME( \ + const TYPENAME *src, \ + const TYPENAME *cos, \ + const TYPENAME *sin, \ + TYPENAME *dst, \ + const uint32_t bh, \ + const uint32_t td, \ + const uint32_t d) { \ + rope(src, cos, sin, dst, bh, td, d); \ + } \ + extern "C" __global__ void FN_NAME_THD( \ + const TYPENAME *src, \ + const TYPENAME *cos, \ + const TYPENAME *sin, \ + TYPENAME *dst, \ + const uint32_t b, \ + const uint32_t t, \ + const uint32_t h, \ + const uint32_t d) { \ + rope_thd(src, cos, sin, dst, b, t, h, d); \ + } \ + #if __CUDA_ARCH__ >= 800 SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) +RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16) +ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16) SUM_OP(__nv_bfloat16, sum_bf16) FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) #endif #if __CUDA_ARCH__ >= 530 SOFTMAX_OP(__half, float, softmax_f16) +RMSNORM_OP(__half, rmsnorm_f16) +ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16) SUM_OP(__half, sum_f16) FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16) #endif @@ -358,6 +514,10 @@ SUM_OP(double, sum_f64) SUM_OP(uint32_t, sum_u32) SOFTMAX_OP(float, float, softmax_f32) SOFTMAX_OP(double, double, softmax_f64) +RMSNORM_OP(float, rmsnorm_f32) +RMSNORM_OP(double, rmsnorm_f64) +ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32) +ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64) FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32) FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64) diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index 74ba1fac..a234304a 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -13,7 +13,7 @@ extern "C" __global__ void FN_NAME( \ ) { \ const size_t *dims = info; \ const size_t *strides = info + num_dims; \ - if (is_contiguous(num_dims, dims, strides)) { \ + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ TYPENAME x = inp ? inp[i] : out[i]; \ out[i] = FUNC; \ @@ -71,7 +71,7 @@ extern "C" __global__ void FN_NAME( \ ) { \ const size_t *dims = info; \ const size_t *strides = info + num_dims; \ - if (is_contiguous(num_dims, dims, strides)) { \ + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ TYPENAME x = inp ? inp[i] : out[i]; \ out[i] = FUNC; \ @@ -86,6 +86,11 @@ extern "C" __global__ void FN_NAME( \ } \ } \ +template +__device__ T sign_(T t) { + return static_cast(t > static_cast(0)) - static_cast(t < static_cast(0)); +} + #if __CUDA_ARCH__ >= 800 UNARY_OP(__nv_bfloat16, ucopy_bf16, x) @@ -110,6 +115,7 @@ UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x)) UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param)) UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x)) UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param)) +UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x)) #endif #if __CUDA_ARCH__ >= 530 @@ -135,6 +141,7 @@ UNARY_OP(__half, urelu_f16, relu_fwd(x)) UNARY_OP1(__half, uelu_f16, elu_fwd(x, param)) UNARY_OP(__half, usilu_f16, silu_fwd(x)) UNARY_OP1(__half, upowf_f16, powg(x, param)) +UNARY_OP(__half, usign_f16, sign_(x)) #endif UNARY_OP(uint8_t, ucopy_u8, x) @@ -184,3 +191,5 @@ UNARY_OP(float, usilu_f32, silu_fwd(x)) UNARY_OP(double, usilu_f64, silu_fwd(x)) UNARY_OP1(float, upowf_f32, powg(x, param)) UNARY_OP1(double, upowf_f64, powg(x, param)) +UNARY_OP(float, usign_f32, sign_(x)) +UNARY_OP(double, usign_f64, sign_(x)) diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index a2837ddb..65e00bbc 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-metal-kernels" -version = "0.4.1" +version = "0.5.0" edition = "2021" description = "Metal kernels for Candle" diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index a4484998..76c0365a 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -89,7 +89,7 @@ kernel void FN_NAME( \ return; \ } \ const TYPENAME x = input[id]; \ - output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \ + output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \ } \ kernel void FN_NAME##_strided( \ constant size_t &dim, \ diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index ae11286a..e83498e4 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -60,21 +60,24 @@ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \ BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \ BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); -#define INT64_BINARY_OP(NAME, FN) \ -BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided); - -#define BFLOAT_BINARY_OP(FN, NAME) \ -BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided); - #define BINARY_OP_OUT(NAME, FN) \ BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \ BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \ BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \ BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); +#define INT64_BINARY_OP(NAME, FN) \ +BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided); + #define INT64_BINARY_OP_OUT(NAME, FN) \ BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided); +#define BFLOAT_BINARY_OP(FN, NAME) \ +BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided); + +#define BFLOAT_BINARY_OP_OUT(NAME, FN) \ +BINARY(FN, bfloat, uint8_t, NAME##_bf16, NAME##_bf16_strided); + BINARY_OP(x + y, add) BINARY_OP(x - y, sub) BINARY_OP(x * y, mul) @@ -112,4 +115,11 @@ BFLOAT_BINARY_OP(x * y, mul) BFLOAT_BINARY_OP(x / y, div) BFLOAT_BINARY_OP(MIN(x, y), min) BFLOAT_BINARY_OP(MAX(x, y), max) + +BFLOAT_BINARY_OP_OUT(eq, x == y) +BFLOAT_BINARY_OP_OUT(ne, x != y) +BFLOAT_BINARY_OP_OUT(le, x <= y) +BFLOAT_BINARY_OP_OUT(lt, x < y) +BFLOAT_BINARY_OP_OUT(ge, x >= y) +BFLOAT_BINARY_OP_OUT(gt, x > y) #endif diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 9aead139..2af3fdce 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -72,27 +72,60 @@ kernel void FN_NAME_STRIDED( \ output[tid] = static_cast(static_cast(input[get_strided_index(tid, num_dims, dims, strides)])); \ } \ +// u32 CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) -CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) -CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) -CAST(cast_f16_f32, cast_f16_f32_strided, half, float) -CAST(cast_f32_f16, cast_f32_f16_strided, float, half) - +CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half) #if __METAL_VERSION__ >= 220 -CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t) CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) -CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) +#endif +#if defined(__HAVE_BFLOAT__) +CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat) #endif +// u8 +CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) +CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) +CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half) +#if __METAL_VERSION__ >= 220 +CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t) +#endif +#if defined(__HAVE_BFLOAT__) +CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) +#endif + +// f16 +CAST(cast_f16_f32, cast_f16_f32_strided, half, float) +CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t) +CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t) +CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t) +#if defined(__HAVE_BFLOAT__) +CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) +#endif + +// i64 +CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) +CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t) +CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t) +CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half) +#if defined(__HAVE_BFLOAT__) +CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float) +#endif + +// f32 +CAST(cast_f32_f16, cast_f32_f16_strided, float, half) +CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t) +CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t) +CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t) +#if defined(__HAVE_BFLOAT__) +CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) +#endif + +// bf16 #if defined(__HAVE_BFLOAT__) CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) +CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) -CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) -CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat) -CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) - CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float) -CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) #endif \ No newline at end of file diff --git a/candle-metal-kernels/src/conv.metal b/candle-metal-kernels/src/conv.metal index dca53161..8fdd0e5f 100644 --- a/candle-metal-kernels/src/conv.metal +++ b/candle-metal-kernels/src/conv.metal @@ -1,3 +1,9 @@ +#include + +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) + template METAL_FUNC void im2col( constant size_t &dst_numel, @@ -200,14 +206,331 @@ kernel void FN_NAME( \ upsample_nearest2d(w_out, h_out, w_scale, h_scale, dims, strides, src, dst, tid); \ } \ +template +METAL_FUNC void avg_pool2d( + constant size_t &w_k, + constant size_t &h_k, + constant size_t &w_stride, + constant size_t &h_stride, + constant size_t *src_dims, + constant size_t *src_strides, + device const T *src, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + const size_t w_out = (w_in - w_k) / w_stride + 1; + const size_t h_out = (h_in - h_k) / h_stride + 1; + if (tid >= src_dims[0] * c * w_out * h_out) { + return; + } + + const size_t b_idx = tid / (w_out * h_out * c); + const size_t c_idx = (tid / (w_out * h_out)) % c; + const size_t dst_w = (tid / h_out) % w_out; + const size_t dst_h = tid % h_out; + + const size_t src_idx0 = b_idx * src_strides[0]; + A d = 0; + for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { + size_t src_w = w_stride * dst_w + w_offset; + if (src_w >= w_in){ + continue; + } + for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { + size_t src_h = h_stride * dst_h + h_offset; + if (src_h >= h_in) { + continue; + } + const size_t src_idx = src_idx0 + c_idx * src_strides[1] + src_w * src_strides[2] + src_h * src_strides[3]; + d += static_cast(src[src_idx]); + } + } + dst[tid] = static_cast(d / (w_k * h_k)); +} + +#define AVGPOOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &w_k, \ + constant size_t &h_k, \ + constant size_t &w_s, \ + constant size_t &h_s, \ + constant size_t *src_dims, \ + constant size_t *src_s, \ + device const TYPENAME *src, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + avg_pool2d(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \ +} \ + +template +METAL_FUNC void max_pool2d( + constant size_t &w_k, + constant size_t &h_k, + constant size_t &w_stride, + constant size_t &h_stride, + constant size_t *src_dims, + constant size_t *src_strides, + device const T *src, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + const size_t w_out = (w_in - w_k) / w_stride + 1; + const size_t h_out = (h_in - h_k) / h_stride + 1; + if (tid >= src_dims[0] * c * w_out * h_out) { + return; + } + + const size_t b_idx = tid / (w_out * h_out * c); + const size_t c_idx = (tid / (w_out * h_out)) % c; + const size_t dst_w = (tid / h_out) % w_out; + const size_t dst_h = tid % h_out; + + const size_t src_idx0 = b_idx * src_strides[0]; + T d = 0; + bool set = false; + for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { + size_t src_w = w_stride * dst_w + w_offset; + if (src_w >= w_in){ + continue; + } + for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { + size_t src_h = h_stride * dst_h + h_offset; + if (src_h >= h_in) { + continue; + } + const size_t src_idx = src_idx0 + c_idx * src_strides[1] + src_w * src_strides[2] + src_h * src_strides[3]; + if (set) { + d = MAX(d, src[src_idx]); + } + else { + d = src[src_idx]; + set = true; + } + } + } + dst[tid] = d; +} + +#define MAXPOOL2D_OP(TYPENAME, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &w_k, \ + constant size_t &h_k, \ + constant size_t &w_s, \ + constant size_t &h_s, \ + constant size_t *src_dims, \ + constant size_t *src_s, \ + device const TYPENAME *src, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + max_pool2d(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \ +} \ + + +// Naive implementation of conv_transpose1d. +template +METAL_FUNC void conv_transpose1d( + constant size_t &l_out, + constant size_t &stride, + constant size_t &padding, + constant size_t &out_padding, + constant size_t &dilation, + constant size_t *src_dims, + constant size_t *src_strides, + constant size_t *k_dims, + constant size_t *k_strides, + device const T *src, + device const T *k, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + // src: (b_size, c_in, l_in) + // kernel: (c_in, c_out, l_k) + const size_t l_k = k_dims[2]; + const size_t c_out = k_dims[1]; + const size_t c_in = src_dims[1]; + const size_t l_in = src_dims[2]; + if (tid >= src_dims[0] * c_out * l_out) { + return; + } + + const size_t b_idx = tid / (l_out * c_out); + const size_t dst_c_idx = (tid / l_out) % c_out; + const size_t out_x = tid % l_out; + + const size_t src_idx0 = b_idx * src_strides[0]; + A d = 0; + for (int k_x = 0; k_x < (int)l_k; ++k_x) { + // let out_x = inp_x * p.stride + k_x * p.dilation - p.padding; + int inp_x_stride = (int)(out_x + padding) - k_x * dilation; + if (inp_x_stride < 0 || inp_x_stride % stride) { + continue; + } + int inp_x = inp_x_stride / stride; + if (inp_x >= l_in) continue; + for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { + const size_t src_idx = src_idx0 + src_c_idx * src_strides[1] + inp_x * src_strides[2]; + const size_t k_idx = src_c_idx * k_strides[0] + dst_c_idx * k_strides[1] + k_x * k_strides[2]; + d += static_cast(src[src_idx]) * static_cast(k[k_idx]); + } + } + dst[tid] = static_cast(d); +} + +#define CONVT1D_OP(TYPENAME, TYPEACC, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &l_out, \ + constant size_t &stride, \ + constant size_t &padding, \ + constant size_t &out_padding, \ + constant size_t &dilation, \ + constant size_t *src_dims, \ + constant size_t *src_strides, \ + constant size_t *k_dims, \ + constant size_t *k_strides, \ + device const TYPENAME *src, \ + device const TYPENAME *k, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + conv_transpose1d(l_out, stride, padding, out_padding, dilation, src_dims, src_strides, k_dims, k_strides, src, k, dst, tid); \ +} \ + +template +METAL_FUNC void conv_transpose2d( + constant size_t &w_out, + constant size_t &h_out, + constant size_t &stride, + constant size_t &padding, + constant size_t &out_padding, + constant size_t &dilation, + constant size_t *input_dims, + constant size_t *input_stride, + constant size_t *k_dims, + constant size_t *k_stride, + device const T *src, + device const T *k, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + const size_t h_k = k_dims[2]; + const size_t w_k = k_dims[3]; + const size_t c_out = k_dims[1]; + const size_t c_in = input_dims[1]; + const size_t h_in = input_dims[2]; + const size_t w_in = input_dims[3]; + + if (tid >= input_dims[0] * c_out * w_out * h_out) { + return; + } + + const size_t b_idx = tid / (w_out * h_out * c_out); + const size_t dst_c_idx = (tid / (w_out * h_out)) % c_out; + const size_t out_y = (tid / w_out) % h_out; + const size_t out_x = tid % w_out; + + const size_t src_idx0 = b_idx * input_stride[0]; + + A d = 0; + for (int k_x = 0; k_x < (int)w_k; ++k_x) { + const int inp_x_stride = (int)(out_x + padding) - k_x * dilation; + if (inp_x_stride < 0 || inp_x_stride % stride) { + continue; + } + const int inp_x = inp_x_stride / stride; + if (inp_x >= w_in) continue; + for (int k_y = 0; k_y < (int)h_k; ++k_y) { + const int inp_y_stride = (int)(out_y + padding) - k_y * dilation; + if (inp_y_stride < 0 || inp_y_stride % stride) { + continue; + } + const int inp_y = inp_y_stride / stride; + if (inp_y >= h_in) continue; + for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { + const size_t src_idx = src_idx0 + src_c_idx * input_stride[1] + inp_y * input_stride[2] + inp_x * input_stride[3]; + const size_t k_idx = src_c_idx * k_stride[0] + dst_c_idx * k_stride[1] + k_y * k_stride[2] + k_x * k_stride[3]; + d += static_cast(src[src_idx]) * static_cast(k[k_idx]); + } + } + } + dst[tid] = static_cast(d); +} + +#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &w_out, \ + constant size_t &h_out, \ + constant size_t &stride, \ + constant size_t &padding, \ + constant size_t &out_padding, \ + constant size_t &dilation, \ + constant size_t *input_dims, \ + constant size_t *input_stride, \ + constant size_t *k_dims, \ + constant size_t *k_stride, \ + device const TYPENAME *src, \ + device const TYPENAME *k, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + conv_transpose2d(w_out, h_out, stride, padding, out_padding, dilation, input_dims, input_stride, k_dims, k_stride, src, k, dst, tid); \ +} \ + IM2COL_OP(float, im2col_f32) +IM2COL_OP(half, im2col_f16) IM2COL_OP(uint8_t, im2col_u8) IM2COL_OP(uint32_t, im2col_u32) +#if defined(__HAVE_BFLOAT__) +IM2COL_OP(bfloat, im2col_bf16) +#endif IM2COL1D_OP(float, im2col1d_f32) IM2COL1D_OP(uint8_t, im2col1d_u8) IM2COL1D_OP(uint32_t, im2col1d_u32) UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32) +UPSAMPLE_NEAREST2D_OP(half, upsample_nearest2d_f16) UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8) UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32) +#if defined(__HAVE_BFLOAT__) +UPSAMPLE_NEAREST2D_OP(bfloat, upsample_nearest2d_bf16) +#endif + +MAXPOOL2D_OP(float, max_pool2d_f32) +MAXPOOL2D_OP(half, max_pool2d_f16) +MAXPOOL2D_OP(uint32_t, max_pool2d_u32) +MAXPOOL2D_OP(uint8_t, max_pool2d_u8) +#if defined(__HAVE_BFLOAT__) +MAXPOOL2D_OP(bfloat, max_pool2d_bf16) +#endif + +AVGPOOL2D_OP(float, float, avg_pool2d_f32) +AVGPOOL2D_OP(half, float, avg_pool2d_f16) +AVGPOOL2D_OP(uint32_t, uint32_t, avg_pool2d_u32) +AVGPOOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8) +#if defined(__HAVE_BFLOAT__) +AVGPOOL2D_OP(bfloat, float, avg_pool2d_bf16) +#endif + +CONVT1D_OP(float, float, conv_transpose1d_f32) +CONVT1D_OP(half, float, conv_transpose1d_f16) +CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8) +CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32) +#if defined(__HAVE_BFLOAT__) +CONVT1D_OP(bfloat, float, conv_transpose1d_bf16) +#endif + +CONVT2D_OP(float, float, conv_transpose2d_f32) +CONVT2D_OP(half, float, conv_transpose2d_f16) +#if defined(__HAVE_BFLOAT__) +CONVT1D_OP(bfloat, float, conv_transpose2d_bf16) +#endif \ No newline at end of file diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 2a57bdbb..9eee97ca 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -1,20 +1,38 @@ #include using namespace metal; +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + template METAL_FUNC void index( constant size_t &dst_size, constant size_t &left_size, constant size_t &src_dim_size, constant size_t &right_size, - constant size_t &ids_size, - const device TYPENAME *input, + constant size_t &ids_size, + constant bool &contiguous, + constant size_t *src_dims, + constant size_t *src_strides, + const device TYPENAME *input, const device INDEX_TYPENAME *input_ids, device TYPENAME *output, uint tid [[ thread_position_in_grid ]] ) { if (tid >= dst_size) { - return; + return; } const size_t id_i = (tid / right_size) % ids_size; const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); @@ -26,7 +44,8 @@ METAL_FUNC void index( // No need to check for zero we're only allowing unsized. */ const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; - output[tid] = input[src_i]; + const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides); + output[tid] = input[strided_src_i]; } # define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \ @@ -36,12 +55,15 @@ kernel void NAME( \ constant size_t &src_dim_size, \ constant size_t &right_size, \ constant size_t &ids_size, \ + constant bool &contiguous, \ + constant size_t *src_dims, \ + constant size_t *src_strides, \ const device TYPENAME *input, \ const device INDEX_TYPENAME *input_ids, \ device TYPENAME *output, \ uint tid [[ thread_position_in_grid ]] \ ) { \ - index(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \ + index(dst_size, left_size, src_dim_size, right_size, ids_size, contiguous, src_dims, src_strides, input, input_ids, output, tid); \ } @@ -165,37 +187,68 @@ kernel void NAME( \ } -INDEX_OP(is_u32_f32, uint, float) -INDEX_OP(is_u32_f16, uint, half) -GATHER_OP(gather_u32_f32, uint, float) -GATHER_OP(gather_u32_f16, uint, half) -SCATTER_ADD_OP(sa_u32_f32, uint, float) -SCATTER_ADD_OP(sa_u32_f16, uint, half) - - +INDEX_OP(is_i64_f32, int64_t, float) +INDEX_OP(is_i64_f16, int64_t, half) #if defined(__HAVE_BFLOAT__) -INDEX_OP(is_u32_bf16, uint32_t, bfloat) -INDEX_OP(is_u8_bf16, uint8_t, bfloat) - -INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) -INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) -INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) +INDEX_OP(is_i64_bf16, int64_t, bfloat) #endif -INDEX_ADD_OP(ia_u32_f16, uint32_t, half) -INDEX_ADD_OP(ia_u8_f16, uint8_t, half) +INDEX_OP(is_u32_f32, uint32_t, float) +INDEX_OP(is_u32_f16, uint32_t, half) +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_u32_bf16, uint32_t, bfloat) +#endif +INDEX_OP(is_u8_f32, uint8_t, float) +INDEX_OP(is_u8_f16, uint8_t, half) +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_u8_bf16, uint8_t, bfloat) +#endif + +GATHER_OP(gather_u32_f32, uint, float) +GATHER_OP(gather_u32_f16, uint, half) +#if defined(__HAVE_BFLOAT__) +GATHER_OP(gather_u32_bf16, uint, bfloat) +#endif + +SCATTER_ADD_OP(sa_u32_f32, uint32_t, float) +SCATTER_ADD_OP(sa_u8_f32, uint8_t, float) +SCATTER_ADD_OP(sa_i64_f32, int64_t, float) +SCATTER_ADD_OP(sa_u32_f16, uint32_t, half) +SCATTER_ADD_OP(sa_u8_f16, uint8_t, half) +SCATTER_ADD_OP(sa_i64_f16, int64_t, half) +#if defined(__HAVE_BFLOAT__) +SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat) +SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat) +SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat) +#endif + +// i64 +INDEX_ADD_OP(ia_i64_f16, int64_t, half) INDEX_ADD_OP(ia_i64_f32, int64_t, float) -INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t) INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t) INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t) +INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) +#endif +// u32 +INDEX_ADD_OP(ia_u32_f16, uint32_t, half) INDEX_ADD_OP(ia_u32_f32, uint32_t, float) -INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t) INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t) INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t) +INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) +#endif +// u8 +INDEX_ADD_OP(ia_u8_f16, uint8_t, half) INDEX_ADD_OP(ia_u8_f32, uint8_t, float) -INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t) -INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t) INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t) +INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t) +INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) +#endif diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index f76af4cb..ab74daa1 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,11 +1,15 @@ use metal::{ - Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, - Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, + Buffer, CommandBufferRef, CompileOptions, ComputePipelineState, Device, Function, + FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, }; use std::collections::HashMap; use std::ffi::c_void; use std::sync::RwLock; +mod utils; +pub use utils::BufferOffset; +use utils::{get_block_dims, linear_split}; + const AFFINE: &str = include_str!("affine.metal"); const INDEXING: &str = include_str!("indexing.metal"); const UNARY: &str = include_str!("unary.metal"); @@ -18,100 +22,6 @@ const RANDOM: &str = include_str!("random.metal"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); const QUANTIZED: &str = include_str!("quantized.metal"); -/// Most kernels apply similarly across the tensors -/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the -/// actual total buffer length). -/// Then kernels can just do their op on their single point in the buffer. -fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { - let size = length as u64; - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); - let count = (size + width - 1) / width; - let thread_group_count = MTLSize { - width: count, - height: 1, - depth: 1, - }; - - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; - (thread_group_count, thread_group_size) -} - -fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: P) { -

::set_param(encoder, position, data) -} - -/// Helper functions to create the various objects on the compute command encoder -/// on a single line. -/// Prevents getting wrong some arguments number and mixing length and size in bytes. -trait EncoderParam { - fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); -} -macro_rules! primitive { - ($type:ty) => { - impl EncoderParam for $type { - fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { - encoder.set_bytes( - position, - core::mem::size_of::<$type>() as u64, - &data as *const $type as *const c_void, - ); - } - } - }; -} -primitive!(bool); -primitive!(usize); -primitive!(i32); -primitive!(i64); -primitive!(u32); -primitive!(u64); -primitive!(f32); - -impl EncoderParam for &[T] { - fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { - encoder.set_bytes( - position, - core::mem::size_of_val(data) as u64, - data.as_ptr() as *const c_void, - ); - } -} - -impl EncoderParam for &Buffer { - fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { - encoder.set_buffer(position, Some(data), 0); - } -} -impl EncoderParam for (&Buffer, usize) { - fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { - encoder.set_buffer(position, Some(data.0), data.1 as u64); - } -} -impl EncoderParam for &mut Buffer { - fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { - encoder.set_buffer(position, Some(data), 0); - } -} -impl EncoderParam for (&mut Buffer, usize) { - fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { - encoder.set_buffer(position, Some(data.0), data.1 as u64); - } -} - -macro_rules! set_params { - ($encoder:ident, ($($param:expr),+)) => ( - let mut _index = 0; - $( - set_param($encoder, _index, $param); - _index += 1; - )* - ); -} - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { Affine, @@ -127,6 +37,16 @@ pub enum Source { Quantized, } +pub mod copy2d { + pub struct Kernel(pub &'static str); + pub const FLOAT: Kernel = Kernel("copy2d_f32"); + pub const HALF: Kernel = Kernel("copy2d_f16"); + pub const BFLOAT: Kernel = Kernel("copy2d_bf16"); + pub const I64: Kernel = Kernel("copy2d_i64"); + pub const U32: Kernel = Kernel("copy2d_u32"); + pub const U8: Kernel = Kernel("copy2d_u8"); +} + macro_rules! ops{ ($($name:ident),+) => { @@ -183,7 +103,7 @@ macro_rules! ops{ pub mod unary { ops!( cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf, - tanh, recip, silu + tanh, recip, silu, sign ); } pub mod binary { @@ -225,6 +145,12 @@ pub struct Kernels { pipelines: RwLock, } +impl Default for Kernels { + fn default() -> Self { + Self::new() + } +} + impl Kernels { pub fn new() -> Self { let libraries = RwLock::new(Libraries::new()); @@ -348,23 +274,66 @@ pub fn call_unary_contiguous( kernels: &Kernels, kernel_name: unary::contiguous::Kernel, length: usize, - input: &Buffer, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, input, output)); + set_params!(encoder, (length, &input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_copy2d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: copy2d::Kernel, + input: &Buffer, + output: &Buffer, + d1: usize, + d2: usize, + src_s: usize, + dst_s: usize, + src_o_in_bytes: usize, + dst_o_in_bytes: usize, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + d1 as i64, + d2 as i64, + src_s as i64, + dst_s as i64, + (input, src_o_in_bytes), + (output, dst_o_in_bytes) + ) + ); + + let grid_dims = MTLSize { + width: d1 as u64, + height: d2 as u64, + depth: 1, + }; + let group_dims = get_block_dims(d1 as u64, d2 as u64, 1); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_threads(grid_dims, group_dims); + encoder.end_encoding(); + Ok(()) +} + #[allow(clippy::too_many_arguments)] pub fn call_unary_strided( device: &Device, @@ -372,11 +341,9 @@ pub fn call_unary_strided( kernels: &Kernels, name: unary::strided::Kernel, shape: &[usize], - input: &Buffer, + input: BufferOffset, strides: &[usize], - offset: usize, - output: &Buffer, - output_offset: usize, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; @@ -385,23 +352,13 @@ pub fn call_unary_strided( encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); - set_params!( - encoder, - ( - length, - num_dims, - shape, - strides, - (input, offset), - (output, output_offset) - ) - ); + set_params!(encoder, (length, num_dims, shape, strides, &input, &output)); let width: usize = shape.iter().product(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); - encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) @@ -414,8 +371,8 @@ pub fn call_binary_contiguous( kernels: &Kernels, kernel_name: binary::contiguous::Kernel, length: usize, - left: &Buffer, - right: &Buffer, + left: BufferOffset, + right: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; @@ -423,12 +380,12 @@ pub fn call_binary_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, left, right, output)); + set_params!(encoder, (length, &left, &right, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); - encoder.use_resource(left, metal::MTLResourceUsage::Read); - encoder.use_resource(right, metal::MTLResourceUsage::Read); + encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -442,12 +399,10 @@ pub fn call_binary_strided( kernels: &Kernels, name: binary::strided::Kernel, shape: &[usize], - left_input: &Buffer, + left_input: BufferOffset, left_strides: &[usize], - left_offset: usize, - right_input: &Buffer, + right_input: BufferOffset, right_strides: &[usize], - right_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?; @@ -467,16 +422,16 @@ pub fn call_binary_strided( shape, left_strides, right_strides, - (left_input, left_offset), - (right_input, right_offset), + &left_input, + &right_input, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); - encoder.use_resource(left_input, metal::MTLResourceUsage::Read); - encoder.use_resource(right_input, metal::MTLResourceUsage::Read); + encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -490,8 +445,7 @@ pub fn call_cast_contiguous( kernels: &Kernels, kernel_name: &'static str, length: usize, - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; @@ -499,10 +453,10 @@ pub fn call_cast_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, (input, input_offset), output)); + set_params!(encoder, (length, &input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -516,9 +470,8 @@ pub fn call_cast_strided( kernels: &Kernels, kernel_name: &'static str, shape: &[usize], - input: &Buffer, + input: BufferOffset, input_strides: &[usize], - input_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; @@ -530,25 +483,19 @@ pub fn call_cast_strided( set_params!( encoder, - ( - length, - shape.len(), - shape, - input_strides, - (input, input_offset), - output - ) + (length, shape.len(), shape, input_strides, &input, output) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_reduce_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -556,8 +503,7 @@ pub fn call_reduce_contiguous( kernel_name: &'static str, length: usize, out_length: usize, - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; @@ -566,10 +512,7 @@ pub fn call_reduce_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!( - encoder, - (length, elements_to_sum, (input, input_offset), output) - ); + set_params!(encoder, (length, elements_to_sum, &input, output)); let thread_group_count = MTLSize { width: out_length as u64, @@ -589,13 +532,14 @@ pub fn call_reduce_contiguous( depth: 1, }; - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_reduce_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -604,8 +548,7 @@ pub fn call_reduce_strided( shape: &[usize], strides: &[usize], out_length: usize, - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let length: usize = shape.iter().product(); @@ -617,14 +560,7 @@ pub fn call_reduce_strided( set_params!( encoder, - ( - shape.len(), - shape, - strides, - elements_to_sum, - (input, input_offset), - output - ) + (shape.len(), shape, strides, elements_to_sum, &input, output) ); let thread_group_count = MTLSize { @@ -645,7 +581,7 @@ pub fn call_reduce_strided( depth: 1, }; - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -700,6 +636,193 @@ pub fn call_last_softmax( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_rms_norm( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + elements_to_sum: usize, + eps: f32, + input: &Buffer, + input_offset: usize, + alpha: &Buffer, + alpha_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + length, + elements_to_sum, + (input, input_offset), + output, + (alpha, alpha_offset), + eps + ) + ); + + let out_length = length / elements_to_sum; + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + elements_to_sum as u64, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_rope_i( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + bh: usize, + td: usize, + src: &Buffer, + src_offset: usize, + cos: &Buffer, + cos_offset: usize, + sin: &Buffer, + sin_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + bh, + td, + (src, src_offset), + (cos, cos_offset), + (sin, sin_offset), + output + ) + ); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2); + encoder.use_resource(src, metal::MTLResourceUsage::Read); + encoder.use_resource(cos, metal::MTLResourceUsage::Read); + encoder.use_resource(sin, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_rope_thd( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + b: usize, + t: usize, + h: usize, + d: usize, + src: &Buffer, + src_offset: usize, + cos: &Buffer, + cos_offset: usize, + sin: &Buffer, + sin_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + b, + t, + h, + d, + (src, src_offset), + (cos, cos_offset), + (sin, sin_offset), + output + ) + ); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, (b * t * h * d) / 2); + encoder.use_resource(src, metal::MTLResourceUsage::Read); + encoder.use_resource(cos, metal::MTLResourceUsage::Read); + encoder.use_resource(sin, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_rope( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + bh: usize, + td: usize, + d: usize, + src: &Buffer, + src_offset: usize, + cos: &Buffer, + cos_offset: usize, + sin: &Buffer, + sin_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + bh, + td, + d, + (src, src_offset), + (cos, cos_offset), + (sin, sin_offset), + output + ) + ); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2); + encoder.use_resource(src, metal::MTLResourceUsage::Read); + encoder.use_resource(cos, metal::MTLResourceUsage::Read); + encoder.use_resource(sin, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[allow(clippy::too_many_arguments)] pub fn call_affine( device: &Device, @@ -707,7 +830,7 @@ pub fn call_affine( kernels: &Kernels, name: &'static str, size: usize, - input: &Buffer, + input: BufferOffset, output: &Buffer, mul: f32, add: f32, @@ -717,10 +840,10 @@ pub fn call_affine( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (size, mul, add, input, output)); + set_params!(encoder, (size, mul, add, &input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -734,9 +857,8 @@ pub fn call_affine_strided( kernels: &Kernels, name: &'static str, shape: &[usize], - input: &Buffer, + input: BufferOffset, input_stride: &[usize], - input_offset: usize, output: &Buffer, mul: f32, add: f32, @@ -756,13 +878,13 @@ pub fn call_affine_strided( input_stride, mul, add, - (input, input_offset), + &input, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -776,7 +898,7 @@ pub fn call_powf( kernels: &Kernels, name: &'static str, size: usize, - input: &Buffer, + input: BufferOffset, output: &Buffer, mul: f32, ) -> Result<(), MetalKernelError> { @@ -785,10 +907,10 @@ pub fn call_powf( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (size, mul, input, output)); + set_params!(encoder, (size, mul, &input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -802,9 +924,8 @@ pub fn call_powf_strided( kernels: &Kernels, name: &'static str, shape: &[usize], - input: &Buffer, + input: BufferOffset, input_stride: &[usize], - input_offset: usize, output: &Buffer, mul: f32, ) -> Result<(), MetalKernelError> { @@ -816,19 +937,11 @@ pub fn call_powf_strided( set_params!( encoder, - ( - size, - shape.len(), - shape, - input_stride, - mul, - (input, input_offset), - output - ) + (size, shape.len(), shape, input_stride, mul, &input, output) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -842,7 +955,7 @@ pub fn call_elu( kernels: &Kernels, name: &'static str, size: usize, - input: &Buffer, + input: BufferOffset, output: &Buffer, mul: f32, ) -> Result<(), MetalKernelError> { @@ -851,10 +964,10 @@ pub fn call_elu( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (size, mul, input, output)); + set_params!(encoder, (size, mul, &input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -868,9 +981,8 @@ pub fn call_elu_strided( kernels: &Kernels, name: &'static str, shape: &[usize], - input: &Buffer, + input: BufferOffset, input_stride: &[usize], - input_offset: usize, output: &Buffer, mul: f32, ) -> Result<(), MetalKernelError> { @@ -882,37 +994,30 @@ pub fn call_elu_strided( set_params!( encoder, - ( - size, - shape.len(), - shape, - input_stride, - mul, - (input, input_offset), - output - ) + (size, shape.len(), shape, input_stride, mul, &input, output) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_where_cond_strided( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, shape: &[usize], - cond: &Buffer, - (cond_stride, cond_offset): (&[usize], usize), - left: &Buffer, - (left_stride, left_offset): (&[usize], usize), - right: &Buffer, - (right_stride, right_offset): (&[usize], usize), + cond: BufferOffset, + cond_stride: &[usize], + left: BufferOffset, + left_stride: &[usize], + right: BufferOffset, + right_stride: &[usize], output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; @@ -932,18 +1037,18 @@ pub fn call_where_cond_strided( cond_stride, left_stride, right_stride, - (cond, cond_offset), - (left, left_offset), - (right, right_offset), + &cond, + &left, + &right, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); - encoder.use_resource(cond, metal::MTLResourceUsage::Read); - encoder.use_resource(left, metal::MTLResourceUsage::Read); - encoder.use_resource(right, metal::MTLResourceUsage::Read); + encoder.use_resource(cond.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -959,8 +1064,11 @@ pub fn call_index_select( shape: &[usize], ids_size: usize, dim: usize, - input: &Buffer, - ids: &Buffer, + contiguous: bool, + src_dims: &[usize], + src_strides: &[usize], + input: BufferOffset, + ids: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); @@ -982,16 +1090,19 @@ pub fn call_index_select( src_dim_size, right_size, ids_size, - input, - ids, + contiguous, + src_dims, + src_strides, + &input, + &ids, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1007,10 +1118,8 @@ pub fn call_gather( shape: &[usize], ids_size: usize, dim: usize, - input: &Buffer, - input_offset: usize, - ids: &Buffer, - ids_offset: usize, + input: BufferOffset, + ids: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); @@ -1032,22 +1141,23 @@ pub fn call_gather( src_dim_size, right_size, ids_size, - (input, input_offset), - (ids, ids_offset), + &input, + &ids, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_scatter_add( device: &Device, command_buffer: &CommandBufferRef, @@ -1056,10 +1166,8 @@ pub fn call_scatter_add( src_shape: &[usize], dst_shape: &[usize], dim: usize, - input: &Buffer, - input_offset: usize, - ids: &Buffer, - ids_offset: usize, + input: BufferOffset, + ids: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = src_shape[..dim].iter().product(); @@ -1082,22 +1190,23 @@ pub fn call_scatter_add( src_dim_size, right_size, dst_dim_size, - (input, input_offset), - (ids, ids_offset), + &input, + &ids, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_index_add( device: &Device, command_buffer: &CommandBufferRef, @@ -1107,10 +1216,8 @@ pub fn call_index_add( dst_shape: &[usize], ids_shape: &[usize], dim: usize, - input: &Buffer, - input_offset: usize, - ids: &Buffer, - ids_offset: usize, + input: BufferOffset, + ids: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = src_shape[..dim].iter().product(); @@ -1134,16 +1241,16 @@ pub fn call_index_add( right_size, dst_dim_size, ids_dim_size, - (input, input_offset), - (ids, ids_offset), + &input, + &ids, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1251,9 +1358,12 @@ pub fn call_gemm( let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; - let a_trans = if lhs_m1 == 1 && lhs_m2 == k { + // lhs has shape b, m, k + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let a_trans = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { false - } else if lhs_m1 == m && lhs_m2 == 1 { + } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) { true } else { return Err(MetalKernelError::MatMulNonContiguous { @@ -1262,9 +1372,10 @@ pub fn call_gemm( mnk: (m, n, k), })?; }; - let b_trans = if rhs_m1 == 1 && rhs_m2 == n { + // rhs has shape b, k, n + let b_trans = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { false - } else if rhs_m1 == k && rhs_m2 == 1 { + } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) { true } else { return Err(MetalKernelError::MatMulNonContiguous { @@ -1406,8 +1517,7 @@ pub fn call_im2col1d_strided( shape: &[usize], strides: &[usize], (k_size, stride, padding, dilation): (usize, usize, usize, usize), - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; @@ -1419,20 +1529,9 @@ pub fn call_im2col1d_strided( encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, - ( - dst_el, - l_out, - k_size, - stride, - padding, - dilation, - shape, - strides, - (input, input_offset), - output - ) + (dst_el, l_out, k_size, stride, padding, dilation, shape, strides, &input, output) ); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1449,8 +1548,7 @@ pub fn call_im2col_strided( shape: &[usize], strides: &[usize], (h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize), - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; @@ -1468,21 +1566,11 @@ pub fn call_im2col_strided( set_params!( encoder, ( - dst_el, - h_out, - w_out, - h_k, - w_k, - stride, - padding, - dilation, - shape, - strides, - (input, input_offset), + dst_el, h_out, w_out, h_k, w_k, stride, padding, dilation, shape, strides, &input, output ) ); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1500,8 +1588,7 @@ pub fn call_upsample_nearest_2d( strides: &[usize], out_w: usize, out_h: usize, - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; @@ -1513,18 +1600,9 @@ pub fn call_upsample_nearest_2d( encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, - ( - out_w, - out_h, - scale_w, - scale_h, - shape, - strides, - (input, input_offset), - output - ) + (out_w, out_h, scale_w, scale_h, shape, strides, &input, output) ); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1559,8 +1637,10 @@ pub fn call_random_uniform( set_params!(encoder, (length, min, max, seed, buffer)); - encoder.use_resource(seed, metal::MTLResourceUsage::Read); - encoder.use_resource(seed, metal::MTLResourceUsage::Write); + encoder.use_resource( + seed, + metal::MTLResourceUsage::Read | metal::MTLResourceUsage::Write, + ); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1590,8 +1670,10 @@ pub fn call_random_normal( set_params!(encoder, (length, mean, stddev, seed, buffer)); - encoder.use_resource(seed, metal::MTLResourceUsage::Read); - encoder.use_resource(seed, metal::MTLResourceUsage::Write); + encoder.use_resource( + seed, + metal::MTLResourceUsage::Read | metal::MTLResourceUsage::Write, + ); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1617,6 +1699,7 @@ pub enum GgmlDType { F32, } +#[allow(clippy::too_many_arguments)] pub fn call_quantized_matmul_t( device: &Device, command_buffer: &CommandBufferRef, @@ -1632,16 +1715,16 @@ pub fn call_quantized_matmul_t( let ne00 = k as i64; let ne01 = n as i64; let ne02 = b as i64; - let ne03 = 1 as i64; + let ne03 = 1i64; let nb00 = 0i64; - let nb01 = 0 as i64; - let nb02 = 0 as i64; + let nb01 = 0i64; + let nb02 = 0i64; let ne10 = k as i64; let ne11 = m as i64; let ne12 = b as i64; - let ne13 = 1 as i64; + let ne13 = 1i64; let nb10 = 0i64; let nb11 = 0i64; @@ -1773,5 +1856,150 @@ fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger } +#[allow(clippy::too_many_arguments)] +pub fn call_pool2d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + strides: &[usize], + out_w: usize, + out_h: usize, + w_k: usize, + h_k: usize, + w_stride: usize, + h_stride: usize, + input: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let dst_el = out_w * out_h * shape[0] * shape[1]; + let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + (w_k, h_k, w_stride, h_stride, shape, strides, input, output) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_conv_transpose1d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + dilation: usize, + stride: usize, + padding: usize, + out_padding: usize, + c_out: usize, + l_out: usize, + b_size: usize, + src_shape: &[usize], + src_strides: &[usize], + kernel_shape: &[usize], + kernel_strides: &[usize], + input: &Buffer, + input_offset: usize, + kernel: &Buffer, + kernel_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let dst_el = c_out * l_out * b_size; + let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + l_out, + stride, + padding, + out_padding, + dilation, + src_shape, + src_strides, + kernel_shape, + kernel_strides, + (input, input_offset), + (kernel, kernel_offset), + output + ) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(kernel, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub struct CallConvTranspose2dCfg<'a> { + pub dilation: usize, + pub stride: usize, + pub padding: usize, + pub output_padding: usize, + pub c_out: usize, + pub out_w: usize, + pub out_h: usize, + pub b_size: usize, + pub input_dims: &'a [usize], + pub input_stride: &'a [usize], + pub kernel_dims: &'a [usize], + pub kernel_stride: &'a [usize], + pub input_offset: usize, + pub kernel_offset: usize, +} + +#[allow(clippy::too_many_arguments)] +pub fn call_conv_transpose2d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + cfg: CallConvTranspose2dCfg, + input: &Buffer, + kernel: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let dst_el = cfg.c_out * cfg.out_w * cfg.out_h * cfg.b_size; + let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + cfg.out_w, + cfg.out_h, + cfg.stride, + cfg.padding, + cfg.output_padding, + cfg.dilation, + cfg.input_dims, + cfg.input_stride, + cfg.kernel_dims, + cfg.kernel_stride, + (input, cfg.input_offset), + (kernel, cfg.kernel_offset), + output + ) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(kernel, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/random.metal b/candle-metal-kernels/src/random.metal index a7e48393..c1a94199 100644 --- a/candle-metal-kernels/src/random.metal +++ b/candle-metal-kernels/src/random.metal @@ -123,16 +123,20 @@ template METAL_FUNC void rand_uniform( return; } + // Evenly sized vectors need an offset when writing the mirror element. + uint off = 1 - size % 2; float diff = abs(min - max); - HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1}); + uint s = atomic_load_explicit(seed, memory_order_relaxed); + HybridTaus rng = HybridTaus::init({ulong(s), tid, 1, 1}); out[tid] = static_cast(rng.rand() * diff + min); if (tid == 0) { atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); - // Return early if tid == 0, otherwise we will write to out[size]. - return; + // Return early if tid == 0 && off == 0, otherwise we will write to out[size]. + if (off == 0) + return; } // Use symmetry to fill the other half of the array. - out[size - tid] = static_cast(rng.rand() * diff + min); + out[size - off - tid] = static_cast(rng.rand() * diff + min); } // Create Gaussian normal distribution using Box-Muller transform: @@ -148,7 +152,10 @@ template METAL_FUNC void normal( if (tid >= size) { return; } - HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1}); + // Evenly sized vectors need an offset when writing the mirror element. + uint off = 1 - size % 2; + uint s = atomic_load_explicit(seed, memory_order_relaxed); + HybridTaus rng = HybridTaus::init({ulong(s), tid, 1, 1}); float u1 = rng.rand(); float u2 = rng.rand(); @@ -162,11 +169,12 @@ template METAL_FUNC void normal( if (tid == 0) { atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); - // Return early if tid == 0, otherwise we will write to out[size]. - return; + // Return early if tid == 0 && off == 0, otherwise we will write to out[size]. + if (off == 0) + return; } // Use symmetry to fill the other half of the array. - out[size - tid] = static_cast(z1); + out[size - off - tid] = static_cast(z1); } #define UNIFORM_OP(NAME, T) \ diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 93dac662..14bfb297 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -21,6 +21,52 @@ METAL_FUNC uint get_strided_index( constant int THREADGROUP_SIZE = 2048; +template +METAL_FUNC void argmin( + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides, + constant size_t &el_to_sum_per_block, + device const T *src, + device uint *dst, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup T *shared_memory, + threadgroup uint *shared_indices +) { + bool notset = true; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = start_idx + el_to_sum_per_block; + size_t idx = start_idx + tid; + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + if (notset || src[strided_i] < shared_memory[tid]) { + shared_memory[tid] = src[strided_i]; + /* Assume that the reduction takes place over the last dimension which is contiguous. */ + shared_indices[tid] = idx % dims[num_dims - 1]; + notset = false; + } + idx += block_dim; + } + + threadgroup_barrier(mem_flags::mem_none); + // reduction in shared memory + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { + shared_indices[tid] = shared_indices[tid + s]; + shared_memory[tid] = shared_memory[tid + s]; + } \ + threadgroup_barrier(mem_flags::mem_none); + } + if (tid == 0) { + dst[dst_id] = shared_indices[0]; + } +} #define ARGMIN(NAME, T, MAXVALUE) \ kernel void NAME( \ @@ -35,53 +81,63 @@ kernel void NAME( \ uint dst_id [[ threadgroup_position_in_grid ]], \ uint block_dim [[ threads_per_threadgroup ]] \ ) { \ - \ - threadgroup T shared_memory[THREADGROUP_SIZE]; \ - threadgroup uint shared_indices[THREADGROUP_SIZE]; \ - \ - shared_memory[tid] = MAXVALUE; \ - shared_indices[tid] = 0xFFFFFFFF; \ - bool notset = true; \ - /* \ - // Elements summed in this block range from dst_id * el_to_sum_per_block \ - // to (dst_id + 1) * el_to_sum_per_block. \ - */ \ - size_t start_idx = dst_id * el_to_sum_per_block; \ - size_t stop_idx = start_idx + el_to_sum_per_block; \ - size_t idx = start_idx + tid; \ - while (idx < stop_idx) { \ - /* \ - // TODO: Fast version for the contiguous case. \ - */ \ - size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ - if (notset || src[strided_i] < shared_memory[tid]) { \ - shared_memory[tid] = src[strided_i]; \ - /* Assume that the reduction takes place over the last dimension which is contiguous. */ \ - shared_indices[tid] = idx % dims[num_dims - 1]; \ - notset = false; \ - } \ - idx += block_dim; \ - } \ - \ - threadgroup_barrier(mem_flags::mem_none); \ - \ - /* \ - // reduction in shared memory \ - */ \ - for (uint s = block_dim / 2; s > 0; s >>= 1) { \ - if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \ - shared_indices[tid] = shared_indices[tid + s]; \ - shared_memory[tid] = shared_memory[tid + s]; \ - } \ - threadgroup_barrier(mem_flags::mem_none); \ - } \ - \ - if (tid == 0){ \ - dst[dst_id] = shared_indices[0]; \ - } \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ + threadgroup uint shared_indices[THREADGROUP_SIZE]; \ + shared_memory[tid] = MAXVALUE; \ + shared_indices[tid] = 0xFFFFFFFF; \ + argmin(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \ } \ +template +METAL_FUNC void argmax( + constant size_t & num_dims, + constant size_t * dims, + constant size_t * strides, + constant size_t & el_to_sum_per_block, + device const T * src, + device uint * dst, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup T * shared_memory, + threadgroup uint * shared_indices + ) { + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = start_idx + el_to_sum_per_block; + size_t idx = start_idx + tid; + bool notset = true; + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + if (notset || shared_memory[tid] < src[strided_i]) { + shared_memory[tid] = src[strided_i]; + shared_indices[tid] = idx % dims[num_dims - 1]; + notset = false; + } + idx += block_dim; + } + + threadgroup_barrier(mem_flags::mem_none); + + // reduction in shared memory + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { + shared_indices[tid] = shared_indices[tid + s]; + shared_memory[tid] = shared_memory[tid + s]; + } + threadgroup_barrier(mem_flags::mem_none); + } + + // Thread 0 writes the result of the reduction + if (tid == 0) { + dst[dst_id] = shared_indices[0]; + } + } + #define ARGMAX(NAME, T, MINVALUE) \ kernel void NAME( \ constant size_t &num_dims, \ @@ -95,170 +151,337 @@ kernel void NAME( \ uint dst_id [[ threadgroup_position_in_grid ]], \ uint block_dim [[ threads_per_threadgroup ]] \ ) { \ - \ threadgroup T shared_memory[THREADGROUP_SIZE]; \ threadgroup uint shared_indices[THREADGROUP_SIZE]; \ - \ shared_memory[tid] = MINVALUE; \ shared_indices[tid] = 0xFFFFFFFF; \ - /* \ - // Elements summed in this block range from dst_id * el_to_sum_per_block \ - // to (dst_id + 1) * el_to_sum_per_block. \ - */ \ - size_t start_idx = dst_id * el_to_sum_per_block; \ - size_t stop_idx = start_idx + el_to_sum_per_block; \ - size_t idx = start_idx + tid; \ - bool notset = true; \ - while (idx < stop_idx) { \ - /* \ - // TODO: Fast version for the contiguous case. \ - */ \ - size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ - if (notset || shared_memory[tid] < src[strided_i]) { \ - shared_memory[tid] = src[strided_i]; \ - shared_indices[tid] = idx % dims[num_dims - 1]; \ - notset = false; \ - } \ - idx += block_dim; \ - } \ - \ - threadgroup_barrier(mem_flags::mem_none); \ - \ - /* \ - // reduction in shared memory \ - */ \ - for (uint s = block_dim / 2; s > 0; s >>= 1) { \ - if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \ - shared_indices[tid] = shared_indices[tid + s]; \ - shared_memory[tid] = shared_memory[tid + s]; \ - } \ - threadgroup_barrier(mem_flags::mem_none); \ - } \ - \ - if (tid == 0){ \ - dst[dst_id] = shared_indices[0]; \ - } \ + argmax(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \ } \ +template +METAL_FUNC void reduce( + constant size_t & num_dims, + constant size_t * dims, + constant size_t * strides, + constant size_t & el_to_sum_per_block, + device const T * src, + device T * dst, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup T * shared_memory, + T (*fn)(T, T) +) { + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = start_idx + el_to_sum_per_block; + size_t idx = start_idx + tid; + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + T x = shared_memory[tid]; + T y = src[strided_i]; + shared_memory[tid] = fn(x, y); + idx += block_dim; + } + + threadgroup_barrier(mem_flags::mem_none); + + // reduction in shared memory + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s) { + T x = shared_memory[tid]; + T y = shared_memory[tid + s]; + shared_memory[tid] = fn(x, y); + } + threadgroup_barrier(mem_flags::mem_none); + } + + if (tid == 0) { + dst[dst_id] = shared_memory[0]; + } +} + #define REDUCE(FN, NAME, T, START) \ +METAL_FUNC T NAME##_##op(T x, T y) { return FN; } \ kernel void NAME( \ constant size_t &num_dims, \ constant size_t *dims, \ constant size_t *strides, \ constant size_t &el_to_sum_per_block, \ - device const T *src, \ + device const T *src, \ device T *dst, \ uint id [[ thread_position_in_grid ]], \ uint tid [[ thread_index_in_threadgroup ]], \ uint dst_id [[ threadgroup_position_in_grid ]], \ uint block_dim [[ threads_per_threadgroup ]] \ ) { \ - \ - threadgroup T shared_memory[THREADGROUP_SIZE]; \ - \ - shared_memory[tid] = START; \ - /* \ - // Elements summed in this block range from dst_id * el_to_sum_per_block \ - // to (dst_id + 1) * el_to_sum_per_block. \ - */ \ - size_t start_idx = dst_id * el_to_sum_per_block; \ - size_t stop_idx = start_idx + el_to_sum_per_block; \ - size_t idx = start_idx + tid; \ - while (idx < stop_idx) { \ - /* \ - // TODO: Fast version for the contiguous case. \ - */ \ - size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ - T x = shared_memory[tid]; \ - T y = src[strided_i]; \ - shared_memory[tid] = FN; \ - idx += block_dim; \ - } \ - \ - threadgroup_barrier(mem_flags::mem_none); \ - \ - /* \ - // reduction in shared memory \ - */ \ - for (uint s = block_dim / 2; s > 0; s >>= 1) { \ - if (tid < s) { \ - T x = shared_memory[tid]; \ - T y = shared_memory[tid + s]; \ - shared_memory[tid] = FN; \ - } \ - threadgroup_barrier(mem_flags::mem_none); \ - } \ - \ - dst[dst_id] = shared_memory[0]; \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = START; \ + reduce(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, NAME##_##op); \ } \ +template +METAL_FUNC void softmax( + constant size_t & src_numel, + constant size_t & el_to_sum_per_block, + device const T * src, + device T * dst, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup float * shared_memory +) { + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; -#define SOFTMAX(NAME, T) \ -kernel void NAME( \ - constant size_t &src_numel, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device T *dst, \ - \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - threadgroup float shared_memory[THREADGROUP_SIZE]; \ - shared_memory[tid] = -INFINITY; \ - size_t start_idx = dst_id * el_to_sum_per_block; \ - size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ - size_t idx = start_idx + tid; \ - \ - \ - float tmp = -INFINITY; \ - while (idx < stop_idx) { \ - tmp = MAX(tmp, float(src[idx])); \ - idx += block_dim; \ - } \ - shared_memory[tid] = tmp; \ - \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - \ - for (uint s = block_dim / 2; s > 0; s >>= 1) { \ - if (tid < s) { \ - shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \ - } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - } \ - \ - /* wait for shared_memory[0] to be filled */ \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - \ - float _max = shared_memory[0]; \ - \ - /* prevent tid=0 from overwriting _max before other threads have written it */ \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - shared_memory[tid] = 0; \ - \ - idx = start_idx + tid; \ - while (idx < stop_idx) { \ - const float val = exp(float(src[idx]) - _max); \ - dst[idx] = T(val); \ - shared_memory[tid] += val; \ - idx += block_dim; \ - } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - for (uint s = block_dim / 2; s > 0; s >>= 1) { \ - if (tid < s) { \ - shared_memory[tid] += shared_memory[tid + s]; \ - } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - } \ - \ - const T inv_acc = T(1.0/shared_memory[0]); \ - idx = start_idx + tid; \ - while (idx < stop_idx) { \ - dst[idx] *= inv_acc; \ - idx += block_dim; \ - } \ -} \ + float tmp = -INFINITY; + while (idx < stop_idx) { + tmp = MAX(tmp, float(src[idx])); + idx += block_dim; + } + shared_memory[tid] = tmp; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]);\ + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + /* wait for shared_memory[0] to be filled */ + threadgroup_barrier(mem_flags::mem_threadgroup); + + float _max = shared_memory[0]; + + /* prevent tid=0 from overwriting _max before other threads have written it */ + threadgroup_barrier(mem_flags::mem_threadgroup); + shared_memory[tid] = 0; + + idx = start_idx + tid; + while (idx < stop_idx) { + const float val = exp(float(src[idx]) - _max); + dst[idx] = T(val); + shared_memory[tid] += val; + idx += block_dim; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] += shared_memory[tid + s]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const T inv_acc = T(1.0 / shared_memory[0]); + idx = start_idx + tid; + while (idx < stop_idx) { + dst[idx] *= inv_acc; + idx += block_dim; + } +} + +#define SOFTMAX(NAME, T) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = -INFINITY; \ + softmax(src_numel, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory); \ +} \ + +template +METAL_FUNC void rmsnorm( + constant size_t & src_numel, + constant size_t & el_to_sum_per_block, + device const T * src, + device T * dst, + device const T * alpha, + constant float & eps, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup float * shared_memory +) { + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + float tmp = 0; + while (idx < stop_idx) { + tmp = tmp + float(src[idx]) * float(src[idx]); + idx += block_dim; + } + shared_memory[tid] = tmp; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + /* wait for shared_memory[0] to be filled */ + threadgroup_barrier(mem_flags::mem_threadgroup); + + float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps); + float inv_norm = 1.0f / norm; + idx = start_idx + tid; + while (idx < stop_idx) { + float val = float(src[idx]) * inv_norm; + if (alpha != nullptr) { + val *= float(alpha[idx - start_idx]); + } + dst[idx] = T(val); + idx += block_dim; + } +} + +#define RMSNORM(NAME, T) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + device const T *alpha, \ + constant float &eps, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = 0; \ + rmsnorm(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \ +} \ + +template +METAL_FUNC void ropei( + constant size_t &bh, + constant size_t &td, + device const T *src, + device const T *cos, + device const T *sin, + device T *dst, + uint tid +) { + if (2 * tid >= bh * td) { + return; + } + size_t rope_idx = tid % (td / 2); + T c = cos[rope_idx]; + T s = sin[rope_idx]; + dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s; + dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c; +} + +template +METAL_FUNC void rope( + constant size_t &bh, + constant size_t &td, + constant size_t &d, + device const T *src, + device const T *cos, + device const T *sin, + device T *dst, + uint idx +) { + if (2 * idx >= bh * td) { + return; + } + size_t i_bh = idx / (td / 2); + size_t i_td = idx - (td / 2) * i_bh; + size_t i_t = i_td / (d / 2); + size_t i_d = i_td - (d / 2) * i_t; + size_t i1 = i_bh * td + i_t * d + i_d; + size_t i2 = i1 + d / 2; + size_t i_cs = i_t * (d / 2) + i_d; + T c = cos[i_cs]; + T s = sin[i_cs]; + dst[i1] = src[i1] * c - src[i2] * s; + dst[i2] = src[i1] * s + src[i2] * c; +} + +template +METAL_FUNC void rope_thd( + constant size_t &b, + constant size_t &t, + constant size_t &h, + constant size_t &d, + device const T *src, + device const T *cos, + device const T *sin, + device T *dst, + uint idx +) { + if (2 * idx >= b * t * h * d) { + return; + } + const size_t i_bth = idx / (d / 2); + const size_t i_d = idx - (d / 2) * i_bth; + const size_t i_t = (i_bth / h) % t; + const size_t i1 = i_bth * d + i_d; + const size_t i2 = i1 + d / 2; + const size_t i_cs = i_t * (d / 2) + i_d; + T c = cos[i_cs]; + T s = sin[i_cs]; + dst[i1] = src[i1] * c - src[i2] * s; + dst[i2] = src[i1] * s + src[i2] * c; +} + +#define ROPE(FN_NAME, FN_NAME_I, FN_NAME_THD, TYPENAME) \ +kernel void FN_NAME_I( \ + constant size_t &bh, \ + constant size_t &td, \ + device const TYPENAME *src, \ + device const TYPENAME *cos, \ + device const TYPENAME *sin, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + ropei(bh, td, src, cos, sin, dst, tid); \ +}\ +kernel void FN_NAME( \ + constant size_t &bh, \ + constant size_t &td, \ + constant size_t &d, \ + device const TYPENAME *src, \ + device const TYPENAME *cos, \ + device const TYPENAME *sin, \ + device TYPENAME *dst, \ + uint idx [[ thread_position_in_grid ]] \ +) { \ + rope(bh, td, d, src, cos, sin, dst, idx); \ +}\ +kernel void FN_NAME_THD( \ + constant size_t &b, \ + constant size_t &t, \ + constant size_t &h, \ + constant size_t &d, \ + device const TYPENAME *src, \ + device const TYPENAME *cos, \ + device const TYPENAME *sin, \ + device TYPENAME *dst, \ + uint idx [[ thread_position_in_grid ]] \ +) { \ + rope_thd(b, t, h, d, src, cos, sin, dst, idx); \ +}\ REDUCE(x + y, fast_sum_f32_strided, float, 0) REDUCE(x + y, fast_sum_u32_strided, uint, 0) @@ -286,6 +509,10 @@ ARGMAX(fast_argmax_u8_strided, uint8_t, 0) SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f16, half) +RMSNORM(rmsnorm_f32, float) +RMSNORM(rmsnorm_f16, half) +ROPE(rope_f32, rope_i_f32, rope_thd_f32, float) +ROPE(rope_f16, rope_i_f16, rope_thd_f16, half) #if __METAL_VERSION__ >= 220 REDUCE(x + y, fast_sum_i64_strided, int64_t, 0) @@ -297,10 +524,16 @@ ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) #if defined(__HAVE_BFLOAT__) REDUCE(x + y, fast_sum_bf16, bfloat, 0) +REDUCE(x + y, fast_sum_bf16_strided, half, 0) REDUCE(x * y, fast_mul_bf16, bfloat, 1) +REDUCE(x * y, fast_mul_bf16_strided, bfloat, 1) REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF) +REDUCE(MAX(x, y), fast_max_bf16_strided, bfloat, -HUGE_VALBF) REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF) +REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF) ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) SOFTMAX(softmax_bf16, bfloat) +RMSNORM(rmsnorm_bf16, bfloat) +ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat) #endif diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 0da8619c..77ae8d82 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,6 +1,6 @@ use super::*; use half::{bf16, f16}; -use metal::{Buffer, Device, MTLResourceOptions}; +use metal::MTLResourceOptions; fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { let ptr = buffer.contents() as *const T; @@ -12,7 +12,7 @@ fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { fn new_buffer(device: &Device, data: &[T]) -> Buffer { let options = MTLResourceOptions::StorageModeManaged; let ptr = data.as_ptr() as *const c_void; - let size = (data.len() * std::mem::size_of::()) as u64; + let size = std::mem::size_of_val(data) as u64; device.new_buffer_with_data(ptr, size, options) } @@ -41,6 +41,10 @@ fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); + let input = BufferOffset { + buffer: &input, + offset_in_bytes: 0, + }; let output = new_buffer(&device, v); call_unary_contiguous( &device, @@ -48,7 +52,7 @@ fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { &kernels, name, v.len(), - &input, + input, &output, ) .unwrap(); @@ -72,8 +76,8 @@ fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V &kernels, name, x.len(), - &left, - &right, + BufferOffset::zero_offset(&left), + BufferOffset::zero_offset(&right), &output, ) .unwrap(); @@ -93,7 +97,15 @@ fn run_strided( let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let output = new_buffer(&device, v); + let input = BufferOffset { + buffer: &input, + offset_in_bytes: offset, + }; + let output_b = new_buffer(&device, v); + let output = BufferOffset { + buffer: &output_b, + offset_in_bytes: 0, + }; let kernels = Kernels::new(); call_unary_strided( &device, @@ -101,16 +113,14 @@ fn run_strided( &kernels, kernel, shape, - &input, + input, strides, - offset, - &output, - 0, + output, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - read_to_vec(&output, v.len()) + read_to_vec(&output_b, v.len()) } #[test] @@ -292,7 +302,7 @@ fn binary_ops_bf16() { binary_op!(max, |x: bf16, y| x.max(y)); } -fn cast(v: &[T], name: &'static str) -> Vec { +fn run_cast(v: &[T], name: &'static str) -> Vec { let device = device(); let kernels = Kernels::new(); let command_queue = device.new_command_queue(); @@ -308,8 +318,7 @@ fn cast(v: &[T], name: &'static str) -> Vec { &kernels, name, v.len(), - &input, - 0, + BufferOffset::zero_offset(&input), &output, ) .unwrap(); @@ -319,107 +328,189 @@ fn cast(v: &[T], name: &'static str) -> Vec { } #[test] -fn cast_u32_f32() { - let v = vec![1u32, 2, 3]; - let results = cast(&v, "cast_u32_f32"); - let expected: Vec<_> = v.iter().map(|&v| v as f32).collect(); - assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); - assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); +fn cast_f32() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); - let v = vec![1.0f32, 2.0, 3.0]; - let input: Vec = v.iter().map(|v| f16::from_f32(*v)).collect(); - let results: Vec = cast(&input, "cast_f16_f32"); - assert_eq!(results, vec![1.0f32, 2.0, 3.0]); + // f32 -> f16 + let results: Vec = run_cast(&v_f32, "cast_f32_f16"); + assert_eq!(results, v_f16); - let v = vec![1.0f32; 10_000]; - let input: Vec = v.iter().map(|v| f16::from_f32(*v)).collect(); - let results: Vec = cast(&input, "cast_f16_f32"); - assert_eq!(results.len(), 10_000); - assert_eq!(&results[..10], vec![1.0f32; 10]); - assert_eq!(results, vec![1.0f32; 10_000]); + // f32 -> bf16 + let results: Vec = run_cast(&v_f32, "cast_f32_bf16"); + assert_eq!(results, v_bf16); + + // f32 -> u32 + let results: Vec = run_cast(&v_f32, "cast_f32_u32"); + assert_eq!(results, v_u32); + + // f32 -> u8 + let results: Vec = run_cast(&v_f32, "cast_f32_u8"); + assert_eq!(results, v_u8); + + // f32 -> i64 + let results: Vec = run_cast(&v_f32, "cast_f32_i64"); + assert_eq!(results, v_i64); } #[test] -fn it_cast_bf16_u32() { - let input: Vec = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); +fn cast_f16() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); - let output: Vec = cast(&input, "cast_bf16_u32"); - let expected: Vec = (1..=3).map(|v| v as u32).collect(); + // f16 -> f32 + let results: Vec = run_cast(&v_f16, "cast_f16_f32"); + assert_eq!(results, v_f32); - assert_eq!(output, expected); + // f16 -> bf16 + let results: Vec = run_cast(&v_f16, "cast_f16_bf16"); + assert_eq!(results, v_bf16); + + // f16 -> u32 + let results: Vec = run_cast(&v_f16, "cast_f16_u32"); + assert_eq!(results, v_u32); + + // f16 -> u8 + let results: Vec = run_cast(&v_f16, "cast_f16_u8"); + assert_eq!(results, v_u8); + + // f16 -> i64 + let results: Vec = run_cast(&v_f16, "cast_f16_i64"); + assert_eq!(results, v_i64); } #[test] -fn it_cast_bf16_f32() { - let input: Vec = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); +fn cast_bf16() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); - let output: Vec = cast(&input, "cast_bf16_f32"); - let expected: Vec = (1..=3).map(|v| v as f32).collect(); + // bf16 -> f32 + let results: Vec = run_cast(&v_bf16, "cast_bf16_f32"); + assert_eq!(results, v_f32); - assert_eq!(output, expected); + // bf16 -> f16 + let results: Vec = run_cast(&v_bf16, "cast_bf16_f16"); + assert_eq!(results, v_f16); + + // bf16 -> u32 + let results: Vec = run_cast(&v_bf16, "cast_bf16_u32"); + assert_eq!(results, v_u32); + + // bf16 -> u8 + let results: Vec = run_cast(&v_bf16, "cast_bf16_u8"); + assert_eq!(results, v_u8); + + // bf16 -> i64 + let results: Vec = run_cast(&v_bf16, "cast_bf16_i64"); + assert_eq!(results, v_i64); } #[test] -fn it_cast_u8_bf16() { - let input: Vec = (1..=3).map(|v| v as u8).collect(); +fn cast_u32() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); - let output: Vec = cast(&input, "cast_u8_bf16"); - let expected: Vec = input - .iter() - .map(|v| bf16::from_f32(*v as f32)) - .collect::>(); + // u32 -> f32 + let results: Vec = run_cast(&v_u32, "cast_u32_f32"); + assert_eq!(results, v_f32); - assert_eq!(output, expected); + // u32 -> f16 + let results: Vec = run_cast(&v_u32, "cast_u32_f16"); + assert_eq!(results, v_f16); + + // u32 -> bf16 + let results: Vec = run_cast(&v_u32, "cast_u32_bf16"); + assert_eq!(results, v_bf16); + + // u32 -> u8 + let results: Vec = run_cast(&v_u32, "cast_u32_u8"); + assert_eq!(results, v_u8); + + // u32 -> i64 + let results: Vec = run_cast(&v_u32, "cast_u32_i64"); + assert_eq!(results, v_i64); } #[test] -fn it_cast_u32_bf16() { - let input: Vec = (1..=3).map(|v| v as u32).collect(); +fn cast_u8() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); - let output: Vec = cast(&input, "cast_u32_bf16"); - let expected: Vec = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); + // u8 -> f32 + let results: Vec = run_cast(&v_u8, "cast_u8_f32"); + assert_eq!(results, v_f32); - assert_eq!(output, expected); + // u8 -> f16 + let results: Vec = run_cast(&v_u8, "cast_u8_f16"); + assert_eq!(results, v_f16); + + // u8 -> bf16 + let results: Vec = run_cast(&v_u8, "cast_u8_bf16"); + assert_eq!(results, v_bf16); + + // u8 -> u32 + let results: Vec = run_cast(&v_u8, "cast_u8_u32"); + assert_eq!(results, v_u32); + + // u8 -> i64 + let results: Vec = run_cast(&v_u8, "cast_u8_i64"); + assert_eq!(results, v_i64); } #[test] -fn it_cast_f32_bf16() { - let input: Vec = (1..=3).map(|v| v as f32).collect(); +fn cast_i64() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); - let output: Vec = cast(&input, "cast_f32_bf16"); - let expected: Vec = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); + // i64 -> f32 + let results: Vec = run_cast(&v_i64, "cast_i64_f32"); + assert_eq!(results, v_f32); - assert_eq!(output, expected); -} + // i64 -> f16 + let results: Vec = run_cast(&v_i64, "cast_i64_f16"); + assert_eq!(results, v_f16); -#[test] -fn it_cast_bf16_u8() { - let input: Vec = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + // i64 -> bf16 + let results: Vec = run_cast(&v_i64, "cast_i64_bf16"); + assert_eq!(results, v_bf16); - let output: Vec = cast(&input, "cast_bf16_u8"); - let expected: Vec = input.iter().map(|v| v.to_f32() as u8).collect(); + // i64 -> u32 + let results: Vec = run_cast(&v_i64, "cast_i64_u32"); + assert_eq!(results, v_u32); - assert_eq!(output, expected); -} - -#[test] -fn it_cast_bf16_f16() { - let input: Vec = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); - - let output: Vec = cast(&input, "cast_bf16_f16"); - let expected: Vec = input.iter().map(|v| f16::from_f32(v.to_f32())).collect(); - - assert_eq!(output, expected); -} - -#[test] -fn it_cast_f16_bf16() { - let input: Vec = (1..=3).map(|v| f16::from_f32(v as f32)).collect(); - - let output: Vec = cast(&input, "cast_f16_bf16"); - let expected: Vec = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect(); - - assert_eq!(output, expected); + // i64 -> u8 + let results: Vec = run_cast(&v_i64, "cast_i64_u8"); + assert_eq!(results, v_u8); } fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { @@ -439,7 +530,7 @@ fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { &kernels, "affine_f32", size, - &input, + BufferOffset::zero_offset(&input), &output, mul as f32, add as f32, @@ -472,9 +563,8 @@ fn run_affine_strided( &kernels, "affine_f32_strided", shape, - &input, + BufferOffset::zero_offset(&input), strides, - 0, &output, mul as f32, add as f32, @@ -518,32 +608,46 @@ fn affine_strided() { fn index_select() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; + let stride = [2, 1]; let ids = [0u32, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32"); assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]); let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [2, 5]; + let stride = [1, 2]; let ids = [0u32, 1, 0]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32"); assert_eq!( result, vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] ); } +#[test] +fn index_select_strided() { + let embedding = (0..16).map(|x| x as f32).collect::>(); + let shape = [2, 2]; + let stride = [2, 4]; + let ids = [0u32]; + let dim = 0; + let result = run_index_select_strided(&embedding, &shape, &stride, &ids, dim, "is_u32_f32"); + assert_eq!(result, vec![0.0, 4.0]); +} + #[test] fn index_select_f16() { let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] .into_iter() - .map(|x| f16::from_f32(x)) + .map(f16::from_f32) .collect(); let shape = [5, 2]; + let stride = [2, 1]; let ids = [0u32, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16"); + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f16"); assert_eq!( approx_f16(result, 4), vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] @@ -554,9 +658,10 @@ fn index_select_f16() { fn index_select_is_u32_bf16() { let embedding: Vec = (1..=10).map(|x| bf16::from_f32(x as f32)).collect(); let shape = [5, 2]; + let stride = [2, 1]; let ids = [0u32, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16"); + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_bf16"); assert_eq!( approx_bf16(result, 4), vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] @@ -567,9 +672,10 @@ fn index_select_is_u32_bf16() { fn index_select_is_u8_bf16() { let embedding: Vec = (1..=10).map(|x| bf16::from_f32(x as f32)).collect(); let shape = [5, 2]; + let stride = [2, 1]; let ids = [0u8, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16"); + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u8_bf16"); assert_eq!( approx_bf16(result, 4), vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] @@ -580,9 +686,10 @@ fn index_select_is_u8_bf16() { fn index_select_dim1() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; + let stride = [2, 1]; let ids = [0u32, 1, 0]; let dim = 1; - let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32"); assert_eq!( result, vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0] @@ -592,6 +699,7 @@ fn index_select_dim1() { fn run_index_select( embeddings: &[T], shape: &[usize], + stride: &[usize], ids: &[I], dim: usize, name: &'static str, @@ -600,8 +708,8 @@ fn run_index_select( let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let embeddings_buffer = new_buffer(&device, &embeddings); - let ids_buffer = new_buffer(&device, &ids); + let embeddings_buffer = new_buffer(&device, embeddings); + let ids_buffer = new_buffer(&device, ids); let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); @@ -611,14 +719,61 @@ fn run_index_select( let kernels = Kernels::new(); call_index_select( &device, - &command_buffer, + command_buffer, &kernels, name, shape, ids.len(), dim, - &embeddings_buffer, - &ids_buffer, + true, + shape, + stride, + BufferOffset::zero_offset(&embeddings_buffer), + BufferOffset::zero_offset(&ids_buffer), + &dst_buffer, + ) + .unwrap(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&dst_buffer, dst_el) +} + +fn run_index_select_strided( + embeddings: &[T], + shape: &[usize], + stride: &[usize], + ids: &[I], + dim: usize, + name: &'static str, +) -> Vec { + let device = Device::system_default().expect("no device found"); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let embeddings_buffer = new_buffer(&device, embeddings); + let ids_buffer = new_buffer(&device, ids); + + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let dst_el = ids.len() * left_size * right_size; + let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + + let kernels = Kernels::new(); + call_index_select( + &device, + command_buffer, + &kernels, + name, + shape, + ids.len(), + dim, + false, + shape, + stride, + BufferOffset::zero_offset(&embeddings_buffer), + BufferOffset::zero_offset(&ids_buffer), &dst_buffer, ) .unwrap(); @@ -660,8 +815,7 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec( shape: &[usize], cond: &[I], @@ -814,18 +969,30 @@ fn run_where_cond( ); let output = device.new_buffer((length * core::mem::size_of::()) as u64, options); + let cond = BufferOffset { + buffer: &cond, + offset_in_bytes: cond_offset, + }; + let left = BufferOffset { + buffer: &left, + offset_in_bytes: left_offset, + }; + let right = BufferOffset { + buffer: &right, + offset_in_bytes: cond_offset, + }; call_where_cond_strided( &device, command_buffer, &kernels, name, shape, - &cond, - (&cond_stride, cond_offset), - &left, - (&left_stride, left_offset), - &right, - (&cond_stride, cond_offset), + cond, + &cond_stride, + left, + &left_stride, + right, + &cond_stride, &output, ) .unwrap(); @@ -1011,7 +1178,7 @@ fn run_random(name: &'static str, seed: u32, length: usize, a: f32, b: #[test] fn random() { fn calc_mean(data: &[f32]) -> f32 { - let sum = data.iter().sum::() as f32; + let sum = data.iter().sum::(); let count = data.len(); assert!(count > 0); sum / count as f32 @@ -1025,7 +1192,7 @@ fn random() { let variance = data .iter() .map(|value| { - let diff = mean - (*value as f32); + let diff = mean - *value; diff * diff }) .sum::() @@ -1080,3 +1247,785 @@ fn random() { validate_random!(f16); validate_random!(bf16); } + +fn run_scatter_add( + input: &[T], + ids: &[I], + shape: &[usize], + dim: usize, + name: &'static str, +) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let input_buffer = new_buffer(&device, input); + let ids_buffer = new_buffer(&device, ids); + let output = device.new_buffer(std::mem::size_of_val(input) as u64, options); + call_scatter_add( + &device, + command_buffer, + &kernels, + name, + shape, + shape, + dim, + BufferOffset::zero_offset(&input_buffer), + BufferOffset::zero_offset(&ids_buffer), + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + read_to_vec(&output, input.len()) +} + +#[test] +fn scatter_add() { + let ids_u8 = [0u8, 0, 1, 0, 2, 2, 3, 3]; + let ids_u32 = [0u32, 0, 1, 0, 2, 2, 3, 3]; + let ids_i64 = [0i64, 0, 1, 0, 2, 2, 3, 3]; + + let input_f32 = [5.0f32, 1.0, 7.0, 2.0, 3.0, 2.0, 1.0, 3.0]; + let input_f16 = input_f32 + .iter() + .map(|v| f16::from_f32(*v)) + .collect::>(); + let input_bf16 = input_f32 + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::>(); + + let output_dim1_f32 = vec![8.0, 7.0, 5.0, 4.0, 0.0, 0.0, 0.0, 0.0]; + let output_dim1_f16 = output_dim1_f32 + .iter() + .map(|v| f16::from_f32(*v)) + .collect::>(); + let output_dim1_bf16 = output_dim1_f32 + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::>(); + + let output_dim2_f32 = vec![5.0, 3.0, 7.0, 0.0, 3.0, 2.0, 1.0, 3.0]; + let output_dim2_f16 = output_dim2_f32 + .iter() + .map(|v| f16::from_f32(*v)) + .collect::>(); + let output_dim2_bf16 = output_dim2_f32 + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::>(); + + for (shape, output_f32, output_f16, output_bf16) in [ + (vec![8], output_dim1_f32, output_dim1_f16, output_dim1_bf16), + ( + vec![4, 2], + output_dim2_f32, + output_dim2_f16, + output_dim2_bf16, + ), + ] { + for results in [ + run_scatter_add(&input_f32, &ids_u8, &shape, 0, "sa_u8_f32"), + run_scatter_add(&input_f32, &ids_u32, &shape, 0, "sa_u32_f32"), + run_scatter_add(&input_f32, &ids_i64, &shape, 0, "sa_i64_f32"), + ] { + assert_eq!(results, output_f32); + } + for results in [ + run_scatter_add(&input_f16, &ids_u8, &shape, 0, "sa_u8_f16"), + run_scatter_add(&input_f16, &ids_u32, &shape, 0, "sa_u32_f16"), + run_scatter_add(&input_f16, &ids_i64, &shape, 0, "sa_i64_f16"), + ] { + assert_eq!(results, output_f16); + } + for results in [ + run_scatter_add(&input_bf16, &ids_u8, &shape, 0, "sa_u8_bf16"), + run_scatter_add(&input_bf16, &ids_u32, &shape, 0, "sa_u32_bf16"), + run_scatter_add(&input_bf16, &ids_i64, &shape, 0, "sa_i64_bf16"), + ] { + assert_eq!(results, output_bf16); + } + } +} + +fn run_index_add( + left: &[T], + right: &[T], + indices: &[I], + shape: &[usize], + dim: usize, + name: &'static str, +) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input_buffer = new_buffer(&device, right); + let output = new_buffer(&device, left); + let indices_buffer = new_buffer(&device, indices); + call_index_add( + &device, + command_buffer, + &kernels, + name, + shape, + shape, + shape, + dim, + BufferOffset::zero_offset(&input_buffer), + BufferOffset::zero_offset(&indices_buffer), + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + read_to_vec(&output, left.len()) +} + +#[test] +fn index_add() { + let left = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let right = vec![1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0]; + let indices = vec![0u32, 1, 0, 1, 0, 1]; + let shape = vec![6]; + + // u32, f32 + { + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_f32"); + assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // u32, f16 + { + let left = left.iter().map(|v| f16::from_f32(*v)).collect::>(); + let right = right.iter().map(|v| f16::from_f32(*v)).collect::>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_f16"); + assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // u32, bf16 + { + let left = left.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let right = right.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_bf16"); + assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // u8, f32 + { + let indices = indices.iter().map(|v| *v as u8).collect::>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_f32"); + assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // u8, f16 + { + let indices = indices.iter().map(|v| *v as u8).collect::>(); + let left = left.iter().map(|v| f16::from_f32(*v)).collect::>(); + let right = right.iter().map(|v| f16::from_f32(*v)).collect::>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_f16"); + assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // u8, bf16 + { + let indices = indices.iter().map(|v| *v as u8).collect::>(); + let left = left.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let right = right.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_bf16"); + assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // i64, f32 + { + let indices = indices.iter().map(|v| *v as i64).collect::>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_f32"); + assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // i64, f16 + { + let indices = indices.iter().map(|v| *v as i64).collect::>(); + let left = left.iter().map(|v| f16::from_f32(*v)).collect::>(); + let right = right.iter().map(|v| f16::from_f32(*v)).collect::>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_f16"); + assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // i64, bf16 + { + let indices = indices.iter().map(|v| *v as i64).collect::>(); + let left = left.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let right = right.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_bf16"); + assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } +} + +fn run_pool2d( + v: &[T], + (w_k, h_k): (usize, usize), + (w_stride, h_stride): (usize, usize), + shape: &[usize], + strides: &[usize], + name: &'static str, +) -> Vec { + let device = device(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let out_w = (shape[2] - w_k) / w_stride + 1; + let out_h = (shape[3] - h_k) / h_stride + 1; + let dst_el = out_w * out_h * shape[0] * shape[1]; + let input = new_buffer(&device, v); + let output = new_buffer(&device, &vec![0.0f32; dst_el]); + let kernels = Kernels::new(); + call_pool2d( + &device, + command_buffer, + &kernels, + name, + shape, + strides, + out_w, + out_h, + w_k, + h_k, + w_stride, + h_stride, + &input, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, dst_el) +} + +#[test] +fn max_pool2d_f32() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as f32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_f32", + ); + let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]; + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec = (0..16).map(|v| v as f32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_f32", + ); + let expected = vec![5.0, 7.0, 13.0, 15.0]; + assert_eq!(results, expected); +} + +#[test] +fn max_pool2d_f16() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| half::f16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_f16", + ); + let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0] + .iter() + .map(|v| half::f16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec = (0..16).map(|v| half::f16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_f16", + ); + let expected = vec![5.0, 7.0, 13.0, 15.0] + .iter() + .map(|v| half::f16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn max_pool2d_bf16() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| half::bf16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_bf16", + ); + let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0] + .iter() + .map(|v| half::bf16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec = (0..16).map(|v| half::bf16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_bf16", + ); + let expected = vec![5.0, 7.0, 13.0, 15.0] + .iter() + .map(|v| half::bf16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn max_pool2d_u8() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as u8).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_u8", + ); + let expected = vec![5, 6, 7, 9, 10, 11, 13, 14, 15]; + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec = (0..16).map(|v| v as u8).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_u8", + ); + let expected = vec![5, 7, 13, 15]; + assert_eq!(results, expected); +} + +#[test] +fn max_pool2d_u32() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as u32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_u32", + ); + let expected = vec![5, 6, 7, 9, 10, 11, 13, 14, 15]; + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec = (0..16).map(|v| v as u32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_u32", + ); + let expected = vec![5, 7, 13, 15]; + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_f32() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as f32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_f32", + ); + let expected = vec![ + 2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000, + ]; + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_f16() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| f16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_f16", + ); + let expected = vec![ + 2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000, + ] + .iter() + .map(|v| f16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_bf16() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| bf16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_bf16", + ); + let expected = vec![ + 2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000, + ] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_u8() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as u8).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_u8", + ); + let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12]; + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_u32() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as u32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_u32", + ); + let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12]; + assert_eq!(results, expected); +} + +#[allow(clippy::too_many_arguments)] +fn run_conv_transpose1d( + input: &[T], + input_shape: &[usize], + input_stride: &[usize], + kernel: &[T], + kernel_shape: &[usize], + kernel_stride: &[usize], + dilation: usize, + stride: usize, + padding: usize, + out_padding: usize, + name: &'static str, +) -> Vec { + let device = device(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let c_out = kernel_shape[1]; + let k_size = kernel_shape[2]; + let b_size = input_shape[0]; + let l_in = input_shape[2]; + let l_out = (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1; + let dst_el = c_out * l_out * b_size; + + let input = new_buffer(&device, input); + let kernel = new_buffer(&device, kernel); + let output = new_buffer(&device, &vec![0.0f32; dst_el]); + let kernels = Kernels::new(); + + call_conv_transpose1d( + &device, + command_buffer, + &kernels, + name, + dilation, + stride, + padding, + out_padding, + c_out, + l_out, + b_size, + input_shape, + input_stride, + kernel_shape, + kernel_stride, + &input, + 0, + &kernel, + 0, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, dst_el) +} + +#[test] +fn conv_transpose1d_f32() { + let input = vec![1.0f32, 2.0, 3.0, 4.0]; + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel = vec![1.0f32, 2.0, 3.0, 4.0]; + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_f32", + ); + + let expected = vec![1., 4., 10., 20., 25., 24., 16.]; + assert_eq!(results, expected); +} + +#[test] +fn conv_transpose1d_f16() { + let input: Vec = vec![1.0, 2.0, 3.0, 4.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel: Vec = vec![1.0, 2.0, 3.0, 4.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_f16", + ); + + let expected = vec![1., 4., 10., 20., 25., 24., 16.] + .iter() + .map(|v| f16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn conv_transpose1d_bf16() { + let input: Vec = vec![1.0, 2.0, 3.0, 4.0] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect(); + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel: Vec = vec![1.0, 2.0, 3.0, 4.0] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect(); + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_bf16", + ); + + let expected = vec![1., 4., 10., 20., 25., 24., 16.] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn conv_transpose1d_u8() { + let input: Vec = vec![1, 2, 3, 4]; + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel: Vec = vec![1, 2, 3, 4]; + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_u8", + ); + + let expected = vec![1, 4, 10, 20, 25, 24, 16]; + assert_eq!(results, expected); +} + +#[test] +fn conv_transpose1d_u32() { + let input: Vec = vec![1, 2, 3, 4]; + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel: Vec = vec![1, 2, 3, 4]; + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_u32", + ); + + let expected = vec![1, 4, 10, 20, 25, 24, 16]; + assert_eq!(results, expected); +} diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 1e0d5526..4b6363ed 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -102,6 +102,26 @@ UNARY(NAME, half, NAME##_f16, NAME##_f16_strided); #define BFLOAT_UNARY_OP(NAME) \ UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided); +#define COPY2D(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant int64_t &d1, \ + constant int64_t &d2, \ + constant int64_t &src_s, \ + constant int64_t &dst_s, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint2 idx [[thread_position_in_grid]] \ +) { \ + if (idx.x >= d1 || idx.y >= d2) return; \ + int64_t src_idx = idx.x * src_s + idx.y; \ + int64_t dst_idx = idx.x * dst_s + idx.y; \ + output[dst_idx] = input[src_idx]; \ +} + +COPY2D(copy2d_f32, float) +COPY2D(copy2d_f16, half) +COPY2D(copy2d_u8, uint8_t) +COPY2D(copy2d_u32, uint32_t) UNARY_OP(cos) UNARY_OP(sin) @@ -121,6 +141,7 @@ UNARY_OP(erf) UNARY_OP(tanh) UNARY_OP(recip) UNARY_OP(relu) +UNARY_OP(sign) UNARY(id, float, copy_f32, copy_f32_strided) UNARY(id, half, copy_f16, copy_f16_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided) @@ -128,6 +149,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided) #if __METAL_VERSION__ >= 220 UNARY(id, int64_t, copy_i64, copy_i64_strided) +COPY2D(copy2d_i64, int64_t) #endif #if defined(__HAVE_BFLOAT__) @@ -149,6 +171,9 @@ BFLOAT_UNARY_OP(erf) BFLOAT_UNARY_OP(tanh) BFLOAT_UNARY_OP(recip) BFLOAT_UNARY_OP(relu) +BFLOAT_UNARY_OP(sign) UNARY(id, bfloat, copy_bf16, copy_bf16_strided) + +COPY2D(copy2d_bf64, bfloat) #endif diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs new file mode 100644 index 00000000..194cddf4 --- /dev/null +++ b/candle-metal-kernels/src/utils.rs @@ -0,0 +1,162 @@ +use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, MTLSize}; +use std::ffi::c_void; + +/// Most kernels apply similarly across the tensors +/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the +/// actual total buffer length). +/// Then kernels can just do their op on their single point in the buffer. +pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { + let size = length as u64; + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); + let count = (size + width - 1) / width; + let thread_group_count = MTLSize { + width: count, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + (thread_group_count, thread_group_size) +} + +// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96 +pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { + let mut pows0 = 0u64; + let mut pows1 = 0u64; + let mut pows2 = 0u64; + let mut sum = 0u64; + loop { + let presum = sum; + // Check all the pows + if dim0 >= (1 << (pows0 + 1)) { + pows0 += 1; + sum += 1; + } + if sum == 10 { + break; + } + if dim1 >= (1 << (pows1 + 1)) { + pows1 += 1; + sum += 1; + } + if sum == 10 { + break; + } + if dim2 >= (1 << (pows2 + 1)) { + pows2 += 1; + sum += 1; + } + if sum == presum || sum == 10 { + break; + } + } + MTLSize { + width: 1 << pows0, + height: 1 << pows1, + depth: 1 << pows2, + } +} + +pub(crate) fn set_param( + encoder: &ComputeCommandEncoderRef, + position: u64, + data: P, +) { +

::set_param(encoder, position, data) +} + +/// Helper functions to create the various objects on the compute command encoder +/// on a single line. +/// Prevents getting wrong some arguments number and mixing length and size in bytes. +pub(crate) trait EncoderParam { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); +} +macro_rules! primitive { + ($type:ty) => { + impl EncoderParam for $type { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + core::mem::size_of::<$type>() as u64, + &data as *const $type as *const c_void, + ); + } + } + }; +} +primitive!(bool); +primitive!(usize); +primitive!(i32); +primitive!(i64); +primitive!(u32); +primitive!(u64); +primitive!(f32); + +pub struct BufferOffset<'a> { + pub buffer: &'a Buffer, + pub offset_in_bytes: usize, +} + +impl<'a> BufferOffset<'a> { + pub fn zero_offset(buffer: &'a Buffer) -> Self { + Self { + buffer, + offset_in_bytes: 0, + } + } +} + +impl EncoderParam for &[T] { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + core::mem::size_of_val(data) as u64, + data.as_ptr() as *const c_void, + ); + } +} + +impl EncoderParam for &Buffer { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data), 0); + } +} + +impl EncoderParam for (&Buffer, usize) { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data.0), data.1 as u64); + } +} + +impl<'a> EncoderParam for &BufferOffset<'a> { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64); + } +} + +impl EncoderParam for &mut Buffer { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data), 0); + } +} + +impl EncoderParam for (&mut Buffer, usize) { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data.0), data.1 as u64); + } +} + +#[macro_export] +macro_rules! set_params { + ($encoder:ident, ($($param:expr),+)) => ( + let mut _index = 0; + $( + $crate::utils::set_param($encoder, _index, $param); + _index += 1; + )* + ); +} diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index 214e8a59..9f0d56bd 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -25,6 +25,8 @@ candle-metal-kernels = { workspace = true, optional = true } [dev-dependencies] anyhow = { workspace = true } clap = { workspace = true } +rand = { workspace = true } +criterion = { workspace = true } [features] default = [] @@ -32,3 +34,7 @@ accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] mkl = ["dep:intel-mkl-src", "candle/mkl"] metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"] + +[[bench]] +name = "bench_main" +harness = false \ No newline at end of file diff --git a/candle-nn/benches/bench_main.rs b/candle-nn/benches/bench_main.rs new file mode 100644 index 00000000..4db1d35c --- /dev/null +++ b/candle-nn/benches/bench_main.rs @@ -0,0 +1,4 @@ +mod benchmarks; + +use criterion::criterion_main; +criterion_main!(benchmarks::layer_norm::benches, benchmarks::conv::benches); diff --git a/candle-nn/benches/benchmarks/conv.rs b/candle-nn/benches/benchmarks/conv.rs new file mode 100644 index 00000000..eb80645b --- /dev/null +++ b/candle-nn/benches/benchmarks/conv.rs @@ -0,0 +1,54 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle::{DType, Device, Module, Tensor}; +use candle_nn::{Conv2d, Conv2dConfig}; +use criterion::{black_box, criterion_group, Criterion}; +use std::time::Instant; + +const B: usize = 1; +const C: usize = 1; +const M: usize = 128; +const K: usize = 128; +const K_SIZE: usize = 3; + +fn run(input: Tensor, weight: Tensor, bias: Tensor, config: Conv2dConfig) { + Conv2d::new(weight, Some(bias), config) + .forward(&input) + .unwrap(); +} + +fn run_conv2d_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let weight = Tensor::ones((1, 1, K_SIZE, K_SIZE), dtype, device) + .unwrap() + .to_dtype(dtype) + .unwrap(); + let bias = Tensor::zeros(K, dtype, device).unwrap(); + let input = Tensor::ones((B, C, M, K), dtype, device).unwrap(); + + let mut group = c.benchmark_group(device.bench_name(name)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run( + black_box(input.clone()), + black_box(weight.clone()), + black_box(bias.clone()), + Default::default(), + ); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let device = BenchDeviceHandler::new().unwrap(); + for d in device.devices { + run_conv2d_benchmark(c, &d, DType::F32, "conv2d_f32"); + run_conv2d_benchmark(c, &d, DType::F16, "conv2d_f16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-nn/benches/benchmarks/layer_norm.rs b/candle-nn/benches/benchmarks/layer_norm.rs new file mode 100644 index 00000000..0be5c450 --- /dev/null +++ b/candle-nn/benches/benchmarks/layer_norm.rs @@ -0,0 +1,48 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle::{DType, Device, Module, Tensor}; +use candle_nn::LayerNorm; +use criterion::{black_box, criterion_group, Criterion}; +use std::time::Instant; + +fn run(input: &Tensor, weight: &Tensor, bias: &Tensor) { + let _ = LayerNorm::new(weight.clone(), bias.clone(), 1e-5).forward(&input); +} + +const B: usize = 1; +const M: usize = 1024; +const K: usize = 1024; + +fn run_layer_norm_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let elements = B * M * K; + + let weight = Tensor::arange(0.0, elements as f32, device) + .unwrap() + .to_dtype(dtype) + .unwrap(); + let bias = weight.ones_like().unwrap(); + let input = weight.ones_like().unwrap(); + + let mut group = c.benchmark_group(device.bench_name(name)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&input), black_box(&weight), black_box(&bias)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let device = BenchDeviceHandler::new().unwrap(); + for d in device.devices { + run_layer_norm_benchmark(c, &d, DType::F32, "layer_norm_f32"); + run_layer_norm_benchmark(c, &d, DType::BF16, "layer_norm_bf16"); + run_layer_norm_benchmark(c, &d, DType::F16, "layer_norm_f16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-nn/benches/benchmarks/mod.rs b/candle-nn/benches/benchmarks/mod.rs new file mode 100644 index 00000000..30a6ab6a --- /dev/null +++ b/candle-nn/benches/benchmarks/mod.rs @@ -0,0 +1,64 @@ +pub(crate) mod conv; +pub(crate) mod layer_norm; + +use candle::{Device, Result}; + +pub(crate) trait BenchDevice { + fn sync(&self) -> Result<()>; + + fn bench_name>(&self, name: S) -> String; +} + +impl BenchDevice for Device { + fn sync(&self) -> Result<()> { + match self { + Device::Cpu => Ok(()), + Device::Cuda(device) => { + #[cfg(feature = "cuda")] + return Ok(device.synchronize()?); + #[cfg(not(feature = "cuda"))] + panic!("Cuda device without cuda feature enabled: {:?}", device) + } + Device::Metal(device) => { + #[cfg(feature = "metal")] + return Ok(device.wait_until_completed()?); + #[cfg(not(feature = "metal"))] + panic!("Metal device without metal feature enabled: {:?}", device) + } + } + } + + fn bench_name>(&self, name: S) -> String { + match self { + Device::Cpu => { + let cpu_type = if cfg!(feature = "accelerate") { + "accelerate" + } else if cfg!(feature = "mkl") { + "mkl" + } else { + "cpu" + }; + format!("{}_{}", cpu_type, name.into()) + } + Device::Cuda(_) => format!("cuda_{}", name.into()), + Device::Metal(_) => format!("metal_{}", name.into()), + } + } +} + +struct BenchDeviceHandler { + devices: Vec, +} + +impl BenchDeviceHandler { + pub fn new() -> Result { + let mut devices = Vec::new(); + if cfg!(feature = "metal") { + devices.push(Device::new_metal(0)?); + } else if cfg!(feature = "cuda") { + devices.push(Device::new_cuda(0)?); + } + devices.push(Device::Cpu); + Ok(Self { devices }) + } +} diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs index 001be116..430316b8 100644 --- a/candle-nn/examples/cpu_benchmarks.rs +++ b/candle-nn/examples/cpu_benchmarks.rs @@ -238,6 +238,23 @@ impl Benchmark for QMatMul { const ITERS: usize = 100; } +struct Cat; +impl Benchmark for Cat { + type PreProcessData = (Tensor, Tensor); + type RunResult = Tensor; + fn preprocess() -> Result { + let lhs = Tensor::randn(0f32, 1., (1, 32, 2000, 128), &Device::Cpu)?; + let rhs = Tensor::randn(0f32, 1., (1, 32, 1, 128), &Device::Cpu)?; + Ok((lhs, rhs)) + } + + fn run_one(d: &Self::PreProcessData) -> Result { + Tensor::cat(&[&d.0, &d.1], 2) + } + + const ITERS: usize = 1000; +} + struct Softmax; impl Benchmark for Softmax { type PreProcessData = Tensor; @@ -295,6 +312,7 @@ enum Task { Qmatmul, Softmax, SoftmaxLastDim, + Cat, } #[derive(Parser, Debug)] @@ -319,6 +337,7 @@ fn main() -> Result<()> { Task::Softmax => run::(args.iters)?, Task::SoftmaxLastDim => run::(args.iters)?, Task::Qmatmul => run::(args.iters)?, + Task::Cat => run::(args.iters)?, } Ok(()) } diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 1bcb78d9..5c0fbb37 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -12,6 +12,7 @@ pub mod loss; pub mod ops; pub mod optim; pub mod rnn; +pub mod rotary_emb; pub mod sequential; pub mod var_builder; pub mod var_map; diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index fdd67142..1dac8c3b 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1,4 +1,4 @@ -use candle::{CpuStorage, Layout, Result, Shape, Tensor}; +use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor}; use rayon::prelude::*; /// Applies the softmax function to the input tensor, rescaling the element so that elements on @@ -74,7 +74,7 @@ pub fn dropout(xs: &Tensor, drop_p: f32) -> Result { xs * mask } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Dropout { drop_p: f32, } @@ -180,11 +180,10 @@ impl candle::CustomOp1 for SoftmaxLastDim { block_dim: (1, 32, 1), shared_mem_bytes: 0, }; - let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("softmax"), kernels::REDUCE)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(el) }.w()?; - let params = (src, &dst, n_cols as i32); + let params = (&src, &dst, n_cols as i32); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(dst) @@ -207,7 +206,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { storage: &candle::MetalStorage, layout: &Layout, ) -> Result<(candle::MetalStorage, Shape)> { - use candle::{backend::BackendStorage, DType}; + use candle::backend::BackendStorage; let device = storage.device(); let command_buffer = device.command_buffer()?; let kernels = device.kernels(); @@ -237,7 +236,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { layout.start_offset() * storage.dtype().size_in_bytes(), &output, ) - .unwrap(); + .map_err(candle::Error::wrap)?; let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype()); Ok((newstorage, layout.shape().clone())) @@ -248,6 +247,215 @@ pub fn softmax_last_dim(xs: &Tensor) -> Result { xs.apply_op1_no_bwd(&SoftmaxLastDim) } +#[derive(Debug, Clone)] +struct RmsNorm { + eps: f32, +} + +impl candle::CustomOp2 for RmsNorm { + fn name(&self) -> &'static str { + "rms-norm" + } + + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + ) -> Result<(CpuStorage, Shape)> { + use candle::backend::BackendStorage; + + let eps = self.eps; + fn inner< + T: candle::WithDType + + num_traits::Float + + num_traits::AsPrimitive + + num_traits::FromPrimitive, + >( + src: &[T], + layout: &Layout, + alpha: &[T], + alpha_layout: &Layout, + eps: f32, + ) -> Result<(CpuStorage, Shape)> { + let src = match layout.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let alpha = match alpha_layout.contiguous_offsets() { + None => candle::bail!("alpha has to be contiguous"), + Some((o1, o2)) => &alpha[o1..o2], + }; + let el_count = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(dim_m1) + .zip(dst.par_chunks_mut(dim_m1)) + .for_each(|(src, dst)| { + let sum2 = src + .iter() + .map(|&v| { + let v = v.as_(); + v * v + }) + .sum::(); + let m = (sum2 / dim_m1 as f32 + eps).sqrt(); + let m = T::from_f32(m).unwrap_or_else(T::nan); + for ((d, s), alpha) in dst.iter_mut().zip(src.iter()).zip(alpha) { + *d = *s / m * *alpha + } + }); + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, Shape::from_dims(dims))) + } + + use CpuStorage as C; + match (s1, s2) { + (C::BF16(s1), C::BF16(s2)) => inner::(s1, l1, s2, l2, eps), + (C::F16(s1), C::F16(s2)) => inner::(s1, l1, s2, l2, eps), + (C::F32(s1), C::F32(s2)) => inner::(s1, l1, s2, l2, eps), + _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &candle::CudaStorage, + l1: &Layout, + s2: &candle::CudaStorage, + l2: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr}; + use candle::{CudaDevice, WithDType}; + + struct S { + eps: f32, + } + impl Map2 for S { + fn f( + &self, + src: &CudaSlice, + layout: &Layout, + alpha: &CudaSlice, + alpha_layout: &Layout, + dev: &CudaDevice, + ) -> Result> { + let src = match layout.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let alpha = match alpha_layout.contiguous_offsets() { + None => candle::bail!("alpha has to be contiguous"), + Some((o1, o2)) => alpha.slice(o1..o2), + }; + let el = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let (n_rows, n_cols) = (el / dim_m1, dim_m1); + + let cfg = LaunchConfig { + grid_dim: (n_rows as u32, 1, 1), + block_dim: (1024, 1, 1), + shared_mem_bytes: 0, + }; + let func = dev.get_or_load_func(&kernel_name::("rmsnorm"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(el) }.w()?; + let params = (&src, &dst, &alpha, n_cols as i32, self.eps); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + } + + use candle::backend::BackendStorage; + let dev = s1.device(); + let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, dev)?; + let dst = candle::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, l1.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + s1: &candle::MetalStorage, + l1: &Layout, + s2: &candle::MetalStorage, + l2: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::backend::BackendStorage; + let device = s1.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + let name = match (s1.dtype(), s2.dtype()) { + (DType::F32, DType::F32) => "rmsnorm_f32", + (DType::F16, DType::F16) => "rmsnorm_f16", + (DType::BF16, DType::BF16) => "rmsnorm_bf16", + (dt1, dt2) => candle::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"), + }; + + if !(l1.is_contiguous() && l2.is_contiguous()) { + candle::bail!("Non contiguous rmsnorm is not implemented"); + } + + let last_dim = l1.dims()[l1.shape().rank() - 1]; + let elem_count = l1.shape().elem_count(); + let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?; + candle_metal_kernels::call_rms_norm( + device.metal_device(), + &command_buffer, + kernels, + name, + elem_count, + last_dim, + self.eps, + s1.buffer(), + l1.start_offset() * s1.dtype().size_in_bytes(), + s2.buffer(), + l2.start_offset() * s2.dtype().size_in_bytes(), + &output, + ) + .map_err(candle::Error::wrap)?; + let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype()); + Ok((newstorage, l1.shape().clone())) + } +} + +pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result { + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = x.dim(candle::D::Minus1)?; + let x = x.to_dtype(internal_dtype)?; + let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?; + x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha) +} + +pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result { + let hidden_size_xs = xs.dim(candle::D::Minus1)?; + let hidden_size_alpha = alpha.dims1()?; + if hidden_size_xs != hidden_size_alpha { + candle::bail!( + "shape mismatch in rms-norm {:?} {:?}", + xs.shape(), + alpha.shape() + ) + } + xs.apply_op2_no_bwd(alpha, &RmsNorm { eps }) +} + // https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result { let (b_size, c, h, w) = xs.dims4()?; diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index 07795eda..dbfa639b 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -31,7 +31,7 @@ pub trait RNN { let (_b_size, seq_len, _features) = input.dims3()?; let mut output = Vec::with_capacity(seq_len); for seq_index in 0..seq_len { - let input = input.i((.., seq_index, ..))?; + let input = input.i((.., seq_index, ..))?.contiguous()?; let state = if seq_index == 0 { self.step(&input, init_state)? } else { diff --git a/candle-nn/src/rotary_emb.rs b/candle-nn/src/rotary_emb.rs new file mode 100644 index 00000000..1084cfb5 --- /dev/null +++ b/candle-nn/src/rotary_emb.rs @@ -0,0 +1,730 @@ +use candle::{CpuStorage, Layout, Result, Shape, Tensor, D}; +use rayon::prelude::*; + +/// Interleaved variant of rotary embeddings. +/// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension. +/// The resulting y0 and y1 are also interleaved with: +/// y0 = x0*cos - x1*sin +/// y1 = x0*sin + x1*cos +#[derive(Debug, Clone)] +struct RotaryEmbI; + +impl candle::CustomOp3 for RotaryEmbI { + fn name(&self) -> &'static str { + "rotary-emb-int" + } + + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + s3: &CpuStorage, + l3: &Layout, + ) -> Result<(CpuStorage, Shape)> { + fn inner( + src: &[T], + l_src: &Layout, + cos: &[T], + l_cos: &Layout, + sin: &[T], + l_sin: &Layout, + ) -> Result<(CpuStorage, Shape)> { + let src = match l_src.contiguous_offsets() { + None => candle::bail!("input src has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let cos = match l_cos.contiguous_offsets() { + None => candle::bail!("input cos has to be contiguous"), + Some((o1, o2)) => &cos[o1..o2], + }; + let sin = match l_sin.contiguous_offsets() { + None => candle::bail!("input sin has to be contiguous"), + Some((o1, o2)) => &sin[o1..o2], + }; + let (b, h, t, d) = l_src.shape().dims4()?; + let el_count = b * h * t * d; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(t * d) + .zip(dst.par_chunks_mut(t * d)) + .for_each(|(src, dst)| { + for i_over_2 in 0..t * d / 2 { + let i = 2 * i_over_2; + dst[i] = src[i] * cos[i_over_2] - src[i + 1] * sin[i_over_2]; + dst[i + 1] = src[i] * sin[i_over_2] + src[i + 1] * cos[i_over_2]; + } + }); + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, (b, h, t, d).into())) + } + + use candle::backend::BackendStorage; + use CpuStorage::{BF16, F16, F32, F64}; + match (s1, s2, s3) { + (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3), + _ => candle::bail!( + "unsupported dtype for rope {:?} {:?} {:?}", + s1.dtype(), + s2.dtype(), + s3.dtype() + ), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &candle::CudaStorage, + l1: &Layout, + s2: &candle::CudaStorage, + l2: &Layout, + s3: &candle::CudaStorage, + l3: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use candle::cuda_backend::{kernel_name, kernels, WrapErr}; + use candle::{CudaDevice, WithDType}; + + fn inner( + src: &CudaSlice, + l_src: &Layout, + cos: &CudaSlice, + l_cos: &Layout, + sin: &CudaSlice, + l_sin: &Layout, + dev: &CudaDevice, + ) -> Result> { + let src = match l_src.contiguous_offsets() { + None => candle::bail!("src input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let cos = match l_cos.contiguous_offsets() { + None => candle::bail!("cos input has to be contiguous"), + Some((o1, o2)) => cos.slice(o1..o2), + }; + let sin = match l_sin.contiguous_offsets() { + None => candle::bail!("sin input has to be contiguous"), + Some((o1, o2)) => sin.slice(o1..o2), + }; + let (b, h, t, d) = l_src.shape().dims4()?; + let el = b * h * t * d; + let cfg = LaunchConfig::for_num_elems((el / 2) as u32); + let func = dev.get_or_load_func(&kernel_name::("rope_i"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(el) }.w()?; + let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + + use candle::backend::BackendStorage; + use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64}; + let dev = s1.device(); + let slice = match (&s1.slice, &s2.slice, &s3.slice) { + (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?), + _ => candle::bail!( + "unsupported dtype for rope {:?} {:?} {:?}", + s1.dtype(), + s2.dtype(), + s3.dtype() + ), + }; + let dst = candle::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, l1.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + src: &candle::MetalStorage, + l_src: &Layout, + cos: &candle::MetalStorage, + l_cos: &Layout, + sin: &candle::MetalStorage, + l_sin: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::backend::BackendStorage; + let device = src.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() { + candle::bail!( + "dtype mismatch in rope-i {:?} {:?} {:?}", + src.dtype(), + cos.dtype(), + sin.dtype() + ) + } + let name = match src.dtype() { + candle::DType::F32 => "rope_i_f32", + candle::DType::F16 => "rope_i_f16", + candle::DType::BF16 => "rope_i_bf16", + dtype => candle::bail!("rope-i is not implemented for {dtype:?}"), + }; + let (b, h, t, d) = l_src.shape().dims4()?; + let el = b * h * t * d; + let output = device.new_buffer(el, src.dtype(), "rope-i")?; + candle_metal_kernels::call_rope_i( + device.metal_device(), + &command_buffer, + kernels, + name, + b * h, + t * d, + src.buffer(), + l_src.start_offset() * src.dtype().size_in_bytes(), + cos.buffer(), + l_cos.start_offset() * cos.dtype().size_in_bytes(), + sin.buffer(), + l_sin.start_offset() * sin.dtype().size_in_bytes(), + &output, + ) + .map_err(candle::Error::wrap)?; + let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype()); + Ok((out, l_src.shape().clone())) + } +} + +pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?; + let (cos_seq_len, cos_n_embd) = cos.dims2()?; + let (sin_seq_len, sin_n_embd) = cos.dims2()?; + if cos_n_embd * 2 != n_embd + || sin_n_embd * 2 != n_embd + || seq_len > cos_seq_len + || seq_len > sin_seq_len + { + candle::bail!( + "inconsistent last dim size in rope {:?} {:?} {:?}", + xs.shape(), + cos.shape(), + sin.shape() + ) + } + if !xs.is_contiguous() { + candle::bail!("xs has to be contiguous in rope") + } + if !cos.is_contiguous() { + candle::bail!("cos has to be contiguous in rope") + } + if !sin.is_contiguous() { + candle::bail!("sin has to be contiguous in rope") + } + xs.apply_op3_no_bwd(cos, sin, &RotaryEmbI) +} + +pub fn rope_i_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; + let cos = cos + .narrow(0, 0, seq_len)? + .reshape((seq_len, n_embd / 2, 1))?; + let sin = sin + .narrow(0, 0, seq_len)? + .reshape((seq_len, n_embd / 2, 1))?; + let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; + let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; + let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; + let x0 = x.narrow(D::Minus1, 0, 1)?; + let x1 = x.narrow(D::Minus1, 1, 1)?; + let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; + let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; + let rope = Tensor::cat(&[y0, y1], D::Minus1)?; + let rope = rope.flatten_from(D::Minus2)?; + Ok(rope) +} + +/// Contiguous variant of rope embeddings. +#[derive(Debug, Clone)] +struct RotaryEmb; + +impl candle::CustomOp3 for RotaryEmb { + fn name(&self) -> &'static str { + "rotary-emb" + } + + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + s3: &CpuStorage, + l3: &Layout, + ) -> Result<(CpuStorage, Shape)> { + fn inner( + src: &[T], + l_src: &Layout, + cos: &[T], + l_cos: &Layout, + sin: &[T], + l_sin: &Layout, + ) -> Result<(CpuStorage, Shape)> { + let src = match l_src.contiguous_offsets() { + None => candle::bail!("input src has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let cos = match l_cos.contiguous_offsets() { + None => candle::bail!("input cos has to be contiguous"), + Some((o1, o2)) => &cos[o1..o2], + }; + let sin = match l_sin.contiguous_offsets() { + None => candle::bail!("input sin has to be contiguous"), + Some((o1, o2)) => &sin[o1..o2], + }; + let (b, h, t, d) = l_src.shape().dims4()?; + let el_count = b * h * t * d; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(t * d) + .zip(dst.par_chunks_mut(t * d)) + .for_each(|(src, dst)| { + for i_t in 0..t { + for i_d in 0..d / 2 { + let i1 = i_t * d + i_d; + let i2 = i1 + d / 2; + let i_cs = i_t * (d / 2) + i_d; + dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs]; + dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs]; + } + } + }); + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, (b, h, t, d).into())) + } + + use candle::backend::BackendStorage; + use CpuStorage::{BF16, F16, F32, F64}; + match (s1, s2, s3) { + (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3), + _ => candle::bail!( + "unsupported dtype for rope {:?} {:?} {:?}", + s1.dtype(), + s2.dtype(), + s3.dtype() + ), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &candle::CudaStorage, + l1: &Layout, + s2: &candle::CudaStorage, + l2: &Layout, + s3: &candle::CudaStorage, + l3: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use candle::cuda_backend::{kernel_name, kernels, WrapErr}; + use candle::{CudaDevice, WithDType}; + + fn inner( + src: &CudaSlice, + l_src: &Layout, + cos: &CudaSlice, + l_cos: &Layout, + sin: &CudaSlice, + l_sin: &Layout, + dev: &CudaDevice, + ) -> Result> { + let src = match l_src.contiguous_offsets() { + None => candle::bail!("src input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let cos = match l_cos.contiguous_offsets() { + None => candle::bail!("cos input has to be contiguous"), + Some((o1, o2)) => cos.slice(o1..o2), + }; + let sin = match l_sin.contiguous_offsets() { + None => candle::bail!("sin input has to be contiguous"), + Some((o1, o2)) => sin.slice(o1..o2), + }; + let (b, h, t, d) = l_src.shape().dims4()?; + let el = b * h * t * d; + let cfg = LaunchConfig::for_num_elems((el / 2) as u32); + let func = dev.get_or_load_func(&kernel_name::("rope"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(el) }.w()?; + let params = ( + &src, + &cos, + &sin, + &dst, + (b * h) as u32, + (t * d) as u32, + d as u32, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + + use candle::backend::BackendStorage; + use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64}; + let dev = s1.device(); + let slice = match (&s1.slice, &s2.slice, &s3.slice) { + (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?), + _ => candle::bail!( + "unsupported dtype for rope {:?} {:?} {:?}", + s1.dtype(), + s2.dtype(), + s3.dtype() + ), + }; + let dst = candle::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, l1.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + src: &candle::MetalStorage, + l_src: &Layout, + cos: &candle::MetalStorage, + l_cos: &Layout, + sin: &candle::MetalStorage, + l_sin: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::backend::BackendStorage; + let device = src.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() { + candle::bail!( + "dtype mismatch in rope {:?} {:?} {:?}", + src.dtype(), + cos.dtype(), + sin.dtype() + ) + } + let name = match src.dtype() { + candle::DType::F32 => "rope_f32", + candle::DType::F16 => "rope_f16", + candle::DType::BF16 => "rope_bf16", + dtype => candle::bail!("rope is not implemented for {dtype:?}"), + }; + let (b, h, t, d) = l_src.shape().dims4()?; + let el = b * h * t * d; + let output = device.new_buffer(el, src.dtype(), "rope-i")?; + candle_metal_kernels::call_rope( + device.metal_device(), + &command_buffer, + kernels, + name, + b * h, + t * d, + d, + src.buffer(), + l_src.start_offset() * src.dtype().size_in_bytes(), + cos.buffer(), + l_cos.start_offset() * cos.dtype().size_in_bytes(), + sin.buffer(), + l_sin.start_offset() * sin.dtype().size_in_bytes(), + &output, + ) + .map_err(candle::Error::wrap)?; + let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype()); + Ok((out, l_src.shape().clone())) + } +} + +pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?; + let (cos_seq_len, cos_n_embd) = cos.dims2()?; + let (sin_seq_len, sin_n_embd) = sin.dims2()?; + if cos_n_embd * 2 != n_embd + || sin_n_embd * 2 != n_embd + || seq_len > cos_seq_len + || seq_len > sin_seq_len + { + candle::bail!( + "inconsistent last dim size in rope {:?} {:?} {:?}", + xs.shape(), + cos.shape(), + sin.shape() + ) + } + if !xs.is_contiguous() { + candle::bail!("xs has to be contiguous in rope") + } + if !cos.is_contiguous() { + candle::bail!("cos has to be contiguous in rope") + } + if !sin.is_contiguous() { + candle::bail!("sin has to be contiguous in rope") + } + xs.apply_op3_no_bwd(cos, sin, &RotaryEmb) +} + +fn rotate_half(xs: &Tensor) -> Result { + let last_dim = xs.dim(D::Minus1)?; + let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; + let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; + Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) +} + +pub fn rope_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let (_b_sz, _h, seq_len, _n_embd) = x.dims4()?; + let cos = Tensor::cat(&[cos, cos], D::Minus1)?; + let sin = Tensor::cat(&[sin, sin], D::Minus1)?; + let cos = cos.narrow(0, 0, seq_len)?; + let sin = sin.narrow(0, 0, seq_len)?; + let cos = cos.unsqueeze(0)?.unsqueeze(0)?; + let sin = sin.unsqueeze(0)?.unsqueeze(0)?; + x.broadcast_mul(&cos)? + rotate_half(x)?.broadcast_mul(&sin)? +} + +/// T (seqlen)/H (num-heads)/D (head-dim) contiguous variant of rope embeddings. +#[derive(Debug, Clone)] +struct RotaryEmbThd; + +impl candle::CustomOp3 for RotaryEmbThd { + fn name(&self) -> &'static str { + "rotary-emb" + } + + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + s3: &CpuStorage, + l3: &Layout, + ) -> Result<(CpuStorage, Shape)> { + fn inner( + src: &[T], + l_src: &Layout, + cos: &[T], + l_cos: &Layout, + sin: &[T], + l_sin: &Layout, + ) -> Result<(CpuStorage, Shape)> { + let src = match l_src.contiguous_offsets() { + None => candle::bail!("input src has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let cos = match l_cos.contiguous_offsets() { + None => candle::bail!("input cos has to be contiguous"), + Some((o1, o2)) => &cos[o1..o2], + }; + let sin = match l_sin.contiguous_offsets() { + None => candle::bail!("input sin has to be contiguous"), + Some((o1, o2)) => &sin[o1..o2], + }; + let (b, t, h, d) = l_src.shape().dims4()?; + let el_count = b * h * t * d; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(t * h * d) + .zip(dst.par_chunks_mut(t * h * d)) + .for_each(|(src, dst)| { + for i_t in 0..t { + for i_d in 0..d / 2 { + let i_cs = i_t * (d / 2) + i_d; + for i_h in 0..h { + let i1 = i_t * h * d + i_h * d + i_d; + let i2 = i1 + d / 2; + dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs]; + dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs]; + } + } + } + }); + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, (b, t, h, d).into())) + } + + use candle::backend::BackendStorage; + use CpuStorage::{BF16, F16, F32, F64}; + match (s1, s2, s3) { + (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3), + _ => candle::bail!( + "unsupported dtype for rope {:?} {:?} {:?}", + s1.dtype(), + s2.dtype(), + s3.dtype() + ), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &candle::CudaStorage, + l1: &Layout, + s2: &candle::CudaStorage, + l2: &Layout, + s3: &candle::CudaStorage, + l3: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use candle::cuda_backend::{kernel_name, kernels, WrapErr}; + use candle::{CudaDevice, WithDType}; + + fn inner( + src: &CudaSlice, + l_src: &Layout, + cos: &CudaSlice, + l_cos: &Layout, + sin: &CudaSlice, + l_sin: &Layout, + dev: &CudaDevice, + ) -> Result> { + let src = match l_src.contiguous_offsets() { + None => candle::bail!("src input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let cos = match l_cos.contiguous_offsets() { + None => candle::bail!("cos input has to be contiguous"), + Some((o1, o2)) => cos.slice(o1..o2), + }; + let sin = match l_sin.contiguous_offsets() { + None => candle::bail!("sin input has to be contiguous"), + Some((o1, o2)) => sin.slice(o1..o2), + }; + let (b, t, h, d) = l_src.shape().dims4()?; + let el = b * h * t * d; + let cfg = LaunchConfig::for_num_elems((el / 2) as u32); + let func = dev.get_or_load_func(&kernel_name::("rope_thd"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(el) }.w()?; + let params = ( + &src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + + use candle::backend::BackendStorage; + use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64}; + let dev = s1.device(); + let slice = match (&s1.slice, &s2.slice, &s3.slice) { + (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?), + _ => candle::bail!( + "unsupported dtype for rope {:?} {:?} {:?}", + s1.dtype(), + s2.dtype(), + s3.dtype() + ), + }; + let dst = candle::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, l1.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + src: &candle::MetalStorage, + l_src: &Layout, + cos: &candle::MetalStorage, + l_cos: &Layout, + sin: &candle::MetalStorage, + l_sin: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::backend::BackendStorage; + let device = src.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() { + candle::bail!( + "dtype mismatch in rope {:?} {:?} {:?}", + src.dtype(), + cos.dtype(), + sin.dtype() + ) + } + let name = match src.dtype() { + candle::DType::F32 => "rope_thd_f32", + candle::DType::F16 => "rope_thd_f16", + candle::DType::BF16 => "rope_thd_bf16", + dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"), + }; + let (b, t, h, d) = l_src.shape().dims4()?; + let el = b * h * t * d; + let output = device.new_buffer(el, src.dtype(), "rope-thd")?; + candle_metal_kernels::call_rope_thd( + device.metal_device(), + &command_buffer, + kernels, + name, + b, + t, + h, + d, + src.buffer(), + l_src.start_offset() * src.dtype().size_in_bytes(), + cos.buffer(), + l_cos.start_offset() * cos.dtype().size_in_bytes(), + sin.buffer(), + l_sin.start_offset() * sin.dtype().size_in_bytes(), + &output, + ) + .map_err(candle::Error::wrap)?; + let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype()); + Ok((out, l_src.shape().clone())) + } +} + +pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?; + let (cos_seq_len, cos_n_embd) = cos.dims2()?; + let (sin_seq_len, sin_n_embd) = sin.dims2()?; + if cos_n_embd * 2 != n_embd + || sin_n_embd * 2 != n_embd + || seq_len > cos_seq_len + || seq_len > sin_seq_len + { + candle::bail!( + "inconsistent last dim size in rope {:?} {:?} {:?}", + xs.shape(), + cos.shape(), + sin.shape() + ) + } + if !xs.is_contiguous() { + candle::bail!("xs has to be contiguous in rope") + } + if !cos.is_contiguous() { + candle::bail!("cos has to be contiguous in rope") + } + if !sin.is_contiguous() { + candle::bail!("sin has to be contiguous in rope") + } + xs.apply_op3_no_bwd(cos, sin, &RotaryEmbThd) +} diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index bf090219..7de46044 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -178,16 +178,27 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { name: &str, hints: B::Hints, ) -> Result { - let path = self.path(name); - self.data - .backend - .get(s.into(), &path, hints, self.data.dtype, &self.data.device) + self.get_with_hints_dtype(s, name, hints, self.data.dtype) } /// Retrieve the tensor associated with the given name at the current path. pub fn get>(&self, s: S, name: &str) -> Result { self.get_with_hints(s, name, Default::default()) } + + /// Retrieve the tensor associated with the given name & dtype at the current path. + pub fn get_with_hints_dtype>( + &self, + s: S, + name: &str, + hints: B::Hints, + dtype: DType, + ) -> Result { + let path = self.path(name); + self.data + .backend + .get(s.into(), &path, hints, dtype, &self.data.device) + } } struct Zeros; diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index 5ca01b37..24a49d06 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -4,11 +4,9 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle::{test_utils::to_vec3_round, Device, Result, Tensor}; +use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor}; -#[test] -fn softmax() -> Result<()> { - let device = &Device::Cpu; +fn softmax(device: &Device) -> Result<()> { let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; let tensor = Tensor::new(data, device)?; let t0 = candle_nn::ops::softmax(&tensor.log()?, 0)?; @@ -54,6 +52,31 @@ fn softmax() -> Result<()> { Ok(()) } +fn rms_norm(device: &Device) -> Result<()> { + let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; + let tensor = Tensor::new(data, device)?; + let alpha = Tensor::new(&[1f32, 2f32, 3f32], device)?; + let t = candle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?; + assert_eq!( + to_vec3_round(&t, 4)?, + &[ + [[1.019, 0.6794, 4.0762], [0.1674, 1.6744, 4.521]], + [[0.4714, 0.4714, 4.9497], [1.206, 0.603, 3.6181]] + ] + ); + let t2 = candle_nn::ops::rms_norm_slow(&tensor, &alpha, 1e-5)?; + assert_eq!( + to_vec3_round(&t2, 4)?, + &[ + [[1.019, 0.6794, 4.0762], [0.1674, 1.6744, 4.521]], + [[0.4714, 0.4714, 4.9497], [1.206, 0.603, 3.6181]] + ] + ); + let diff = (t - t2)?.abs()?.sum_all()?.to_vec0::()?; + assert!(diff < 1e-5); + Ok(()) +} + #[test] fn softmax_numerical_stability() -> Result<()> { let dev = &Device::Cpu; @@ -62,3 +85,93 @@ fn softmax_numerical_stability() -> Result<()> { assert_eq!(softmax.to_vec1::()?, &[1f32, 0.]); Ok(()) } + +fn ropei(device: &Device) -> Result<()> { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16); + let el_count = b_size * num_head * seq_len * head_dim; + let mut rng = StdRng::seed_from_u64(299792458); + let src: Vec = (0..el_count).map(|_| rng.gen::()).collect(); + let cos: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::()) + .collect(); + let sin: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::()) + .collect(); + let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?; + let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?; + let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?; + let rope1 = candle_nn::rotary_emb::rope_i(&src, &cos, &sin)?; + let rope2 = candle_nn::rotary_emb::rope_i_slow(&src, &cos, &sin)?; + let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::()?; + if device.is_cpu() { + assert_eq!(sum_diff, 0.); + } else { + assert!(sum_diff < 1e-4); + } + Ok(()) +} + +fn rope(device: &Device) -> Result<()> { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16); + let el_count = b_size * num_head * seq_len * head_dim; + let mut rng = StdRng::seed_from_u64(299792458); + let src: Vec = (0..el_count).map(|_| rng.gen::()).collect(); + let cos: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::()) + .collect(); + let sin: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::()) + .collect(); + let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?; + let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?; + let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?; + let rope1 = candle_nn::rotary_emb::rope(&src, &cos, &sin)?; + let rope2 = candle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?; + let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::()?; + if device.is_cpu() { + assert_eq!(sum_diff, 0.); + } else { + assert!(sum_diff < 1e-4); + } + Ok(()) +} + +fn rope_thd(device: &Device) -> Result<()> { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16); + let el_count = b_size * num_head * seq_len * head_dim; + let mut rng = StdRng::seed_from_u64(299792458); + let src: Vec = (0..el_count).map(|_| rng.gen::()).collect(); + let cos: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::()) + .collect(); + let sin: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::()) + .collect(); + let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?; + let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?; + let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?; + let rope1 = { + let src = src.transpose(1, 2)?.contiguous()?; + candle_nn::rotary_emb::rope_thd(&src, &cos, &sin)?.transpose(1, 2)? + }; + let rope2 = candle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?; + let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::()?; + if device.is_cpu() { + assert_eq!(sum_diff, 0.); + } else { + assert!(sum_diff < 1e-4); + } + Ok(()) +} + +test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal); +test_device!(rope, rope_cpu, rope_gpu, rope_metal); +test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal); +test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal); +test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal); diff --git a/candle-onnx/Cargo.toml b/candle-onnx/Cargo.toml index 9a75f802..2f438cda 100644 --- a/candle-onnx/Cargo.toml +++ b/candle-onnx/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-onnx" -version = "0.4.1" +version = "0.5.0" edition = "2021" description = "ONNX support for Candle" @@ -10,8 +10,8 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -candle = { path = "../candle-core", package = "candle-core", version = "0.4.1" } -candle-nn = { path = "../candle-nn", version = "0.4.1" } +candle = { path = "../candle-core", package = "candle-core", version = "0.5.0" } +candle-nn = { path = "../candle-nn", version = "0.5.0" } prost = "0.12.1" [build-dependencies] diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index cacb56ca..f7cae31c 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -508,17 +508,33 @@ pub fn simple_eval( values.insert(node.output[0].clone(), xs); } "Gather" => { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather let xs = get(&node.input[0])?; let indices = get(&node.input[1])?; let axis = get_attr_opt::(node, "axis")?.copied().unwrap_or(0); let axis = xs.normalize_axis(axis)?; - // TODO: Provide an op to handle the ONNX generalized gather op ideally in a - // differentiable way. - let xs = if indices.rank() == 0 { - let index = indices.to_vec0::()? as usize; - xs.narrow(axis, index, 1)?.squeeze(axis)? - } else { - todo!("implement gather for {xs:?} {indices:?} axis {axis}") + + // In Pytorch or Numpy this can be done by indexing the xs tensor using the indices + // tensor directly, but candle does not support tensor indexing at the moment, so + // some workarounds must be done. + let xs = match indices.dims() { + [] => { + let index = indices.to_vec0::()? as usize; + xs.narrow(axis, index, 1)?.squeeze(axis)? + } + [_] => xs.index_select(indices, axis)?, + [first, _] => { + let mut v = Vec::with_capacity(*first); + for i in 0..*first { + v.push(xs.index_select(&indices.get(i)?, axis)?) + } + Tensor::stack(&v, axis)? + } + _ => { + // TODO: Provide an op to handle the ONNX generalized gather op ideally in a + // differentiable way. + todo!("implement gather for {xs:?} {indices:?} axis {axis}") + } }; values.insert(node.output[0].clone(), xs); } @@ -776,6 +792,11 @@ pub fn simple_eval( let output = input.reshape(new_shape)?; values.insert(node.output[0].clone(), output); } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#identity + "Identity" => { + let input = get(&node.input[0])?; + values.insert(node.output[0].clone(), input.clone()); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index a686f198..18cd53c9 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -4,7 +4,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle::{Device, Result, Tensor}; +use candle::{Device, NdArray, Result, Tensor}; use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto}; use std::collections::HashMap; @@ -829,7 +829,134 @@ fn test_flatten_operation() -> Result<()> { // #[test] // "Gather" -// #[test] +#[test] +fn test_gather_operation() -> Result<()> { + // test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary. + test( + &[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], + &[[0i64, 1], [1, 2]], + 0, + &[[[1.0, 1.2], [2.3, 3.4]], [[2.3, 3.4], [4.5, 5.7]]], + )?; + + // test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary. + test( + &[[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9]], + &[[0i64, 2]], + 1, + &[[[1.0, 1.9]], [[2.3, 3.9]], [[4.5, 5.9]]], + )?; + + // all the tests below are generated from numpy.take, which works like + // onnx's Gather operation. + test(&[1.0, 2.0, 3.0, 4.0], 3i64, 0, 4.0)?; + + test(&[[1.0, 2.0, 3.0, 4.0]], 3i64, 1, &[4.0])?; + + test( + &[[1.0], [2.0], [3.0], [4.0]], + &[3i64, 2], + 0, + &[[4.0], [3.0]], + )?; + + test( + &[ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ], + 1i64, + 0, + &[[5.0, 6.0], [7.0, 8.0]], + )?; + + test( + &[ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ], + &[1i64, 0], + 0, + &[[[5.0, 6.0], [7.0, 8.0]], [[1.0, 2.0], [3.0, 4.0]]], + )?; + + fn test( + data: impl NdArray, + indices: impl NdArray, + axis: i64, + expected: impl NdArray, + ) -> Result<()> { + let att_axis = AttributeProto { + name: "axis".to_string(), + ref_attr_name: "axis".to_string(), + i: axis, + doc_string: "axis".to_string(), + r#type: 2, + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Gather".to_string(), + domain: "".to_string(), + attribute: vec![att_axis], + input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); + inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let expected = Tensor::new(expected, &Device::Cpu)?; + match expected.dims().len() { + 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), + 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), + 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), + 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), + _ => unreachable!(), + }; + + Ok(()) + } + Ok(()) +} // "Shape" #[test] diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 7c6fbd68..88001334 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -20,10 +20,10 @@ candle-nn = { workspace = true } candle-onnx = { workspace = true, optional = true } half = { workspace = true } intel-mkl-src = { workspace = true, optional = true } -pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] } +pyo3 = { version = "0.21.0", features = ["extension-module", "abi3-py38"] } [build-dependencies] -pyo3-build-config = "0.20" +pyo3-build-config = "0.21" [features] default = [] diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi index aef0707d..b0f05de5 100644 --- a/candle-pyo3/py_src/candle/__init__.pyi +++ b/candle-pyo3/py_src/candle/__init__.pyi @@ -324,6 +324,12 @@ class Tensor: """ pass + def gather(self, index, dim): + """ + Gathers values along an axis specified by dim. + """ + pass + def get(self, index: int) -> Tensor: """ Gets the value at the specified index. diff --git a/candle-pyo3/py_src/candle/nn/__init__.pyi b/candle-pyo3/py_src/candle/nn/__init__.pyi new file mode 100644 index 00000000..118c4cff --- /dev/null +++ b/candle-pyo3/py_src/candle/nn/__init__.pyi @@ -0,0 +1,19 @@ +# Generated content DO NOT EDIT +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence +from os import PathLike +from candle.typing import _ArrayLike, Device, Scalar, Index, Shape +from candle import Tensor, DType, QTensor + +@staticmethod +def silu(tensor: Tensor) -> Tensor: + """ + Applies the Sigmoid Linear Unit (SiLU) function to a given tensor. + """ + pass + +@staticmethod +def softmax(tensor: Tensor, dim: int) -> Tensor: + """ + Applies the Softmax function to a given tensor.# + """ + pass diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 7b9a7413..0da2c700 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -60,8 +60,8 @@ impl PyDType { impl PyDType { fn from_pyobject(ob: PyObject, py: Python<'_>) -> PyResult { use std::str::FromStr; - if let Ok(dtype) = ob.extract::<&str>(py) { - let dtype = DType::from_str(dtype) + if let Ok(dtype) = ob.extract::(py) { + let dtype = DType::from_str(&dtype) .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?; Ok(Self(dtype)) } else { @@ -116,8 +116,8 @@ impl PyDevice { impl<'source> FromPyObject<'source> for PyDevice { fn extract(ob: &'source PyAny) -> PyResult { - let device: &str = ob.extract()?; - let device = match device { + let device: String = ob.extract()?; + let device = match device.as_str() { "cpu" => PyDevice::Cpu, "cuda" => PyDevice::Cuda, _ => Err(PyTypeError::new_err(format!("invalid device '{device}'")))?, @@ -265,7 +265,7 @@ impl PyTensor { } else if let Ok(TorchTensor(numpy)) = data.extract::(py) { return PyTensor::new(py, numpy); } else { - let ty = data.as_ref(py).get_type(); + let ty = data.bind(py).get_type(); Err(PyTypeError::new_err(format!( "incorrect type {ty} for tensor" )))? @@ -322,7 +322,7 @@ impl PyTensor { fn to_torch(&self, py: Python<'_>) -> PyResult { let candle_values = self.values(py)?; let torch_tensor: PyObject = py - .import("torch")? + .import_bound("torch")? .getattr("tensor")? .call1((candle_values,))? .extract()?; @@ -333,7 +333,7 @@ impl PyTensor { /// Gets the tensor's shape. /// &RETURNS&: Tuple[int] fn shape(&self, py: Python<'_>) -> PyObject { - PyTuple::new(py, self.0.dims()).to_object(py) + PyTuple::new_bound(py, self.0.dims()).to_object(py) } #[getter] @@ -347,7 +347,7 @@ impl PyTensor { /// Gets the tensor's strides. /// &RETURNS&: Tuple[int] fn stride(&self, py: Python<'_>) -> PyObject { - PyTuple::new(py, self.0.stride()).to_object(py) + PyTuple::new_bound(py, self.0.stride()).to_object(py) } #[getter] @@ -448,6 +448,12 @@ impl PyTensor { Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?)) } + /// Gathers values along an axis specified by dim. + fn gather(&self, index: &Self, dim: i64) -> PyResult { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.gather(index, dim).map_err(wrap_err)?)) + } + #[pyo3(text_signature = "(self, rhs:Tensor)")] /// Performs a matrix multiplication between the two tensors. /// &RETURNS&: Tensor @@ -521,7 +527,7 @@ impl PyTensor { } fn extract_indexer( - py_indexer: &PyAny, + py_indexer: &Bound, current_dim: usize, dims: &[usize], index_argument_count: usize, @@ -561,7 +567,7 @@ impl PyTensor { ), current_dim + 1, )) - } else if py_indexer.is_ellipsis() { + } else if py_indexer.is(&py_indexer.py().Ellipsis()) { // Handle '...' e.g. tensor[..., 0] if current_dim > 0 { return Err(PyTypeError::new_err( @@ -580,7 +586,7 @@ impl PyTensor { } } - if let Ok(tuple) = idx.downcast::(py) { + if let Ok(tuple) = idx.downcast_bound::(py) { let not_none_count: usize = tuple.iter().filter(|x| !x.is_none()).count(); if not_none_count > dims.len() { @@ -590,12 +596,12 @@ impl PyTensor { let mut current_dim = 0; for item in tuple.iter() { let (indexer, new_current_dim) = - extract_indexer(item, current_dim, dims, not_none_count)?; + extract_indexer(&item, current_dim, dims, not_none_count)?; current_dim = new_current_dim; indexers.push(indexer); } } else { - let (indexer, _) = extract_indexer(idx.downcast::(py)?, 0, dims, 1)?; + let (indexer, _) = extract_indexer(idx.downcast_bound::(py)?, 0, dims, 1)?; indexers.push(indexer); } @@ -646,7 +652,7 @@ impl PyTensor { /// Add two tensors. /// &RETURNS&: Tensor - fn __add__(&self, rhs: &PyAny) -> PyResult { + fn __add__(&self, rhs: &Bound) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { self.0.broadcast_add(&rhs.0).map_err(wrap_err)? } else if let Ok(rhs) = rhs.extract::() { @@ -657,13 +663,13 @@ impl PyTensor { Ok(Self(tensor)) } - fn __radd__(&self, rhs: &PyAny) -> PyResult { + fn __radd__(&self, rhs: &Bound) -> PyResult { self.__add__(rhs) } /// Multiply two tensors. /// &RETURNS&: Tensor - fn __mul__(&self, rhs: &PyAny) -> PyResult { + fn __mul__(&self, rhs: &Bound) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { self.0.broadcast_mul(&rhs.0).map_err(wrap_err)? } else if let Ok(rhs) = rhs.extract::() { @@ -674,13 +680,13 @@ impl PyTensor { Ok(Self(tensor)) } - fn __rmul__(&self, rhs: &PyAny) -> PyResult { + fn __rmul__(&self, rhs: &Bound) -> PyResult { self.__mul__(rhs) } /// Subtract two tensors. /// &RETURNS&: Tensor - fn __sub__(&self, rhs: &PyAny) -> PyResult { + fn __sub__(&self, rhs: &Bound) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { self.0.broadcast_sub(&rhs.0).map_err(wrap_err)? } else if let Ok(rhs) = rhs.extract::() { @@ -693,7 +699,7 @@ impl PyTensor { /// Divide two tensors. /// &RETURNS&: Tensor - fn __truediv__(&self, rhs: &PyAny) -> PyResult { + fn __truediv__(&self, rhs: &Bound) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { self.0.broadcast_div(&rhs.0).map_err(wrap_err)? } else if let Ok(rhs) = rhs.extract::() { @@ -705,7 +711,7 @@ impl PyTensor { } /// Rich-compare two tensors. /// &RETURNS&: Tensor - fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult { + fn __richcmp__(&self, rhs: &Bound, op: CompareOp) -> PyResult { let compare = |lhs: &Tensor, rhs: &Tensor| { let t = match op { CompareOp::Eq => lhs.eq(rhs), @@ -951,7 +957,7 @@ impl PyTensor { #[pyo3(signature = (*args, **kwargs), text_signature = "(self, *args, **kwargs)")] /// Performs Tensor dtype and/or device conversion. /// &RETURNS&: Tensor - fn to(&self, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult { + fn to(&self, args: &Bound, kwargs: Option<&Bound>) -> PyResult { let mut device: Option = None; let mut dtype: Option = None; let mut other: Option = None; @@ -1221,7 +1227,7 @@ impl PyQTensor { ///Gets the shape of the tensor. /// &RETURNS&: Tuple[int] fn shape(&self, py: Python<'_>) -> PyObject { - PyTuple::new(py, self.0.shape().dims()).to_object(py) + PyTuple::new_bound(py, self.0.shape().dims()).to_object(py) } fn __repr__(&self) -> String { @@ -1259,7 +1265,7 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult { .into_iter() .map(|(key, value)| (key, PyTensor(value).into_py(py))) .collect::>(); - Ok(res.into_py_dict(py).to_object(py)) + Ok(res.into_py_dict_bound(py).to_object(py)) } #[pyfunction] @@ -1297,7 +1303,7 @@ fn load_ggml( .map(|(key, qtensor)| Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py)))) .collect::<::candle::Result>>() .map_err(wrap_err)?; - let tensors = tensors.into_py_dict(py).to_object(py); + let tensors = tensors.into_py_dict_bound(py).to_object(py); let hparams = [ ("n_vocab", ggml.hparams.n_vocab), ("n_embd", ggml.hparams.n_embd), @@ -1307,7 +1313,7 @@ fn load_ggml( ("n_rot", ggml.hparams.n_rot), ("ftype", ggml.hparams.ftype), ]; - let hparams = hparams.into_py_dict(py).to_object(py); + let hparams = hparams.into_py_dict_bound(py).to_object(py); let vocab = ggml .vocab .token_score_pairs @@ -1345,7 +1351,7 @@ fn load_gguf( gguf_file::Value::Bool(x) => x.into_py(py), gguf_file::Value::String(x) => x.into_py(py), gguf_file::Value::Array(x) => { - let list = pyo3::types::PyList::empty(py); + let list = pyo3::types::PyList::empty_bound(py); for elem in x.iter() { list.append(gguf_value_to_pyobject(elem, py)?)?; } @@ -1365,13 +1371,13 @@ fn load_gguf( }) .collect::<::candle::Result>>() .map_err(wrap_err)?; - let tensors = tensors.into_py_dict(py).to_object(py); + let tensors = tensors.into_py_dict_bound(py).to_object(py); let metadata = gguf .metadata .iter() .map(|(key, value)| Ok((key, gguf_value_to_pyobject(value, py)?))) .collect::>>()? - .into_py_dict(py) + .into_py_dict_bound(py) .to_object(py); Ok((tensors, metadata)) } @@ -1384,7 +1390,7 @@ fn load_gguf( fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> { use ::candle::quantized::gguf_file; - fn pyobject_to_gguf_value(v: &PyAny, py: Python<'_>) -> PyResult { + fn pyobject_to_gguf_value(v: &Bound, py: Python<'_>) -> PyResult { let v: gguf_file::Value = if let Ok(x) = v.extract::() { gguf_file::Value::U8(x) } else if let Ok(x) = v.extract::() { @@ -1412,7 +1418,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) } else if let Ok(x) = v.extract::>() { let x = x .into_iter() - .map(|f| pyobject_to_gguf_value(f.as_ref(py), py)) + .map(|f| pyobject_to_gguf_value(f.bind(py), py)) .collect::>>()?; gguf_file::Value::Array(x) } else { @@ -1444,7 +1450,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) Ok(( key.extract::() .map_err(|_| PyErr::new::("keys must be strings"))?, - pyobject_to_gguf_value(value, py)?, + pyobject_to_gguf_value(&value.as_borrowed(), py)?, )) }) .collect::>>()?; @@ -1492,7 +1498,7 @@ fn get_num_threads() -> usize { ::candle::utils::get_num_threads() } -fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> { +fn candle_utils(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?; m.add_function(wrap_pyfunction!(get_num_threads, m)?)?; m.add_function(wrap_pyfunction!(has_accelerate, m)?)?; @@ -1573,7 +1579,7 @@ fn tanh(tensor: PyTensor) -> PyResult { Ok(PyTensor(s)) } -fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> { +fn candle_functional_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(silu, m)?)?; m.add_function(wrap_pyfunction!(softmax, m)?)?; m.add_function(wrap_pyfunction!(max_pool2d, m)?)?; @@ -1585,7 +1591,7 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> { } #[cfg(feature = "onnx")] -fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> { +fn candle_onnx_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { use onnx::{PyONNXModel, PyONNXTensorDescriptor}; m.add_class::()?; m.add_class::()?; @@ -1593,18 +1599,18 @@ fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> { } #[pymodule] -fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { - let utils = PyModule::new(py, "utils")?; - candle_utils(py, utils)?; - m.add_submodule(utils)?; - let nn = PyModule::new(py, "functional")?; - candle_functional_m(py, nn)?; - m.add_submodule(nn)?; +fn candle(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + let utils = PyModule::new_bound(py, "utils")?; + candle_utils(py, &utils)?; + m.add_submodule(&utils)?; + let nn = PyModule::new_bound(py, "functional")?; + candle_functional_m(py, &nn)?; + m.add_submodule(&nn)?; #[cfg(feature = "onnx")] { - let onnx = PyModule::new(py, "onnx")?; - candle_onnx_m(py, onnx)?; - m.add_submodule(onnx)?; + let onnx = PyModule::new_bound(py, "onnx")?; + candle_onnx_m(py, &onnx)?; + m.add_submodule(&onnx)?; } m.add_class::()?; m.add_class::()?; diff --git a/candle-pyo3/src/onnx.rs b/candle-pyo3/src/onnx.rs index b9a0eb22..a2e9a087 100644 --- a/candle-pyo3/src/onnx.rs +++ b/candle-pyo3/src/onnx.rs @@ -39,7 +39,7 @@ impl PyONNXTensorDescriptor { /// The shape of the tensor. /// &RETURNS&: Tuple[Union[int,str,Any]] fn shape(&self, py: Python) -> PyResult> { - let shape = PyList::empty(py); + let shape = PyList::empty_bound(py); if let Some(d) = &self.0.shape { for dim in d.dim.iter() { if let Some(value) = &dim.value { diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py index 165941bd..b0e472e6 100644 --- a/candle-pyo3/stub.py +++ b/candle-pyo3/stub.py @@ -206,6 +206,8 @@ def write(module, directory, origin, check=False): if check: with open(filename, "r") as f: data = f.read() + print("generated content") + print(pyi_content) assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`" else: with open(filename, "w") as f: @@ -229,6 +231,8 @@ def write(module, directory, origin, check=False): if check: with open(filename, "r") as f: data = f.read() + print("generated content") + print(py_content) assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`" else: with open(filename, "w") as f: diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index b1a567c3..c250a186 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -1,24 +1,36 @@ use candle::{DType, Error, Result, Tensor}; use rand::{distributions::Distribution, SeedableRng}; +#[derive(Clone, PartialEq, Debug)] +pub enum Sampling { + ArgMax, + All { temperature: f64 }, + TopK { k: usize, temperature: f64 }, + TopP { p: f64, temperature: f64 }, + TopKThenTopP { k: usize, p: f64, temperature: f64 }, +} + pub struct LogitsProcessor { rng: rand::rngs::StdRng, - temperature: Option, - top_p: Option, + sampling: Sampling, } impl LogitsProcessor { + pub fn from_sampling(seed: u64, sampling: Sampling) -> Self { + let rng = rand::rngs::StdRng::seed_from_u64(seed); + Self { rng, sampling } + } + pub fn new(seed: u64, temperature: Option, top_p: Option) -> Self { - let temperature = if temperature.map_or(true, |v| v < 1e-7) { - None - } else { - temperature + let temperature = temperature.and_then(|v| if v < 1e-7 { None } else { Some(v) }); + let sampling = match temperature { + None => Sampling::ArgMax, + Some(temperature) => match top_p { + None => Sampling::All { temperature }, + Some(p) => Sampling::TopP { p, temperature }, + }, }; - Self { - rng: rand::rngs::StdRng::seed_from_u64(seed), - temperature, - top_p, - } + Self::from_sampling(seed, sampling) } fn sample_argmax(&mut self, logits: Tensor) -> Result { @@ -38,14 +50,14 @@ impl LogitsProcessor { Ok(next_token) } + /// top-p sampling (or "nucleus sampling") samples from the smallest set of tokens that exceed + /// probability top_p. This way we never sample tokens that have very low probabilities and are + /// less likely to go "off the rails". fn sample_topp(&mut self, prs: &mut Vec, top_p: f32) -> Result { - // top-p sampling (or "nucleus sampling") samples from the smallest set of - // tokens that exceed probability top_p. This way we never sample tokens that - // have very low probabilities and are less likely to go "off the rails". let mut argsort_indices = (0..prs.len()).collect::>(); // Sort by descending probability. - argsort_indices.sort_by(|&i, &j| prs[j].partial_cmp(&prs[i]).unwrap()); + argsort_indices.sort_by(|&i, &j| prs[j].total_cmp(&prs[i])); // Clamp smaller probabilities to zero. let mut cumsum = 0.; @@ -60,23 +72,78 @@ impl LogitsProcessor { self.sample_multinomial(prs) } + // top-k sampling samples from the k tokens with the largest probabilities. + fn sample_topk(&mut self, prs: &mut Vec, top_k: usize) -> Result { + if top_k >= prs.len() { + self.sample_multinomial(prs) + } else { + let mut argsort_indices = (0..prs.len()).collect::>(); + let (indices, _, _) = + argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i])); + let prs = indices.iter().map(|&i| prs[i]).collect::>(); + let index = self.sample_multinomial(&prs)?; + Ok(indices[index as usize] as u32) + } + } + + // top-k sampling samples from the k tokens with the largest probabilities. + // then top-p sampling. + fn sample_topk_topp(&mut self, prs: &mut Vec, top_k: usize, top_p: f32) -> Result { + if top_k >= prs.len() { + self.sample_topp(prs, top_p) + } else { + let mut argsort_indices = (0..prs.len()).collect::>(); + let (indices, _, _) = + argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i])); + let mut prs = indices.iter().map(|&i| prs[i]).collect::>(); + let sum_p = prs.iter().sum::(); + let index = if top_p <= 0.0 || top_p >= sum_p { + self.sample_multinomial(&prs)? + } else { + self.sample_topp(&mut prs, top_p)? + }; + Ok(indices[index as usize] as u32) + } + } + pub fn sample(&mut self, logits: &Tensor) -> Result { + self.sample_f(logits, |_| {}) + } + + pub fn sample_f(&mut self, logits: &Tensor, f: impl FnOnce(&mut [f32])) -> Result { let logits = logits.to_dtype(DType::F32)?; - let next_token = match self.temperature { - None => self.sample_argmax(logits)?, - Some(temperature) => { - let logits = &(&logits / temperature)?; - let prs = candle_nn::ops::softmax_last_dim(logits)?; - let mut prs: Vec = prs.to_vec1()?; - let top_p = self.top_p.unwrap_or(1.); - if top_p <= 0.0 || top_p >= 1.0 { + let prs = |temperature: f64| -> Result> { + let logits = (&logits / temperature)?; + let prs = candle_nn::ops::softmax_last_dim(&logits)?; + let mut prs = prs.to_vec1()?; + f(&mut prs); + Ok(prs) + }; + + let next_token = match &self.sampling { + Sampling::ArgMax => self.sample_argmax(logits)?, + Sampling::All { temperature } => { + let prs = prs(*temperature)?; + self.sample_multinomial(&prs)? + } + Sampling::TopP { p, temperature } => { + let mut prs = prs(*temperature)?; + if *p <= 0.0 || *p >= 1.0 { // simply sample from the predicted probability distribution self.sample_multinomial(&prs)? } else { // top-p (nucleus) sampling, clamping the least likely tokens to zero - self.sample_topp(&mut prs, top_p as f32)? + self.sample_topp(&mut prs, *p as f32)? } } + Sampling::TopK { k, temperature } => { + let mut prs = prs(*temperature)?; + self.sample_topk(&mut prs, *k)? + } + Sampling::TopKThenTopP { k, p, temperature } => { + let mut prs = prs(*temperature)?; + self.sample_topk_topp(&mut prs, *k, *p as f32)? + } }; Ok(next_token) } diff --git a/candle-transformers/src/models/clip/mod.rs b/candle-transformers/src/models/clip/mod.rs new file mode 100644 index 00000000..9613fdab --- /dev/null +++ b/candle-transformers/src/models/clip/mod.rs @@ -0,0 +1,154 @@ +//! Contrastive Language-Image Pre-Training +//! +//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! https://github.com/openai/CLIP +//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip +use self::{ + text_model::{Activation, ClipTextTransformer}, + vision_model::ClipVisionTransformer, +}; +use candle::{Result, Tensor, D}; + +pub mod text_model; +pub mod vision_model; + +#[derive(Clone, Debug)] +pub struct ClipModel { + text_model: ClipTextTransformer, + vision_model: ClipVisionTransformer, + visual_projection: candle_nn::Linear, + text_projection: candle_nn::Linear, + logit_scale: Tensor, +} + +#[derive(Clone, Debug)] +pub enum EncoderConfig { + Text(text_model::ClipTextConfig), + Vision(vision_model::ClipVisionConfig), +} + +impl EncoderConfig { + pub fn embed_dim(&self) -> usize { + match self { + Self::Text(c) => c.embed_dim, + Self::Vision(c) => c.embed_dim, + } + } + + pub fn num_attention_heads(&self) -> usize { + match self { + Self::Text(c) => c.num_attention_heads, + Self::Vision(c) => c.num_attention_heads, + } + } + + pub fn intermediate_size(&self) -> usize { + match self { + Self::Text(c) => c.intermediate_size, + Self::Vision(c) => c.intermediate_size, + } + } + + pub fn num_hidden_layers(&self) -> usize { + match self { + Self::Text(c) => c.num_hidden_layers, + Self::Vision(c) => c.num_hidden_layers, + } + } + + pub fn activation(&self) -> Activation { + match self { + Self::Text(_c) => Activation::QuickGelu, + Self::Vision(c) => c.activation, + } + } +} + +#[derive(Clone, Debug)] +pub struct ClipConfig { + pub text_config: text_model::ClipTextConfig, + pub vision_config: vision_model::ClipVisionConfig, + pub logit_scale_init_value: f32, + pub image_size: usize, +} + +impl ClipConfig { + // base image size is 224, model size is 600Mb + pub fn vit_base_patch32() -> Self { + let text_config = text_model::ClipTextConfig::vit_base_patch32(); + let vision_config = vision_model::ClipVisionConfig::vit_base_patch32(); + + Self { + text_config, + vision_config, + logit_scale_init_value: 2.6592, + image_size: 224, + } + } +} + +impl ClipModel { + pub fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result { + let text_model = ClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?; + + let vision_model = ClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?; + + let visual_projection = candle_nn::linear_no_bias( + c.vision_config.embed_dim, + c.vision_config.projection_dim, + vs.pp("visual_projection"), + )?; + + let text_projection = candle_nn::linear_no_bias( + c.text_config.embed_dim, + c.text_config.projection_dim, + vs.pp("text_projection"), + )?; + + // originally nn.Parameter + let logit_scale = if vs.contains_tensor("logit_scale") { + vs.get(&[], "logit_scale")? + } else { + Tensor::new(&[c.logit_scale_init_value], vs.device())? + }; + + Ok(Self { + text_model, + vision_model, + visual_projection, + text_projection, + logit_scale, + }) + } + + pub fn get_text_features(&self, input_ids: &Tensor) -> Result { + input_ids + .apply(&self.text_model)? + .apply(&self.text_projection) + } + + pub fn get_image_features(&self, pixel_values: &Tensor) -> Result { + pixel_values + .apply(&self.vision_model)? + .apply(&self.visual_projection) + } + + pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> { + let image_features = self.get_image_features(pixel_values)?; + let text_features = self.get_text_features(input_ids)?; + let image_features_normalized = div_l2_norm(&image_features)?; + let text_features_normalized = div_l2_norm(&text_features)?; + let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?; + let logit_scale = self.logit_scale.exp()?; + let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?; + let logits_per_image = logits_per_text.t()?; + Ok((logits_per_text, logits_per_image)) + } +} + +pub fn div_l2_norm(v: &Tensor) -> Result { + let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?; + v.broadcast_div(&l2_norm) +} diff --git a/candle-transformers/src/models/clip/text_model.rs b/candle-transformers/src/models/clip/text_model.rs new file mode 100644 index 00000000..d3ba26ff --- /dev/null +++ b/candle-transformers/src/models/clip/text_model.rs @@ -0,0 +1,333 @@ +//! Contrastive Language-Image Pre-Training +//! +//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! https://github.com/openai/CLIP +//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip + +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn as nn; +use candle_nn::Module; + +use super::EncoderConfig; + +#[derive(Debug, Clone, Copy)] +pub enum Activation { + QuickGelu, +} + +impl Module for Activation { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?, + } + } +} + +#[derive(Debug, Clone)] +pub struct ClipTextConfig { + pub vocab_size: usize, + pub embed_dim: usize, + pub activation: Activation, + pub intermediate_size: usize, + pub max_position_embeddings: usize, + pub pad_with: Option, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + #[allow(dead_code)] + pub projection_dim: usize, +} + +impl ClipTextConfig { + // The config details can be found in the "text_config" section of this json file: + // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json + pub fn vit_base_patch32() -> Self { + Self { + vocab_size: 49408, + embed_dim: 512, + intermediate_size: 2048, + max_position_embeddings: 77, + pad_with: None, + num_hidden_layers: 12, + num_attention_heads: 8, + projection_dim: 512, + activation: Activation::QuickGelu, + } + } +} + +// ClipTextEmbeddings mostly based on the existing implementation in the stable diffision model. +// TODO rewrite to be more similar to https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L142 +#[derive(Clone, Debug)] +struct ClipTextEmbeddings { + token_embedding: candle_nn::Embedding, + position_embedding: candle_nn::Embedding, + position_ids: Tensor, +} + +impl ClipTextEmbeddings { + fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result { + let token_embedding = + candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?; + let position_embedding: nn::Embedding = candle_nn::embedding( + c.max_position_embeddings, + c.embed_dim, + vs.pp("position_embedding"), + )?; + let position_ids = + Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?; + Ok(ClipTextEmbeddings { + token_embedding, + position_embedding, + position_ids, + }) + } +} + +impl Module for ClipTextEmbeddings { + fn forward(&self, input_ids: &Tensor) -> Result { + let seq_length = input_ids.dim(D::Minus1)?; + let inputs_embeds = self.token_embedding.forward(input_ids)?; + let position_ids = self.position_ids.narrow(1, 0, seq_length)?; + let position_embedding = self.position_embedding.forward(&position_ids)?; + inputs_embeds.broadcast_add(&position_embedding) + } +} + +#[derive(Clone, Debug)] +struct ClipAttention { + k_proj: candle_nn::Linear, + v_proj: candle_nn::Linear, + q_proj: candle_nn::Linear, + out_proj: candle_nn::Linear, + head_dim: usize, + scale: f64, + num_attention_heads: usize, +} + +impl ClipAttention { + fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result { + let embed_dim = c.embed_dim(); + let num_attention_heads = c.num_attention_heads(); + let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?; + let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?; + let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?; + let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?; + let head_dim = embed_dim / num_attention_heads; + let scale = (head_dim as f64).powf(-0.5); + + Ok(ClipAttention { + k_proj, + v_proj, + q_proj, + out_proj, + head_dim, + scale, + num_attention_heads, + }) + } + + fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result { + xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous() + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result { + let in_dtype = xs.dtype(); + let (bsz, seq_len, embed_dim) = xs.dims3()?; + + let query_states = (self.q_proj.forward(xs)? * self.scale)?; + let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim); + let query_states = self + .shape(&query_states, seq_len, bsz)? + .reshape(proj_shape)? + .to_dtype(DType::F32)?; + let key_states = self + .shape(&self.k_proj.forward(xs)?, seq_len, bsz)? + .reshape(proj_shape)? + .to_dtype(DType::F32)?; + let value_states = self + .shape(&self.v_proj.forward(xs)?, seq_len, bsz)? + .reshape(proj_shape)? + .to_dtype(DType::F32)?; + let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; + + let src_len = key_states.dim(1)?; + + let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask { + attn_weights + .reshape((bsz, self.num_attention_heads, seq_len, src_len))? + .broadcast_add(causal_attention_mask)? + .reshape((bsz * self.num_attention_heads, seq_len, src_len))? + } else { + attn_weights + }; + + let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; + + let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?; + let attn_output = attn_output + .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))? + .transpose(1, 2)? + .reshape((bsz, seq_len, embed_dim))?; + self.out_proj.forward(&attn_output) + } +} + +#[derive(Clone, Debug)] +struct ClipMlp { + fc1: candle_nn::Linear, + fc2: candle_nn::Linear, + activation: Activation, +} + +impl ClipMlp { + fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result { + let fc1 = candle_nn::linear(c.embed_dim(), c.intermediate_size(), vs.pp("fc1"))?; + let fc2 = candle_nn::linear(c.intermediate_size(), c.embed_dim(), vs.pp("fc2"))?; + + Ok(ClipMlp { + fc1, + fc2, + activation: c.activation(), + }) + } +} + +impl ClipMlp { + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.fc1.forward(xs)?; + self.fc2.forward(&self.activation.forward(&xs)?) + } +} + +#[derive(Clone, Debug)] +struct ClipEncoderLayer { + self_attn: ClipAttention, + layer_norm1: candle_nn::LayerNorm, + mlp: ClipMlp, + layer_norm2: candle_nn::LayerNorm, +} + +impl ClipEncoderLayer { + fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result { + let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?; + let layer_norm1 = candle_nn::layer_norm(c.embed_dim(), 1e-5, vs.pp("layer_norm1"))?; + let mlp = ClipMlp::new(vs.pp("mlp"), c)?; + let layer_norm2 = candle_nn::layer_norm(c.embed_dim(), 1e-5, vs.pp("layer_norm2"))?; + + Ok(ClipEncoderLayer { + self_attn, + layer_norm1, + mlp, + layer_norm2, + }) + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result { + let residual = xs; + let xs = self.layer_norm1.forward(xs)?; + let xs = self.self_attn.forward(&xs, causal_attention_mask)?; + let xs = (xs + residual)?; + + let residual = &xs; + let xs = self.layer_norm2.forward(&xs)?; + let xs = self.mlp.forward(&xs)?; + xs + residual + } +} + +#[derive(Clone, Debug)] +pub struct ClipEncoder { + layers: Vec, +} + +impl ClipEncoder { + pub fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result { + let vs = vs.pp("layers"); + let mut layers: Vec = Vec::new(); + for index in 0..c.num_hidden_layers() { + let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?; + layers.push(layer) + } + Ok(ClipEncoder { layers }) + } + + pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs, causal_attention_mask)?; + } + Ok(xs) + } +} + +/// A CLIP transformer based model. +#[derive(Clone, Debug)] +pub struct ClipTextTransformer { + embeddings: ClipTextEmbeddings, + encoder: ClipEncoder, + final_layer_norm: candle_nn::LayerNorm, +} + +impl ClipTextTransformer { + pub fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result { + let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?; + let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Text(c.clone()))?; + let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?; + Ok(ClipTextTransformer { + embeddings, + encoder, + final_layer_norm, + }) + } + + // TODO: rewrrite to newer version + fn build_causal_attention_mask( + bsz: usize, + seq_len: usize, + mask_after: usize, + device: &Device, + ) -> Result { + let mask: Vec<_> = (0..seq_len) + .flat_map(|i| { + (0..seq_len).map(move |j| { + if j > i || j > mask_after { + f32::MIN + } else { + 0. + } + }) + }) + .collect(); + let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?; + mask.broadcast_as((bsz, 1, seq_len, seq_len)) + } + + pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result { + let (bsz, seq_len) = input_ids.dims2()?; + let input_ids = self.embeddings.forward(input_ids)?; + let causal_attention_mask = + Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?; + let input_ids = self + .encoder + .forward(&input_ids, Some(&causal_attention_mask))?; + self.final_layer_norm.forward(&input_ids) + } +} + +impl Module for ClipTextTransformer { + fn forward(&self, input_ids: &Tensor) -> Result { + let output = self.forward_with_mask(input_ids, usize::MAX)?; + let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?; + + let mut indices = Vec::new(); + for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::()?.iter().enumerate() { + let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?; + indices.push(index); + } + Tensor::cat(&indices, 0) + } +} diff --git a/candle-transformers/src/models/clip/vision_model.rs b/candle-transformers/src/models/clip/vision_model.rs new file mode 100644 index 00000000..88992434 --- /dev/null +++ b/candle-transformers/src/models/clip/vision_model.rs @@ -0,0 +1,147 @@ +//! Contrastive Language-Image Pre-Training +//! +//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! https://github.com/openai/CLIP +//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip + +use candle::{IndexOp, Result, Shape, Tensor, D}; +use candle_nn as nn; +use candle_nn::Module; +use nn::Conv2dConfig; + +use super::{ + text_model::{Activation, ClipEncoder}, + EncoderConfig, +}; + +#[derive(Debug, Clone)] +pub struct ClipVisionConfig { + pub embed_dim: usize, + pub activation: Activation, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + #[allow(dead_code)] + pub projection_dim: usize, + pub num_channels: usize, + pub image_size: usize, + pub patch_size: usize, +} + +impl ClipVisionConfig { + // The config details can be found in the "vision_config" section of this json file: + // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json + pub fn vit_base_patch32() -> Self { + Self { + embed_dim: 768, + activation: Activation::QuickGelu, + intermediate_size: 3072, + num_hidden_layers: 12, + num_attention_heads: 12, + projection_dim: 512, + num_channels: 3, + image_size: 224, + patch_size: 32, + } + } +} + +// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112 +#[derive(Clone, Debug)] +struct ClipVisionEmbeddings { + patch_embedding: candle_nn::Conv2d, + position_ids: Tensor, + class_embedding: Tensor, + position_embedding: candle_nn::Embedding, +} + +impl ClipVisionEmbeddings { + fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result { + // originally nn.Parameter + let class_embedding = if vs.contains_tensor("class_embedding") { + vs.get(c.embed_dim, "class_embedding")? + } else { + Tensor::randn(0f32, 1f32, c.embed_dim, vs.device())? + }; + + let num_patches = (c.image_size / c.patch_size).pow(2); + let num_positions = num_patches + 1; + let position_ids = Tensor::arange(0, num_positions as i64, vs.device())?; + + let conv2dconfig = Conv2dConfig { + stride: c.patch_size, + ..Default::default() + }; + let position_embedding = + candle_nn::embedding(num_positions, c.embed_dim, vs.pp("position_embedding"))?; + let patch_embedding = candle_nn::conv2d_no_bias( + c.num_channels, + c.embed_dim, + c.patch_size, + conv2dconfig, + vs.pp("patch_embedding"), + )?; + Ok(Self { + patch_embedding, + position_ids, + class_embedding, + position_embedding, + }) + } +} + +impl Module for ClipVisionEmbeddings { + fn forward(&self, pixel_values: &Tensor) -> Result { + let batch_size = pixel_values.shape().dims(); + let patch_embeds = self + .patch_embedding + .forward(pixel_values)? + .flatten_from(2)? + .transpose(1, 2)?; + let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?)); + let class_embeds = self.class_embedding.expand(shape)?; + let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?; + let position_embedding = self.position_embedding.forward(&self.position_ids)?; + embeddings.broadcast_add(&position_embedding) + } +} + +// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L743 +#[derive(Clone, Debug)] +pub struct ClipVisionTransformer { + embeddings: ClipVisionEmbeddings, + encoder: ClipEncoder, + pre_layer_norm: candle_nn::LayerNorm, + final_layer_norm: candle_nn::LayerNorm, +} + +impl ClipVisionTransformer { + pub fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result { + let embeddings = ClipVisionEmbeddings::new(vs.pp("embeddings"), c)?; + let pre_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("pre_layrnorm"))?; + let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Vision(c.clone()))?; + let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("post_layernorm"))?; + Ok(Self { + embeddings, + encoder, + final_layer_norm, + pre_layer_norm, + }) + } +} + +impl Module for ClipVisionTransformer { + fn forward(&self, pixel_values: &Tensor) -> Result { + let hidden_states = pixel_values + .apply(&self.embeddings)? + .apply(&self.pre_layer_norm)?; + + let encoder_outputs = self.encoder.forward(&hidden_states, None)?; + // https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787 + // pooled_output = encoder_outputs[:, 0, :] + let pooled_output = encoder_outputs.i((.., 0, ..))?; + self.final_layer_norm.forward(&pooled_output) + } +} diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs index 0edc8494..757aa88a 100644 --- a/candle-transformers/src/models/dinov2.rs +++ b/candle-transformers/src/models/dinov2.rs @@ -52,8 +52,8 @@ impl Module for Attention { .transpose(0, 1)? // 20134 .transpose(2, 3)?; // 20314 let q = (qkv.i(0)? * self.scale)?; - let k = qkv.i(1)?; - let v = qkv.i(2)?; + let k = qkv.i(1)?.contiguous()?; + let v = qkv.i(2)?.contiguous()?; let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?; let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?; self.proj.forward(&attn) diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs index 86cf8451..e9d4af7e 100644 --- a/candle-transformers/src/models/falcon.rs +++ b/candle-transformers/src/models/falcon.rs @@ -1,5 +1,6 @@ use candle::{DType, Device, Result, Tensor, D}; use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder}; +use serde::Deserialize; const MAX_SEQ_LEN: usize = 5000; @@ -18,7 +19,7 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { } // https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py -#[derive(Debug)] +#[derive(Clone, Debug, Deserialize)] pub struct Config { pub vocab_size: usize, pub hidden_size: usize, @@ -178,7 +179,9 @@ impl FalconRotaryEmbedding { fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let on_true = Tensor::new(on_true, on_false.device())? + .to_dtype(on_false.dtype())? + .broadcast_as(shape.dims())?; let m = mask.where_cond(&on_true, on_false)?; Ok(m) } @@ -247,7 +250,7 @@ impl FalconAttention { } } - fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result { + fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result { let fused_qkv = self.query_key_value.forward(x)?; let head_dim = self.head_dim; let (query, key, value) = self.split_heads(&fused_qkv)?; @@ -267,7 +270,6 @@ impl FalconAttention { (query, key) }; let (mut key, mut value) = (key, value); - let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?; if self.use_cache { if let Some((cache_k, cache_v)) = &self.kv_cache { // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for @@ -293,13 +295,18 @@ impl FalconAttention { // Only handle the case where alibi is None here, and non-flash attention. let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?; - let attention_scores = candle_nn::ops::softmax( - &attention_scores - .broadcast_add(&mask.squeeze(1)?)? - .to_dtype(DType::F32)?, - D::Minus1, - )? - .to_dtype(x.dtype())?; + let attention_scores = match mask { + None => attention_scores, + Some(mask) => { + let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)? + .to_dtype(query.dtype())?; + attention_scores.broadcast_add(&mask.squeeze(1)?)? + } + }; + + let attention_scores = + candle_nn::ops::softmax(&attention_scores.to_dtype(DType::F32)?, D::Minus1)? + .to_dtype(x.dtype())?; let attn_output = attention_scores .matmul(&value)? .reshape((b_sz, self.num_heads, seq_len, head_dim))? @@ -372,7 +379,7 @@ impl FalconDecoderLayer { }) } - fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result { + fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result { let residual = x.clone(); let ln_attn = self.inp_layernorm.forward(x)?; let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?; @@ -457,9 +464,13 @@ impl Falcon { Some((k, _)) => k.dim(1)?, None => 0, }; - let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?; + let causal_mask = if seq_len <= 1 { + None + } else { + Some(prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?) + }; for block in self.blocks.iter_mut() { - hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?; + hidden_state = block.forward(&hidden_state, causal_mask.as_ref(), past_kv_len)?; } let hidden_state = self.ln_f.forward(&hidden_state)?; let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?; diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index 282d5eb2..15e4dccb 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use candle::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::{linear_b as linear, Linear, VarBuilder}; +use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder}; fn default_max_position_embeddings() -> usize { 4096 @@ -11,7 +11,9 @@ fn default_max_position_embeddings() -> usize { pub struct Config { pub attention_bias: bool, pub head_dim: usize, - pub hidden_act: candle_nn::Activation, + // The code gemma configs include both hidden_act and hidden_activation. + pub hidden_act: Option, + pub hidden_activation: Option, pub hidden_size: usize, pub intermediate_size: usize, pub num_attention_heads: usize, @@ -25,6 +27,16 @@ pub struct Config { pub max_position_embeddings: usize, } +impl Config { + fn hidden_act(&self) -> Result { + match (self.hidden_act, self.hidden_activation) { + (None, Some(act)) | (Some(act), None) => Ok(act), + (Some(_), Some(_)) => candle::bail!("both hidden_act and hidden_activation are set"), + (None, None) => candle::bail!("none of hidden_act and hidden_activation are set"), + } + } +} + #[derive(Debug, Clone)] struct RmsNorm { weight: Tensor, @@ -126,7 +138,7 @@ impl MLP { gate_proj, up_proj, down_proj, - act_fn: cfg.hidden_act, + act_fn: cfg.hidden_act()?, }) } } diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index a091d3eb..73671cdc 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,4 +1,4 @@ -use super::with_tracing::{linear_no_bias as linear, Linear}; +use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; use std::collections::HashMap; @@ -133,25 +133,6 @@ impl Cache { } } -#[derive(Debug, Clone)] -struct RmsNorm { - inner: candle_nn::RmsNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn load(size: usize, eps: f64, vb: VarBuilder) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let inner = candle_nn::rms_norm(size, eps, vb)?; - Ok(Self { inner, span }) - } - - fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - #[derive(Debug, Clone)] struct CausalSelfAttention { q_proj: Linear, @@ -259,8 +240,12 @@ impl CausalSelfAttention { let k = k.to_dtype(DType::F32)?; let v = v.to_dtype(DType::F32)?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = if seq_len == 1 { + att + } else { + let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; + masked_fill(&att, &mask, f32::NEG_INFINITY)? + }; let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)? @@ -377,8 +362,8 @@ impl Block { let span = tracing::span!(tracing::Level::TRACE, "block"); let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?; - let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; - let rms_2 = RmsNorm::load( + let rms_1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let rms_2 = RmsNorm::new( cfg.hidden_size, cfg.rms_norm_eps, vb.pp("post_attention_layernorm"), @@ -409,7 +394,7 @@ impl Llama { x = block.forward(&x, index_pos, block_idx, cache)?; } let x = self.ln_f.forward(&x)?; - let x = x.i((.., seq_len - 1, ..))?; + let x = x.i((.., seq_len - 1, ..))?.contiguous()?; let logits = self.lm_head.forward(&x)?; logits.to_dtype(DType::F32) } @@ -417,7 +402,7 @@ impl Llama { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; - let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; + let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.num_hidden_layers) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg).unwrap()) .collect(); diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index 7b4f120b..bba8b666 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -194,8 +194,12 @@ impl CausalSelfAttention { let v = v.transpose(1, 2)?.contiguous()?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = if seq_len <= 1 { + att + } else { + let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; + masked_fill(&att, &mask, f32::NEG_INFINITY)? + }; let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; diff --git a/candle-transformers/src/models/mamba.rs b/candle-transformers/src/models/mamba.rs index 81828ad5..a75ee87a 100644 --- a/candle-transformers/src/models/mamba.rs +++ b/candle-transformers/src/models/mamba.rs @@ -1,4 +1,3 @@ -#![allow(unused)] /// A fast implementation of mamba for inference only. /// This is based on: https://github.com/LaurentMazare/mamba.rs use crate::models::with_tracing::{linear, linear_no_bias, Linear}; @@ -10,10 +9,10 @@ const D_STATE: usize = 16; #[derive(Debug, Clone, serde::Deserialize)] pub struct Config { - d_model: usize, - n_layer: usize, - vocab_size: usize, - pad_vocab_size_multiple: usize, + pub d_model: usize, + pub n_layer: usize, + pub vocab_size: usize, + pub pad_vocab_size_multiple: usize, } impl Config { @@ -38,12 +37,12 @@ pub struct State { } impl State { - pub fn new(batch_size: usize, cfg: &Config, device: &Device) -> Result { + pub fn new(batch_size: usize, cfg: &Config, dtype: DType, device: &Device) -> Result { let mut hs = Vec::with_capacity(cfg.n_layer); let mut prev_xs = Vec::with_capacity(cfg.n_layer); for _i in 0..cfg.n_layer { - let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), DType::F32, device)?; - let x = Tensor::zeros((batch_size, cfg.d_inner()), DType::F32, device)?; + let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), dtype, device)?; + let x = Tensor::zeros((batch_size, cfg.d_inner()), dtype, device)?; hs.push(h); prev_xs.push([x.clone(), x.clone(), x.clone(), x.clone()]); } @@ -121,15 +120,15 @@ impl MambaBlock { // Algorithm 3.2 on page 6, https://arxiv.org/pdf/2312.00752.pdf let x_proj = self.x_proj.forward(&proj_for_conv)?; - let delta = x_proj.narrow(D::Minus1, 0, self.dt_rank)?; + let delta = x_proj.narrow(D::Minus1, 0, self.dt_rank)?.contiguous()?; let b = x_proj.narrow(D::Minus1, self.dt_rank, D_STATE)?; let c = x_proj.narrow(D::Minus1, self.dt_rank + D_STATE, D_STATE)?; let delta = delta.apply(&self.dt_proj)?; // softplus let delta = (delta.exp()? + 1.)?.log()?; - let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?; - let d = self.d.to_dtype(candle::DType::F32)?; + let a = self.a_log.to_dtype(delta.dtype())?.exp()?.neg()?; + let d = self.d.to_dtype(delta.dtype())?; // Selective scan part // Eqn (2a), page 3, h_t = Ab h_{t-1} + Bb x_t @@ -178,6 +177,7 @@ pub struct Model { layers: Vec, norm_f: RmsNorm, lm_head: Linear, + dtype: DType, } impl Model { @@ -196,6 +196,7 @@ impl Model { layers, norm_f, lm_head, + dtype: vb.dtype(), }) } @@ -208,4 +209,8 @@ impl Model { state.pos += 1; xs.apply(&self.norm_f)?.apply(&self.lm_head) } + + pub fn dtype(&self) -> DType { + self.dtype + } } diff --git a/candle-transformers/src/models/metavoice.rs b/candle-transformers/src/models/metavoice.rs index 35cb30c7..43de594f 100644 --- a/candle-transformers/src/models/metavoice.rs +++ b/candle-transformers/src/models/metavoice.rs @@ -2,7 +2,7 @@ use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D}; use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; // Equivalent to torch.repeat_interleave -fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result { +pub(crate) fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result { let img = img.unsqueeze(dim + 1)?; let mut dims = img.dims().to_vec(); dims[dim + 1] = repeats; @@ -181,6 +181,7 @@ pub mod tokenizers { pub end_of_text: usize, pub offset: usize, pub ranks: HashMap, Rank>, + span: tracing::Span, } impl BPE { @@ -231,6 +232,7 @@ pub mod tokenizers { end_of_text, offset, ranks, + span: tracing::span!(tracing::Level::TRACE, "bpe"), }) } @@ -310,6 +312,7 @@ pub mod tokenizers { } pub fn encode(&self, text: &str) -> Result> { + let _enter = self.span.enter(); let mut bpe_tokens: Vec = Vec::new(); for word in self.re.find_iter(text) { let word = word.map_err(E::wrap)?; @@ -426,6 +429,7 @@ pub mod gpt { c_attn: Linear, c_proj: Linear, n_head: usize, + span: tracing::Span, } impl SelfAttention { @@ -444,12 +448,14 @@ pub mod gpt { c_attn, c_proj, n_head: cfg.n_head, + span: tracing::span!(tracing::Level::TRACE, "self-attn"), }) } } impl Module for SelfAttention { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let (b, t, c) = xs.dims3()?; let c_x = xs .apply(&self.c_attn)? @@ -474,11 +480,13 @@ pub mod gpt { Gelu { c_fc: Linear, c_proj: Linear, + span: tracing::Span, }, Swiglu { w1: Linear, w3: Linear, c_proj: Linear, + span: tracing::Span, }, } @@ -489,7 +497,11 @@ pub mod gpt { NonLinearityType::Gelu => { let c_fc = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("c_fc"))?; let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?; - Self::Gelu { c_fc, c_proj } + Self::Gelu { + c_fc, + c_proj, + span: tracing::span!(tracing::Level::TRACE, "mlp-gelu"), + } } NonLinearityType::Swiglu => { let hidden_dim = (2 * hidden_dim) / 3; @@ -502,7 +514,12 @@ pub mod gpt { let w1 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w1"))?; let w3 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w3"))?; let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?; - Self::Swiglu { w1, w3, c_proj } + Self::Swiglu { + w1, + w3, + c_proj, + span: tracing::span!(tracing::Level::TRACE, "mlp-swiglu"), + } } }; Ok(slf) @@ -512,8 +529,17 @@ pub mod gpt { impl Module for MLP { fn forward(&self, xs: &Tensor) -> Result { match self { - Self::Gelu { c_fc, c_proj } => xs.apply(c_fc)?.gelu()?.apply(c_proj), - Self::Swiglu { w1, w3, c_proj } => { + Self::Gelu { c_fc, c_proj, span } => { + let _enter = span.enter(); + xs.apply(c_fc)?.gelu()?.apply(c_proj) + } + Self::Swiglu { + w1, + w3, + c_proj, + span, + } => { + let _enter = span.enter(); let w1 = xs.apply(w1)?; let w3 = xs.apply(w3)?; (w1.silu()? * w3)?.apply(c_proj) @@ -528,6 +554,7 @@ pub mod gpt { ln_2: Norm, attn: SelfAttention, mlp: MLP, + span: tracing::Span, } impl Block { @@ -541,12 +568,14 @@ pub mod gpt { ln_2, attn, mlp, + span: tracing::span!(tracing::Level::TRACE, "gpt-block"), }) } } impl Module for Block { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let xs = (xs + xs.apply(&self.ln_1)?.apply(&self.attn))?; let xs = (&xs + xs.apply(&self.ln_2)?.apply(&self.mlp))?; Ok(xs) @@ -563,6 +592,7 @@ pub mod gpt { lm_heads: Vec, cfg: Config, dtype: DType, + span: tracing::Span, } impl Model { @@ -598,6 +628,7 @@ pub mod gpt { lm_heads, cfg, dtype: vb.dtype(), + span: tracing::span!(tracing::Level::TRACE, "gpt"), }) } @@ -606,6 +637,7 @@ pub mod gpt { } pub fn forward(&self, idx: &Tensor) -> Result> { + let _enter = self.span.enter(); let device = idx.device(); let (b, _num_hierarchies, t) = idx.dims3()?; let pos = Tensor::arange(0u32, t as u32, device)?; @@ -664,15 +696,15 @@ pub mod transformer { } } - fn n_local_heads(&self) -> usize { + pub(crate) fn n_local_heads(&self) -> usize { self.n_local_heads.unwrap_or(self.n_head) } - fn head_dim(&self) -> usize { + pub(crate) fn head_dim(&self) -> usize { self.dim / self.n_head } - fn intermediate_size(&self) -> usize { + pub(crate) fn intermediate_size(&self) -> usize { match self.intermediate_size { Some(intermediate_size) => intermediate_size, None => { @@ -689,6 +721,7 @@ pub mod transformer { w1: Linear, w2: Linear, w3: Linear, + span: tracing::Span, } impl FeedForward { @@ -697,12 +730,18 @@ pub mod transformer { let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?; let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?; let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?; - Ok(Self { w1, w2, w3 }) + Ok(Self { + w1, + w2, + w3, + span: tracing::span!(tracing::Level::TRACE, "feed-forward"), + }) } } impl Module for FeedForward { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?; swiglu.apply(&self.w2) } @@ -718,6 +757,7 @@ pub mod transformer { head_dim: usize, n_head: usize, kv_cache: Option<(Tensor, Tensor)>, + span: tracing::Span, } impl Attention { @@ -736,10 +776,12 @@ pub mod transformer { head_dim, n_head: cfg.n_head, kv_cache: None, + span: tracing::span!(tracing::Level::TRACE, "feed-forward"), }) } fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result { + let _enter = self.span.enter(); let (b_sz, seqlen, _) = xs.dims3()?; let qkv = xs.apply(&self.wqkv)?; @@ -793,6 +835,7 @@ pub mod transformer { feed_forward: FeedForward, ffn_norm: RmsNorm, attention_norm: RmsNorm, + span: tracing::Span, } impl Block { @@ -806,10 +849,12 @@ pub mod transformer { feed_forward, ffn_norm, attention_norm, + span: tracing::span!(tracing::Level::TRACE, "block"), }) } fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result { + let _enter = self.span.enter(); let hs = xs.apply(&self.attention_norm)?; let hs = (xs + self.attention.forward(&hs, pos, mask))?; &hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward) @@ -829,6 +874,7 @@ pub mod transformer { norm: RmsNorm, output: Linear, spk_cond_mask: Tensor, + span: tracing::Span, } impl Model { @@ -865,6 +911,7 @@ pub mod transformer { norm, output, spk_cond_mask, + span: tracing::span!(tracing::Level::TRACE, "transformer"), }) } @@ -875,6 +922,7 @@ pub mod transformer { } pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result { + let _enter = self.span.enter(); let (_b_sz, seqlen) = xs.dims2()?; let mask: Vec<_> = (0..seqlen) .flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) @@ -905,14 +953,19 @@ pub mod adapters { // https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/tilted_encodec.py pub struct TiltedEncodec { end_of_audio_token: u32, + span: tracing::Span, } impl TiltedEncodec { pub fn new(end_of_audio_token: u32) -> Self { - Self { end_of_audio_token } + Self { + end_of_audio_token, + span: tracing::span!(tracing::Level::TRACE, "tilted-encodec"), + } } pub fn decode(&self, tokens: &[Vec]) -> (Vec, Vec>) { + let _enter = self.span.enter(); let mut text_ids = vec![]; let mut extracted_audio_ids = vec![]; let mut min_audio_ids_len = usize::MAX; @@ -941,14 +994,19 @@ pub mod adapters { // https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/flattened_encodec.py#L4 pub struct FlattenedInterleavedEncodec2Codebook { end_of_audio_token: u32, + span: tracing::Span, } impl FlattenedInterleavedEncodec2Codebook { pub fn new(end_of_audio_token: u32) -> Self { - Self { end_of_audio_token } + Self { + end_of_audio_token, + span: tracing::span!(tracing::Level::TRACE, "encodec2codebook"), + } } pub fn decode(&self, tokens: &[u32]) -> (Vec, Vec, Vec) { + let _enter = self.span.enter(); let mut text_ids = vec![]; let mut audio_ids1 = vec![]; let mut audio_ids2 = vec![]; diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index 2809ae0a..d899c712 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -1,23 +1,28 @@ -use crate::models::with_tracing::{linear_no_bias, Linear}; +use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm}; /// Mistral LLM, https://github.com/mistralai/mistral-src use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; -#[derive(Debug, Clone, PartialEq)] +fn default_use_flash_attn() -> bool { + false +} + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] pub struct Config { - pub(crate) vocab_size: usize, - pub(crate) hidden_size: usize, - pub(crate) intermediate_size: usize, - pub(crate) num_hidden_layers: usize, - pub(crate) num_attention_heads: usize, - pub(crate) num_key_value_heads: usize, - pub(crate) hidden_act: Activation, - pub(crate) max_position_embeddings: usize, - pub(crate) rms_norm_eps: f64, - pub(crate) rope_theta: f64, - pub(crate) sliding_window: usize, - pub(crate) use_flash_attn: bool, + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub hidden_act: Activation, + pub max_position_embeddings: usize, + pub rms_norm_eps: f64, + pub rope_theta: f64, + pub sliding_window: Option, + #[serde(default = "default_use_flash_attn")] + pub use_flash_attn: bool, } impl Config { @@ -34,7 +39,7 @@ impl Config { max_position_embeddings: 32768, rms_norm_eps: 1e-5, rope_theta: 10_000., - sliding_window: 4096, + sliding_window: Some(4096), use_flash_attn, } } @@ -53,7 +58,7 @@ impl Config { max_position_embeddings: 32768, rms_norm_eps: 1e-5, rope_theta: 10_000., - sliding_window: 4096, + sliding_window: Some(4096), use_flash_attn, } } @@ -71,53 +76,26 @@ impl Config { max_position_embeddings: 32768, rms_norm_eps: 1e-5, rope_theta: 10_000., - sliding_window: 4096, + sliding_window: Some(4096), use_flash_attn, } } } -#[derive(Debug, Clone)] -struct RmsNorm { - inner: candle_nn::RmsNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let inner = candle_nn::rms_norm(size, eps, vb)?; - Ok(Self { inner, span }) - } -} - -impl Module for RmsNorm { - fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - #[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, cos: Tensor, } -fn rotate_half(xs: &Tensor) -> Result { - let last_dim = xs.dim(D::Minus1)?; - let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; - let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; - Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) -} - impl RotaryEmbedding { fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let rope_theta = cfg.rope_theta as f32; let dim = cfg.hidden_size / cfg.num_attention_heads; let max_seq_len = cfg.max_position_embeddings; let inv_freq: Vec<_> = (0..dim) .step_by(2) - .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) + .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32)) .collect(); let inv_freq_len = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; @@ -125,7 +103,6 @@ impl RotaryEmbedding { .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, @@ -141,10 +118,8 @@ impl RotaryEmbedding { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; - let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; - let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?; Ok((q_embed, k_embed)) } } @@ -267,10 +242,12 @@ impl Attention { let query_states = query_states .reshape((b_sz, q_len, self.num_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let key_states = key_states .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let value_states = value_states .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; @@ -374,7 +351,7 @@ pub struct Model { layers: Vec, norm: RmsNorm, lm_head: Linear, - sliding_window: usize, + sliding_window: Option, device: Device, dtype: DType, } @@ -406,15 +383,14 @@ impl Model { fn prepare_decoder_attention_mask( &self, - b_size: usize, tgt_len: usize, seqlen_offset: usize, ) -> Result { - // Sliding window mask? + let sliding_window = self.sliding_window.unwrap_or(tgt_len + 1); let mask: Vec<_> = (0..tgt_len) .flat_map(|i| { (0..tgt_len).map(move |j| { - if i < j || j + self.sliding_window < i { + if i < j || j + sliding_window < i { f32::NEG_INFINITY } else { 0. @@ -429,16 +405,16 @@ impl Model { } else { mask }; - mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))? .to_dtype(self.dtype) } pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result { - let (b_size, seq_len) = input_ids.dims2()?; + let (_b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { None } else { - let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; + let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?; Some(mask) }; let mut xs = self.embed_tokens.forward(input_ids)?; diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index f7eb0abe..700829e3 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -126,18 +126,11 @@ impl Module for Embedding { } } -fn get_mask(size: usize, device: &Device) -> Result { +fn get_mask(size: usize, dtype: DType, device: &Device) -> Result { let mask: Vec<_> = (0..size) - .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .flat_map(|i| (0..size).map(move |j| if j > i { f32::NEG_INFINITY } else { 0. })) .collect(); - Tensor::from_slice(&mask, (size, size), device) -} - -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) + Tensor::from_slice(&mask, (size, size), device)?.to_dtype(dtype) } #[derive(Debug, Clone)] @@ -147,7 +140,7 @@ struct RotaryEmbedding { } impl RotaryEmbedding { - fn new(dim: usize, max_seq_len: usize, dev: &Device) -> Result { + fn new(dim: usize, max_seq_len: usize, dtype: DType, dev: &Device) -> Result { let inv_freq: Vec<_> = (0..dim) .step_by(2) .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) @@ -159,8 +152,8 @@ impl RotaryEmbedding { .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; Ok(Self { - sin: freqs.sin()?, - cos: freqs.cos()?, + sin: freqs.sin()?.to_dtype(dtype)?, + cos: freqs.cos()?.to_dtype(dtype)?, }) } @@ -175,30 +168,14 @@ impl RotaryEmbedding { } let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?; let rotary_dim = rotary_dim * 2; - let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?; + let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?.contiguous()?; let q_pass = qkv.i((.., .., 0, .., rotary_dim..))?; - let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?; + let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?.contiguous()?; let k_pass = qkv.i((.., .., 1, .., rotary_dim..))?; - let q12 = q_rot.chunk(2, D::Minus1)?; - let k12 = k_rot.chunk(2, D::Minus1)?; - let (q1, q2) = (&q12[0], &q12[1]); - let (k1, k2) = (&k12[0], &k12[1]); - let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?; - let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?; - let q_rot = Tensor::cat( - &[ - (q1.broadcast_mul(&c)? - q2.broadcast_mul(&s)?)?, - (q1.broadcast_mul(&s)? + q2.broadcast_mul(&c)?)?, - ], - D::Minus1, - )?; - let k_rot = Tensor::cat( - &[ - (k1.broadcast_mul(&c)? - k2.broadcast_mul(&s)?)?, - (k1.broadcast_mul(&s)? + k2.broadcast_mul(&c)?)?, - ], - D::Minus1, - )?; + let c = self.cos.narrow(0, seqlen_offset, seqlen)?; + let s = self.sin.narrow(0, seqlen_offset, seqlen)?; + let q_rot = candle_nn::rotary_emb::rope_thd(&q_rot, &c, &s)?; + let k_rot = candle_nn::rotary_emb::rope_thd(&k_rot, &c, &s)?; let q = Tensor::cat(&[&q_rot, &q_pass], D::Minus1)?; let k = Tensor::cat(&[&k_rot, &k_pass], D::Minus1)?; let v = qkv.i((.., .., 2))?; @@ -212,6 +189,7 @@ struct MLP { fc1: Linear, fc2: Linear, act: Activation, + span: tracing::Span, } impl MLP { @@ -223,12 +201,14 @@ impl MLP { fc1, fc2, act: cfg.activation_function, + span: tracing::span!(tracing::Level::TRACE, "mlp"), }) } } impl Module for MLP { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2) } } @@ -263,9 +243,11 @@ struct MHA { rotary_emb: RotaryEmbedding, kv_cache: Option<(Tensor, Tensor)>, head_dim: usize, - n_head: usize, softmax_scale: f64, span: tracing::Span, + span_rope: tracing::Span, + span_mask: tracing::Span, + span_softmax: tracing::Span, } impl MHA { @@ -274,17 +256,20 @@ impl MHA { let op_size = cfg.n_embd; let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?; let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?; - let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?; + let rotary_emb = + RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.dtype(), vb.device())?; let softmax_scale = 1f64 / (head_dim as f64).sqrt(); Ok(Self { wqkv, out_proj, head_dim, - n_head: cfg.n_head, kv_cache: None, rotary_emb, softmax_scale, span: tracing::span!(tracing::Level::TRACE, "mha"), + span_rope: tracing::span!(tracing::Level::TRACE, "rope"), + span_mask: tracing::span!(tracing::Level::TRACE, "mask"), + span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"), }) } @@ -300,7 +285,10 @@ impl MHA { Some((prev_k, _)) => prev_k.dim(1)?, }; // In the python implementation, a single tensor is returned with the third axis of size 3. - let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?; + let (q, k, v) = { + let _enter = self.span_rope.enter(); + self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)? + }; let (k, v) = match &self.kv_cache { None => (k, v), Some((prev_k, prev_v)) => { @@ -320,13 +308,15 @@ impl MHA { // scores = scores + causal_mask.to(dtype=scores.dtype) let attn_weights = match mask { None => attn_weights, - Some(mask) => masked_fill( - &attn_weights, - &mask.broadcast_left(b_size * self.n_head)?, - f32::NEG_INFINITY, - )?, + Some(mask) => { + let _enter = self.span_mask.enter(); + attn_weights.broadcast_add(mask)? + } + }; + let attn_weights = { + let _enter = self.span_softmax.enter(); + candle_nn::ops::softmax_last_dim(&attn_weights)? }; - let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; // output = torch.einsum('bhts,bshd->bthd', attention_drop, v) // attn_weights: b*h,t,s, v: b*h,s,d @@ -430,7 +420,7 @@ impl MixFormerSequentialForCausalLM { let mask = if seq_len <= 1 { None } else { - Some(get_mask(seq_len, xs.device())?) + Some(get_mask(seq_len, xs.dtype(), xs.device())?) }; for block in self.blocks.iter_mut() { xs = block.forward(&xs, mask.as_ref())? @@ -438,6 +428,30 @@ impl MixFormerSequentialForCausalLM { xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1) } + pub fn forward_with_img( + &mut self, + bos_token: &Tensor, + xs: &Tensor, + img_embeds: &Tensor, + ) -> Result { + let _enter = self.span.enter(); + let xs = xs.apply(&self.embedding)?; + let bos_token = bos_token.apply(&self.embedding)?; + // Python implementation sequence order is + // https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L43-L56 + let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?; + let (_b_size, seq_len, _embds) = xs.dims3()?; + let mask = Some(get_mask(seq_len, xs.dtype(), xs.device())?); + for block in self.blocks.iter_mut() { + xs = block.forward(&xs, mask.as_ref())? + } + let xs = xs + .narrow(1, seq_len - 1, 1)? + .apply(&self.head)? + .squeeze(1)?; + Ok(xs) + } + pub fn clear_kv_cache(&mut self) { self.blocks.iter_mut().for_each(|b| b.clear_kv_cache()) } diff --git a/candle-transformers/src/models/mixtral.rs b/candle-transformers/src/models/mixtral.rs index ede74d3f..f69c68e3 100644 --- a/candle-transformers/src/models/mixtral.rs +++ b/candle-transformers/src/models/mixtral.rs @@ -1,4 +1,4 @@ -use crate::models::with_tracing::{linear_no_bias, Linear}; +use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm}; /// Mixtral Model /// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py /// https://mistral.ai/news/mixtral-of-experts/ @@ -48,27 +48,6 @@ impl Config { } } -#[derive(Debug, Clone)] -struct RmsNorm { - inner: candle_nn::RmsNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let inner = candle_nn::rms_norm(size, eps, vb)?; - Ok(Self { inner, span }) - } -} - -impl Module for RmsNorm { - fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - #[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 66e06e0e..3514e648 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -3,6 +3,7 @@ pub mod bigcode; pub mod blip; pub mod blip_text; pub mod chatglm; +pub mod clip; pub mod convmixer; pub mod convnext; pub mod dinov2; @@ -23,6 +24,7 @@ pub mod mistral; pub mod mixformer; pub mod mixtral; pub mod mobileone; +pub mod moondream; pub mod mpt; pub mod persimmon; pub mod phi; @@ -30,14 +32,17 @@ pub mod quantized_blip; pub mod quantized_blip_text; pub mod quantized_llama; pub mod quantized_llama2_c; +pub mod quantized_metavoice; pub mod quantized_mistral; pub mod quantized_mixformer; +pub mod quantized_moondream; pub mod quantized_mpt; pub mod quantized_rwkv_v5; pub mod quantized_rwkv_v6; pub mod quantized_stable_lm; pub mod quantized_t5; pub mod qwen2; +pub mod qwen2_moe; pub mod repvgg; pub mod resnet; pub mod rwkv_v5; diff --git a/candle-transformers/src/models/moondream.rs b/candle-transformers/src/models/moondream.rs new file mode 100644 index 00000000..7ad8c921 --- /dev/null +++ b/candle-transformers/src/models/moondream.rs @@ -0,0 +1,327 @@ +use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel}; +use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear}; +use candle::{IndexOp, Module, Result, Tensor, D}; +use candle_nn::VarBuilder; + +pub struct Config { + pub phi_config: PhiConfig, + pub vision_config: VisionConfig, +} + +impl Config { + pub fn v2() -> Self { + Self { + phi_config: PhiConfig::v1_5(), + vision_config: VisionConfig::v2(), + } + } +} + +fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result { + let dim = q.dim(D::Minus1)?; + let scale_factor = 1.0 / (dim as f64).sqrt(); + let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?; + candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v) +} + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +pub struct VisionConfig { + pub(crate) image_embedding_dim: usize, + pub(crate) model_dim: usize, + pub(crate) hidden_dim: usize, + pub(crate) hidden_features: usize, + pub(crate) embed_len: usize, + pub(crate) embed_dim: usize, + pub(crate) num_blocks: usize, + pub(crate) num_heads: usize, + pub(crate) act: candle_nn::Activation, +} + +impl VisionConfig { + pub fn v2() -> Self { + Self { + image_embedding_dim: 1152, + model_dim: 2048, + hidden_dim: 2048 * 4, + hidden_features: 4304, + embed_len: 729, + embed_dim: 1152, + num_blocks: 27, + num_heads: 16, + act: candle_nn::Activation::GeluPytorchTanh, + } + } +} + +#[derive(Debug, Clone)] +struct LinearPatchEmbedding { + linear: Linear, +} + +impl LinearPatchEmbedding { + fn new(vb: VarBuilder) -> Result { + let linear = linear_b(588, 1152, true, vb.pp("linear"))?; + Ok(Self { linear }) + } +} + +impl Module for LinearPatchEmbedding { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.linear) + } +} + +#[derive(Debug, Clone)] +struct Attention { + num_heads: usize, + head_dim: usize, + qkv: Linear, + proj: Linear, + span: tracing::Span, +} + +impl Attention { + pub fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result { + let qkv = linear_b(dim, dim * 3, true, vb.pp("qkv"))?; + let proj = linear_b(dim, dim, true, vb.pp("proj"))?; + Ok(Self { + num_heads, + head_dim: dim / num_heads, + qkv, + proj, + span: tracing::span!(tracing::Level::TRACE, "vit-attn"), + }) + } +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let (b, n, c) = xs.dims3()?; + let qkv = xs + .apply(&self.qkv)? + .reshape((b, n, 3, self.num_heads, self.head_dim))? + .permute((2, 0, 3, 1, 4))?; + let (q, k, v) = ( + qkv.i(0)?.contiguous()?, + qkv.i(1)?.contiguous()?, + qkv.i(2)?.contiguous()?, + ); + scaled_dot_product_attention(&q, &k, &v)? + .transpose(1, 2)? + .reshape((b, n, c))? + .apply(&self.proj) + } +} + +#[derive(Debug, Clone)] +struct VitBlock { + attn: Attention, + mlp: Mlp, + norm1: LayerNorm, + norm2: LayerNorm, + span: tracing::Span, +} + +impl VitBlock { + fn new(vb: VarBuilder, dim: usize, num_heads: usize, cfg: &VisionConfig) -> Result { + let attn = Attention::new(vb.pp("attn"), dim, num_heads)?; + let mlp = Mlp::new(vb.pp("mlp"), dim, cfg.hidden_features, dim, cfg.act)?; + let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; + let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; + Ok(Self { + attn, + mlp, + norm1, + norm2, + span: tracing::span!(tracing::Level::TRACE, "vit-block"), + }) + } +} + +impl Module for VitBlock { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let ys = xs.apply(&self.norm1)?.apply(&self.attn)?; + let xs = (xs + &ys)?; + let ys = xs.apply(&self.norm2)?.apply(&self.mlp)?; + let xs = (&xs + &ys)?; + Ok(xs) + } +} + +#[derive(Debug, Clone)] +struct VisionTransformer { + patch_embed: LinearPatchEmbedding, + pos_embed: Tensor, + blocks: Vec, + norm: LayerNorm, + span: tracing::Span, +} + +impl VisionTransformer { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let patch_embed = LinearPatchEmbedding::new(vb.pp("patch_embed"))?; + let pos_embed = vb.get((1, cfg.embed_len, cfg.embed_dim), "pos_embed")?; + let blocks = (0..cfg.num_blocks) + .map(|i| { + VitBlock::new( + vb.pp(&format!("blocks.{}", i)), + cfg.embed_dim, + cfg.num_heads, + cfg, + ) + }) + .collect::>()?; + let norm = layer_norm(cfg.embed_dim, 1e-5, vb.pp("norm"))?; + Ok(Self { + patch_embed, + pos_embed, + blocks, + norm, + span: tracing::span!(tracing::Level::TRACE, "vit"), + }) + } +} + +impl Module for VisionTransformer { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let mut xs = (&xs.apply(&self.patch_embed)? + &self.pos_embed)?; + for block in self.blocks.iter() { + xs = xs.apply(block)?; + } + xs.apply(&self.norm) + } +} + +#[derive(Debug, Clone)] +pub struct Encoder { + model: VisionTransformer, +} + +impl Encoder { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let model = VisionTransformer::new(cfg, vb.pp("model.visual"))?; + Ok(Self { model }) + } +} + +impl Module for Encoder { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.model) + } +} + +#[derive(Debug, Clone)] +struct Mlp { + fc1: Linear, + act: candle_nn::Activation, + fc2: Linear, + span: tracing::Span, +} + +impl Mlp { + fn new( + vb: VarBuilder, + in_features: usize, + hidden_features: usize, + out_features: usize, + act: candle_nn::Activation, + ) -> Result { + let fc1 = linear_b(in_features, hidden_features, true, vb.pp("fc1"))?; + let fc2 = linear_b(hidden_features, out_features, true, vb.pp("fc2"))?; + Ok(Self { + fc1, + act, + fc2, + span: tracing::span!(tracing::Level::TRACE, "mlp"), + }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2) + } +} + +#[derive(Debug, Clone)] +struct VisionProjection { + mlp: Mlp, +} + +impl VisionProjection { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let mlp = Mlp::new( + vb.pp("mlp"), + cfg.image_embedding_dim, + cfg.hidden_dim, + cfg.model_dim, + cfg.act, + )?; + Ok(Self { mlp }) + } +} + +impl Module for VisionProjection { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.mlp) + } +} + +#[derive(Debug, Clone)] +pub struct VisionEncoder { + encoder: Encoder, + projection: VisionProjection, +} + +impl VisionEncoder { + pub fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let encoder = Encoder::new(cfg, vb.pp("encoder"))?; + let projection = VisionProjection::new(cfg, vb.pp("projection"))?; + Ok(Self { + encoder, + projection, + }) + } +} + +impl Module for VisionEncoder { + fn forward(&self, xs: &Tensor) -> Result { + let (b, c, hp1, wp2) = xs.dims4()?; + let (p1, p2) = (14, 14); + let h = hp1 / p1; + let w = wp2 / p2; + xs.reshape((b, c, h, p1, h, p2))? + .permute((0, 2, 4, 1, 3, 5))? + .reshape((b, h * w, c * p1 * p2))? + .apply(&self.encoder)? + .apply(&self.projection) + } +} + +pub struct Model { + pub text_model: PhiModel, + pub vision_encoder: VisionEncoder, +} + +impl Model { + pub fn new(config: &Config, vb: VarBuilder) -> Result { + let text_model = PhiModel::new_v2(&config.phi_config, vb.pp("text_model"))?; + let vision_encoder = VisionEncoder::new(&config.vision_config, vb.pp("vision_encoder"))?; + Ok(Self { + text_model, + vision_encoder, + }) + } + + pub fn vision_encoder(&self) -> &VisionEncoder { + &self.vision_encoder + } + + pub fn text_model(&mut self) -> &mut PhiModel { + &mut self.text_model + } +} diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 94324149..e1519b2d 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -1,32 +1,13 @@ use std::collections::HashMap; +use crate::quantized_nn::RmsNorm; use candle::quantized::QTensor; use candle::quantized::{ggml_file, gguf_file}; -use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module}; pub const MAX_SEQ_LEN: usize = 4096; -#[derive(Debug, Clone)] -struct RmsNorm { - inner: candle_nn::LayerNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn new(scale: QTensor, eps: f32) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let scale = scale.dequantize(&scale.device())?; - let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64); - Ok(Self { inner, span }) - } - - fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - // QMatMul wrapper adding some tracing. #[derive(Debug, Clone)] struct QMatMul { @@ -173,34 +154,20 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result Result { let _enter = self.span_rot.enter(); - let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; - let cos = self - .cos - .narrow(0, index_pos, seq_len)? - .reshape((seq_len, n_embd / 2, 1))?; - let sin = self - .sin - .narrow(0, index_pos, seq_len)? - .reshape((seq_len, n_embd / 2, 1))?; - let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; - let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; - // This mimics the llama.cpp behavior. - // https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105 - // The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension. - // The resulting y0 and y1 are also interleaved with: - // y0 = x0*cos - x1*sin - // y1 = x0*sin + x1*cos - let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; - let x0 = x.narrow(D::Minus1, 0, 1)?; - let x1 = x.narrow(D::Minus1, 1, 1)?; - let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; - let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; - let rope = Tensor::cat(&[y0, y1], D::Minus1)?; - let rope = rope.flatten_from(D::Minus2)?; - Ok(rope) + let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?; + let cos = self.cos.narrow(0, index_pos, seq_len)?; + let sin = self.sin.narrow(0, index_pos, seq_len)?; + // The call to contiguous below is only necessary when processing the prompt. + // When the seq_len is 1 in the inference loop, this is a no-op. + candle_nn::rotary_emb::rope_i(&x.contiguous()?, &cos, &sin) } - fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result { + fn forward_attn( + &mut self, + x: &Tensor, + mask: Option<&Tensor>, + index_pos: usize, + ) -> Result { let _enter = self.span_attn.enter(); let (b_sz, seq_len, n_embd) = x.dims3()?; let q = self.attention_wq.forward(x)?; @@ -215,7 +182,11 @@ impl LayerWeights { .transpose(1, 2)?; let v = v .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + // This call to contiguous ensures that the fast kernel can be called below. It's + // actually a no-op except when processing the initial prompt so has no significant + // impact on performance. + .contiguous()?; let q = self.apply_rotary_emb(&q, index_pos)?; let k = self.apply_rotary_emb(&k, index_pos)?; @@ -226,8 +197,8 @@ impl LayerWeights { if index_pos == 0 { (k, v) } else { - let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?; - let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?; + let k = Tensor::cat(&[k_cache, &k], 2)?; + let v = Tensor::cat(&[v_cache, &v], 2)?; (k, v) } } @@ -239,8 +210,13 @@ impl LayerWeights { let v = self.repeat_kv(v)?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = mask.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, &self.neg_inf)?; + let att = match mask { + None => att, + Some(mask) => { + let mask = mask.broadcast_as(att.shape())?; + masked_fill(&att, &mask, &self.neg_inf)? + } + }; let att = candle_nn::ops::softmax_last_dim(&att)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; @@ -301,7 +277,7 @@ impl ModelWeights { let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?; let tok_embeddings = ct.remove("tok_embeddings.weight")?; let tok_embeddings = tok_embeddings.dequantize(&ct.device)?; - let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; + let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?; let output = ct.remove("output.weight")?; let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); for layer_idx in 0..ct.hparams.n_layer { @@ -330,9 +306,9 @@ impl ModelWeights { attention_wk: QMatMul::from_qtensor(attention_wk)?, attention_wv: QMatMul::from_qtensor(attention_wv)?, attention_wo: QMatMul::from_qtensor(attention_wo)?, - attention_norm: RmsNorm::new(attention_norm, 1e-5)?, + attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?, mlp_or_moe, - ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?, + ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?, n_head: ct.hparams.n_head as usize, n_kv_head: ct.hparams.n_head as usize / gqa, head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, @@ -381,7 +357,7 @@ impl ModelWeights { let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. - let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?; + let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; let rope_freq_base = md_get("llama.rope.freq_base") .and_then(|m| m.to_f32()) @@ -391,7 +367,7 @@ impl ModelWeights { let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; - let norm = RmsNorm::new( + let norm = RmsNorm::from_qtensor( ct.tensor(reader, "output_norm.weight", device)?, rms_norm_eps, )?; @@ -450,9 +426,9 @@ impl ModelWeights { attention_wk: QMatMul::from_qtensor(attention_wk)?, attention_wv: QMatMul::from_qtensor(attention_wv)?, attention_wo: QMatMul::from_qtensor(attention_wo)?, - attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?, + attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?, mlp_or_moe, - ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?, + ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?, n_head: head_count, n_kv_head: head_count_kv, head_dim: embedding_length / head_count, @@ -493,14 +469,18 @@ impl ModelWeights { pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result { let (_b_sz, seq_len) = x.dims2()?; - let mask = self.mask(seq_len, x.device())?; + let mask = if seq_len == 1 { + None + } else { + Some(self.mask(seq_len, x.device())?) + }; let _enter = self.span.enter(); let mut layer_in = self.tok_embeddings.forward(x)?; for layer in self.layers.iter_mut() { let x = layer_in; let residual = &x; let x = layer.attention_norm.forward(&x)?; - let attn = layer.forward_attn(&x, &mask, index_pos)?; + let attn = layer.forward_attn(&x, mask.as_ref(), index_pos)?; let x = (attn + residual)?; // MLP diff --git a/candle-transformers/src/models/quantized_llama2_c.rs b/candle-transformers/src/models/quantized_llama2_c.rs index b43ca9ff..cbb8aad8 100644 --- a/candle-transformers/src/models/quantized_llama2_c.rs +++ b/candle-transformers/src/models/quantized_llama2_c.rs @@ -71,8 +71,12 @@ impl CausalSelfAttention { let v = v.transpose(1, 2)?.contiguous()?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = if seq_len <= 1 { + att + } else { + let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; + masked_fill(&att, &mask, f32::NEG_INFINITY)? + }; let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; diff --git a/candle-transformers/src/models/quantized_metavoice.rs b/candle-transformers/src/models/quantized_metavoice.rs new file mode 100644 index 00000000..947ab750 --- /dev/null +++ b/candle-transformers/src/models/quantized_metavoice.rs @@ -0,0 +1,243 @@ +use crate::quantized_nn::{linear_b, Embedding, Linear, RmsNorm}; +pub use crate::quantized_var_builder::VarBuilder; + +use crate::models::metavoice::repeat_interleave; +use candle::{Module, Result, Tensor, D}; + +pub mod transformer { + use super::*; + + type Config = crate::models::metavoice::transformer::Config; + + #[derive(Debug, Clone)] + struct FeedForward { + w1: Linear, + w2: Linear, + w3: Linear, + span: tracing::Span, + } + + impl FeedForward { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let i_size = cfg.intermediate_size(); + let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?; + let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?; + let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?; + Ok(Self { + w1, + w2, + w3, + span: tracing::span!(tracing::Level::TRACE, "feed-forward"), + }) + } + } + + impl Module for FeedForward { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?; + swiglu.apply(&self.w2) + } + } + + #[derive(Debug, Clone)] + struct Attention { + wqkv: Linear, + wo: Linear, + dim: usize, + kv_size: usize, + n_local_heads: usize, + head_dim: usize, + n_head: usize, + kv_cache: Option<(Tensor, Tensor)>, + span: tracing::Span, + } + + impl Attention { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let n_local_heads = cfg.n_local_heads(); + let head_dim = cfg.head_dim(); + let total_head_dim = (cfg.n_head + 2 * n_local_heads) * head_dim; + let wqkv = linear_b(cfg.dim, total_head_dim, false, vb.pp("wqkv"))?; + let wo = linear_b(cfg.dim, cfg.dim, false, vb.pp("wo"))?; + Ok(Self { + wqkv, + wo, + dim: cfg.dim, + kv_size: n_local_heads * head_dim, + n_local_heads, + head_dim, + n_head: cfg.n_head, + kv_cache: None, + span: tracing::span!(tracing::Level::TRACE, "attention"), + }) + } + + fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result { + let _enter = self.span.enter(); + let (b_sz, seqlen, _) = xs.dims3()?; + + let qkv = xs.apply(&self.wqkv)?; + let q = qkv.narrow(D::Minus1, 0, self.dim)?; + let k = qkv.narrow(D::Minus1, self.dim, self.kv_size)?; + let v = qkv.narrow(D::Minus1, self.dim + self.kv_size, self.kv_size)?; + let q = q + .reshape((b_sz, seqlen, self.n_head, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let k = k + .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))? + .transpose(1, 2)?; + + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((prev_k, prev_v)) => { + let k = Tensor::cat(&[prev_k, &k], 2)?; + let v = Tensor::cat(&[prev_v, &v], 2)?; + (k, v) + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + + let k = repeat_interleave(&k, self.n_head / self.n_local_heads, 1)?; + let v = repeat_interleave(&v, self.n_head / self.n_local_heads, 1)?; + + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; + + let attn_weights = attn_weights.broadcast_add(mask)?; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_output = attn_weights.matmul(&v)?; + attn_output + .transpose(1, 2)? + .reshape((b_sz, seqlen, self.dim))? + .apply(&self.wo) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } + } + + #[derive(Debug, Clone)] + struct Block { + attention: Attention, + feed_forward: FeedForward, + ffn_norm: RmsNorm, + attention_norm: RmsNorm, + span: tracing::Span, + } + + impl Block { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let attention = Attention::new(cfg, vb.pp("attention"))?; + let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?; + let ffn_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?; + let attention_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?; + Ok(Self { + attention, + feed_forward, + ffn_norm, + attention_norm, + span: tracing::span!(tracing::Level::TRACE, "block"), + }) + } + + fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result { + let _enter = self.span.enter(); + let hs = xs.apply(&self.attention_norm)?; + let hs = (xs + self.attention.forward(&hs, pos, mask))?; + &hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward) + } + + fn clear_kv_cache(&mut self) { + self.attention.clear_kv_cache() + } + } + + #[derive(Debug, Clone)] + pub struct Model { + tok_embeddings: Embedding, + pos_embeddings: Embedding, + speaker_cond_pos: Linear, + layers: Vec, + norm: RmsNorm, + output: Linear, + spk_cond_mask: Tensor, + span: tracing::Span, + } + + impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let tok_embeddings = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("tok_embeddings"))?; + let pos_embeddings = Embedding::new(cfg.block_size, cfg.dim, vb.pp("pos_embeddings"))?; + let speaker_cond_pos = linear_b( + cfg.speaker_emb_dim, + cfg.dim, + false, + vb.pp("speaker_cond_pos"), + )?; + let mut layers = Vec::with_capacity(cfg.n_layer); + let vb_l = vb.pp("layers"); + for layer_idx in 0..cfg.n_layer { + let layer = Block::new(cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("norm"))?; + let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?; + let spk_cond_mask = Tensor::cat( + &[ + Tensor::ones((1, 1, cfg.dim), candle::DType::F32, vb.device())?, + Tensor::zeros((1, 1, cfg.dim), candle::DType::F32, vb.device())?, + ], + 0, + )?; + Ok(Self { + tok_embeddings, + pos_embeddings, + speaker_cond_pos, + layers, + norm, + output, + spk_cond_mask, + span: tracing::span!(tracing::Level::TRACE, "qtransformer"), + }) + } + + pub fn clear_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.clear_kv_cache() + } + } + + pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result { + let _enter = self.span.enter(); + let (_b_sz, seqlen) = xs.dims2()?; + let mask: Vec<_> = (0..seqlen) + .flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) + .collect(); + let mask = Tensor::from_slice(&mask, (1, 1, seqlen, seqlen), xs.device())?; + let input_pos = Tensor::arange(pos as u32, (pos + seqlen) as u32, xs.device())?; + let tok_embeddings = xs.apply(&self.tok_embeddings)?; + let pos_embeddings = input_pos.apply(&self.pos_embeddings)?; + let mut xs = tok_embeddings + .broadcast_add(&pos_embeddings)? + .broadcast_add( + &spk_emb + .apply(&self.speaker_cond_pos)? + .broadcast_mul(&self.spk_cond_mask)?, + )?; + let mask = mask.to_dtype(xs.dtype())?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, pos, &mask)? + } + xs.narrow(1, seqlen - 1, 1)? + .contiguous()? + .apply(&self.norm)? + .apply(&self.output) + } + } +} diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs index f2cb3b27..e37785de 100644 --- a/candle-transformers/src/models/quantized_mistral.rs +++ b/candle-transformers/src/models/quantized_mistral.rs @@ -12,20 +12,14 @@ struct RotaryEmbedding { cos: Tensor, } -fn rotate_half(xs: &Tensor) -> Result { - let last_dim = xs.dim(D::Minus1)?; - let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; - let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; - Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) -} - impl RotaryEmbedding { fn new(cfg: &Config, dev: &Device) -> Result { + let rope_theta = cfg.rope_theta as f32; let dim = cfg.hidden_size / cfg.num_attention_heads; let max_seq_len = cfg.max_position_embeddings; let inv_freq: Vec<_> = (0..dim) .step_by(2) - .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) + .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32)) .collect(); let inv_freq_len = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; @@ -33,7 +27,6 @@ impl RotaryEmbedding { .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, @@ -49,10 +42,8 @@ impl RotaryEmbedding { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; - let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; - let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?; Ok((q_embed, k_embed)) } } @@ -157,10 +148,12 @@ impl Attention { let query_states = query_states .reshape((b_sz, q_len, self.num_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let key_states = key_states .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let value_states = value_states .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; @@ -257,7 +250,7 @@ pub struct Model { layers: Vec, norm: RmsNorm, lm_head: Linear, - sliding_window: usize, + sliding_window: Option, device: Device, } @@ -287,15 +280,14 @@ impl Model { fn prepare_decoder_attention_mask( &self, - b_size: usize, tgt_len: usize, seqlen_offset: usize, ) -> Result { - // Sliding window mask? + let sliding_window = self.sliding_window.unwrap_or(tgt_len + 1); let mask: Vec<_> = (0..tgt_len) .flat_map(|i| { (0..tgt_len).map(move |j| { - if i < j || j + self.sliding_window < i { + if i < j || j + sliding_window < i { f32::NEG_INFINITY } else { 0. @@ -310,16 +302,16 @@ impl Model { } else { mask }; - mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))? .to_dtype(DType::F32) } pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result { - let (b_size, seq_len) = input_ids.dims2()?; + let (_b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { None } else { - let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; + let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?; Some(mask) }; let mut xs = self.embed_tokens.forward(input_ids)?; @@ -327,6 +319,7 @@ impl Model { xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? } xs.narrow(1, seq_len - 1, 1)? + .contiguous()? .apply(&self.norm)? .apply(&self.lm_head) } diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs index 882f4cf8..fa72672a 100644 --- a/candle-transformers/src/models/quantized_mixformer.rs +++ b/candle-transformers/src/models/quantized_mixformer.rs @@ -337,6 +337,30 @@ impl MixFormerSequentialForCausalLM { xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1) } + pub fn forward_with_img( + &mut self, + bos_token: &Tensor, + xs: &Tensor, + img_embeds: &Tensor, + ) -> Result { + let _enter = self.span.enter(); + let xs = xs.apply(&self.embedding)?; + let bos_token = bos_token.apply(&self.embedding)?; + // Python implementation sequence order is + // https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L43-L56 + let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?; + let (_b_size, seq_len, _embds) = xs.dims3()?; + let mask = Some(get_mask(seq_len, xs.device())?); + for block in self.blocks.iter_mut() { + xs = block.forward(&xs, mask.as_ref())? + } + let xs = xs + .narrow(1, seq_len - 1, 1)? + .apply(&self.head)? + .squeeze(1)?; + Ok(xs) + } + pub fn clear_kv_cache(&mut self) { self.blocks.iter_mut().for_each(|b| b.clear_kv_cache()) } diff --git a/candle-transformers/src/models/quantized_moondream.rs b/candle-transformers/src/models/quantized_moondream.rs new file mode 100644 index 00000000..1b125d93 --- /dev/null +++ b/candle-transformers/src/models/quantized_moondream.rs @@ -0,0 +1,271 @@ +use crate::models::moondream::{Config, VisionConfig}; +use crate::models::quantized_mixformer::MixFormerSequentialForCausalLM as PhiModel; +use crate::quantized_nn::{layer_norm, linear_b, Linear}; +use crate::quantized_var_builder::VarBuilder; +use candle::{IndexOp, Module, Result, Tensor, D}; + +fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result { + let dim = q.dim(D::Minus1)?; + let scale_factor = 1.0 / (dim as f64).sqrt(); + let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?; + candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v) +} + +#[derive(Debug, Clone)] +struct LinearPatchEmbedding { + linear: Linear, +} + +impl LinearPatchEmbedding { + fn new(vb: VarBuilder) -> Result { + let linear = linear_b(588, 1152, true, vb.pp("linear"))?; + Ok(Self { linear }) + } +} + +impl Module for LinearPatchEmbedding { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.linear) + } +} + +#[derive(Debug, Clone)] +struct Attention { + num_heads: usize, + head_dim: usize, + qkv: Linear, + proj: Linear, +} + +impl Attention { + pub fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result { + let qkv = linear_b(dim, dim * 3, true, vb.pp("qkv"))?; + let proj = linear_b(dim, dim, true, vb.pp("proj"))?; + Ok(Self { + num_heads, + head_dim: dim / num_heads, + qkv, + proj, + }) + } +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result { + let (b, n, c) = xs.dims3()?; + let qkv = xs + .apply(&self.qkv)? + .reshape((b, n, 3, self.num_heads, self.head_dim))? + .permute((2, 0, 3, 1, 4))?; + let (q, k, v) = ( + qkv.i(0)?.contiguous()?, + qkv.i(1)?.contiguous()?, + qkv.i(2)?.contiguous()?, + ); + scaled_dot_product_attention(&q, &k, &v)? + .transpose(1, 2)? + .reshape((b, n, c))? + .apply(&self.proj) + } +} + +#[derive(Debug, Clone)] +struct VitBlock { + attn: Attention, + mlp: Mlp, + norm1: candle_nn::LayerNorm, + norm2: candle_nn::LayerNorm, +} + +impl VitBlock { + fn new(vb: VarBuilder, dim: usize, num_heads: usize, cfg: &VisionConfig) -> Result { + let attn = Attention::new(vb.pp("attn"), dim, num_heads)?; + let mlp = Mlp::new(vb.pp("mlp"), dim, cfg.hidden_features, dim, cfg.act)?; + let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; + let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; + Ok(Self { + attn, + mlp, + norm1, + norm2, + }) + } +} + +impl Module for VitBlock { + fn forward(&self, xs: &Tensor) -> Result { + let ys = xs.apply(&self.norm1)?.apply(&self.attn)?; + let xs = (xs + &ys)?; + let ys = xs.apply(&self.norm2)?.apply(&self.mlp)?; + let xs = (&xs + &ys)?; + Ok(xs) + } +} + +#[derive(Debug, Clone)] +struct VisionTransformer { + patch_embed: LinearPatchEmbedding, + pos_embed: Tensor, + blocks: Vec, + norm: candle_nn::LayerNorm, +} + +impl VisionTransformer { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let patch_embed = LinearPatchEmbedding::new(vb.pp("patch_embed"))?; + let pos_embed = vb + .get((1, cfg.embed_len, cfg.embed_dim), "pos_embed")? + .dequantize(vb.device())?; + let blocks = (0..cfg.num_blocks) + .map(|i| { + VitBlock::new( + vb.pp(format!("blocks.{}", i)), + cfg.embed_dim, + cfg.num_heads, + cfg, + ) + }) + .collect::>()?; + let norm = layer_norm(cfg.embed_dim, 1e-5, vb.pp("norm"))?; + Ok(Self { + patch_embed, + pos_embed, + blocks, + norm, + }) + } +} + +impl Module for VisionTransformer { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = (&xs.apply(&self.patch_embed)? + &self.pos_embed)?; + for block in self.blocks.iter() { + xs = xs.apply(block)?; + } + xs.apply(&self.norm) + } +} + +#[derive(Debug, Clone)] +pub struct Encoder { + model: VisionTransformer, +} + +impl Encoder { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let model = VisionTransformer::new(cfg, vb.pp("model.visual"))?; + Ok(Self { model }) + } +} + +impl Module for Encoder { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.model) + } +} + +#[derive(Debug, Clone)] +struct Mlp { + fc1: Linear, + act: candle_nn::Activation, + fc2: Linear, +} + +impl Mlp { + fn new( + vb: VarBuilder, + in_features: usize, + hidden_features: usize, + out_features: usize, + act: candle_nn::Activation, + ) -> Result { + let fc1 = linear_b(in_features, hidden_features, true, vb.pp("fc1"))?; + let fc2 = linear_b(hidden_features, out_features, true, vb.pp("fc2"))?; + Ok(Self { fc1, act, fc2 }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2) + } +} + +#[derive(Debug, Clone)] +struct VisionProjection { + mlp: Mlp, +} + +impl VisionProjection { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let mlp = Mlp::new( + vb.pp("mlp"), + cfg.image_embedding_dim, + cfg.hidden_dim, + cfg.model_dim, + cfg.act, + )?; + Ok(Self { mlp }) + } +} + +impl Module for VisionProjection { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.mlp) + } +} + +#[derive(Debug, Clone)] +pub struct VisionEncoder { + encoder: Encoder, + projection: VisionProjection, +} + +impl VisionEncoder { + pub fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let encoder = Encoder::new(cfg, vb.pp("encoder"))?; + let projection = VisionProjection::new(cfg, vb.pp("projection"))?; + Ok(Self { + encoder, + projection, + }) + } +} + +impl Module for VisionEncoder { + fn forward(&self, xs: &Tensor) -> Result { + let (b, c, hp1, wp2) = xs.dims4()?; + let (p1, p2) = (14, 14); + let h = hp1 / p1; + let w = wp2 / p2; + xs.reshape((b, c, h, p1, h, p2))? + .permute((0, 2, 4, 1, 3, 5))? + .reshape((b, h * w, c * p1 * p2))? + .apply(&self.encoder)? + .apply(&self.projection) + } +} + +pub struct Model { + pub text_model: PhiModel, + pub vision_encoder: VisionEncoder, +} + +impl Model { + pub fn new(config: &Config, vb: VarBuilder) -> Result { + let text_model = PhiModel::new_v2(&config.phi_config, vb.pp("text_model"))?; + let vision_encoder = VisionEncoder::new(&config.vision_config, vb.pp("vision_encoder"))?; + Ok(Self { + text_model, + vision_encoder, + }) + } + + pub fn vision_encoder(&self) -> &VisionEncoder { + &self.vision_encoder + } + + pub fn text_model(&mut self) -> &mut PhiModel { + &mut self.text_model + } +} diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs index 26431b7d..9a12eba5 100644 --- a/candle-transformers/src/models/qwen2.rs +++ b/candle-transformers/src/models/qwen2.rs @@ -1,4 +1,4 @@ -use crate::models::with_tracing::{linear, linear_no_bias, Linear}; +use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; @@ -21,27 +21,6 @@ pub struct Config { pub hidden_act: Activation, } -#[derive(Debug, Clone)] -struct RmsNorm { - inner: candle_nn::RmsNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let inner = candle_nn::rms_norm(size, eps, vb)?; - Ok(Self { inner, span }) - } -} - -impl Module for RmsNorm { - fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - #[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, diff --git a/candle-transformers/src/models/qwen2_moe.rs b/candle-transformers/src/models/qwen2_moe.rs new file mode 100644 index 00000000..d6566e90 --- /dev/null +++ b/candle-transformers/src/models/qwen2_moe.rs @@ -0,0 +1,488 @@ +use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; +use candle::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::{Activation, VarBuilder}; +use std::sync::Arc; + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub max_position_embeddings: usize, + pub sliding_window: usize, + pub max_window_layers: usize, + pub tie_word_embeddings: bool, + pub rope_theta: f64, + pub rms_norm_eps: f64, + pub use_sliding_window: bool, + pub hidden_act: Activation, + pub decoder_sparse_step: usize, + pub moe_intermediate_size: usize, + pub shared_expert_intermediate_size: usize, + pub num_experts_per_tok: usize, + pub num_experts: usize, + pub norm_topk_prob: bool, +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +fn rotate_half(xs: &Tensor) -> Result { + let last_dim = xs.dim(D::Minus1)?; + let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; + let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; + Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) +} + +impl RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let dim = cfg.hidden_size / cfg.num_attention_heads; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv( + &self, + q: &Tensor, + k: &Tensor, + seqlen_offset: usize, + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; + let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; + let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; + let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: Activation, +} + +impl MLP { + fn new(intermediate_sz: usize, cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; + let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; + let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + Ok(Self { + gate_proj, + up_proj, + down_proj, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = xs.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + hidden_size: usize, + rotary_emb: Arc, + kv_cache: Option<(Tensor, Tensor)>, +} + +impl Attention { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + let head_dim = hidden_sz / num_heads; + let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; + let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; + let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; + let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + hidden_size: hidden_sz, + rotary_emb, + kv_cache: None, + }) + } + + fn repeat_kv(&self, xs: Tensor) -> Result { + let n_rep = self.num_kv_groups; + if n_rep == 1 { + Ok(xs) + } else { + let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; + xs.unsqueeze(2)? + .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? + .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) + } + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let (b_sz, q_len, _) = xs.dims3()?; + + let query_states = self.q_proj.forward(xs)?; + let key_states = self.k_proj.forward(xs)?; + let value_states = self.v_proj.forward(xs)?; + + let query_states = query_states + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let key_states = key_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let value_states = value_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + let (query_states, key_states) = + self.rotary_emb + .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?; + + let (key_states, value_states) = match &self.kv_cache { + None => (key_states, value_states), + Some((prev_k, prev_v)) => { + let key_states = Tensor::cat(&[prev_k, &key_states], 2)?; + let value_states = Tensor::cat(&[prev_v, &value_states], 2)?; + (key_states, value_states) + } + }; + self.kv_cache = Some((key_states.clone(), value_states.clone())); + + let key_states = self.repeat_kv(key_states)?.contiguous()?; + let value_states = self.repeat_kv(value_states)?.contiguous()?; + + let attn_output = { + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_states)? + }; + attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, self.hidden_size))? + .apply(&self.o_proj) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } +} + +// https://github.com/huggingface/transformers/blob/536ea2aca234fb48c5c69769431d643b0d93b233/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py#L800 +#[derive(Debug, Clone)] +struct SparseMoeBlock { + gate: Linear, + experts: Vec, + shared_expert: MLP, + shared_expert_gate: Linear, + norm_topk_prob: bool, + num_experts_per_tok: usize, +} + +impl SparseMoeBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let gate = linear_no_bias(cfg.hidden_size, cfg.num_experts, vb.pp("gate"))?; + let mut experts = Vec::with_capacity(cfg.num_experts); + let vb_e = vb.pp("experts"); + for idx in 0..cfg.num_experts { + let expert = MLP::new(cfg.moe_intermediate_size, cfg, vb_e.pp(idx))?; + experts.push(expert) + } + let shared_expert = MLP::new( + cfg.shared_expert_intermediate_size, + cfg, + vb.pp("shared_expert"), + )?; + let shared_expert_gate = linear_no_bias(cfg.hidden_size, 1, vb.pp("shared_expert_gate"))?; + Ok(Self { + gate, + experts, + shared_expert, + shared_expert_gate, + norm_topk_prob: cfg.norm_topk_prob, + num_experts_per_tok: cfg.num_experts_per_tok, + }) + } +} + +impl Module for SparseMoeBlock { + fn forward(&self, xs: &Tensor) -> Result { + let (b_size, seq_len, hidden_dim) = xs.dims3()?; + let xs = xs.reshape(((), hidden_dim))?; + let router_logits = xs.apply(&self.gate)?; + let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + + // In order to extract topk, we extract the data from the tensor and manipulate it + // directly. Maybe we will want to use some custom ops instead at some point. + let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::()?; + + // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + // top_x contains the row indexes to evaluate for each expert. + let mut top_x = vec![vec![]; self.experts.len()]; + let mut selected_experts = vec![vec![]; self.experts.len()]; + for (row_idx, rw) in routing_weights.iter().enumerate() { + let mut dst = (0..rw.len() as u32).collect::>(); + dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize])); + let mut sum_routing_weights = 0f32; + for &expert_idx in dst.iter().take(self.num_experts_per_tok) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + sum_routing_weights += routing_weight; + top_x[expert_idx].push(row_idx as u32); + } + for &expert_idx in dst.iter().take(self.num_experts_per_tok) { + let expert_idx = expert_idx as usize; + let routing_weight = if self.norm_topk_prob { + rw[expert_idx] / sum_routing_weights + } else { + rw[expert_idx] + }; + selected_experts[expert_idx].push(routing_weight) + } + } + + let mut ys = xs.zeros_like()?; + for (expert_idx, expert_layer) in self.experts.iter().enumerate() { + let top_x = &top_x[expert_idx]; + if top_x.is_empty() { + continue; + } + let top_x = Tensor::new(top_x.as_slice(), xs.device())?; + let selected_experts = + Tensor::new(selected_experts[expert_idx].as_slice(), xs.device())? + .reshape(((), 1))? + .to_dtype(xs.dtype())?; + // Index the correct hidden states and compute the expert hidden state for + // the current expert. We need to make sure to multiply the output hidden + // states by `routing_weights` on the corresponding tokens (top-1 and top-2) + let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?; + // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None]) + let current_hidden_states = expert_layer.forward(¤t_state)?; + let current_hidden_states = current_hidden_states.broadcast_mul(&selected_experts)?; + ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?; + } + let shared_expert_output = xs.apply(&self.shared_expert)?; + let shared_expert_output = shared_expert_output.broadcast_mul(&candle_nn::ops::sigmoid( + &xs.apply(&self.shared_expert_gate)?, + )?)?; + let ys = (ys + shared_expert_output)?; + let ys = ys.reshape((b_size, seq_len, hidden_dim))?; + Ok(ys) + } +} + +#[derive(Debug, Clone)] +enum MlpOrMoeBlock { + Mlp(MLP), + MoeBlock(SparseMoeBlock), +} + +impl Module for MlpOrMoeBlock { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::MoeBlock(m) => m.forward(xs), + Self::Mlp(m) => m.forward(xs), + } + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Attention, + mlp: MlpOrMoeBlock, + input_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, +} + +impl DecoderLayer { + fn new( + layer_idx: usize, + rotary_emb: Arc, + cfg: &Config, + vb: VarBuilder, + ) -> Result { + let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let mlp = if cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0 { + MlpOrMoeBlock::MoeBlock(SparseMoeBlock::new(cfg, vb.pp("mlp"))?) + } else { + MlpOrMoeBlock::Mlp(MLP::new(cfg.intermediate_size, cfg, vb.pp("mlp"))?) + }; + let input_layernorm = + RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + input_layernorm, + post_attention_layernorm, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; + residual + xs + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache() + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec, + norm: RmsNorm, + lm_head: Linear, + sliding_window: usize, + device: Device, + dtype: DType, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_m = vb.pp("model"); + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(layer_idx, rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; + let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + sliding_window: cfg.sliding_window, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + fn prepare_decoder_attention_mask( + &self, + b_size: usize, + tgt_len: usize, + seqlen_offset: usize, + ) -> Result { + // Sliding window mask? + let mask: Vec<_> = (0..tgt_len) + .flat_map(|i| { + (0..tgt_len).map(move |j| { + if i < j || j + self.sliding_window < i { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) + .collect(); + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; + let mask = if seqlen_offset > 0 { + let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(self.dtype) + } + + pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result { + let (b_size, seq_len) = input_ids.dims2()?; + let attention_mask = if seq_len <= 1 { + None + } else { + let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; + Some(mask) + }; + let mut xs = self.embed_tokens.forward(input_ids)?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? + } + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head) + } + + pub fn clear_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.clear_kv_cache() + } + } +} diff --git a/candle-transformers/src/models/segment_anything/prompt_encoder.rs b/candle-transformers/src/models/segment_anything/prompt_encoder.rs index 16e8a4e8..258fb5aa 100644 --- a/candle-transformers/src/models/segment_anything/prompt_encoder.rs +++ b/candle-transformers/src/models/segment_anything/prompt_encoder.rs @@ -218,7 +218,8 @@ impl PromptEncoder { (Some(se_points), None) => se_points, (None, Some(se_boxes)) => se_boxes, (None, None) => { - Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)? + let dev = self.no_mask_embed.embeddings().device(); + Tensor::zeros((1, 0, self.embed_dim), DType::F32, dev)? } }; diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs index 07ce0fe4..4d5a7c47 100644 --- a/candle-transformers/src/models/stable_diffusion/attention.rs +++ b/candle-transformers/src/models/stable_diffusion/attention.rs @@ -533,7 +533,9 @@ impl Module for AttentionBlock { let attention_scores = (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?; let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?; - let xs = attention_probs.matmul(&value_states.contiguous()?)?; + // TODO: revert the call to force_contiguous once the three matmul kernels have been + // adapted to handle layout with some dims set to 1. + let xs = attention_probs.matmul(&value_states)?; let xs = xs.to_dtype(in_dtype)?; let xs = xs.transpose(1, 2)?.contiguous()?; let xs = xs.flatten_from(D::Minus2)?; diff --git a/candle-transformers/src/models/stable_diffusion/schedulers.rs b/candle-transformers/src/models/stable_diffusion/schedulers.rs index 0f0441e0..94f8ab86 100644 --- a/candle-transformers/src/models/stable_diffusion/schedulers.rs +++ b/candle-transformers/src/models/stable_diffusion/schedulers.rs @@ -5,7 +5,7 @@ //! inference speed and quality. use candle::{Result, Tensor}; -pub trait SchedulerConfig: std::fmt::Debug { +pub trait SchedulerConfig: std::fmt::Debug + Send + Sync { fn build(&self, inference_steps: usize) -> Result>; } diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 5dc44cb5..f4b5b4b0 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -70,26 +70,26 @@ where #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { - vocab_size: usize, - d_model: usize, - d_kv: usize, - d_ff: usize, - num_layers: usize, - num_decoder_layers: Option, - num_heads: usize, - relative_attention_num_buckets: usize, + pub vocab_size: usize, + pub d_model: usize, + pub d_kv: usize, + pub d_ff: usize, + pub num_layers: usize, + pub num_decoder_layers: Option, + pub num_heads: usize, + pub relative_attention_num_buckets: usize, #[serde(default = "default_relative_attention_max_distance")] - relative_attention_max_distance: usize, - dropout_rate: f64, - layer_norm_epsilon: f64, - initializer_factor: f64, + pub relative_attention_max_distance: usize, + pub dropout_rate: f64, + pub layer_norm_epsilon: f64, + pub initializer_factor: f64, #[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")] - feed_forward_proj: ActivationWithOptionalGating, + pub feed_forward_proj: ActivationWithOptionalGating, #[serde(default = "default_tie_word_embeddings")] - tie_word_embeddings: bool, + pub tie_word_embeddings: bool, #[serde(default = "default_is_decoder")] - is_decoder: bool, - is_encoder_decoder: bool, + pub is_decoder: bool, + pub is_encoder_decoder: bool, #[serde(default = "default_use_cache")] pub use_cache: bool, pub pad_token_id: usize, diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index 383ae71c..1c34bfa2 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -116,6 +116,12 @@ impl QMatMul { let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); Ok(Self { inner, span }) } + + pub fn from_weights(ws: std::sync::Arc) -> Result { + let inner = candle::quantized::QMatMul::from_arc(ws)?; + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + Ok(Self { inner, span }) + } } impl Module for QMatMul { @@ -161,3 +167,24 @@ pub fn layer_norm>( let span = tracing::span!(tracing::Level::TRACE, "layer-norm"); Ok(LayerNorm { inner, span }) } + +#[derive(Debug, Clone)] +pub struct RmsNorm { + inner: candle_nn::RmsNorm, + span: tracing::Span, +} + +impl RmsNorm { + pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); + let inner = candle_nn::rms_norm(size, eps, vb)?; + Ok(Self { inner, span }) + } +} + +impl Module for RmsNorm { + fn forward(&self, x: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} diff --git a/candle-transformers/src/models/yi.rs b/candle-transformers/src/models/yi.rs index 14b6feeb..99d9de1b 100644 --- a/candle-transformers/src/models/yi.rs +++ b/candle-transformers/src/models/yi.rs @@ -1,5 +1,5 @@ /// https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py -use crate::models::with_tracing::{linear_no_bias, Linear}; +use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; @@ -50,27 +50,6 @@ impl Config { } } -#[derive(Debug, Clone)] -struct RmsNorm { - inner: candle_nn::RmsNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let inner = candle_nn::rms_norm(size, eps, vb)?; - Ok(Self { inner, span }) - } -} - -impl Module for RmsNorm { - fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - #[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs index 99e8d45b..9298b80e 100644 --- a/candle-transformers/src/quantized_nn.rs +++ b/candle-transformers/src/quantized_nn.rs @@ -1,5 +1,6 @@ use crate::models::with_tracing::QMatMul; use crate::quantized_var_builder::VarBuilder; +use candle::quantized::QTensor; use candle::{Module, Result, Tensor}; #[derive(Debug, Clone)] @@ -35,6 +36,11 @@ pub struct Linear { } impl Linear { + pub fn from_arc(weight: std::sync::Arc, bias: Option) -> Result { + let weight = QMatMul::from_weights(weight)?; + Ok(Self { weight, bias }) + } + pub fn from_weights(weight: QMatMul, bias: Option) -> Self { Self { weight, bias } } @@ -50,6 +56,16 @@ impl Module for Linear { } } +pub fn linear_b(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result { + let bias = if bias { + Some(vb.get(out_dim, "bias")?.dequantize(vb.device())?) + } else { + None + }; + let weight = QMatMul::new(in_dim, out_dim, vb)?; + Ok(Linear { weight, bias }) +} + pub fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result { let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?; let weight = QMatMul::new(in_dim, out_dim, vb)?; @@ -77,7 +93,8 @@ pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result Result { let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); let weight = vb.get(size, "weight")?.dequantize(vb.device())?; - let inner = candle_nn::RmsNorm::new(weight, eps); - Ok(Self { inner, span }) + Ok(Self { weight, eps, span }) + } + + pub fn from_qtensor(weight: QTensor, eps: f64) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); + let weight = weight.dequantize(&weight.device())?; + Ok(Self { weight, eps, span }) } } impl Module for RmsNorm { fn forward(&self, x: &Tensor) -> Result { let _enter = self.span.enter(); - self.inner.forward(x) + candle_nn::ops::rms_norm(x, &self.weight, self.eps as f32) } } diff --git a/candle-transformers/src/quantized_var_builder.rs b/candle-transformers/src/quantized_var_builder.rs index bfd0629f..a963e311 100644 --- a/candle-transformers/src/quantized_var_builder.rs +++ b/candle-transformers/src/quantized_var_builder.rs @@ -3,6 +3,7 @@ use candle::{Device, Result, Shape}; use std::sync::Arc; // VarBuilder specialized for QTensors +#[derive(Clone)] pub struct VarBuilder { data: Arc>>, path: Vec, diff --git a/candle-transformers/src/utils.rs b/candle-transformers/src/utils.rs index 50d3b707..d29995ed 100644 --- a/candle-transformers/src/utils.rs +++ b/candle-transformers/src/utils.rs @@ -2,10 +2,14 @@ use candle::{Result, Tensor}; pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result { let device = logits.device(); - let mut logits = logits.to_vec1::()?; - let context: std::collections::HashSet<_> = context.iter().collect(); - for (token_id, logit) in logits.iter_mut().enumerate() { - if context.contains(&(token_id as u32)) { + let mut logits = logits.to_dtype(candle::DType::F32)?.to_vec1::()?; + let mut already_seen = std::collections::HashSet::new(); + for token_id in context { + if already_seen.contains(token_id) { + continue; + } + already_seen.insert(token_id); + if let Some(logit) = logits.get_mut(*token_id as usize) { if *logit >= 0. { *logit /= penalty } else { diff --git a/candle-transformers/tests/generation_tests.rs b/candle-transformers/tests/generation_tests.rs index 76f994d0..cc499a44 100644 --- a/candle-transformers/tests/generation_tests.rs +++ b/candle-transformers/tests/generation_tests.rs @@ -27,3 +27,30 @@ fn sample_with_top_p() -> Result<()> { assert_eq!(token, 2); Ok(()) } + +#[test] +fn sample_with_top_k() -> Result<()> { + let mut logits_process = LogitsProcessor::from_sampling( + 42, + candle_transformers::generation::Sampling::TopK { + k: 1, + temperature: 1.0, + }, + ); + let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?; + let token = logits_process.sample(&logits)?; + assert_eq!(token, 3); + let mut logits_process = LogitsProcessor::from_sampling( + 42, + candle_transformers::generation::Sampling::TopK { + k: 2, + temperature: 1.0, + }, + ); + let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?; + let token = logits_process.sample(&logits)?; + assert_eq!(token, 3); + let token = logits_process.sample(&logits)?; + assert_eq!(token, 2); + Ok(()) +} diff --git a/candle-wasm-examples/moondream/Cargo.toml b/candle-wasm-examples/moondream/Cargo.toml new file mode 100644 index 00000000..fc1b82ca --- /dev/null +++ b/candle-wasm-examples/moondream/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "candle-wasm-example-moondream" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true + +[dependencies] +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } +tokenizers = { workspace = true, features = ["unstable_wasm"] } +num-traits = { workspace = true } + +# App crates. +anyhow = { workspace = true } +byteorder = { workspace = true } +getrandom = { version = "0.2", features = ["js"] } +image = { workspace = true } +log = { workspace = true } +safetensors = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } + +# Wasm specific crates. +console_error_panic_hook = "0.1.7" +wasm-bindgen = "0.2.87" +js-sys = "0.3.64" +serde-wasm-bindgen = "0.6.5" diff --git a/candle-wasm-examples/moondream/README.md b/candle-wasm-examples/moondream/README.md new file mode 100644 index 00000000..ca7f7ced --- /dev/null +++ b/candle-wasm-examples/moondream/README.md @@ -0,0 +1,24 @@ +## Running [Moondream 2](https://huggingface.co/vikhyatk/moondream2) Model Example + +### Vanilla JS and WebWorkers + +To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library: + +```bash +sh build-lib.sh +``` + +This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module: + +```js +import init, { Model } from "./build/m.js"; +``` + +The full example can be found under `./index.html`. All needed assets are fetched from the web, so no need to download anything. +Finally, you can preview the example by running a local HTTP server. For example: + +```bash +python -m http.server +``` + +Then open `http://localhost:8000/index.html` in your browser. diff --git a/candle-wasm-examples/moondream/build-lib.sh b/candle-wasm-examples/moondream/build-lib.sh new file mode 100644 index 00000000..b0ebb182 --- /dev/null +++ b/candle-wasm-examples/moondream/build-lib.sh @@ -0,0 +1,2 @@ +cargo build --target wasm32-unknown-unknown --release +wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web diff --git a/candle-wasm-examples/moondream/code.js b/candle-wasm-examples/moondream/code.js new file mode 100644 index 00000000..c766196d --- /dev/null +++ b/candle-wasm-examples/moondream/code.js @@ -0,0 +1,262 @@ +import snarkdown from "https://cdn.skypack.dev/snarkdown"; +import hljs from "https://cdn.skypack.dev/highlight.js"; +// models base url +const MODELS = { + moondream2_q4k: { + base_url: + "https://huggingface.co/santiagomed/candle-moondream/resolve/main/", + model: "model-q4_0.gguf", + tokenizer: "tokenizer.json", + quantized: true, + size: "1.51 GB", + }, +}; + +const moodreamWorker = new Worker("./moondreamWorker.js", { + type: "module", +}); + +async function generateSequence(controller) { + const getValue = (id) => document.querySelector(`#${id}`).value; + const modelID = getValue("model"); + const model = MODELS[modelID]; + const weightsURL = + model.model instanceof Array + ? model.model.map((m) => model.base_url + m) + : model.base_url + model.model; + const tokenizerURL = model.base_url + model.tokenizer; + + const prompt = getValue("prompt").trim(); + const temperature = getValue("temperature"); + const topP = getValue("top-p"); + const repeatPenalty = getValue("repeat_penalty"); + const seed = getValue("seed"); + const maxSeqLen = getValue("max-seq"); + + if (prompt?.value?.trim() === "") { + return; + } + + function updateStatus(data) { + const outStatus = document.querySelector("#output-status"); + const outGen = document.querySelector("#output-generation"); + const outCounter = document.querySelector("#output-counter"); + + switch (data.status) { + case "loading": + outStatus.hidden = false; + outStatus.textContent = data.message; + outGen.hidden = true; + outCounter.hidden = true; + break; + case "generating": + const { message, prompt, sentence, tokensSec, totalTime } = data; + outStatus.hidden = true; + outCounter.hidden = false; + outGen.hidden = false; + outGen.innerHTML = snarkdown(prompt + sentence); + outCounter.innerHTML = `${(totalTime / 1000).toFixed( + 2 + )}s (${tokensSec.toFixed(2)} tok/s)`; + hljs.highlightAll(); + break; + case "complete": + outStatus.hidden = true; + outGen.hidden = false; + break; + } + } + + return new Promise((resolve, reject) => { + moodreamWorker.postMessage({ + weightsURL, + modelID, + tokenizerURL, + quantized: model.quantized, + imageURL: currentImageURL, + prompt, + temp: temperature, + top_p: topP, + repeatPenalty, + seed: seed, + maxSeqLen, + verbose_prompt: false, + command: "start", + }); + + const handleAbort = () => { + moodreamWorker.postMessage({ command: "abort" }); + }; + const handleMessage = (event) => { + const { status, error, message, prompt, sentence } = event.data; + if (status) updateStatus(event.data); + if (error) { + moodreamWorker.removeEventListener("message", handleMessage); + reject(new Error(error)); + } + if (status === "aborted") { + moodreamWorker.removeEventListener("message", handleMessage); + resolve(event.data); + } + if (status === "complete") { + moodreamWorker.removeEventListener("message", handleMessage); + resolve(event.data); + } + }; + + controller.signal.addEventListener("abort", handleAbort); + moodreamWorker.addEventListener("message", handleMessage); + }); +} + +const form = document.querySelector("#form"); +const prompt = document.querySelector("#prompt"); +const runBtn = document.querySelector("#run"); +const modelSelect = document.querySelector("#model"); +const dropArea = document.querySelector("#drop-area"); +const canvas = document.querySelector("#canvas"); +const ctxCanvas = canvas.getContext("2d"); +const fileUpload = document.querySelector("#file-upload"); +const clearImgBtn = document.querySelector("#clear-img-btn"); +const imagesExamples = document.querySelector("#image-select"); + +let currentImageURL = null; +let runController = new AbortController(); +let isRunning = false; + +document.addEventListener("DOMContentLoaded", () => { + for (const [id, model] of Object.entries(MODELS)) { + const option = document.createElement("option"); + option.value = id; + option.innerText = `${id} (${model.size})`; + modelSelect.appendChild(option); + } + const query = new URLSearchParams(window.location.search); + const modelID = query.get("model"); + if (modelID) { + modelSelect.value = modelID; + } else { + modelSelect.value = "moondream2_q4k"; + } +}); + +imagesExamples.addEventListener("click", (e) => { + // if (isEmbedding || isSegmenting) { + // return; + // } + const target = e.target; + if (target.nodeName === "IMG") { + const href = target.src; + clearImageCanvas(); + currentImageURL = href; + drawImageCanvas(href); + } +}); +modelSelect.addEventListener("change", (e) => { + const query = new URLSearchParams(window.location.search); + query.set("model", e.target.value); + window.history.replaceState({}, "", `${window.location.pathname}?${query}`); + window.parent.postMessage({ queryString: "?" + query }, "*"); + const model = MODELS[e.target.value]; + document.querySelector("#max-seq").max = model.seq_len; + document.querySelector("#max-seq").nextElementSibling.value = 200; +}); + +clearImgBtn.addEventListener("click", () => { + clearImageCanvas(); +}); + +//add event listener to file input +fileUpload.addEventListener("input", async (e) => { + const target = e.target; + if (target.files.length > 0 && !target.files[0].type.includes("svg")) { + const href = URL.createObjectURL(target.files[0]); + clearImageCanvas(); + await drawImageCanvas(href); + } +}); +// add event listener to drop-area +dropArea.addEventListener("dragenter", (e) => { + e.preventDefault(); + dropArea.classList.add("border-blue-700"); +}); +dropArea.addEventListener("dragleave", (e) => { + e.preventDefault(); + dropArea.classList.remove("border-blue-700"); +}); +dropArea.addEventListener("dragover", (e) => { + e.preventDefault(); +}); +dropArea.addEventListener("drop", async (e) => { + e.preventDefault(); + dropArea.classList.remove("border-blue-700"); + const url = e.dataTransfer.getData("text/uri-list"); + const files = e.dataTransfer.files; + if (files.length > 0) { + const href = URL.createObjectURL(files[0]); + clearImageCanvas(); + await drawImageCanvas(href); + } else if (url) { + clearImageCanvas(); + await drawImageCanvas(url); + } +}); + +form.addEventListener("submit", async (e) => { + e.preventDefault(); + if (isRunning) { + stopRunning(); + } else { + startRunning(); + await generateSequence(runController); + stopRunning(); + } +}); + +async function drawImageCanvas(imgURL) { + if (!imgURL) { + throw new Error("No image URL provided"); + } + return new Promise((resolve, reject) => { + ctxCanvas.clearRect(0, 0, canvas.width, canvas.height); + ctxCanvas.clearRect(0, 0, canvas.width, canvas.height); + const img = new Image(); + img.crossOrigin = "anonymous"; + img.onload = () => { + canvas.width = img.width; + canvas.height = img.height; + ctxCanvas.drawImage(img, 0, 0); + clearImgBtn.disabled = false; + resolve(img); + }; + img.src = imgURL; + currentImageURL = imgURL; + }); +} + +function clearImageCanvas() { + ctxCanvas.clearRect(0, 0, canvas.width, canvas.height); + clearImgBtn.disabled = true; + canvas.parentElement.style.height = "auto"; + currentImageURL = null; + canvas.width = 0; + canvas.height = 0; +} + +function startRunning() { + isRunning = true; + runBtn.textContent = "Stop"; + prompt.disabled = true; +} + +function stopRunning() { + runController.abort(); + runController = new AbortController(); + runBtn.textContent = "Run"; + isRunning = false; + prompt.disabled = false; +} + +prompt.addEventListener("input", (e) => { + runBtn.disabled = false; +}); diff --git a/candle-wasm-examples/moondream/index.html b/candle-wasm-examples/moondream/index.html new file mode 100644 index 00000000..26bd6a40 --- /dev/null +++ b/candle-wasm-examples/moondream/index.html @@ -0,0 +1,312 @@ + + + + Candle Moondream Rust/WASM + + + + + + + + + + + + + + + + +

+ 🕯️ +
+

Candle Moondream 2

+

Rust/WASM Demo

+

+ Moondream 2 + by + Vik + and model implementation on Candle by + Santiago Medina + +

+
+ +
+

+ Note: + When first run, the app will download and cache the model, which could + take a few minutes. Then, the embeddings and generation will take a + few minutes to start 😔. +

+
+
+ + +
+
+ + + +
+ +
+ Advanced Options + +
+ + + + 500 + + + + 0.00 + + + + 1.00 + + + + + 1.10 + + + +
+
+ +
+
+
+
+
+ +
+
+
+
+ + + +
+ +
+ +
+ +
+
+
+
+

Generation:

+
+ + + No output yet +
+
+
+
+
+

Examples:

+ + + + + +
+
+
+ + diff --git a/candle-wasm-examples/moondream/moondreamWorker.js b/candle-wasm-examples/moondream/moondreamWorker.js new file mode 100644 index 00000000..cf85053f --- /dev/null +++ b/candle-wasm-examples/moondream/moondreamWorker.js @@ -0,0 +1,201 @@ +import init, { Model } from "./build/m.js"; + +async function fetchArrayBuffer(url, cacheModel = true) { + if (!cacheModel) + return new Uint8Array(await (await fetch(url)).arrayBuffer()); + const cacheName = "moondream-candle-cache"; + const cache = await caches.open(cacheName); + const cachedResponse = await cache.match(url); + if (cachedResponse) { + const data = await cachedResponse.arrayBuffer(); + return new Uint8Array(data); + } + const res = await fetch(url, { cache: "force-cache" }); + cache.put(url, res.clone()); + return new Uint8Array(await res.arrayBuffer()); +} + +async function concatenateArrayBuffers(urls) { + const arrayBuffers = await Promise.all( + urls.map((url) => fetchArrayBuffer(url)) + ); + + let totalLength = arrayBuffers.reduce( + (acc, arrayBuffer) => acc + arrayBuffer.byteLength, + 0 + ); + let concatenatedBuffer = new Uint8Array(totalLength); + + let offset = 0; + arrayBuffers.forEach((buffer) => { + concatenatedBuffer.set(new Uint8Array(buffer), offset); + offset += buffer.byteLength; + }); + return concatenatedBuffer; +} + +class Moondream { + static imageArrayHash = {}; + static instance = {}; + static currentModelID = null; + + static async getInstance(weightsURL, modelID, tokenizerURL, quantized) { + // load individual modelID only once + if (!this.instance[modelID]) { + await init(); + + self.postMessage({ status: "loading", message: "Loading Model" }); + const [weightsArrayU8, tokenizerArrayU8] = await Promise.all([ + weightsURL instanceof Array + ? concatenateArrayBuffers(weightsURL) + : fetchArrayBuffer(weightsURL), + fetchArrayBuffer(tokenizerURL), + ]); + + this.instance[modelID] = new Model( + weightsArrayU8, + tokenizerArrayU8, + quantized + ); + } + this.currentModelID = modelID; + return this.instance[modelID]; + } + + // Remove the modelID parameter from setImageEmbeddings + static setImageEmbeddings(imageArrayU8) { + // check if image embeddings are already set for this image and model + const imageArrayHash = this.getSimpleHash(imageArrayU8); + if ( + this.imageArrayHash[this.currentModelID] === imageArrayHash && + this.instance[this.currentModelID] + ) { + self.postMessage({ + status: "embedding", + message: "Embeddings Already Set", + }); + return; + } + this.imageArrayHash[this.currentModelID] = imageArrayHash; + this.instance[this.currentModelID].set_image_embeddings(imageArrayU8); + self.postMessage({ status: "embedding", message: "Embeddings Set" }); + } + + static getSimpleHash(imageArrayU8) { + // get simple hash of imageArrayU8 + let imageArrayHash = 0; + for (let i = 0; i < imageArrayU8.length; i += 100) { + imageArrayHash ^= imageArrayU8[i]; + } + return imageArrayHash.toString(16); + } +} + +let controller = null; +self.addEventListener("message", (event) => { + if (event.data.command === "start") { + controller = new AbortController(); + generate(event.data); + } else if (event.data.command === "abort") { + controller.abort(); + } +}); + +async function generate(data) { + const { + weightsURL, + modelID, + tokenizerURL, + quantized, + imageURL, + prompt, + seed, + temp, + top_p, + repeatPenalty, + maxSeqLen, + verbose_prompt, + } = data; + try { + self.postMessage({ status: "loading", message: "Starting Moondream" }); + const model = await Moondream.getInstance( + weightsURL, + modelID, + tokenizerURL, + quantized + ); + + self.postMessage({ status: "loading", message: "Initializing model" }); + + self.postMessage({ status: "loading", message: "Loading Image" }); + const imageArrayU8 = await fetchArrayBuffer(imageURL, false); + + self.postMessage({ status: "embedding", message: "Creating Embeddings" }); + Moondream.setImageEmbeddings(imageArrayU8); + self.postMessage({ + status: "complete-embedding", + message: "Embeddings Complete", + }); + const { token, token_id } = model.init_with_image_prompt({ + prompt, + seed: BigInt(seed), + temp: parseFloat(temp), + top_p: parseFloat(top_p), + repeat_penalty: parseFloat(repeatPenalty), + repeat_last_n: 64, + verbose_prompt, + }); + + const seq_len = 2048; + + let sentence = token; + let maxTokens = maxSeqLen ? maxSeqLen : seq_len - prompt.length - 1; + let startTime = performance.now(); + let tokensCount = 0; + while (tokensCount < maxTokens) { + await new Promise(async (resolve) => { + if (controller && controller.signal.aborted) { + console.log("Aborted"); + self.postMessage({ + status: "aborted", + message: "Aborted", + output: prompt + sentence, + }); + return; + } + const { token, token_id } = await model.next_token(); + if (token_id === 50256) { + // <|endoftext|> + self.postMessage({ + status: "complete", + message: "complete", + output: prompt + sentence, + }); + return; + } + const tokensSec = + ((tokensCount + 1) / (performance.now() - startTime)) * 1000; + + sentence += token; + self.postMessage({ + status: "generating", + message: "Generating token", + token: token, + sentence: sentence, + totalTime: performance.now() - startTime, + tokensSec, + prompt: prompt, + }); + setTimeout(resolve, 0); + }); + tokensCount++; + } + self.postMessage({ + status: "complete", + message: "complete", + output: prompt + sentence, + }); + } catch (e) { + self.postMessage({ error: e }); + } +} diff --git a/candle-wasm-examples/moondream/src/bin/m.rs b/candle-wasm-examples/moondream/src/bin/m.rs new file mode 100644 index 00000000..2af6c0d2 --- /dev/null +++ b/candle-wasm-examples/moondream/src/bin/m.rs @@ -0,0 +1,279 @@ +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::{ + generation::LogitsProcessor, + models::{moondream, quantized_moondream}, +}; +use candle_wasm_example_moondream::console_log; +use js_sys::Date; +use serde::{Deserialize, Serialize}; +use tokenizers::Tokenizer; +use wasm_bindgen::prelude::*; + +enum SelectedModel { + Moondream(moondream::Model), + Quantized(quantized_moondream::Model), +} + +#[wasm_bindgen] +pub struct Model { + model: SelectedModel, + tokenizer: Tokenizer, + logits_processor: LogitsProcessor, + tokens: Vec, + repeat_penalty: f32, + repeat_last_n: usize, + index: usize, + bos_token: Option, + image_embeddings: Option, +} + +#[derive(Serialize, Deserialize)] +struct Output { + token: String, + token_id: u32, +} +#[derive(Serialize, Deserialize)] +struct InitInput { + prompt: String, + seed: u64, + temp: f64, + top_p: f64, + repeat_penalty: f32, + repeat_last_n: usize, + verbose_prompt: bool, +} + +#[wasm_bindgen] +impl Model { + #[wasm_bindgen(constructor)] + pub fn load(weights: Vec, tokenizer: Vec, quantized: bool) -> Result { + console_error_panic_hook::set_once(); + console_log!("loading model"); + let device = Device::Cpu; + let config = moondream::Config::v2(); + + console_log!("config loaded in {:?}", Date::now()); + let tokenizer = + Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; + let start = Date::now(); + console_log!("weights len: {:?}", weights.len()); + let model = if quantized { + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer( + &weights, &device, + )?; + console_log!("weights loaded"); + let model = quantized_moondream::Model::new(&config, vb)?; + SelectedModel::Quantized(model) + } else { + let device = &Device::Cpu; + let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?; + let model = moondream::Model::new(&config, vb)?; + SelectedModel::Moondream(model) + }; + console_log!("model loaded in {:?}s", (Date::now() - start) / 1000.); + let logits_processor = LogitsProcessor::new(299792458, None, None); + Ok(Self { + model, + tokenizer, + tokens: vec![], + logits_processor, + repeat_penalty: 1., + repeat_last_n: 64, + bos_token: None, + image_embeddings: None, + index: 0, + }) + } + + pub fn set_image_embeddings(&mut self, image: Vec) -> Result<(), JsError> { + let device = Device::Cpu; + + console_log!("loading image as tensor"); + let start = Date::now(); + let image: Tensor = self.load_image(image)?.to_device(&device)?; + console_log!("image loaded in {:?}s", (Date::now() - start) / 1000.); + let start = Date::now(); + let image_embeds = &image.unsqueeze(0)?; + let image_embeds = match &self.model { + SelectedModel::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?, + SelectedModel::Quantized(ref m) => image_embeds.apply(m.vision_encoder())?, + }; + console_log!( + "loaded and encoded the image {image:?} in {:?}", + (Date::now() - start) / 1000. + ); + self.image_embeddings = Some(image_embeds); + Ok(()) + } + + #[wasm_bindgen] + pub fn init_with_image_prompt(&mut self, input: JsValue) -> Result { + let InitInput { + prompt, + seed, + temp, + top_p, + repeat_penalty, + repeat_last_n, + verbose_prompt, + } = serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; + + let device = Device::Cpu; + let prompt = format!("\n\nQuestion: {0}\n\nAnswer:", prompt); + match &mut self.model { + SelectedModel::Moondream(m) => m.text_model.clear_kv_cache(), + SelectedModel::Quantized(m) => m.text_model.clear_kv_cache(), + }; + + let temp = if temp <= 0. { None } else { Some(temp) }; + let top_p = if top_p <= 0. || top_p >= 1. { + None + } else { + Some(top_p) + }; + self.logits_processor = LogitsProcessor::new(seed, temp, top_p); + self.repeat_penalty = repeat_penalty; + self.repeat_last_n = repeat_last_n; + self.tokens.clear(); + self.index = 0; + + // Moondream tokenizer bos_token is "<|endoftext|>" + // https://huggingface.co/vikhyatk/moondream2/blob/main/special_tokens_map.json + let special_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { + Some(token) => *token, + None => return Err(JsError::new("BOS token not found in the tokenizer.")), + }; + + self.bos_token = Some(Tensor::new(&[special_token], &device)?.unsqueeze(0)?); + + let tokens = self + .tokenizer + .encode(prompt, true) + .map_err(|m| JsError::new(&m.to_string()))?; + + if tokens.is_empty() { + return Err(JsError::new( + "Empty prompts are not supported in the Moondream model.", + )); + } + + if verbose_prompt { + for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) { + let token = token.replace('▁', " ").replace("<0x0A>", "\n"); + println!("{id:7} -> '{token}'"); + } + } + let tokens = tokens.get_ids().to_vec(); + let text = match self.process(&tokens) { + Ok(text) => text, + Err(_e) => { + console_log!("error decoding token"); + Output { + token: "".to_string(), + token_id: 0, + } + } + }; + Ok(serde_wasm_bindgen::to_value(&text)?) + } + #[wasm_bindgen] + pub fn next_token(&mut self) -> Result { + let last_token = *self.tokens.last().unwrap(); + let text = match self.process(&[last_token]) { + Ok(text) => text, + Err(_e) => { + console_log!("error decoding token"); + Output { + token: "".to_string(), + token_id: 0, + } + } + }; + Ok(serde_wasm_bindgen::to_value(&text)?) + } +} +impl Model { + fn load_image(&self, image: Vec) -> Result { + let img = image::io::Reader::new(std::io::Cursor::new(image)) + .with_guessed_format()? + .decode() + .map_err(|e| JsError::new(&e.to_string()))? + .resize_to_fill(378, 378, image::imageops::FilterType::Triangle); // Adjusted to 378x378 + let img = img.to_rgb8(); + let data = img.into_raw(); + let data = Tensor::from_vec(data, (378, 378, 3), &Device::Cpu)?.permute((2, 0, 1))?; + let mean = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?; + let std = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?; + (data.to_dtype(candle::DType::F32)? / 255.)? + .broadcast_sub(&mean)? + .broadcast_div(&std) + .map_err(|e| JsError::new(&e.to_string())) + } +} + +impl Model { + fn process(&mut self, tokens: &[u32]) -> Result { + let image_embeddings = match &self.image_embeddings { + Some(embeddings) => embeddings, + None => return Err(JsError::new("Image embeddings are not set.")), + }; + let bos_token = match &self.bos_token { + Some(token) => token, + None => return Err(JsError::new("BOS token is not set.")), + }; + let device = Device::Cpu; + let context_size = if self.index > 0 { 1 } else { tokens.len() }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; + let logits = if self.index > 0 { + match self.model { + SelectedModel::Moondream(ref mut model) => model.text_model.forward(&input)?, + SelectedModel::Quantized(ref mut model) => model.text_model.forward(&input)?, + } + } else { + match self.model { + SelectedModel::Moondream(ref mut model) => { + model + .text_model + .forward_with_img(bos_token, &input, image_embeddings)? + } + SelectedModel::Quantized(ref mut model) => { + model + .text_model + .forward_with_img(bos_token, &input, image_embeddings)? + } + } + }; + + let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + let next_token = self.logits_processor.sample(&logits)?; + self.tokens.push(next_token); + let token = match self.tokenizer.decode(&[next_token], true) { + Ok(token) => token, + Err(e) => { + console_log!("error decoding token: {:?}", e); + "".to_string() + } + }; + self.index += 1; + Ok(Output { + token, + token_id: next_token, + }) + } +} + +fn main() { + console_error_panic_hook::set_once(); +} diff --git a/candle-wasm-examples/moondream/src/lib.rs b/candle-wasm-examples/moondream/src/lib.rs new file mode 100644 index 00000000..cb15633c --- /dev/null +++ b/candle-wasm-examples/moondream/src/lib.rs @@ -0,0 +1,16 @@ +use wasm_bindgen::prelude::*; + +#[wasm_bindgen] +extern "C" { + // Use `js_namespace` here to bind `console.log(..)` instead of just + // `log(..)` + #[wasm_bindgen(js_namespace = console)] + pub fn log(s: &str); +} + +#[macro_export] +macro_rules! console_log { + // Note that this is using the `log` function imported above during + // `bare_bones` + ($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string())) +} diff --git a/tensor-tools/Cargo.toml b/tensor-tools/Cargo.toml new file mode 100644 index 00000000..eecd7e43 --- /dev/null +++ b/tensor-tools/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "tensor-tools" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true + +[dependencies] +anyhow = { workspace = true } +candle = { workspace = true } +clap = { workspace = true } +rayon = { workspace = true } +safetensors = { workspace = true } diff --git a/candle-core/examples/tensor-tools.rs b/tensor-tools/src/main.rs similarity index 69% rename from candle-core/examples/tensor-tools.rs rename to tensor-tools/src/main.rs index 1801ac58..ad351171 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/tensor-tools/src/main.rs @@ -1,5 +1,5 @@ -use candle_core::quantized::{gguf_file, GgmlDType, QTensor}; -use candle_core::{Device, Result}; +use candle::quantized::{gguf_file, GgmlDType, QTensor}; +use candle::{Device, Result}; use clap::{Parser, Subcommand, ValueEnum}; use rayon::prelude::*; @@ -117,6 +117,24 @@ enum Command { verbose: bool, }, + Print { + file: std::path::PathBuf, + + names: Vec, + + /// The file format to use, if unspecified infer from the file extension. + #[arg(long, value_enum)] + format: Option, + + /// Print the whole content of each tensor. + #[arg(long)] + full: bool, + + /// Line width for printing the tensors. + #[arg(long)] + line_width: Option, + }, + Quantize { /// The input file(s), in safetensors format. in_file: Vec, @@ -150,6 +168,105 @@ struct Args { command: Command, } +fn run_print( + file: &std::path::PathBuf, + names: Vec, + format: Option, + full: bool, + line_width: Option, + device: &Device, +) -> Result<()> { + if full { + candle::display::set_print_options_full(); + } + if let Some(line_width) = line_width { + candle::display::set_line_width(line_width) + } + let format = match format { + Some(format) => format, + None => match Format::infer(file) { + Some(format) => format, + None => { + println!( + "{file:?}: cannot infer format from file extension, use the --format flag" + ); + return Ok(()); + } + }, + }; + match format { + Format::Npz => { + let tensors = candle::npy::NpzTensors::new(file)?; + for name in names.iter() { + println!("==== {name} ===="); + match tensors.get(name)? { + Some(tensor) => println!("{tensor}"), + None => println!("not found"), + } + } + } + Format::Safetensors => { + use candle::safetensors::Load; + let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? }; + let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect(); + for name in names.iter() { + println!("==== {name} ===="); + match tensors.get(name) { + Some(tensor_view) => { + let tensor = tensor_view.load(device)?; + println!("{tensor}") + } + None => println!("not found"), + } + } + } + Format::Pth => { + let pth_file = candle::pickle::PthTensors::new(file, None)?; + for name in names.iter() { + println!("==== {name} ===="); + match pth_file.get(name)? { + Some(tensor) => { + println!("{tensor}") + } + None => println!("not found"), + } + } + } + Format::Pickle => { + candle::bail!("pickle format is not supported for print") + } + Format::Ggml => { + let mut file = std::fs::File::open(file)?; + let content = candle::quantized::ggml_file::Content::read(&mut file, device)?; + for name in names.iter() { + println!("==== {name} ===="); + match content.tensors.get(name) { + Some(tensor) => { + let tensor = tensor.dequantize(device)?; + println!("{tensor}") + } + None => println!("not found"), + } + } + } + Format::Gguf => { + let mut file = std::fs::File::open(file)?; + let content = gguf_file::Content::read(&mut file)?; + for name in names.iter() { + println!("==== {name} ===="); + match content.tensor(&mut file, name, device) { + Ok(tensor) => { + let tensor = tensor.dequantize(device)?; + println!("{tensor}") + } + Err(_) => println!("not found"), + } + } + } + } + Ok(()) +} + fn run_ls( file: &std::path::PathBuf, format: Option, @@ -170,7 +287,7 @@ fn run_ls( }; match format { Format::Npz => { - let tensors = candle_core::npy::NpzTensors::new(file)?; + let tensors = candle::npy::NpzTensors::new(file)?; let mut names = tensors.names(); names.sort(); for name in names { @@ -182,12 +299,12 @@ fn run_ls( } } Format::Safetensors => { - let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? }; + let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? }; let mut tensors = tensors.tensors(); tensors.sort_by(|a, b| a.0.cmp(&b.0)); for (name, view) in tensors.iter() { let dtype = view.dtype(); - let dtype = match candle_core::DType::try_from(dtype) { + let dtype = match candle::DType::try_from(dtype) { Ok(dtype) => format!("{dtype:?}"), Err(_) => format!("{dtype:?}"), }; @@ -196,7 +313,7 @@ fn run_ls( } } Format::Pth => { - let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?; + let mut tensors = candle::pickle::read_pth_tensor_info(file, verbose, None)?; tensors.sort_by(|a, b| a.name.cmp(&b.name)); for tensor_info in tensors.iter() { println!( @@ -213,7 +330,7 @@ fn run_ls( Format::Pickle => { let file = std::fs::File::open(file)?; let mut reader = std::io::BufReader::new(file); - let mut stack = candle_core::pickle::Stack::empty(); + let mut stack = candle::pickle::Stack::empty(); stack.read_loop(&mut reader)?; for (i, obj) in stack.stack().iter().enumerate() { println!("{i} {obj:?}"); @@ -221,7 +338,7 @@ fn run_ls( } Format::Ggml => { let mut file = std::fs::File::open(file)?; - let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?; + let content = candle::quantized::ggml_file::Content::read(&mut file, device)?; let mut tensors = content.tensors.into_iter().collect::>(); tensors.sort_by(|a, b| a.0.cmp(&b.0)); for (name, qtensor) in tensors.iter() { @@ -257,7 +374,7 @@ fn run_quantize_safetensors( let mut out_file = std::fs::File::create(out_file)?; let mut tensors = std::collections::HashMap::new(); for in_file in in_files.iter() { - let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?; + let in_tensors = candle::safetensors::load(in_file, &Device::Cpu)?; tensors.extend(in_tensors) } println!("tensors: {}", tensors.len()); @@ -299,7 +416,7 @@ fn run_dequantize( let tensor = tensor.dequantize(device)?; tensors.insert(tensor_name.to_string(), tensor); } - candle_core::safetensors::save(&tensors, out_file)?; + candle::safetensors::save(&tensors, out_file)?; Ok(()) } @@ -311,11 +428,11 @@ fn run_quantize( device: &Device, ) -> Result<()> { if in_files.is_empty() { - candle_core::bail!("no specified input files") + candle::bail!("no specified input files") } if let Some(extension) = out_file.extension() { if extension == "safetensors" { - candle_core::bail!("the generated file cannot use the safetensors extension") + candle::bail!("the generated file cannot use the safetensors extension") } } if let Some(extension) = in_files[0].extension() { @@ -325,7 +442,7 @@ fn run_quantize( } if in_files.len() != 1 { - candle_core::bail!("only a single in-file can be used when quantizing gguf files") + candle::bail!("only a single in-file can be used when quantizing gguf files") } // Open the out file early so as to fail directly on missing directories etc. @@ -377,6 +494,13 @@ fn main() -> anyhow::Result<()> { run_ls(file, format.clone(), verbose, &device)? } } + Command::Print { + file, + names, + format, + full, + line_width, + } => run_print(&file, names, format, full, line_width, &device)?, Command::Quantize { in_file, out_file,