mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Qmetal tweaks (#1704)
* Add the dummy qmetal backend. * Fix the metal compilation.
This commit is contained in:
43
candle-core/src/quantized/dummy_metal.rs
Normal file
43
candle-core/src/quantized/dummy_metal.rs
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
use super::GgmlDType;
|
||||||
|
use crate::{Error, MetalDevice, MetalStorage, Result};
|
||||||
|
|
||||||
|
pub struct QMetalStorage {
|
||||||
|
dtype: GgmlDType,
|
||||||
|
device: MetalDevice,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QMetalStorage {
|
||||||
|
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dtype(&self) -> GgmlDType {
|
||||||
|
self.dtype
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &MetalDevice {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn storage_size_in_bytes(&self) -> usize {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fwd(
|
||||||
|
&self,
|
||||||
|
_self_shape: &crate::Shape,
|
||||||
|
_storage: &MetalStorage,
|
||||||
|
_layout: &crate::Layout,
|
||||||
|
) -> Result<(MetalStorage, crate::Shape)> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
}
|
@ -1,5 +1,6 @@
|
|||||||
use super::{GgmlDType, QStorage};
|
use super::{GgmlDType, QStorage};
|
||||||
use crate::{DType, MetalDevice, MetalStorage, Result};
|
use crate::backend::BackendStorage;
|
||||||
|
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
|
||||||
use metal::Buffer;
|
use metal::Buffer;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@ -10,6 +11,16 @@ pub struct QMetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl QMetalStorage {
|
impl QMetalStorage {
|
||||||
|
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||||
|
let size = elem_count * dtype.type_size() / dtype.block_size();
|
||||||
|
let buffer = device.allocate_zeros(size)?;
|
||||||
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
pub fn dtype(&self) -> GgmlDType {
|
||||||
self.dtype
|
self.dtype
|
||||||
}
|
}
|
||||||
@ -22,14 +33,6 @@ impl QMetalStorage {
|
|||||||
&self.buffer
|
&self.buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: GgmlDType) -> Self {
|
|
||||||
Self {
|
|
||||||
device,
|
|
||||||
buffer,
|
|
||||||
dtype,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
||||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
@ -134,6 +137,59 @@ impl QMetalStorage {
|
|||||||
self.buffer = buffer;
|
self.buffer = buffer;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn storage_size_in_bytes(&self) -> usize {
|
||||||
|
self.buffer.length() as usize
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fwd(
|
||||||
|
&self,
|
||||||
|
self_shape: &Shape,
|
||||||
|
storage: &MetalStorage,
|
||||||
|
layout: &crate::Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
use crate::MetalError;
|
||||||
|
|
||||||
|
if !layout.is_contiguous() {
|
||||||
|
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||||
|
}
|
||||||
|
let src_shape = layout.shape();
|
||||||
|
// self is transposed so n is first then k.
|
||||||
|
if src_shape.rank() < 2 {
|
||||||
|
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||||
|
}
|
||||||
|
let (n, k) = self_shape.dims2()?;
|
||||||
|
let mut dst_shape = src_shape.dims().to_vec();
|
||||||
|
|
||||||
|
let (b, m) = match dst_shape.len() {
|
||||||
|
3 => (dst_shape[0], dst_shape[1]),
|
||||||
|
2 => (1, dst_shape[0]),
|
||||||
|
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||||
|
};
|
||||||
|
let last_k = dst_shape.pop().unwrap();
|
||||||
|
if last_k != k {
|
||||||
|
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
|
||||||
|
}
|
||||||
|
dst_shape.push(n);
|
||||||
|
let dst_shape = Shape::from(dst_shape);
|
||||||
|
let device = storage.device().clone();
|
||||||
|
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||||
|
let command_buffer = device.command_buffer()?;
|
||||||
|
candle_metal_kernels::call_quantized_matmul_t(
|
||||||
|
device.device(),
|
||||||
|
&command_buffer,
|
||||||
|
device.kernels(),
|
||||||
|
self.dtype.into(),
|
||||||
|
(b, m, n, k),
|
||||||
|
storage.buffer(),
|
||||||
|
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||||
|
&self.buffer,
|
||||||
|
&dst,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
||||||
|
Ok((dst_storage, dst_shape))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
||||||
@ -155,3 +211,24 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
|||||||
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
||||||
slice.to_vec()
|
slice.to_vec()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
||||||
|
fn from(value: GgmlDType) -> Self {
|
||||||
|
match value {
|
||||||
|
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
||||||
|
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
||||||
|
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
||||||
|
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
||||||
|
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
||||||
|
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
||||||
|
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
||||||
|
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
||||||
|
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
||||||
|
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
||||||
|
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
||||||
|
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
||||||
|
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
||||||
|
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,16 +1,19 @@
|
|||||||
#[cfg(feature = "metal")]
|
|
||||||
use crate::{backend::BackendStorage, DType};
|
|
||||||
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
||||||
use k_quants::*;
|
use k_quants::*;
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
|
|
||||||
#[cfg(target_feature = "avx")]
|
#[cfg(target_feature = "avx")]
|
||||||
pub mod avx;
|
pub mod avx;
|
||||||
|
mod dummy_metal;
|
||||||
pub mod ggml_file;
|
pub mod ggml_file;
|
||||||
pub mod gguf_file;
|
pub mod gguf_file;
|
||||||
pub mod k_quants;
|
pub mod k_quants;
|
||||||
#[cfg(feature = "metal")]
|
#[cfg(feature = "metal")]
|
||||||
pub mod metal;
|
pub mod metal;
|
||||||
|
#[cfg(not(feature = "metal"))]
|
||||||
|
mod metal {
|
||||||
|
pub use super::dummy_metal::*;
|
||||||
|
}
|
||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
pub mod neon;
|
pub mod neon;
|
||||||
#[cfg(target_feature = "simd128")]
|
#[cfg(target_feature = "simd128")]
|
||||||
@ -32,19 +35,9 @@ impl Device {
|
|||||||
let storage = dtype.cpu_zeros(elem_count);
|
let storage = dtype.cpu_zeros(elem_count);
|
||||||
Ok(QStorage::Cpu(storage))
|
Ok(QStorage::Cpu(storage))
|
||||||
}
|
}
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
Device::Metal(metal) => {
|
Device::Metal(metal) => {
|
||||||
let size = elem_count * dtype.type_size() / dtype.block_size();
|
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
|
||||||
let buffer = metal.allocate_zeros(size)?;
|
Ok(QStorage::Metal(storage))
|
||||||
Ok(QStorage::Metal(metal::QMetalStorage::new(
|
|
||||||
buffer,
|
|
||||||
metal.clone(),
|
|
||||||
dtype,
|
|
||||||
)))
|
|
||||||
}
|
|
||||||
#[cfg(not(feature = "metal"))]
|
|
||||||
Device::Metal(_metal) => {
|
|
||||||
crate::bail!("Metal feature not activated");
|
|
||||||
}
|
}
|
||||||
Device::Cuda(_cuda) => {
|
Device::Cuda(_cuda) => {
|
||||||
crate::bail!("Cuda ggml quantization not supported");
|
crate::bail!("Cuda ggml quantization not supported");
|
||||||
@ -55,7 +48,6 @@ impl Device {
|
|||||||
|
|
||||||
pub enum QStorage {
|
pub enum QStorage {
|
||||||
Cpu(Box<dyn QuantizedType>),
|
Cpu(Box<dyn QuantizedType>),
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
Metal(metal::QMetalStorage),
|
Metal(metal::QMetalStorage),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,7 +55,6 @@ impl QStorage {
|
|||||||
fn block_size(&self) -> usize {
|
fn block_size(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(storage) => storage.block_size(),
|
QStorage::Cpu(storage) => storage.block_size(),
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
QStorage::Metal(storage) => storage.dtype().block_size(),
|
QStorage::Metal(storage) => storage.dtype().block_size(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -71,7 +62,6 @@ impl QStorage {
|
|||||||
fn dtype(&self) -> GgmlDType {
|
fn dtype(&self) -> GgmlDType {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(storage) => storage.dtype(),
|
QStorage::Cpu(storage) => storage.dtype(),
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
QStorage::Metal(storage) => storage.dtype(),
|
QStorage::Metal(storage) => storage.dtype(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -79,7 +69,6 @@ impl QStorage {
|
|||||||
fn device(&self) -> Device {
|
fn device(&self) -> Device {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(_storage) => Device::Cpu,
|
QStorage::Cpu(_storage) => Device::Cpu,
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
|
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -87,8 +76,7 @@ impl QStorage {
|
|||||||
fn size_in_bytes(&self) -> usize {
|
fn size_in_bytes(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
||||||
#[cfg(feature = "metal")]
|
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
|
||||||
QStorage::Metal(storage) => storage.buffer().length() as usize,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,7 +85,6 @@ impl QStorage {
|
|||||||
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
||||||
storage.from_float(src.as_slice::<f32>()?)?;
|
storage.from_float(src.as_slice::<f32>()?)?;
|
||||||
}
|
}
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
||||||
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
||||||
}
|
}
|
||||||
@ -107,7 +94,6 @@ impl QStorage {
|
|||||||
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -120,7 +106,6 @@ impl QStorage {
|
|||||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||||
Ok(Cow::from(data))
|
Ok(Cow::from(data))
|
||||||
}
|
}
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
QStorage::Metal(_storage) => {
|
QStorage::Metal(_storage) => {
|
||||||
crate::bail!("not implemented");
|
crate::bail!("not implemented");
|
||||||
}
|
}
|
||||||
@ -439,8 +424,7 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
#[allow(clippy::infallible_destructuring_match)]
|
#[allow(clippy::infallible_destructuring_match)]
|
||||||
let self_storage = match &self.storage {
|
let self_storage = match &self.storage {
|
||||||
QStorage::Cpu(storage) => storage,
|
QStorage::Cpu(storage) => storage,
|
||||||
#[cfg(feature = "metal")]
|
QStorage::Metal(_) => crate::bail!("Invalid storage"),
|
||||||
_ => crate::bail!("Invalid storage"),
|
|
||||||
};
|
};
|
||||||
let slice = storage.as_slice::<f32>()?;
|
let slice = storage.as_slice::<f32>()?;
|
||||||
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||||
@ -449,79 +433,16 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
fn metal_fwd(
|
fn metal_fwd(
|
||||||
&self,
|
&self,
|
||||||
storage: &crate::MetalStorage,
|
storage: &crate::MetalStorage,
|
||||||
layout: &crate::Layout,
|
layout: &crate::Layout,
|
||||||
) -> Result<(crate::MetalStorage, Shape)> {
|
) -> Result<(crate::MetalStorage, Shape)> {
|
||||||
use crate::MetalError;
|
let self_storage = match &self.storage {
|
||||||
|
QStorage::Metal(metal) => metal,
|
||||||
if !layout.is_contiguous() {
|
|
||||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
|
||||||
}
|
|
||||||
let src_shape = layout.shape();
|
|
||||||
// self is transposed so n is first then k.
|
|
||||||
if src_shape.rank() < 2 {
|
|
||||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
|
||||||
}
|
|
||||||
let (n, k) = self.shape.dims2()?;
|
|
||||||
let mut dst_shape = src_shape.dims().to_vec();
|
|
||||||
|
|
||||||
let (b, m) = match dst_shape.len() {
|
|
||||||
3 => (dst_shape[0], dst_shape[1]),
|
|
||||||
2 => (1, dst_shape[0]),
|
|
||||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
|
||||||
};
|
|
||||||
let last_k = dst_shape.pop().unwrap();
|
|
||||||
if last_k != k {
|
|
||||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
|
||||||
}
|
|
||||||
dst_shape.push(n);
|
|
||||||
let dst_shape = Shape::from(dst_shape);
|
|
||||||
let device = storage.device().clone();
|
|
||||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
|
||||||
let (buffer, dtype) = match &self.storage {
|
|
||||||
QStorage::Metal(metal) => (metal.buffer(), metal.dtype()),
|
|
||||||
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
||||||
};
|
};
|
||||||
let command_buffer = device.command_buffer()?;
|
self_storage.fwd(&self.shape, storage, layout)
|
||||||
candle_metal_kernels::call_quantized_matmul_t(
|
|
||||||
device.device(),
|
|
||||||
&command_buffer,
|
|
||||||
device.kernels(),
|
|
||||||
dtype.into(),
|
|
||||||
(b, m, n, k),
|
|
||||||
storage.buffer(),
|
|
||||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
|
||||||
buffer,
|
|
||||||
&dst,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
|
||||||
Ok((dst_storage, dst_shape))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
|
||||||
fn from(value: GgmlDType) -> Self {
|
|
||||||
match value {
|
|
||||||
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
|
||||||
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
|
||||||
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
|
||||||
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
|
||||||
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
|
||||||
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
|
||||||
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
|
||||||
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
|
||||||
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
|
||||||
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
|
||||||
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
|
||||||
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
|
||||||
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
|
||||||
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user