From 7161002a3410ed9a5531dc470e976ef295d5360e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 2 Nov 2023 15:32:28 +0100 Subject: [PATCH] Finished scaffolding, lots of TODOs - Most kernels just copy themselfs to get the shapes correct - Matmul works only in 1 case and simply empty allocates otherwise - Logits and randomized to make the demo finish itself. Performance is quite bad (30ms/token), but lot's of prints and allocs and some actual sending to metal. Couln't get it super high by removing the obvious blockers (println + the actual running matmuls). Allocations takes between 1us and 100us and seems very stable, Maybe metal doesn't really have a smart allocator and we'll need to own it. --- candle-core/Cargo.toml | 2 +- candle-core/src/metal_backend.rs | 173 ++++++++++++++---- candle-core/src/op.rs | 2 +- candle-core/src/quantized/mod.rs | 43 +++++ candle-examples/Cargo.toml | 1 + candle-examples/examples/quantized/main.rs | 10 +- candle-metal-kernels/Cargo.toml | 17 +- candle-nn/Cargo.toml | 1 + candle-nn/src/ops.rs | 10 + candle-transformers/Cargo.toml | 1 + .../src/models/quantized_llama.rs | 4 + 11 files changed, 212 insertions(+), 52 deletions(-) diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index f840842f..69bf47cf 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -13,7 +13,7 @@ readme = "README.md" accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true } -candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.0.1", optional = true } +candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index c618cba3..982e5ee1 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -1,15 +1,15 @@ use crate::backend::{BackendDevice, BackendStorage}; +use crate::bail; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; -pub use candle_metal_kernels; +use candle_metal_kernels; use core::mem; use half::{bf16, f16}; use metal; -use metal::mps::matrix::{MatrixMultiplication, Matrix, MatrixDescriptor}; +use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication}; use metal::mps::{Float32, MPSDataType}; use metal::MTLResourceOptions; -use crate::bail; /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -72,11 +72,16 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result { - todo!() + match self.dtype{ + DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(self.buffer.length() as usize / 4))), + dtype => todo!("Unsupported dtype {dtype:?}") + } } fn affine(&self, _: &Layout, _: f64, _: f64) -> Result { - todo!() + println!("TODO Affine"); + Ok(self.clone()) + // todo!() } fn powf(&self, _: &Layout, _: f64) -> Result { @@ -88,7 +93,9 @@ impl BackendStorage for MetalStorage { } fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result { - todo!() + println!("TODO reduce_op"); + Ok(self.clone()) + // todo!() } fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { @@ -100,15 +107,22 @@ impl BackendStorage for MetalStorage { } fn unary_impl(&self, _: &Layout) -> Result { - todo!() + // todo!() + // TODO + println!("TODO {:?}", B::NAME); + Ok(self.clone()) } fn binary_impl(&self, _: &Self, _: &Layout, _: &Layout) -> Result { - todo!() + println!("TODO Binary {:?}", B::NAME); + Ok(self.clone()) + // todo!() } - fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result { - todo!() + fn where_cond(&self, _: &Layout, rhs: &Self, _: &Layout, _: &Self, _: &Layout) -> Result { + println!("TODO where_cond"); + Ok(rhs.clone()) + // todo!() } fn conv1d( @@ -174,7 +188,9 @@ impl BackendStorage for MetalStorage { } fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result { - todo!() + println!("TODO Index select"); + Ok(self.clone()) + // todo!() } fn index_add( @@ -195,21 +211,96 @@ impl BackendStorage for MetalStorage { (b, m, n, k): (usize, usize, usize, usize), lhs_l: &Layout, rhs_l: &Layout, + ) -> Result { + let transpose_left = false; + let transpose_right = false; + let alpha = 1.0; + let beta = 0.0; + self.matmul_generic( + rhs, + (b, m, n, k), + lhs_l, + rhs_l, + transpose_left, + transpose_right, + alpha, + beta, + ) + } + + fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { + println!("TODO Copy strided"); + Ok(()) + } +} + +impl MetalStorage { + pub(crate) fn matmul_t( + &self, + rhs: &Self, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + let transpose_left = false; + let transpose_right = true; + let alpha = 1.0; + let beta = 0.0; + self.matmul_generic( + rhs, + (b, m, n, k), + lhs_l, + rhs_l, + transpose_left, + transpose_right, + alpha, + beta, + ) + } + pub(crate) fn matmul_generic( + &self, + rhs: &Self, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + transpose_left: bool, + transpose_right: bool, + alpha: f64, + beta: f64, ) -> Result { let elem_count = b * m * n; match (self.dtype, rhs.dtype) { (DType::F32, DType::F32) => { if b != 1 { - bail!("Didn't implemented strided matmul yet"); + println!("TODO implement batched matmul for B={b}"); + // bail!("Didn't implemented strided matmul yet"); + let out_buffer = self.device.new_buffer( + (elem_count * mem::size_of::()) as u64, + MTLResourceOptions::empty(), + ); + return Ok(Self { + buffer: out_buffer, + device: self.device.clone(), + dtype: self.dtype(), + }); } if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { - bail!("Didn't implemented non contiguous matmul yet"); + println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous()); + let out_buffer = self.device.new_buffer( + (elem_count * mem::size_of::()) as u64, + MTLResourceOptions::empty(), + ); + return Ok(Self { + buffer: out_buffer, + device: self.device.clone(), + dtype: self.dtype(), + }); } let out_buffer = self.device.new_buffer( (elem_count * mem::size_of::()) as u64, MTLResourceOptions::empty(), ); - let m : u64 = m.try_into().expect("usize should fit u64"); + let m: u64 = m.try_into().expect("usize should fit u64"); let n: u64 = n.try_into().expect("usize should fit u64"); let k: u64 = k.try_into().expect("usize should fit u64"); // Create descriptors @@ -220,6 +311,9 @@ impl BackendStorage for MetalStorage { let result_descriptor = MatrixDescriptor::init_single(m, n, n * Float32::SIZE, Float32::TYPE_ID); + println!("lhs {:?} {m} {k}", self.buffer.length()); + println!("rhs {:?} {k} {n}", rhs.buffer.length()); + println!("out {:?} {m} {n}", out_buffer.length()); // Create matrix objects let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor) @@ -231,12 +325,8 @@ impl BackendStorage for MetalStorage { let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor) .expect("Failed to create left matrix"); - - let transpose_left = false; - let transpose_right = false; - let alpha = 1.0; - let beta = 0.0; + println!("lhs {:?}", lhs_l.shape()); // Create kernel let matrix_multiplication = MatrixMultiplication::init( @@ -258,20 +348,15 @@ impl BackendStorage for MetalStorage { &right_matrix, &result_matrix, ); - Ok(Self{ - buffer: out_buffer, - device: self.device.clone(), - dtype: self.dtype(), - }) - + Ok(Self { + buffer: out_buffer, + device: self.device.clone(), + dtype: self.dtype(), + }) } _ => todo!("Unimplemented matmul for this pair"), } } - - fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { - todo!() - } } impl BackendDevice for MetalDevice { @@ -281,7 +366,11 @@ impl BackendDevice for MetalDevice { let device = metal::Device::all().swap_remove(ordinal); let _command_queue = device.new_command_queue(); let command_buffer = _command_queue.new_owned_command_buffer(); - Ok(Self { device, _command_queue, command_buffer }) + Ok(Self { + device, + _command_queue, + command_buffer, + }) } fn set_seed(&self, _seed: u64) -> Result<()> { @@ -296,12 +385,16 @@ impl BackendDevice for MetalDevice { self.device.registry_id() == rhs.device.registry_id() } - fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result { - todo!() + fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { + // TODO Is there a faster way ? + let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?; + self.storage_from_cpu_storage(&cpu_storage) } - fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result { - todo!() + fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { + // TODO Is there a faster way ? + let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; + self.storage_from_cpu_storage(&cpu_storage) } fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { @@ -350,11 +443,15 @@ impl BackendDevice for MetalDevice { }) } - fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { - todo!() + fn rand_uniform(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result { + // TODO is there a better way ? + let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?; + self.storage_from_cpu_storage(&cpu_storage) } - fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { - todo!() + fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result { + // TODO is there a better way ? + let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?; + self.storage_from_cpu_storage(&cpu_storage) } } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index bb2c0ee7..f25a60a2 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -182,7 +182,7 @@ pub trait CustomOp1 { _layout: &Layout, ) -> Result<(MetalStorage, Shape)> { Err(crate::Error::Metal( - format!("no cuda implementation for {}", self.name()).into(), + format!("no metal implementation for {}", self.name()).into(), )) } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 7c51e778..c680d677 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -315,6 +315,49 @@ impl crate::CustomOp1 for QTensor { )?; Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) } + + fn metal_fwd( + &self, + storage: &crate::MetalStorage, + layout: &crate::Layout, + ) -> Result<(crate::MetalStorage, Shape)> { + println!("TODO qmatmul"); + if !layout.is_contiguous() { + crate::bail!("input tensor is not contiguous {layout:?}") + } + let src_shape = layout.shape(); + // self is transposed so n is first then k. + let (n, k) = self.shape.dims2()?; + if src_shape.rank() < 2 { + crate::bail!("input tensor has only one dimension {layout:?}") + } + let mut dst_shape = src_shape.dims().to_vec(); + let last_k = dst_shape.pop().unwrap(); + if last_k != k { + crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape) + } + dst_shape.push(n); + let dst_shape = Shape::from(dst_shape); + // let storage = storage.as_slice::()?; + // let storage = + // &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; + let dst_storage = vec![0f32; dst_shape.elem_count()]; + // self.matmul_t( + // (dst_shape.elem_count() / n, k, n), + // storage, + // &mut dst_storage, + // )?; + let cpu_storage = crate::CpuStorage::F32(dst_storage); + use crate::backend::{BackendDevice, BackendStorage}; + if let Device::Metal(device) = &self.device{ + Ok(( + device.storage_from_cpu_storage(&cpu_storage)?, + dst_shape, + )) + }else{ + crate::bail!("qtensor not on metal device") + } + } } impl QMatMul { diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index b1913541..26d8db49 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -51,6 +51,7 @@ anyhow = { workspace = true } default = [] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] +metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"] cudnn = ["candle/cudnn"] flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 1fb32701..347e87a8 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -9,7 +9,7 @@ use std::io::Write; use tokenizers::Tokenizer; use candle::quantized::{ggml_file, gguf_file}; -use candle::{Device, Tensor}; +use candle::{Tensor}; use candle_transformers::generation::LogitsProcessor; use candle_transformers::models::quantized_llama as model; @@ -367,9 +367,11 @@ fn main() -> anyhow::Result<()> { let start_prompt_processing = std::time::Instant::now(); let mut next_token = { - let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?; + let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; let logits = model.forward(&input, 0)?; let logits = logits.squeeze(0)?; + // TODO Remove this once implementation is finished. + let logits = logits.ones_like()?; logits_processor.sample(&logits)? }; let prompt_dt = start_prompt_processing.elapsed(); @@ -380,7 +382,7 @@ fn main() -> anyhow::Result<()> { let start_post_prompt = std::time::Instant::now(); for index in 0..to_sample { - let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?; + let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; let logits = model.forward(&input, prompt_tokens.len() + index)?; let logits = logits.squeeze(0)?; let logits = if args.repeat_penalty == 1. { @@ -393,6 +395,8 @@ fn main() -> anyhow::Result<()> { &all_tokens[start_at..], )? }; + // TODO Remove this once implementation is finished. + let logits = logits.ones_like()?; next_token = logits_processor.sample(&logits)?; all_tokens.push(next_token); print_token(next_token, &tokenizer); diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 92bc12ff..b0238a1b 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -1,13 +1,12 @@ [package] name = "candle-metal-kernels" -version = "0.0.1" -edition = "2021" - -description = "Metal kernels for Candle" -repository = "https://github.com/huggingface/candle" -keywords = ["blas", "tensor", "machine-learning"] -categories = ["science"] -license = "MIT OR Apache-2.0" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true [dependencies] -metal = { workspace = true, optional = true} +metal = { workspace = true } diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index 4b1f7917..d4324e65 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -28,4 +28,5 @@ clap = { workspace = true } default = [] accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] +metal = ["candle/metal"] mkl = ["dep:intel-mkl-src", "candle/mkl"] diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index e9812108..3a6fdd39 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -190,6 +190,16 @@ impl candle::CustomOp1 for SoftmaxLastDim { device: dev.clone(), }; Ok((dst, layout.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + storage: &candle::MetalStorage, + layout: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + println!("TODO softmax-last-dim"); + Ok((storage.clone(), layout.shape().clone())) } } diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index e7290be6..53dcabef 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -28,5 +28,6 @@ wav = { workspace = true } default = [] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda"] +metal = ["candle/metal", "candle-nn/metal"] flash-attn = ["cuda", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"] diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 7b6480c0..3685d3de 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -112,6 +112,7 @@ impl LayerWeights { let q = self.attention_wq.forward(x)?; let k = self.attention_wk.forward(x)?; let v = self.attention_wv.forward(x)?; + // println!("Q {:?} K {:?} V {:?}", q.dtype(), k.dtype(), v.dtype()); let q = q .reshape((b_sz, seq_len, self.n_head, self.head_dim))? @@ -145,9 +146,12 @@ impl LayerWeights { let v = self.repeat_kv(v)?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + // println!("att {:?}", att.dtype()); let mask = mask.broadcast_as(att.shape())?; + // println!("mask {:?}", mask.dtype()); let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = candle_nn::ops::softmax_last_dim(&att)?; + // println!("att {:?} v {:?}", att.dtype(), v.dtype()); // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;