use super::{GgmlDType, QStorage}; use crate::backend::BackendStorage; use crate::{DType, MetalDevice, MetalStorage, Result, Shape}; use metal::Buffer; use std::sync::Arc; pub struct QMetalStorage { dtype: GgmlDType, device: MetalDevice, buffer: Arc, } impl QMetalStorage { pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result { let size = elem_count * dtype.type_size() / dtype.block_size(); let buffer = device.allocate_zeros(size)?; Ok(Self { buffer, device: device.clone(), dtype, }) } pub fn dtype(&self) -> GgmlDType { self.dtype } pub fn device(&self) -> &MetalDevice { &self.device } pub fn buffer(&self) -> &Buffer { &self.buffer } pub fn dequantize(&self, elem_count: usize) -> Result { let buffer = self.device.new_buffer_managed(self.buffer.length())?; let command_buffer = self.device.command_buffer()?; command_buffer.set_label("to_cpu"); let blit = command_buffer.new_blit_command_encoder(); blit.set_label("blit_to_cpu"); blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); blit.end_encoding(); self.device.wait_until_completed()?; let mut out = vec![0.0; elem_count]; match self.dtype { GgmlDType::F32 => { let vec: Vec = read_to_vec(&buffer, elem_count); use crate::quantized::k_quants::GgmlType; f32::to_float(&vec, &mut out)?; } GgmlDType::F16 => { let vec: Vec = read_to_vec(&buffer, elem_count); use crate::quantized::k_quants::GgmlType; half::f16::to_float(&vec, &mut out)?; } GgmlDType::Q4_0 => { let vec: Vec = read_to_vec(&buffer, elem_count); use crate::quantized::k_quants::GgmlType; crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?; } GgmlDType::Q4_1 => { let vec: Vec = read_to_vec(&buffer, elem_count); use crate::quantized::k_quants::GgmlType; crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?; } GgmlDType::Q5_0 => { let vec: Vec = read_to_vec(&buffer, elem_count); use crate::quantized::k_quants::GgmlType; crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?; } GgmlDType::Q5_1 => { let vec: Vec = read_to_vec(&buffer, elem_count); use crate::quantized::k_quants::GgmlType; crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?; } GgmlDType::Q8_0 => { let vec: Vec = read_to_vec(&buffer, elem_count); use crate::quantized::k_quants::GgmlType; crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?; } GgmlDType::Q8_1 => { let vec: Vec = read_to_vec(&buffer, elem_count); use crate::quantized::k_quants::GgmlType; crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?; } GgmlDType::Q2K => { let vec: Vec = read_to_vec(&buffer, elem_count / self.dtype.block_size()); use crate::quantized::k_quants::GgmlType; crate::quantized::BlockQ2K::to_float(&vec, &mut out)?; } GgmlDType::Q3K => { let vec: Vec = read_to_vec(&buffer, elem_count / self.dtype.block_size()); use crate::quantized::k_quants::GgmlType; crate::quantized::BlockQ3K::to_float(&vec, &mut out)?; } GgmlDType::Q4K => { let vec: Vec = read_to_vec(&buffer, elem_count / self.dtype.block_size()); use crate::quantized::k_quants::GgmlType; crate::quantized::BlockQ4K::to_float(&vec, &mut out)?; } GgmlDType::Q5K => { let vec: Vec = read_to_vec(&buffer, elem_count / self.dtype.block_size()); use crate::quantized::k_quants::GgmlType; crate::quantized::BlockQ5K::to_float(&vec, &mut out)?; } GgmlDType::Q6K => { let vec: Vec = read_to_vec(&buffer, elem_count / self.dtype.block_size()); use crate::quantized::k_quants::GgmlType; crate::quantized::BlockQ6K::to_float(&vec, &mut out)?; } GgmlDType::Q8K => { let vec: Vec = read_to_vec(&buffer, elem_count / self.dtype.block_size()); use crate::quantized::k_quants::GgmlType; crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; } } let buffer = self.device.new_buffer_with_data(&out)?; Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32)) } pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> { // Quantization only happens on CPU for now. let src = src.to_cpu::()?; let elem_count = src.len(); let src = crate::Storage::Cpu(crate::CpuStorage::F32(src)); let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?; qcpu_storage.quantize(&src)?; let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?; self.buffer = buffer; Ok(()) } pub fn storage_size_in_bytes(&self) -> usize { self.buffer.length() as usize } pub fn fwd( &self, self_shape: &Shape, storage: &MetalStorage, layout: &crate::Layout, ) -> Result<(MetalStorage, Shape)> { use crate::MetalError; if !layout.is_contiguous() { crate::bail!("input tensor is not contiguous {layout:?}") } let src_shape = layout.shape(); // self is transposed so n is first then k. if src_shape.rank() < 2 { crate::bail!("input tensor has only one dimension {layout:?}") } let (n, k) = self_shape.dims2()?; let mut dst_shape = src_shape.dims().to_vec(); let (b, m) = match dst_shape.len() { 3 => (dst_shape[0], dst_shape[1]), 2 => (1, dst_shape[0]), n => crate::bail!("Invalid rank {n} for quantized matmul metal"), }; let last_k = dst_shape.pop().unwrap(); if last_k != k { crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape) } dst_shape.push(n); let dst_shape = Shape::from(dst_shape); let device = storage.device().clone(); let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?; let command_buffer = device.command_buffer()?; candle_metal_kernels::call_quantized_matmul_t( device.device(), &command_buffer, device.kernels(), self.dtype.into(), (b, m, n, k), storage.buffer(), layout.start_offset() * storage.dtype().size_in_bytes(), &self.buffer, &dst, ) .map_err(MetalError::from)?; let dst_storage = crate::MetalStorage::new(dst, device, DType::F32); Ok((dst_storage, dst_shape)) } } pub fn load_quantized_metal( device: &MetalDevice, data: &[T], ) -> Result { let buffer = device.new_buffer_with_data(data)?; let device = device.clone(); Ok(QStorage::Metal(QMetalStorage { dtype: T::DTYPE, device, buffer, })) } fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { let ptr = buffer.contents() as *const T; assert!(!ptr.is_null()); let slice = unsafe { std::slice::from_raw_parts(ptr, n) }; slice.to_vec() } impl From for candle_metal_kernels::GgmlDType { fn from(value: GgmlDType) -> Self { match value { GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0, GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1, GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0, GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1, GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0, GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1, GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K, GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K, GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K, GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K, GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K, GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, } } }