From 0de079522023002b74127a7767a13d6c1e8b007c Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 13 Feb 2024 18:11:17 +0100 Subject: [PATCH] Qmetal tweaks (#1704) * Add the dummy qmetal backend. * Fix the metal compilation. --- candle-core/src/quantized/dummy_metal.rs | 43 ++++++++++ candle-core/src/quantized/metal.rs | 95 +++++++++++++++++++-- candle-core/src/quantized/mod.rs | 103 +++-------------------- 3 files changed, 141 insertions(+), 100 deletions(-) create mode 100644 candle-core/src/quantized/dummy_metal.rs diff --git a/candle-core/src/quantized/dummy_metal.rs b/candle-core/src/quantized/dummy_metal.rs new file mode 100644 index 00000000..96f91c50 --- /dev/null +++ b/candle-core/src/quantized/dummy_metal.rs @@ -0,0 +1,43 @@ +#![allow(unused)] +use super::GgmlDType; +use crate::{Error, MetalDevice, MetalStorage, Result}; + +pub struct QMetalStorage { + dtype: GgmlDType, + device: MetalDevice, +} + +impl QMetalStorage { + pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + pub fn dtype(&self) -> GgmlDType { + self.dtype + } + + pub fn device(&self) -> &MetalDevice { + &self.device + } + + pub fn dequantize(&self, _elem_count: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + + pub fn storage_size_in_bytes(&self) -> usize { + 0 + } + + pub fn fwd( + &self, + _self_shape: &crate::Shape, + _storage: &MetalStorage, + _layout: &crate::Layout, + ) -> Result<(MetalStorage, crate::Shape)> { + Err(Error::NotCompiledWithMetalSupport) + } +} diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index 94105327..5cdfe6ab 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -1,5 +1,6 @@ use super::{GgmlDType, QStorage}; -use crate::{DType, MetalDevice, MetalStorage, Result}; +use crate::backend::BackendStorage; +use crate::{DType, MetalDevice, MetalStorage, Result, Shape}; use metal::Buffer; use std::sync::Arc; @@ -10,6 +11,16 @@ pub struct QMetalStorage { } impl QMetalStorage { + pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result { + let size = elem_count * dtype.type_size() / dtype.block_size(); + let buffer = device.allocate_zeros(size)?; + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) + } + pub fn dtype(&self) -> GgmlDType { self.dtype } @@ -22,14 +33,6 @@ impl QMetalStorage { &self.buffer } - pub fn new(buffer: Arc, device: MetalDevice, dtype: GgmlDType) -> Self { - Self { - device, - buffer, - dtype, - } - } - pub fn dequantize(&self, elem_count: usize) -> Result { let buffer = self.device.new_buffer_managed(self.buffer.length())?; let command_buffer = self.device.command_buffer()?; @@ -134,6 +137,59 @@ impl QMetalStorage { self.buffer = buffer; Ok(()) } + + pub fn storage_size_in_bytes(&self) -> usize { + self.buffer.length() as usize + } + + pub fn fwd( + &self, + self_shape: &Shape, + storage: &MetalStorage, + layout: &crate::Layout, + ) -> Result<(MetalStorage, Shape)> { + use crate::MetalError; + + if !layout.is_contiguous() { + crate::bail!("input tensor is not contiguous {layout:?}") + } + let src_shape = layout.shape(); + // self is transposed so n is first then k. + if src_shape.rank() < 2 { + crate::bail!("input tensor has only one dimension {layout:?}") + } + let (n, k) = self_shape.dims2()?; + let mut dst_shape = src_shape.dims().to_vec(); + + let (b, m) = match dst_shape.len() { + 3 => (dst_shape[0], dst_shape[1]), + 2 => (1, dst_shape[0]), + n => crate::bail!("Invalid rank {n} for quantized matmul metal"), + }; + let last_k = dst_shape.pop().unwrap(); + if last_k != k { + crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape) + } + dst_shape.push(n); + let dst_shape = Shape::from(dst_shape); + let device = storage.device().clone(); + let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?; + let command_buffer = device.command_buffer()?; + candle_metal_kernels::call_quantized_matmul_t( + device.device(), + &command_buffer, + device.kernels(), + self.dtype.into(), + (b, m, n, k), + storage.buffer(), + layout.start_offset() * storage.dtype().size_in_bytes(), + &self.buffer, + &dst, + ) + .map_err(MetalError::from)?; + let dst_storage = crate::MetalStorage::new(dst, device, DType::F32); + Ok((dst_storage, dst_shape)) + } } pub fn load_quantized_metal( @@ -155,3 +211,24 @@ fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { let slice = unsafe { std::slice::from_raw_parts(ptr, n) }; slice.to_vec() } + +impl From for candle_metal_kernels::GgmlDType { + fn from(value: GgmlDType) -> Self { + match value { + GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0, + GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1, + GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0, + GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1, + GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0, + GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1, + GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K, + GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K, + GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K, + GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K, + GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K, + GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, + GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, + GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, + } + } +} diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 366552d9..d14b2dc2 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,16 +1,19 @@ -#[cfg(feature = "metal")] -use crate::{backend::BackendStorage, DType}; use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor}; use k_quants::*; use std::borrow::Cow; #[cfg(target_feature = "avx")] pub mod avx; +mod dummy_metal; pub mod ggml_file; pub mod gguf_file; pub mod k_quants; #[cfg(feature = "metal")] pub mod metal; +#[cfg(not(feature = "metal"))] +mod metal { + pub use super::dummy_metal::*; +} #[cfg(target_feature = "neon")] pub mod neon; #[cfg(target_feature = "simd128")] @@ -32,19 +35,9 @@ impl Device { let storage = dtype.cpu_zeros(elem_count); Ok(QStorage::Cpu(storage)) } - #[cfg(feature = "metal")] Device::Metal(metal) => { - let size = elem_count * dtype.type_size() / dtype.block_size(); - let buffer = metal.allocate_zeros(size)?; - Ok(QStorage::Metal(metal::QMetalStorage::new( - buffer, - metal.clone(), - dtype, - ))) - } - #[cfg(not(feature = "metal"))] - Device::Metal(_metal) => { - crate::bail!("Metal feature not activated"); + let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?; + Ok(QStorage::Metal(storage)) } Device::Cuda(_cuda) => { crate::bail!("Cuda ggml quantization not supported"); @@ -55,7 +48,6 @@ impl Device { pub enum QStorage { Cpu(Box), - #[cfg(feature = "metal")] Metal(metal::QMetalStorage), } @@ -63,7 +55,6 @@ impl QStorage { fn block_size(&self) -> usize { match self { QStorage::Cpu(storage) => storage.block_size(), - #[cfg(feature = "metal")] QStorage::Metal(storage) => storage.dtype().block_size(), } } @@ -71,7 +62,6 @@ impl QStorage { fn dtype(&self) -> GgmlDType { match self { QStorage::Cpu(storage) => storage.dtype(), - #[cfg(feature = "metal")] QStorage::Metal(storage) => storage.dtype(), } } @@ -79,7 +69,6 @@ impl QStorage { fn device(&self) -> Device { match self { QStorage::Cpu(_storage) => Device::Cpu, - #[cfg(feature = "metal")] QStorage::Metal(storage) => Device::Metal(storage.device().clone()), } } @@ -87,8 +76,7 @@ impl QStorage { fn size_in_bytes(&self) -> usize { match self { QStorage::Cpu(storage) => storage.storage_size_in_bytes(), - #[cfg(feature = "metal")] - QStorage::Metal(storage) => storage.buffer().length() as usize, + QStorage::Metal(storage) => storage.storage_size_in_bytes(), } } @@ -97,7 +85,6 @@ impl QStorage { (QStorage::Cpu(storage), Storage::Cpu(src)) => { storage.from_float(src.as_slice::()?)?; } - #[cfg(feature = "metal")] (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?, _ => crate::bail!("Invalid dequantize storage locations do not match"), } @@ -107,7 +94,6 @@ impl QStorage { fn dequantize(&self, elem_count: usize) -> Result { match self { QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)), - #[cfg(feature = "metal")] QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)), } } @@ -120,7 +106,6 @@ impl QStorage { let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; Ok(Cow::from(data)) } - #[cfg(feature = "metal")] QStorage::Metal(_storage) => { crate::bail!("not implemented"); } @@ -439,8 +424,7 @@ impl crate::CustomOp1 for QTensor { #[allow(clippy::infallible_destructuring_match)] let self_storage = match &self.storage { QStorage::Cpu(storage) => storage, - #[cfg(feature = "metal")] - _ => crate::bail!("Invalid storage"), + QStorage::Metal(_) => crate::bail!("Invalid storage"), }; let slice = storage.as_slice::()?; let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; @@ -449,79 +433,16 @@ impl crate::CustomOp1 for QTensor { Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) } - #[cfg(feature = "metal")] fn metal_fwd( &self, storage: &crate::MetalStorage, layout: &crate::Layout, ) -> Result<(crate::MetalStorage, Shape)> { - use crate::MetalError; - - if !layout.is_contiguous() { - crate::bail!("input tensor is not contiguous {layout:?}") - } - let src_shape = layout.shape(); - // self is transposed so n is first then k. - if src_shape.rank() < 2 { - crate::bail!("input tensor has only one dimension {layout:?}") - } - let (n, k) = self.shape.dims2()?; - let mut dst_shape = src_shape.dims().to_vec(); - - let (b, m) = match dst_shape.len() { - 3 => (dst_shape[0], dst_shape[1]), - 2 => (1, dst_shape[0]), - n => crate::bail!("Invalid rank {n} for quantized matmul metal"), - }; - let last_k = dst_shape.pop().unwrap(); - if last_k != k { - crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape) - } - dst_shape.push(n); - let dst_shape = Shape::from(dst_shape); - let device = storage.device().clone(); - let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?; - let (buffer, dtype) = match &self.storage { - QStorage::Metal(metal) => (metal.buffer(), metal.dtype()), + let self_storage = match &self.storage { + QStorage::Metal(metal) => metal, _ => unreachable!("Cannot call metal matmul on non metal QTensor"), }; - let command_buffer = device.command_buffer()?; - candle_metal_kernels::call_quantized_matmul_t( - device.device(), - &command_buffer, - device.kernels(), - dtype.into(), - (b, m, n, k), - storage.buffer(), - layout.start_offset() * storage.dtype().size_in_bytes(), - buffer, - &dst, - ) - .map_err(MetalError::from)?; - let dst_storage = crate::MetalStorage::new(dst, device, DType::F32); - Ok((dst_storage, dst_shape)) - } -} - -#[cfg(feature = "metal")] -impl From for candle_metal_kernels::GgmlDType { - fn from(value: GgmlDType) -> Self { - match value { - GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0, - GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1, - GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0, - GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1, - GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0, - GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1, - GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K, - GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K, - GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K, - GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K, - GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K, - GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, - GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, - GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, - } + self_storage.fwd(&self.shape, storage, layout) } }