mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add quantized tensors. (#458)
* Add quantized tensors. * Implement the debug trait for QTensor. * Add the QMatMul custom op.
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
//! Support for the GGML file format.
|
||||
|
||||
use super::{k_quants, GgmlDType};
|
||||
use crate::{DType, Device, Result, Tensor};
|
||||
use crate::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
|
||||
@ -116,121 +116,47 @@ impl Vocab {
|
||||
}
|
||||
}
|
||||
|
||||
fn dequantize_and_create_tensor<T: super::GgmlType>(
|
||||
fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
||||
raw_data: &[u8],
|
||||
tensor_elems: usize,
|
||||
size_in_bytes: usize,
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let mut f32_data = vec![0f32; tensor_elems];
|
||||
) -> Result<super::QTensor> {
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||
let raw_data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
T::to_float(raw_data, &mut f32_data)?;
|
||||
Tensor::from_vec(f32_data, dims, device)
|
||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
Ok(super::QTensor::new(data.to_vec(), dims))
|
||||
}
|
||||
|
||||
/// Creates a [Tensor] from a raw GGML tensor.
|
||||
pub fn tensor_from_ggml(
|
||||
pub fn qtensor_from_ggml(
|
||||
ggml_dtype: GgmlDType,
|
||||
raw_data: &[u8],
|
||||
dims: Vec<usize>,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
) -> Result<super::QTensor> {
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||
|
||||
let tensor = match ggml_dtype {
|
||||
GgmlDType::F32 => Tensor::from_raw_buffer(raw_data, DType::F32, &dims, device),
|
||||
GgmlDType::F16 => Tensor::from_raw_buffer(raw_data, DType::F16, &dims, device),
|
||||
GgmlDType::Q4_0 => dequantize_and_create_tensor::<k_quants::BlockQ4_0>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
),
|
||||
GgmlDType::Q4_1 => dequantize_and_create_tensor::<k_quants::BlockQ4_1>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
),
|
||||
GgmlDType::Q5_0 => dequantize_and_create_tensor::<k_quants::BlockQ5_0>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
),
|
||||
GgmlDType::Q5_1 => dequantize_and_create_tensor::<k_quants::BlockQ5_1>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
),
|
||||
GgmlDType::Q8_0 => dequantize_and_create_tensor::<k_quants::BlockQ8_0>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
),
|
||||
GgmlDType::Q2K => dequantize_and_create_tensor::<k_quants::BlockQ2K>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
),
|
||||
GgmlDType::Q3K => dequantize_and_create_tensor::<k_quants::BlockQ3K>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
),
|
||||
GgmlDType::Q4K => dequantize_and_create_tensor::<k_quants::BlockQ4K>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
),
|
||||
GgmlDType::Q5K => dequantize_and_create_tensor::<k_quants::BlockQ5K>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
),
|
||||
GgmlDType::Q6K => dequantize_and_create_tensor::<k_quants::BlockQ6K>(
|
||||
raw_data,
|
||||
tensor_elems,
|
||||
size_in_bytes,
|
||||
dims,
|
||||
device,
|
||||
),
|
||||
_ => crate::bail!("quantized type {dtype:?} is not supported yet"),
|
||||
}?;
|
||||
//We only have ggml-quant to f32 conversions, meaning we have to convert to the desired type
|
||||
if tensor.dtype() != dtype {
|
||||
tensor.to_dtype(dtype)
|
||||
} else {
|
||||
Ok(tensor)
|
||||
match ggml_dtype {
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
|
||||
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
||||
}
|
||||
}
|
||||
|
||||
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
magic: VersionedMagic,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<(String, Tensor)> {
|
||||
) -> Result<(String, super::QTensor)> {
|
||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||
let ggml_dtype = reader.read_u32::<LittleEndian>()?;
|
||||
@ -252,26 +178,21 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
// TODO: Mmap version to avoid copying the data around?
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
match tensor_from_ggml(ggml_dtype, &raw_data, dims, dtype, device) {
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
|
||||
Ok(tensor) => Ok((name, tensor)),
|
||||
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Content {
|
||||
pub magic: VersionedMagic,
|
||||
pub hparams: HParams,
|
||||
pub vocab: Vocab,
|
||||
pub tensors: Vec<(String, Tensor)>,
|
||||
pub tensors: Vec<(String, super::QTensor)>,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Content> {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||
@ -281,7 +202,7 @@ impl Content {
|
||||
let mut tensors = vec![];
|
||||
|
||||
while reader.stream_position()? != last_position {
|
||||
let (name, tensor) = read_one_tensor(reader, magic, dtype, device)?;
|
||||
let (name, tensor) = read_one_tensor(reader, magic)?;
|
||||
tensors.push((name, tensor))
|
||||
}
|
||||
Ok(Self {
|
||||
|
@ -1,10 +1,15 @@
|
||||
use crate::Result;
|
||||
use crate::{Device, Result, Shape, Tensor};
|
||||
|
||||
pub mod ggml_file;
|
||||
pub mod k_quants;
|
||||
|
||||
pub use k_quants::GgmlType;
|
||||
|
||||
pub struct QTensor {
|
||||
data: Box<dyn QuantizedType>,
|
||||
shape: Shape,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum GgmlDType {
|
||||
F32,
|
||||
@ -80,3 +85,110 @@ impl GgmlDType {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A version of GgmlType without `vec_dot` so that it can be dyn boxed.
|
||||
pub trait QuantizedType: Send + Sync {
|
||||
fn dtype(&self) -> GgmlDType;
|
||||
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
|
||||
}
|
||||
|
||||
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
|
||||
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
|
||||
}
|
||||
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
T::DTYPE
|
||||
}
|
||||
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
|
||||
T::to_float(self.as_slice(), ys)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for QTensor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
|
||||
}
|
||||
}
|
||||
|
||||
impl QTensor {
|
||||
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||
data: Vec<T>,
|
||||
shape: S,
|
||||
) -> Self {
|
||||
Self {
|
||||
data: Box::new(data),
|
||||
shape: shape.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.data.dtype()
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &Shape {
|
||||
&self.shape
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||
let mut f32_data = vec![0f32; self.shape.elem_count()];
|
||||
self.data.to_float(&mut f32_data)?;
|
||||
Tensor::from_vec(f32_data, &self.shape, device)
|
||||
}
|
||||
|
||||
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
|
||||
self.data.matmul_t(mkn, lhs, dst)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QMatMul(std::sync::Arc<QTensor>);
|
||||
|
||||
impl QMatMul {
|
||||
pub fn new(qtensor: std::sync::Arc<QTensor>) -> Self {
|
||||
Self(qtensor)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::CustomOp1 for QMatMul {
|
||||
fn name(&self) -> &'static str {
|
||||
"qmatmul"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
storage: &crate::CpuStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::CpuStorage, Shape)> {
|
||||
if !layout.is_contiguous() {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
let (k, n) = self.0.shape.dims2()?;
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!(
|
||||
"input tensor {layout:?} incompatible with {:?}",
|
||||
self.0.shape
|
||||
)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
let storage = storage.as_slice::<f32>()?;
|
||||
let storage =
|
||||
&storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||
self.0.matmul_t(
|
||||
(dst_shape.elem_count() / n, k, n),
|
||||
storage,
|
||||
&mut dst_storage,
|
||||
)?;
|
||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||
}
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ use clap::Parser;
|
||||
use std::fs::File;
|
||||
|
||||
use candle::quantized::ggml_file::Content;
|
||||
use candle::{DType, Device};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@ -18,7 +17,7 @@ fn main() -> Result<()> {
|
||||
|
||||
let mut file = File::open(args.model)?;
|
||||
let start = std::time::Instant::now();
|
||||
let model = Content::read(&mut file, DType::F16, &Device::Cpu)?;
|
||||
let model = Content::read(&mut file)?;
|
||||
|
||||
println!(
|
||||
"Loaded {:?} tensors in {:?}",
|
||||
|
Reference in New Issue
Block a user