mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
235 lines
9.4 KiB
Rust
235 lines
9.4 KiB
Rust
use super::{GgmlDType, QStorage};
|
|
use crate::backend::BackendStorage;
|
|
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
|
|
use metal::Buffer;
|
|
use std::sync::Arc;
|
|
|
|
pub struct QMetalStorage {
|
|
dtype: GgmlDType,
|
|
device: MetalDevice,
|
|
buffer: Arc<Buffer>,
|
|
}
|
|
|
|
impl QMetalStorage {
|
|
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
|
|
let size = elem_count * dtype.type_size() / dtype.block_size();
|
|
let buffer = device.allocate_zeros(size)?;
|
|
Ok(Self {
|
|
buffer,
|
|
device: device.clone(),
|
|
dtype,
|
|
})
|
|
}
|
|
|
|
pub fn dtype(&self) -> GgmlDType {
|
|
self.dtype
|
|
}
|
|
|
|
pub fn device(&self) -> &MetalDevice {
|
|
&self.device
|
|
}
|
|
|
|
pub fn buffer(&self) -> &Buffer {
|
|
&self.buffer
|
|
}
|
|
|
|
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
|
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
|
let command_buffer = self.device.command_buffer()?;
|
|
command_buffer.set_label("to_cpu");
|
|
let blit = command_buffer.new_blit_command_encoder();
|
|
blit.set_label("blit_to_cpu");
|
|
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
|
blit.end_encoding();
|
|
self.device.wait_until_completed()?;
|
|
let mut out = vec![0.0; elem_count];
|
|
match self.dtype {
|
|
GgmlDType::F32 => {
|
|
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
|
|
use crate::quantized::k_quants::GgmlType;
|
|
f32::to_float(&vec, &mut out)?;
|
|
}
|
|
GgmlDType::F16 => {
|
|
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
|
|
use crate::quantized::k_quants::GgmlType;
|
|
half::f16::to_float(&vec, &mut out)?;
|
|
}
|
|
GgmlDType::Q4_0 => {
|
|
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
|
|
use crate::quantized::k_quants::GgmlType;
|
|
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
|
}
|
|
GgmlDType::Q4_1 => {
|
|
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
|
|
use crate::quantized::k_quants::GgmlType;
|
|
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
|
}
|
|
GgmlDType::Q5_0 => {
|
|
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
|
|
use crate::quantized::k_quants::GgmlType;
|
|
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
|
}
|
|
GgmlDType::Q5_1 => {
|
|
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
|
|
use crate::quantized::k_quants::GgmlType;
|
|
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
|
}
|
|
GgmlDType::Q8_0 => {
|
|
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
|
|
use crate::quantized::k_quants::GgmlType;
|
|
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
|
}
|
|
GgmlDType::Q8_1 => {
|
|
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
|
|
use crate::quantized::k_quants::GgmlType;
|
|
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
|
}
|
|
GgmlDType::Q2K => {
|
|
let vec: Vec<crate::quantized::BlockQ2K> =
|
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
use crate::quantized::k_quants::GgmlType;
|
|
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
|
}
|
|
GgmlDType::Q3K => {
|
|
let vec: Vec<crate::quantized::BlockQ3K> =
|
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
use crate::quantized::k_quants::GgmlType;
|
|
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
|
}
|
|
GgmlDType::Q4K => {
|
|
let vec: Vec<crate::quantized::BlockQ4K> =
|
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
use crate::quantized::k_quants::GgmlType;
|
|
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
|
}
|
|
GgmlDType::Q5K => {
|
|
let vec: Vec<crate::quantized::BlockQ5K> =
|
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
use crate::quantized::k_quants::GgmlType;
|
|
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
|
}
|
|
GgmlDType::Q6K => {
|
|
let vec: Vec<crate::quantized::BlockQ6K> =
|
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
use crate::quantized::k_quants::GgmlType;
|
|
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
|
}
|
|
GgmlDType::Q8K => {
|
|
let vec: Vec<crate::quantized::BlockQ8K> =
|
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
use crate::quantized::k_quants::GgmlType;
|
|
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
|
}
|
|
}
|
|
|
|
let buffer = self.device.new_buffer_with_data(&out)?;
|
|
Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
|
|
}
|
|
|
|
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
|
// Quantization only happens on CPU for now.
|
|
let src = src.to_cpu::<f32>()?;
|
|
let elem_count = src.len();
|
|
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
|
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;
|
|
qcpu_storage.quantize(&src)?;
|
|
let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
|
|
self.buffer = buffer;
|
|
Ok(())
|
|
}
|
|
|
|
pub fn storage_size_in_bytes(&self) -> usize {
|
|
self.buffer.length() as usize
|
|
}
|
|
|
|
pub fn fwd(
|
|
&self,
|
|
self_shape: &Shape,
|
|
storage: &MetalStorage,
|
|
layout: &crate::Layout,
|
|
) -> Result<(MetalStorage, Shape)> {
|
|
use crate::MetalError;
|
|
|
|
if !layout.is_contiguous() {
|
|
crate::bail!("input tensor is not contiguous {layout:?}")
|
|
}
|
|
let src_shape = layout.shape();
|
|
// self is transposed so n is first then k.
|
|
if src_shape.rank() < 2 {
|
|
crate::bail!("input tensor has only one dimension {layout:?}")
|
|
}
|
|
let (n, k) = self_shape.dims2()?;
|
|
let mut dst_shape = src_shape.dims().to_vec();
|
|
|
|
let (b, m) = match dst_shape.len() {
|
|
3 => (dst_shape[0], dst_shape[1]),
|
|
2 => (1, dst_shape[0]),
|
|
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
|
};
|
|
let last_k = dst_shape.pop().unwrap();
|
|
if last_k != k {
|
|
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
|
|
}
|
|
dst_shape.push(n);
|
|
let dst_shape = Shape::from(dst_shape);
|
|
let device = storage.device().clone();
|
|
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
|
let command_buffer = device.command_buffer()?;
|
|
candle_metal_kernels::call_quantized_matmul_t(
|
|
device.device(),
|
|
&command_buffer,
|
|
device.kernels(),
|
|
self.dtype.into(),
|
|
(b, m, n, k),
|
|
storage.buffer(),
|
|
layout.start_offset() * storage.dtype().size_in_bytes(),
|
|
&self.buffer,
|
|
&dst,
|
|
)
|
|
.map_err(MetalError::from)?;
|
|
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
|
Ok((dst_storage, dst_shape))
|
|
}
|
|
}
|
|
|
|
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
|
device: &MetalDevice,
|
|
data: &[T],
|
|
) -> Result<QStorage> {
|
|
let buffer = device.new_buffer_with_data(data)?;
|
|
let device = device.clone();
|
|
Ok(QStorage::Metal(QMetalStorage {
|
|
dtype: T::DTYPE,
|
|
device,
|
|
buffer,
|
|
}))
|
|
}
|
|
|
|
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
|
let ptr = buffer.contents() as *const T;
|
|
assert!(!ptr.is_null());
|
|
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
|
slice.to_vec()
|
|
}
|
|
|
|
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
|
fn from(value: GgmlDType) -> Self {
|
|
match value {
|
|
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
|
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
|
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
|
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
|
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
|
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
|
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
|
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
|
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
|
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
|
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
|
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
|
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
|
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
|
}
|
|
}
|
|
}
|