From e68b2accb4f680c7a0f21be2523400a46e088a85 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 15 Aug 2023 20:26:27 +0100 Subject: [PATCH] Split out the quantized file. (#456) --- candle-core/src/lib.rs | 2 +- candle-core/src/quantized/ggml_file.rs | 294 ++++++++++++++ .../src/{ggml.rs => quantized/k_quants.rs} | 370 +----------------- candle-core/src/quantized/mod.rs | 82 ++++ .../{ggml_tests.rs => quantized_tests.rs} | 12 +- candle-examples/examples/ggml/main.rs | 2 +- 6 files changed, 386 insertions(+), 376 deletions(-) create mode 100644 candle-core/src/quantized/ggml_file.rs rename candle-core/src/{ggml.rs => quantized/k_quants.rs} (66%) create mode 100644 candle-core/src/quantized/mod.rs rename candle-core/tests/{ggml_tests.rs => quantized_tests.rs} (74%) diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 3fc34302..62ad55d1 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -50,13 +50,13 @@ pub mod display; mod dtype; mod dummy_cuda_backend; pub mod error; -pub mod ggml; mod indexer; pub mod layout; #[cfg(feature = "mkl")] mod mkl; pub mod npy; mod op; +pub mod quantized; pub mod safetensors; pub mod shape; mod storage; diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs new file mode 100644 index 00000000..2824f075 --- /dev/null +++ b/candle-core/src/quantized/ggml_file.rs @@ -0,0 +1,294 @@ +//! Support for the GGML file format. + +use super::{k_quants, GgmlDType}; +use crate::{DType, Device, Result, Tensor}; +use byteorder::{LittleEndian, ReadBytesExt}; + +// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Magic { + Ggjt, + Ggla, + Ggmf, + Ggml, + Ggsn, +} + +impl TryFrom for Magic { + type Error = crate::Error; + fn try_from(value: u32) -> Result { + let magic = match value { + 0x67676a74 => Self::Ggjt, + 0x67676c61 => Self::Ggla, + 0x67676d66 => Self::Ggmf, + 0x67676d6c => Self::Ggml, + 0x6767736e => Self::Ggsn, + _ => crate::bail!("unknown magic {value:08x}"), + }; + Ok(magic) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VersionedMagic { + GgmlUnversioned, + GgmfV1, + GgjtV1, + GgjtV2, + GgjtV3, +} + +impl VersionedMagic { + fn read(reader: &mut R) -> Result { + let magic = reader.read_u32::()?; + let magic = Magic::try_from(magic)?; + if magic == Magic::Ggml { + return Ok(Self::GgmlUnversioned); + } + let version = reader.read_u32::()?; + let versioned_magic = match (magic, version) { + (Magic::Ggmf, 1) => Self::GgmfV1, + (Magic::Ggjt, 1) => Self::GgjtV1, + (Magic::Ggjt, 2) => Self::GgjtV2, + (Magic::Ggjt, 3) => Self::GgjtV3, + _ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"), + }; + Ok(versioned_magic) + } + + fn align32(&self) -> bool { + match self { + Self::GgmlUnversioned | Self::GgmfV1 => false, + Self::GgjtV1 | Self::GgjtV2 | Self::GgjtV3 => true, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HParams { + pub n_vocab: u32, + pub n_embd: u32, + pub n_mult: u32, + pub n_head: u32, + pub n_layer: u32, + pub n_rot: u32, + pub ftype: u32, +} + +impl HParams { + fn read(reader: &mut R) -> Result { + let n_vocab = reader.read_u32::()?; + let n_embd = reader.read_u32::()?; + let n_mult = reader.read_u32::()?; + let n_head = reader.read_u32::()?; + let n_layer = reader.read_u32::()?; + let n_rot = reader.read_u32::()?; + let ftype = reader.read_u32::()?; + Ok(Self { + n_vocab, + n_embd, + n_mult, + n_head, + n_layer, + n_rot, + ftype, + }) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Vocab { + pub token_score_pairs: Vec<(Vec, f32)>, +} + +impl Vocab { + fn read(reader: &mut R, n_vocab: usize) -> Result { + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L556 + let mut token_score_pairs = Vec::with_capacity(n_vocab); + for _index in 0..n_vocab { + let len = reader.read_u32::()? as usize; + let mut word = vec![0u8; len]; + reader.read_exact(&mut word)?; + let score = reader.read_f32::()?; + token_score_pairs.push((word, score)) + } + Ok(Self { token_score_pairs }) + } +} + +fn dequantize_and_create_tensor( + raw_data: &[u8], + tensor_elems: usize, + size_in_bytes: usize, + dims: Vec, + device: &Device, +) -> Result { + let mut f32_data = vec![0f32; tensor_elems]; + let raw_data_ptr = raw_data.as_ptr(); + let n_blocks = size_in_bytes / std::mem::size_of::(); + let raw_data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; + T::to_float(raw_data, &mut f32_data)?; + Tensor::from_vec(f32_data, dims, device) +} + +/// Creates a [Tensor] from a raw GGML tensor. +pub fn tensor_from_ggml( + ggml_dtype: GgmlDType, + raw_data: &[u8], + dims: Vec, + dtype: DType, + device: &Device, +) -> Result { + let tensor_elems = dims.iter().product::(); + let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size(); + + let tensor = match ggml_dtype { + GgmlDType::F32 => Tensor::from_raw_buffer(raw_data, DType::F32, &dims, device), + GgmlDType::F16 => Tensor::from_raw_buffer(raw_data, DType::F16, &dims, device), + GgmlDType::Q4_0 => dequantize_and_create_tensor::( + raw_data, + tensor_elems, + size_in_bytes, + dims, + device, + ), + GgmlDType::Q4_1 => dequantize_and_create_tensor::( + raw_data, + tensor_elems, + size_in_bytes, + dims, + device, + ), + GgmlDType::Q5_0 => dequantize_and_create_tensor::( + raw_data, + tensor_elems, + size_in_bytes, + dims, + device, + ), + GgmlDType::Q5_1 => dequantize_and_create_tensor::( + raw_data, + tensor_elems, + size_in_bytes, + dims, + device, + ), + GgmlDType::Q8_0 => dequantize_and_create_tensor::( + raw_data, + tensor_elems, + size_in_bytes, + dims, + device, + ), + GgmlDType::Q2K => dequantize_and_create_tensor::( + raw_data, + tensor_elems, + size_in_bytes, + dims, + device, + ), + GgmlDType::Q3K => dequantize_and_create_tensor::( + raw_data, + tensor_elems, + size_in_bytes, + dims, + device, + ), + GgmlDType::Q4K => dequantize_and_create_tensor::( + raw_data, + tensor_elems, + size_in_bytes, + dims, + device, + ), + GgmlDType::Q5K => dequantize_and_create_tensor::( + raw_data, + tensor_elems, + size_in_bytes, + dims, + device, + ), + GgmlDType::Q6K => dequantize_and_create_tensor::( + raw_data, + tensor_elems, + size_in_bytes, + dims, + device, + ), + _ => crate::bail!("quantized type {dtype:?} is not supported yet"), + }?; + //We only have ggml-quant to f32 conversions, meaning we have to convert to the desired type + if tensor.dtype() != dtype { + tensor.to_dtype(dtype) + } else { + Ok(tensor) + } +} + +fn read_one_tensor( + reader: &mut R, + magic: VersionedMagic, + dtype: DType, + device: &Device, +) -> Result<(String, Tensor)> { + let n_dims = reader.read_u32::()?; + let name_len = reader.read_u32::()?; + let ggml_dtype = reader.read_u32::()?; + let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?; + let mut dims = vec![0u32; n_dims as usize]; + reader.read_u32_into::(&mut dims)?; + let mut name = vec![0u8; name_len as usize]; + reader.read_exact(&mut name)?; + let name = String::from_utf8_lossy(&name).into_owned(); + + if magic.align32() { + let pos = reader.stream_position()?; + reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?; + } + let dims = dims.iter().map(|&u| u as usize).collect::>(); + let tensor_elems = dims.iter().product::(); + let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size(); + println!("{name} {ggml_dtype:?} {dims:?}"); + // TODO: Mmap version to avoid copying the data around? + let mut raw_data = vec![0u8; size_in_bytes]; + reader.read_exact(&mut raw_data)?; + match tensor_from_ggml(ggml_dtype, &raw_data, dims, dtype, device) { + Ok(tensor) => Ok((name, tensor)), + Err(e) => crate::bail!("Error creating tensor {name}: {e}"), + } +} + +#[derive(Debug)] +pub struct Content { + pub magic: VersionedMagic, + pub hparams: HParams, + pub vocab: Vocab, + pub tensors: Vec<(String, Tensor)>, +} + +impl Content { + pub fn read( + reader: &mut R, + dtype: DType, + device: &Device, + ) -> Result { + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505 + let last_position = reader.seek(std::io::SeekFrom::End(0))?; + reader.seek(std::io::SeekFrom::Start(0))?; + let magic = VersionedMagic::read(reader)?; + let hparams = HParams::read(reader)?; + let vocab = Vocab::read(reader, hparams.n_vocab as usize)?; + let mut tensors = vec![]; + + while reader.stream_position()? != last_position { + let (name, tensor) = read_one_tensor(reader, magic, dtype, device)?; + tensors.push((name, tensor)) + } + Ok(Self { + magic, + hparams, + vocab, + tensors, + }) + } +} diff --git a/candle-core/src/ggml.rs b/candle-core/src/quantized/k_quants.rs similarity index 66% rename from candle-core/src/ggml.rs rename to candle-core/src/quantized/k_quants.rs index 3a41eeec..2b88d3f1 100644 --- a/candle-core/src/ggml.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1,7 +1,5 @@ -//! Support for the GGML file format. - -use crate::{DType, Device, Result, Tensor}; -use byteorder::{LittleEndian, ReadBytesExt}; +use super::GgmlDType; +use crate::Result; use half::f16; // Default to QK_K 256 rather than 64. @@ -728,367 +726,3 @@ pub fn matmul( } Ok(()) } - -// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37 -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum Magic { - Ggjt, - Ggla, - Ggmf, - Ggml, - Ggsn, -} - -impl TryFrom for Magic { - type Error = crate::Error; - fn try_from(value: u32) -> Result { - let magic = match value { - 0x67676a74 => Self::Ggjt, - 0x67676c61 => Self::Ggla, - 0x67676d66 => Self::Ggmf, - 0x67676d6c => Self::Ggml, - 0x6767736e => Self::Ggsn, - _ => crate::bail!("unknown magic {value:08x}"), - }; - Ok(magic) - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum VersionedMagic { - GgmlUnversioned, - GgmfV1, - GgjtV1, - GgjtV2, - GgjtV3, -} - -impl VersionedMagic { - fn read(reader: &mut R) -> Result { - let magic = reader.read_u32::()?; - let magic = Magic::try_from(magic)?; - if magic == Magic::Ggml { - return Ok(Self::GgmlUnversioned); - } - let version = reader.read_u32::()?; - let versioned_magic = match (magic, version) { - (Magic::Ggmf, 1) => Self::GgmfV1, - (Magic::Ggjt, 1) => Self::GgjtV1, - (Magic::Ggjt, 2) => Self::GgjtV2, - (Magic::Ggjt, 3) => Self::GgjtV3, - _ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"), - }; - Ok(versioned_magic) - } - - fn align32(&self) -> bool { - match self { - Self::GgmlUnversioned | Self::GgmfV1 => false, - Self::GgjtV1 | Self::GgjtV2 | Self::GgjtV3 => true, - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct HParams { - pub n_vocab: u32, - pub n_embd: u32, - pub n_mult: u32, - pub n_head: u32, - pub n_layer: u32, - pub n_rot: u32, - pub ftype: u32, -} - -impl HParams { - fn read(reader: &mut R) -> Result { - let n_vocab = reader.read_u32::()?; - let n_embd = reader.read_u32::()?; - let n_mult = reader.read_u32::()?; - let n_head = reader.read_u32::()?; - let n_layer = reader.read_u32::()?; - let n_rot = reader.read_u32::()?; - let ftype = reader.read_u32::()?; - Ok(Self { - n_vocab, - n_embd, - n_mult, - n_head, - n_layer, - n_rot, - ftype, - }) - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct Vocab { - pub token_score_pairs: Vec<(Vec, f32)>, -} - -impl Vocab { - fn read(reader: &mut R, n_vocab: usize) -> Result { - // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L556 - let mut token_score_pairs = Vec::with_capacity(n_vocab); - for _index in 0..n_vocab { - let len = reader.read_u32::()? as usize; - let mut word = vec![0u8; len]; - reader.read_exact(&mut word)?; - let score = reader.read_f32::()?; - token_score_pairs.push((word, score)) - } - Ok(Self { token_score_pairs }) - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum GgmlDType { - F32, - F16, - Q4_0, - Q4_1, - Q5_0, - Q5_1, - Q8_0, - Q8_1, - Q2K, - Q3K, - Q4K, - Q5K, - Q6K, - Q8K, -} - -impl GgmlDType { - fn from_u32(u: u32) -> Result { - let dtype = match u { - 0 => Self::F32, - 1 => Self::F16, - 2 => Self::Q4_0, - 3 => Self::Q4_1, - 6 => Self::Q5_0, - 7 => Self::Q5_1, - 8 => Self::Q8_0, - 9 => Self::Q8_1, - 10 => Self::Q2K, - 11 => Self::Q3K, - 12 => Self::Q4K, - 13 => Self::Q5K, - 14 => Self::Q6K, - 15 => Self::Q8K, - _ => crate::bail!("unknown dtype for tensor {u}"), - }; - Ok(dtype) - } - - fn type_size(&self) -> usize { - match self { - Self::F32 => 4, - Self::F16 => 2, - Self::Q4_0 => std::mem::size_of::(), - Self::Q4_1 => std::mem::size_of::(), - Self::Q5_0 => std::mem::size_of::(), - Self::Q5_1 => std::mem::size_of::(), - // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932 - Self::Q8_0 => std::mem::size_of::(), - Self::Q8_1 => std::mem::size_of::(), - Self::Q2K => std::mem::size_of::(), - Self::Q3K => std::mem::size_of::(), - Self::Q4K => std::mem::size_of::(), - Self::Q5K => std::mem::size_of::(), - Self::Q6K => std::mem::size_of::(), - Self::Q8K => std::mem::size_of::(), - } - } - - fn blck_size(&self) -> usize { - match self { - Self::F32 => 1, - Self::F16 => 1, - Self::Q4_0 => QK4_0, - Self::Q4_1 => QK4_1, - Self::Q5_0 => QK5_0, - Self::Q5_1 => QK5_1, - Self::Q8_0 => QK8_0, - Self::Q8_1 => QK8_1, - Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => QK_K, - } - } -} - -fn dequantize_and_create_tensor( - raw_data: &[u8], - tensor_elems: usize, - size_in_bytes: usize, - dims: Vec, - device: &Device, -) -> Result { - let mut f32_data = vec![0f32; tensor_elems]; - let raw_data_ptr = raw_data.as_ptr(); - let n_blocks = size_in_bytes / std::mem::size_of::(); - let raw_data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; - T::to_float(raw_data, &mut f32_data)?; - Tensor::from_vec(f32_data, dims, device) -} - -/// Creates a [Tensor] from a raw GGML tensor. -pub fn tensor_from_ggml( - ggml_dtype: GgmlDType, - raw_data: &[u8], - dims: Vec, - dtype: DType, - device: &Device, -) -> Result { - let tensor_elems = dims.iter().product::(); - let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size(); - - let tensor = match ggml_dtype { - GgmlDType::F32 => Tensor::from_raw_buffer(raw_data, DType::F32, &dims, device), - GgmlDType::F16 => Tensor::from_raw_buffer(raw_data, DType::F16, &dims, device), - GgmlDType::Q4_0 => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q4_1 => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q5_0 => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q5_1 => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q8_0 => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q2K => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q3K => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q4K => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q5K => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q6K => dequantize_and_create_tensor::( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - _ => crate::bail!("quantized type {dtype:?} is not supported yet"), - }?; - //We only have ggml-quant to f32 conversions, meaning we have to convert to the desired type - if tensor.dtype() != dtype { - tensor.to_dtype(dtype) - } else { - Ok(tensor) - } -} - -fn read_one_tensor( - reader: &mut R, - magic: VersionedMagic, - dtype: DType, - device: &Device, -) -> Result<(String, Tensor)> { - let n_dims = reader.read_u32::()?; - let name_len = reader.read_u32::()?; - let ggml_dtype = reader.read_u32::()?; - let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?; - let mut dims = vec![0u32; n_dims as usize]; - reader.read_u32_into::(&mut dims)?; - let mut name = vec![0u8; name_len as usize]; - reader.read_exact(&mut name)?; - let name = String::from_utf8_lossy(&name).into_owned(); - - if magic.align32() { - let pos = reader.stream_position()?; - reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?; - } - let dims = dims.iter().map(|&u| u as usize).collect::>(); - let tensor_elems = dims.iter().product::(); - let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size(); - println!("{name} {ggml_dtype:?} {dims:?}"); - // TODO: Mmap version to avoid copying the data around? - let mut raw_data = vec![0u8; size_in_bytes]; - reader.read_exact(&mut raw_data)?; - match tensor_from_ggml(ggml_dtype, &raw_data, dims, dtype, device) { - Ok(tensor) => Ok((name, tensor)), - Err(e) => crate::bail!("Error creating tensor {name}: {e}"), - } -} - -#[derive(Debug)] -pub struct Content { - pub magic: VersionedMagic, - pub hparams: HParams, - pub vocab: Vocab, - pub tensors: Vec<(String, Tensor)>, -} - -impl Content { - pub fn read( - reader: &mut R, - dtype: DType, - device: &Device, - ) -> Result { - // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505 - let last_position = reader.seek(std::io::SeekFrom::End(0))?; - reader.seek(std::io::SeekFrom::Start(0))?; - let magic = VersionedMagic::read(reader)?; - let hparams = HParams::read(reader)?; - let vocab = Vocab::read(reader, hparams.n_vocab as usize)?; - let mut tensors = vec![]; - - while reader.stream_position()? != last_position { - let (name, tensor) = read_one_tensor(reader, magic, dtype, device)?; - tensors.push((name, tensor)) - } - Ok(Self { - magic, - hparams, - vocab, - tensors, - }) - } -} diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs new file mode 100644 index 00000000..c7e24592 --- /dev/null +++ b/candle-core/src/quantized/mod.rs @@ -0,0 +1,82 @@ +use crate::Result; + +pub mod ggml_file; +pub mod k_quants; + +pub use k_quants::GgmlType; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum GgmlDType { + F32, + F16, + Q4_0, + Q4_1, + Q5_0, + Q5_1, + Q8_0, + Q8_1, + Q2K, + Q3K, + Q4K, + Q5K, + Q6K, + Q8K, +} + +impl GgmlDType { + pub(crate) fn from_u32(u: u32) -> Result { + let dtype = match u { + 0 => Self::F32, + 1 => Self::F16, + 2 => Self::Q4_0, + 3 => Self::Q4_1, + 6 => Self::Q5_0, + 7 => Self::Q5_1, + 8 => Self::Q8_0, + 9 => Self::Q8_1, + 10 => Self::Q2K, + 11 => Self::Q3K, + 12 => Self::Q4K, + 13 => Self::Q5K, + 14 => Self::Q6K, + 15 => Self::Q8K, + _ => crate::bail!("unknown dtype for tensor {u}"), + }; + Ok(dtype) + } + + fn type_size(&self) -> usize { + use k_quants::*; + match self { + Self::F32 => 4, + Self::F16 => 2, + Self::Q4_0 => std::mem::size_of::(), + Self::Q4_1 => std::mem::size_of::(), + Self::Q5_0 => std::mem::size_of::(), + Self::Q5_1 => std::mem::size_of::(), + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932 + Self::Q8_0 => std::mem::size_of::(), + Self::Q8_1 => std::mem::size_of::(), + Self::Q2K => std::mem::size_of::(), + Self::Q3K => std::mem::size_of::(), + Self::Q4K => std::mem::size_of::(), + Self::Q5K => std::mem::size_of::(), + Self::Q6K => std::mem::size_of::(), + Self::Q8K => std::mem::size_of::(), + } + } + + fn blck_size(&self) -> usize { + match self { + Self::F32 => 1, + Self::F16 => 1, + Self::Q4_0 => k_quants::QK4_0, + Self::Q4_1 => k_quants::QK4_1, + Self::Q5_0 => k_quants::QK5_0, + Self::Q5_1 => k_quants::QK5_1, + Self::Q8_0 => k_quants::QK8_0, + Self::Q8_1 => k_quants::QK8_1, + Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K, + } + } +} diff --git a/candle-core/tests/ggml_tests.rs b/candle-core/tests/quantized_tests.rs similarity index 74% rename from candle-core/tests/ggml_tests.rs rename to candle-core/tests/quantized_tests.rs index d976ad99..8b4ddcd9 100644 --- a/candle-core/tests/ggml_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -1,18 +1,18 @@ -use candle_core::{ggml, Device, Result, Tensor}; -use ggml::GgmlType; +use candle_core::{quantized, Device, Result, Tensor}; +use quantized::{k_quants, GgmlType}; #[test] -fn ggml_matmul() -> Result<()> { +fn quantized_matmul() -> Result<()> { let cpu = &Device::Cpu; let (m, k, n) = (3, 64, 4); let lhs = (0..(m * k)).map(|v| v as f32).collect::>(); let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?; let mut dst = vec![42.; 3 * 4]; - let mut rhs_t = vec![ggml::BlockQ4_0::zeros(); 8]; + let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let rhs = (0..(k * n)).map(|v| v as f32).collect::>(); let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?; - ggml::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; - ggml::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; + k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; + k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; assert_eq!( dst, &[ diff --git a/candle-examples/examples/ggml/main.rs b/candle-examples/examples/ggml/main.rs index c3fc6b9e..75a7334c 100644 --- a/candle-examples/examples/ggml/main.rs +++ b/candle-examples/examples/ggml/main.rs @@ -2,7 +2,7 @@ use anyhow::Result; use clap::Parser; use std::fs::File; -use candle::ggml::Content; +use candle::quantized::ggml_file::Content; use candle::{DType, Device}; #[derive(Parser, Debug)]