mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Quantized GGUF style (#1523)
* Metal quantized modifications proposal. - Add a device param, wherever needed. - Create new QMetal storage thing that implements QuantizedType. - Update everywhere needed. Fix Python. Fixing examples. Fix: fmt + clippy + stub. Moving everything around. Only missing the actual implems. Fixing everything + adding dequantized kernels. More work. Fixing matmul. Fmt + Clippy Some clippy fixes. Working state. Q2K Metal -> Bugged (also present in GGML). Q4K CPU -> Bugged (present previously, new test catch it). Q5K CPU -> Bugged (present previously). Q8_1 Both -> Never really implemented it seems Q8K metal -> Never implemented in metal Fixing Q2K bug (present in ggml). * Cleanup. * Fix the rebase. * Removing the fences speeds everything up and *is* correct this time... * Cleanup the fence. * After rebase. * Bad code removal. * Rebase after phi2 merge + fix replit default to CPU. * Making the CI happy. * More happy tests. --------- Co-authored-by: Nicolas Patry <nicolas@Nicolass-MacBook-Pro.local>
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
use candle_core::quantized::{gguf_file, k_quants, QTensor};
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
|
||||
use candle_core::{Device, Result};
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
use rayon::prelude::*;
|
||||
|
||||
@ -11,12 +11,7 @@ enum QuantizationMode {
|
||||
}
|
||||
|
||||
impl QuantizationMode {
|
||||
fn quantize(
|
||||
&self,
|
||||
name: &str,
|
||||
tensor: QTensor,
|
||||
default: fn(&Tensor) -> Result<QTensor>,
|
||||
) -> Result<QTensor> {
|
||||
fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result<QTensor> {
|
||||
match self {
|
||||
Self::Llama => {
|
||||
// Same behavior as the llama.cpp quantization.
|
||||
@ -24,9 +19,9 @@ impl QuantizationMode {
|
||||
if should_quantize {
|
||||
let tensor = tensor.dequantize(&Device::Cpu)?;
|
||||
if name == "output.weight" {
|
||||
QTensor::quantize::<k_quants::BlockQ6K>(&tensor)
|
||||
QTensor::quantize(&tensor, GgmlDType::Q6K)
|
||||
} else {
|
||||
default(&tensor)
|
||||
QTensor::quantize(&tensor, dtype)
|
||||
}
|
||||
} else {
|
||||
Ok(tensor)
|
||||
@ -60,6 +55,27 @@ enum Quantization {
|
||||
F32,
|
||||
}
|
||||
|
||||
impl Quantization {
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
match self {
|
||||
Quantization::Q4_0 => GgmlDType::Q4_0,
|
||||
Quantization::Q4_1 => GgmlDType::Q4_1,
|
||||
Quantization::Q5_0 => GgmlDType::Q5_0,
|
||||
Quantization::Q5_1 => GgmlDType::Q5_1,
|
||||
Quantization::Q8_0 => GgmlDType::Q8_0,
|
||||
Quantization::Q8_1 => GgmlDType::Q8_1,
|
||||
Quantization::Q2k => GgmlDType::Q2K,
|
||||
Quantization::Q3k => GgmlDType::Q3K,
|
||||
Quantization::Q4k => GgmlDType::Q4K,
|
||||
Quantization::Q5k => GgmlDType::Q5K,
|
||||
Quantization::Q6k => GgmlDType::Q6K,
|
||||
Quantization::Q8k => GgmlDType::Q8K,
|
||||
Quantization::F16 => GgmlDType::F16,
|
||||
Quantization::F32 => GgmlDType::F32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
enum Format {
|
||||
Safetensors,
|
||||
@ -134,7 +150,12 @@ struct Args {
|
||||
command: Command,
|
||||
}
|
||||
|
||||
fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
|
||||
fn run_ls(
|
||||
file: &std::path::PathBuf,
|
||||
format: Option<Format>,
|
||||
verbose: bool,
|
||||
device: &Device,
|
||||
) -> Result<()> {
|
||||
let format = match format {
|
||||
Some(format) => format,
|
||||
None => match Format::infer(file) {
|
||||
@ -200,7 +221,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
||||
}
|
||||
Format::Ggml => {
|
||||
let mut file = std::fs::File::open(file)?;
|
||||
let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
|
||||
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
|
||||
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, qtensor) in tensors.iter() {
|
||||
@ -241,37 +262,8 @@ fn run_quantize_safetensors(
|
||||
}
|
||||
println!("tensors: {}", tensors.len());
|
||||
|
||||
let quantize_fn = match q {
|
||||
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||
Quantization::F32 => QTensor::quantize::<f32>,
|
||||
};
|
||||
let block_size = match q {
|
||||
Quantization::Q4_0 => k_quants::QK4_0,
|
||||
Quantization::Q4_1 => k_quants::QK4_1,
|
||||
Quantization::Q5_0 => k_quants::QK5_0,
|
||||
Quantization::Q5_1 => k_quants::QK5_1,
|
||||
Quantization::Q8_0 => k_quants::QK8_0,
|
||||
Quantization::Q8_1 => k_quants::QK8_1,
|
||||
Quantization::Q2k
|
||||
| Quantization::Q3k
|
||||
| Quantization::Q4k
|
||||
| Quantization::Q5k
|
||||
| Quantization::Q6k
|
||||
| Quantization::Q8k => k_quants::QK_K,
|
||||
Quantization::F16 | Quantization::F32 => 1,
|
||||
};
|
||||
let dtype = q.dtype();
|
||||
let block_size = dtype.block_size();
|
||||
|
||||
let qtensors = tensors
|
||||
.into_par_iter()
|
||||
@ -279,9 +271,9 @@ fn run_quantize_safetensors(
|
||||
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
|
||||
println!(" quantizing {name} {tensor:?} {should_quantize}");
|
||||
let tensor = if should_quantize {
|
||||
quantize_fn(&tensor)?
|
||||
QTensor::quantize(&tensor, dtype)?
|
||||
} else {
|
||||
QTensor::quantize::<f32>(&tensor)?
|
||||
QTensor::quantize(&tensor, GgmlDType::F32)?
|
||||
};
|
||||
Ok((name, tensor))
|
||||
})
|
||||
@ -294,13 +286,17 @@ fn run_quantize_safetensors(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_dequantize(in_file: std::path::PathBuf, out_file: std::path::PathBuf) -> Result<()> {
|
||||
fn run_dequantize(
|
||||
in_file: std::path::PathBuf,
|
||||
out_file: std::path::PathBuf,
|
||||
device: &Device,
|
||||
) -> Result<()> {
|
||||
let mut in_file = std::fs::File::open(in_file)?;
|
||||
let content = gguf_file::Content::read(&mut in_file)?;
|
||||
let mut tensors = std::collections::HashMap::new();
|
||||
for (tensor_name, _) in content.tensor_infos.iter() {
|
||||
let tensor = content.tensor(&mut in_file, tensor_name)?;
|
||||
let tensor = tensor.dequantize(&Device::Cpu)?;
|
||||
let tensor = content.tensor(&mut in_file, tensor_name, device)?;
|
||||
let tensor = tensor.dequantize(device)?;
|
||||
tensors.insert(tensor_name.to_string(), tensor);
|
||||
}
|
||||
candle_core::safetensors::save(&tensors, out_file)?;
|
||||
@ -312,6 +308,7 @@ fn run_quantize(
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
qmode: QuantizationMode,
|
||||
device: &Device,
|
||||
) -> Result<()> {
|
||||
if in_files.is_empty() {
|
||||
candle_core::bail!("no specified input files")
|
||||
@ -337,31 +334,15 @@ fn run_quantize(
|
||||
let content = gguf_file::Content::read(&mut in_)?;
|
||||
println!("tensors: {}", content.tensor_infos.len());
|
||||
|
||||
let quantize_fn = match q {
|
||||
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||
Quantization::F32 => QTensor::quantize::<f32>,
|
||||
};
|
||||
|
||||
let dtype = q.dtype();
|
||||
let qtensors = content
|
||||
.tensor_infos
|
||||
.par_iter()
|
||||
.map(|(name, _)| {
|
||||
println!(" quantizing {name}");
|
||||
let mut in_file = std::fs::File::open(&in_files[0])?;
|
||||
let tensor = content.tensor(&mut in_file, name)?;
|
||||
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
|
||||
let tensor = content.tensor(&mut in_file, name, device)?;
|
||||
let tensor = qmode.quantize(name, tensor, dtype)?;
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
@ -381,6 +362,7 @@ fn run_quantize(
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = Device::Cpu;
|
||||
match args.command {
|
||||
Command::Ls {
|
||||
files,
|
||||
@ -392,7 +374,7 @@ fn main() -> anyhow::Result<()> {
|
||||
if multiple_files {
|
||||
println!("--- {file:?} ---");
|
||||
}
|
||||
run_ls(file, format.clone(), verbose)?
|
||||
run_ls(file, format.clone(), verbose, &device)?
|
||||
}
|
||||
}
|
||||
Command::Quantize {
|
||||
@ -400,8 +382,8 @@ fn main() -> anyhow::Result<()> {
|
||||
out_file,
|
||||
quantization,
|
||||
mode,
|
||||
} => run_quantize(&in_file, out_file, quantization, mode)?,
|
||||
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file)?,
|
||||
} => run_quantize(&in_file, out_file, quantization, mode, &device)?,
|
||||
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -84,13 +84,8 @@ pub struct MetalDevice {
|
||||
command_buffer_index: Arc<RwLock<usize>>,
|
||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||
compute_per_buffer: usize,
|
||||
/// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the
|
||||
/// execution order to be linear.
|
||||
/// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
|
||||
/// compute graph.
|
||||
fence: metal::Fence,
|
||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||
/// Heavily used by [`candle_metal_kernels`], both fences need to match
|
||||
/// Heavily used by [`candle_metal_kernels`]
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
/// Simple allocator struct.
|
||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||
@ -221,10 +216,8 @@ impl MetalDevice {
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("with_data");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.wait_for_fence(&self.fence);
|
||||
blit.set_label("with_data_blit");
|
||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||
blit.update_fence(&self.fence);
|
||||
blit.end_encoding();
|
||||
|
||||
// This is necessary, for mmaped safetensors
|
||||
@ -238,6 +231,27 @@ impl MetalDevice {
|
||||
Ok(real)
|
||||
}
|
||||
|
||||
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
|
||||
let buffer = self.allocate_buffer(
|
||||
size_in_bytes as NSUInteger,
|
||||
MTLResourceOptions::StorageModePrivate,
|
||||
"allocate_zeros",
|
||||
)?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("zeros");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.fill_buffer(
|
||||
&buffer,
|
||||
metal::NSRange {
|
||||
location: 0,
|
||||
length: buffer.length(),
|
||||
},
|
||||
0,
|
||||
);
|
||||
blit.end_encoding();
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
/// The critical allocator algorithm
|
||||
fn allocate_buffer(
|
||||
&self,
|
||||
@ -308,35 +322,14 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
let length = self.buffer.length() as usize;
|
||||
let size = self.dtype.size_in_bytes();
|
||||
if length % size != 0 {
|
||||
crate::bail!(
|
||||
"The Metal buffer length is not aligned with dtype {:?}",
|
||||
self.dtype
|
||||
);
|
||||
}
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||
{
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
blit.wait_for_fence(&self.device.fence);
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
blit.update_fence(&self.device.fence);
|
||||
blit.end_encoding();
|
||||
}
|
||||
self.device.wait_until_completed()?;
|
||||
|
||||
match self.dtype {
|
||||
DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))),
|
||||
DType::U32 => Ok(CpuStorage::U32(read_to_vec(&buffer, length / size))),
|
||||
DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))),
|
||||
DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))),
|
||||
DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))),
|
||||
DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))),
|
||||
DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))),
|
||||
DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)),
|
||||
DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)),
|
||||
DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)),
|
||||
DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)),
|
||||
DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)),
|
||||
DType::F32 => Ok(CpuStorage::F32(self.to_cpu()?)),
|
||||
DType::F64 => Ok(CpuStorage::F64(self.to_cpu()?)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1264,7 +1257,7 @@ impl BackendStorage for MetalStorage {
|
||||
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length);
|
||||
blit.end_encoding();
|
||||
} else {
|
||||
let src_shape = src_l.shape();
|
||||
@ -1521,6 +1514,28 @@ impl MetalStorage {
|
||||
command_buffer.set_label("binary");
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
pub(crate) fn to_cpu<T: Clone>(&self) -> Result<Vec<T>> {
|
||||
let length = self.buffer.length() as usize;
|
||||
let size = self.dtype.size_in_bytes();
|
||||
if length % size != 0 {
|
||||
crate::bail!(
|
||||
"The Metal buffer length is not aligned with dtype {:?}",
|
||||
self.dtype
|
||||
);
|
||||
}
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||
{
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
blit.end_encoding();
|
||||
}
|
||||
self.device.wait_until_completed()?;
|
||||
Ok(read_to_vec(&buffer, length / size))
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for MetalDevice {
|
||||
@ -1533,16 +1548,14 @@ impl BackendDevice for MetalDevice {
|
||||
command_buffer.enqueue();
|
||||
let command_buffer = Arc::new(RwLock::new(command_buffer));
|
||||
let command_buffer_index = Arc::new(RwLock::new(0));
|
||||
let fence = device.new_fence();
|
||||
let kernels = Arc::new(Kernels::new(fence.clone()));
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
||||
Ok(val) => val.parse()?,
|
||||
_ => 20,
|
||||
_ => 10,
|
||||
};
|
||||
Ok(Self {
|
||||
device,
|
||||
fence,
|
||||
command_queue,
|
||||
command_buffer,
|
||||
command_buffer_index,
|
||||
@ -1567,21 +1580,8 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("zeros");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.wait_for_fence(&self.fence);
|
||||
blit.fill_buffer(
|
||||
&buffer,
|
||||
metal::NSRange {
|
||||
location: 0,
|
||||
length: buffer.length(),
|
||||
},
|
||||
0,
|
||||
);
|
||||
blit.update_fence(&self.fence);
|
||||
blit.end_encoding();
|
||||
let size = shape.elem_count() * dtype.size_in_bytes();
|
||||
let buffer = self.allocate_zeros(size)?;
|
||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
//! Support for the GGML file format.
|
||||
|
||||
use super::{k_quants, GgmlDType};
|
||||
use crate::Result;
|
||||
#[cfg(feature = "metal")]
|
||||
use super::metal::load_quantized_metal;
|
||||
use super::{k_quants, GgmlDType, QStorage};
|
||||
use crate::{Device, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -121,11 +123,22 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
||||
raw_data: &[u8],
|
||||
size_in_bytes: usize,
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<super::QTensor> {
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
super::QTensor::new(data.to_vec(), dims)
|
||||
let data: QStorage = match device {
|
||||
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
|
||||
#[cfg(feature = "metal")]
|
||||
Device::Metal(metal) => load_quantized_metal(metal, data)?,
|
||||
#[cfg(not(feature = "metal"))]
|
||||
Device::Metal(_metal) => {
|
||||
crate::bail!("Metal backend requires `metal` feature")
|
||||
}
|
||||
device => unimplemented!("Implement quantized tensor for device {device:?}"),
|
||||
};
|
||||
super::QTensor::new(data, dims)
|
||||
}
|
||||
|
||||
/// Creates a [Tensor] from a raw GGML tensor.
|
||||
@ -133,29 +146,50 @@ pub fn qtensor_from_ggml(
|
||||
ggml_dtype: GgmlDType,
|
||||
raw_data: &[u8],
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<super::QTensor> {
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let blck_size = ggml_dtype.blck_size();
|
||||
if tensor_elems % blck_size != 0 {
|
||||
let block_size = ggml_dtype.block_size();
|
||||
if tensor_elems % block_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
|
||||
let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size();
|
||||
|
||||
match ggml_dtype {
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
|
||||
GgmlDType::Q4_0 => {
|
||||
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
||||
}
|
||||
}
|
||||
@ -163,6 +197,7 @@ pub fn qtensor_from_ggml(
|
||||
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
magic: VersionedMagic,
|
||||
device: &Device,
|
||||
) -> Result<(String, super::QTensor)> {
|
||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||
@ -183,11 +218,11 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
}
|
||||
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size();
|
||||
// TODO: Mmap version to avoid copying the data around?
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
|
||||
Ok(tensor) => Ok((name, tensor)),
|
||||
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
||||
}
|
||||
@ -201,7 +236,10 @@ pub struct Content {
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Content> {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||
@ -211,7 +249,7 @@ impl Content {
|
||||
let mut tensors = HashMap::new();
|
||||
|
||||
while reader.stream_position()? != last_position {
|
||||
let (name, tensor) = read_one_tensor(reader, magic)?;
|
||||
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||
tensors.insert(name, tensor);
|
||||
}
|
||||
Ok(Self {
|
||||
|
@ -3,7 +3,7 @@
|
||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||
|
||||
use super::{GgmlDType, QTensor};
|
||||
use crate::Result;
|
||||
use crate::{Device, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -59,19 +59,25 @@ impl TensorInfo {
|
||||
&self,
|
||||
reader: &mut R,
|
||||
tensor_data_offset: u64,
|
||||
device: &Device,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_elems = self.shape.elem_count();
|
||||
let blck_size = self.ggml_dtype.blck_size();
|
||||
if tensor_elems % blck_size != 0 {
|
||||
let block_size = self.ggml_dtype.block_size();
|
||||
if tensor_elems % block_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
|
||||
let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size();
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
|
||||
super::ggml_file::qtensor_from_ggml(
|
||||
self.ggml_dtype,
|
||||
&raw_data,
|
||||
self.shape.dims().to_vec(),
|
||||
device,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -460,12 +466,13 @@ impl Content {
|
||||
&self,
|
||||
reader: &mut R,
|
||||
name: &str,
|
||||
device: &Device,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
Some(tensor_info) => tensor_info,
|
||||
None => crate::bail!("cannot find tensor info for {name}"),
|
||||
};
|
||||
tensor_info.read(reader, self.tensor_data_offset)
|
||||
tensor_info.read(reader, self.tensor_data_offset, device)
|
||||
}
|
||||
}
|
||||
|
||||
@ -517,10 +524,9 @@ pub fn write<W: std::io::Seek + std::io::Write>(
|
||||
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
|
||||
)
|
||||
}
|
||||
let data_ptr = tensor.as_ptr();
|
||||
let size_in_bytes = tensor.storage_size_in_bytes();
|
||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||
w.write_all(data)?;
|
||||
let data = tensor.data()?;
|
||||
let size_in_bytes = data.len();
|
||||
w.write_all(&data)?;
|
||||
let padding = 31 - (31 + size_in_bytes) % 32;
|
||||
w.write_all(&vec![0u8; padding])?;
|
||||
}
|
||||
|
153
candle-core/src/quantized/metal.rs
Normal file
153
candle-core/src/quantized/metal.rs
Normal file
@ -0,0 +1,153 @@
|
||||
use super::{GgmlDType, QStorage};
|
||||
use crate::{DType, MetalDevice, MetalStorage, Result};
|
||||
use metal::Buffer;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct QMetalStorage {
|
||||
dtype: GgmlDType,
|
||||
device: MetalDevice,
|
||||
buffer: Arc<Buffer>,
|
||||
}
|
||||
|
||||
impl QMetalStorage {
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn buffer(&self) -> &Buffer {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: GgmlDType) -> Self {
|
||||
Self {
|
||||
device,
|
||||
buffer,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
blit.end_encoding();
|
||||
self.device.wait_until_completed()?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
match self.dtype {
|
||||
GgmlDType::F32 => {
|
||||
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
f32::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::F16 => {
|
||||
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
half::f16::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
let vec: Vec<crate::quantized::BlockQ2K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
let vec: Vec<crate::quantized::BlockQ3K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let vec: Vec<crate::quantized::BlockQ4K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
let vec: Vec<crate::quantized::BlockQ5K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let vec: Vec<crate::quantized::BlockQ6K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8K => {
|
||||
let vec: Vec<crate::quantized::BlockQ8K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
}
|
||||
|
||||
let buffer = self.device.new_buffer_with_data(&out)?;
|
||||
Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
||||
// Quantization only happens on CPU for now.
|
||||
let src = src.to_cpu::<f32>()?;
|
||||
let elem_count = src.len();
|
||||
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
||||
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;
|
||||
qcpu_storage.quantize(&src)?;
|
||||
let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
|
||||
self.buffer = buffer;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
||||
device: &MetalDevice,
|
||||
data: &[T],
|
||||
) -> Result<QStorage> {
|
||||
let buffer = device.new_buffer_with_data(data)?;
|
||||
let device = device.clone();
|
||||
Ok(QStorage::Metal(QMetalStorage {
|
||||
dtype: T::DTYPE,
|
||||
device,
|
||||
buffer,
|
||||
}))
|
||||
}
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let ptr = buffer.contents() as *const T;
|
||||
assert!(!ptr.is_null());
|
||||
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
||||
slice.to_vec()
|
||||
}
|
@ -1,23 +1,125 @@
|
||||
use crate::{Device, Result, Shape, Tensor};
|
||||
#[cfg(feature = "metal")]
|
||||
use crate::{backend::BackendStorage, DType};
|
||||
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
||||
use k_quants::*;
|
||||
use std::borrow::Cow;
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
pub mod ggml_file;
|
||||
pub mod gguf_file;
|
||||
pub mod k_quants;
|
||||
#[cfg(feature = "metal")]
|
||||
pub mod metal;
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub mod simd128;
|
||||
pub mod utils;
|
||||
use half::f16;
|
||||
|
||||
pub use k_quants::GgmlType;
|
||||
|
||||
pub struct QTensor {
|
||||
data: Box<dyn QuantizedType>,
|
||||
storage: QStorage,
|
||||
shape: Shape,
|
||||
}
|
||||
|
||||
impl Device {
|
||||
fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = dtype.cpu_zeros(elem_count);
|
||||
Ok(QStorage::Cpu(storage))
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
Device::Metal(metal) => {
|
||||
let size = elem_count * dtype.type_size() / dtype.block_size();
|
||||
let buffer = metal.allocate_zeros(size)?;
|
||||
Ok(QStorage::Metal(metal::QMetalStorage::new(
|
||||
buffer,
|
||||
metal.clone(),
|
||||
dtype,
|
||||
)))
|
||||
}
|
||||
#[cfg(not(feature = "metal"))]
|
||||
Device::Metal(_metal) => {
|
||||
crate::bail!("Metal feature not activated");
|
||||
}
|
||||
Device::Cuda(_cuda) => {
|
||||
crate::bail!("Cuda ggml quantization not supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum QStorage {
|
||||
Cpu(Box<dyn QuantizedType>),
|
||||
#[cfg(feature = "metal")]
|
||||
Metal(metal::QMetalStorage),
|
||||
}
|
||||
|
||||
impl QStorage {
|
||||
fn block_size(&self) -> usize {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.block_size(),
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(storage) => storage.dtype().block_size(),
|
||||
}
|
||||
}
|
||||
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.dtype(),
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(storage) => storage.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(storage) => storage.buffer().length() as usize,
|
||||
}
|
||||
}
|
||||
|
||||
fn quantize(&mut self, src: &Storage) -> Result<()> {
|
||||
match (self, src) {
|
||||
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
||||
storage.from_float(src.as_slice::<f32>()?)?;
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
||||
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
||||
}
|
||||
}
|
||||
|
||||
fn data(&self) -> Result<Cow<[u8]>> {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => {
|
||||
let data_ptr = storage.as_ptr();
|
||||
let size_in_bytes = storage.storage_size_in_bytes();
|
||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||
Ok(Cow::from(data))
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(_storage) => {
|
||||
crate::bail!("not implemented");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum GgmlDType {
|
||||
F32,
|
||||
@ -77,6 +179,25 @@ impl GgmlDType {
|
||||
}
|
||||
}
|
||||
|
||||
/// The block dtype
|
||||
pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
|
||||
match self {
|
||||
Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
|
||||
Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
|
||||
Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
|
||||
Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
|
||||
Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
|
||||
Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
|
||||
Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
|
||||
Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
|
||||
Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
|
||||
Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
|
||||
Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
|
||||
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
|
||||
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
|
||||
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
|
||||
}
|
||||
}
|
||||
/// The type size for blocks in bytes.
|
||||
pub fn type_size(&self) -> usize {
|
||||
use k_quants::*;
|
||||
@ -100,7 +221,7 @@ impl GgmlDType {
|
||||
}
|
||||
|
||||
/// The block size, i.e. the number of elements stored in each block.
|
||||
pub fn blck_size(&self) -> usize {
|
||||
pub fn block_size(&self) -> usize {
|
||||
match self {
|
||||
Self::F32 => 1,
|
||||
Self::F16 => 1,
|
||||
@ -119,9 +240,13 @@ impl GgmlDType {
|
||||
pub trait QuantizedType: Send + Sync {
|
||||
fn dtype(&self) -> GgmlDType;
|
||||
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
|
||||
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
|
||||
fn storage_size_in_bytes(&self) -> usize;
|
||||
fn as_ptr(&self) -> *const u8;
|
||||
fn block_size(&self) -> usize;
|
||||
#[allow(clippy::wrong_self_convention)]
|
||||
fn from_float(&mut self, xs: &[f32]) -> Result<()>;
|
||||
fn size(&self) -> usize;
|
||||
}
|
||||
|
||||
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
@ -129,12 +254,26 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.len() * core::mem::size_of::<T>()
|
||||
}
|
||||
|
||||
fn from_float(&mut self, xs: &[f32]) -> Result<()> {
|
||||
T::from_float(xs, self)
|
||||
}
|
||||
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
T::DTYPE
|
||||
}
|
||||
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
|
||||
T::to_float(self.as_slice(), ys)
|
||||
fn block_size(&self) -> usize {
|
||||
T::BLCK_SIZE
|
||||
}
|
||||
|
||||
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
|
||||
let mut ys = vec![0.0f32; elem_count];
|
||||
T::to_float(self.as_slice(), &mut ys)?;
|
||||
Ok(CpuStorage::F32(ys))
|
||||
}
|
||||
|
||||
fn storage_size_in_bytes(&self) -> usize {
|
||||
@ -152,56 +291,49 @@ impl std::fmt::Debug for QTensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
|
||||
fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
|
||||
let dims = shape.dims();
|
||||
if dims.is_empty() {
|
||||
crate::bail!("scalar tensor cannot be quantized {shape:?}")
|
||||
}
|
||||
if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
|
||||
if dims[dims.len() - 1] % block_size != 0 {
|
||||
crate::bail!(
|
||||
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
|
||||
T::BLCK_SIZE
|
||||
block_size
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl QTensor {
|
||||
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||
data: Vec<T>,
|
||||
shape: S,
|
||||
) -> Result<Self> {
|
||||
pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
check_shape::<T>(&shape)?;
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
shape,
|
||||
})
|
||||
check_shape(&shape, storage.block_size())?;
|
||||
Ok(Self { storage, shape })
|
||||
}
|
||||
|
||||
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
|
||||
pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
|
||||
let shape = src.shape();
|
||||
check_shape::<T>(shape)?;
|
||||
let src = src
|
||||
.to_dtype(crate::DType::F32)?
|
||||
.flatten_all()?
|
||||
.to_vec1::<f32>()?;
|
||||
if src.len() % T::BLCK_SIZE != 0 {
|
||||
let block_size = dtype.block_size();
|
||||
check_shape(shape, block_size)?;
|
||||
let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
|
||||
let elem_count = shape.elem_count();
|
||||
if elem_count % block_size != 0 {
|
||||
crate::bail!(
|
||||
"tensor size ({shape:?}) is not divisible by block size {}",
|
||||
T::BLCK_SIZE
|
||||
block_size
|
||||
)
|
||||
}
|
||||
let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
||||
T::from_float(&src, &mut data)?;
|
||||
let mut storage = src.device().qzeros(elem_count, dtype)?;
|
||||
storage.quantize(&src.storage())?;
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.data.dtype()
|
||||
self.storage.dtype()
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
@ -213,21 +345,19 @@ impl QTensor {
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||
let mut f32_data = vec![0f32; self.shape.elem_count()];
|
||||
self.data.to_float(&mut f32_data)?;
|
||||
Tensor::from_vec(f32_data, &self.shape, device)
|
||||
}
|
||||
|
||||
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
|
||||
self.data.matmul_t(mkn, lhs, dst)
|
||||
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
||||
let none = crate::op::BackpropOp::none();
|
||||
let is_variable = false;
|
||||
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
|
||||
.to_device(device)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.data.storage_size_in_bytes()
|
||||
self.storage.size_in_bytes()
|
||||
}
|
||||
|
||||
pub fn as_ptr(&self) -> *const u8 {
|
||||
self.data.as_ptr()
|
||||
pub fn data(&self) -> Result<Cow<'_, [u8]>> {
|
||||
self.storage.data()
|
||||
}
|
||||
}
|
||||
|
||||
@ -294,17 +424,93 @@ impl crate::CustomOp1 for QTensor {
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
let storage = storage.as_slice::<f32>()?;
|
||||
let storage =
|
||||
&storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
#[allow(clippy::infallible_destructuring_match)]
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Cpu(storage) => storage,
|
||||
#[cfg(feature = "metal")]
|
||||
_ => crate::bail!("Invalid storage"),
|
||||
};
|
||||
let slice = storage.as_slice::<f32>()?;
|
||||
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||
self.matmul_t(
|
||||
(dst_shape.elem_count() / n, k, n),
|
||||
storage,
|
||||
&mut dst_storage,
|
||||
)?;
|
||||
self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?;
|
||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &crate::MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::MetalStorage, Shape)> {
|
||||
use crate::MetalError;
|
||||
|
||||
if !layout.is_contiguous() {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
// self is transposed so n is first then k.
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let (n, k) = self.shape.dims2()?;
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
|
||||
let (b, m) = match dst_shape.len() {
|
||||
3 => (dst_shape[0], dst_shape[1]),
|
||||
2 => (1, dst_shape[0]),
|
||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||
};
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
let device = storage.device().clone();
|
||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||
let (buffer, dtype) = match &self.storage {
|
||||
QStorage::Metal(metal) => (metal.buffer(), metal.dtype()),
|
||||
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
||||
};
|
||||
let command_buffer = device.command_buffer()?;
|
||||
candle_metal_kernels::call_quantized_matmul_t(
|
||||
device.device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
dtype.into(),
|
||||
(b, m, n, k),
|
||||
storage.buffer(),
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
buffer,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
||||
Ok((dst_storage, dst_shape))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
||||
fn from(value: GgmlDType) -> Self {
|
||||
match value {
|
||||
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
||||
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
||||
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
||||
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
||||
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
||||
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
||||
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
||||
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
||||
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
||||
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
||||
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
||||
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
||||
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
||||
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for QMatMul {
|
||||
|
@ -1,6 +1,7 @@
|
||||
use candle_core::{
|
||||
bail,
|
||||
quantized::{self, GgmlDType},
|
||||
test_device,
|
||||
test_utils::to_vec2_round,
|
||||
Device, Module, Result, Tensor,
|
||||
};
|
||||
@ -14,16 +15,48 @@ const GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS: f32 = 0.0075;
|
||||
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040;
|
||||
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
fn test_matmul(
|
||||
device: &Device,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
dtype: GgmlDType,
|
||||
) -> Result<()> {
|
||||
let lhs = (0..(m * k))
|
||||
.map(|v| v as f32 / (m * k) as f32)
|
||||
.collect::<Vec<_>>();
|
||||
let rhs = (0..(k * n))
|
||||
.map(|v| v as f32 / (n * k) as f32)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||
let rhs = Tensor::from_slice(&rhs, (k, n), device)?;
|
||||
let mm = lhs.matmul(&rhs)?;
|
||||
let qtensor = quantized::QTensor::quantize(&rhs.t()?, dtype)?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&lhs)?;
|
||||
|
||||
let error: f32 = ((&mm - &res)?.abs()? / &mm.abs()?)?
|
||||
.sum_all()?
|
||||
.to_scalar()?;
|
||||
let error = error / (b * m * n) as f32;
|
||||
assert!(
|
||||
error <= 0.02,
|
||||
"Error {error} is too big. \nExpected:\n {mm} \nFound:\n {res}\n for {dtype:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||
let mut dst = vec![42.; 3 * 4];
|
||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||
assert_eq!(
|
||||
@ -33,6 +66,7 @@ fn quantized_matmul() -> Result<()> {
|
||||
341876.0, 994283.0, 1655709.0, 2301518.0
|
||||
]
|
||||
);
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||
assert_eq!(
|
||||
mm.to_vec2::<f32>()?,
|
||||
@ -43,35 +77,49 @@ fn quantized_matmul() -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[85120.0, 214562.0, 345455.0, 474748.0],
|
||||
[213475.0, 604465.0, 1000686.0, 1388317.0],
|
||||
[341876.0, 994283.0, 1655709.0, 2301518.0]
|
||||
]
|
||||
);
|
||||
match device {
|
||||
Device::Metal(_) => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[84946.0, 214126.0, 344757.0, 473798.0],
|
||||
[213458.0, 604350.0, 1000469.0, 1387990.0],
|
||||
[341970.0, 994574.0, 1656181.0, 2302182.0]
|
||||
]
|
||||
),
|
||||
_ => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[85120.0, 214562.0, 345455.0, 474748.0],
|
||||
[213475.0, 604465.0, 1000686.0, 1388317.0],
|
||||
[341876.0, 994283.0, 1655709.0, 2301518.0]
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_neg() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k))
|
||||
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
||||
.collect::<Vec<_>>();
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||
let mut dst = vec![42.; 3 * 4];
|
||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||
let rhs = (0..k * n)
|
||||
.map(|v| v as f32 - (k * n) as f32 / 3.0)
|
||||
.collect::<Vec<_>>();
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||
assert_eq!(
|
||||
@ -91,32 +139,56 @@ fn quantized_matmul_neg() -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||
[23777.0, 21651.0, 19398.0, 18367.0],
|
||||
[-196472.0, 63012.0, 324585.0, 587902.0]
|
||||
]
|
||||
);
|
||||
match device {
|
||||
Device::Metal(_) => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[243666.0, -19714.0, -285433.0, -550453.0],
|
||||
[23782.0, 21654.0, 19400.0, 18369.0],
|
||||
[-196102.0, 63022.0, 324233.0, 587191.0]
|
||||
]
|
||||
),
|
||||
_ => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||
[23777.0, 21651.0, 19398.0, 18367.0],
|
||||
[-196472.0, 63012.0, 324585.0, 587902.0]
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q4_0() -> Result<()> {
|
||||
use k_quants::BlockQ4_0;
|
||||
test_device!(
|
||||
quantized_matmul,
|
||||
quantized_matmul_cpu,
|
||||
quantized_matmul_cuda,
|
||||
quantized_matmul_metal
|
||||
);
|
||||
test_device!(
|
||||
quantized_matmul_neg,
|
||||
quantized_matmul_neg_cpu,
|
||||
quantized_matmul_neg_cuda,
|
||||
quantized_matmul_neg_metal
|
||||
);
|
||||
|
||||
fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ4_0::zeros(); 4];
|
||||
BlockQ4_0::from_float(&src, &mut quant)?;
|
||||
BlockQ4_0::to_float(&quant, dst.as_mut_slice())?;
|
||||
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
assert_eq!(
|
||||
dst,
|
||||
dst.to_vec1::<f32>()?,
|
||||
&[
|
||||
-0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625,
|
||||
11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25,
|
||||
@ -132,21 +204,21 @@ fn quantize_q4_0() -> Result<()> {
|
||||
127.0, 127.0
|
||||
]
|
||||
);
|
||||
ggml_quantization_error_test::<BlockQ4_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test(GgmlDType::Q4_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q4_1() -> Result<()> {
|
||||
use k_quants::BlockQ4_1;
|
||||
|
||||
fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ4_1::zeros(); 4];
|
||||
BlockQ4_1::from_float(&src, &mut quant)?;
|
||||
BlockQ4_1::to_float(&quant, dst.as_mut_slice())?;
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
assert_eq!(
|
||||
round_vector(&dst),
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
&[
|
||||
0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332,
|
||||
12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73,
|
||||
@ -162,21 +234,21 @@ fn quantize_q4_1() -> Result<()> {
|
||||
118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996
|
||||
]
|
||||
);
|
||||
ggml_quantization_error_test::<BlockQ4_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test(GgmlDType::Q4_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q5_0() -> Result<()> {
|
||||
use k_quants::BlockQ5_0;
|
||||
|
||||
fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ5_0::zeros(); 4];
|
||||
BlockQ5_0::from_float(&src, &mut quant)?;
|
||||
BlockQ5_0::to_float(&quant, dst.as_mut_slice())?;
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
assert_eq!(
|
||||
round_vector(&dst),
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
&[
|
||||
-0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625,
|
||||
11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313,
|
||||
@ -192,21 +264,21 @@ fn quantize_q5_0() -> Result<()> {
|
||||
119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0
|
||||
]
|
||||
);
|
||||
ggml_quantization_error_test::<BlockQ5_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test(GgmlDType::Q5_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q5_1() -> Result<()> {
|
||||
use k_quants::BlockQ5_1;
|
||||
|
||||
fn quantize_q5_1(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ5_1::zeros(); 4];
|
||||
BlockQ5_1::from_float(&src, &mut quant)?;
|
||||
BlockQ5_1::to_float(&quant, dst.as_mut_slice())?;
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
assert_eq!(
|
||||
dst,
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
&[
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
|
||||
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
|
||||
@ -220,13 +292,11 @@ fn quantize_q5_1() -> Result<()> {
|
||||
124.0, 125.0, 126.0, 127.0
|
||||
]
|
||||
);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ5_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test(GgmlDType::Q5_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generates a small test vector ranging from -`bound` to `bound` with `size` steps
|
||||
fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
|
||||
fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result<Tensor> {
|
||||
assert!(
|
||||
size % crate::quantized::k_quants::QK_K == 0,
|
||||
"size must be a multiple of {}",
|
||||
@ -236,10 +306,8 @@ fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
|
||||
let src = (0..size)
|
||||
.map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let dst = vec![0f32; size];
|
||||
assert_eq!([src[0], src[size / 2]], [-bound, 0.0]);
|
||||
(src, dst)
|
||||
Tensor::from_vec(src, (size,), device)
|
||||
}
|
||||
|
||||
/// Round a vector
|
||||
@ -288,11 +356,12 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
|
||||
|
||||
/// Similar to the GGML quantization unit test:
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
||||
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
|
||||
fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f32) -> Result<()> {
|
||||
let src = create_ggml_like_vector(0.0);
|
||||
let mut dst = vec![0.0; GGML_TEST_SIZE];
|
||||
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
|
||||
let error = calculate_rmse(src.as_slice(), dst.as_slice());
|
||||
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
|
||||
if error > max_error {
|
||||
bail!(
|
||||
"Quantization error {} exceeds max error {}",
|
||||
@ -303,19 +372,19 @@ fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantize_roundtrip<T: GgmlType>(src: &[f32], dst: &mut [f32]) -> Result<Vec<T>> {
|
||||
let mut quant = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
||||
T::from_float(src, &mut quant)?;
|
||||
T::to_float(&quant, dst)?;
|
||||
Ok(quant)
|
||||
}
|
||||
fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q2K;
|
||||
|
||||
#[test]
|
||||
fn quantize_q2k() -> Result<()> {
|
||||
use k_quants::BlockQ2K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ2K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.1);
|
||||
|
||||
// Test some specific values
|
||||
@ -329,20 +398,30 @@ fn quantize_q2k() -> Result<()> {
|
||||
[-0.499, -0.366, -0.249, 0.0, 0.295, 0.492]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ2K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ2K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q3k() -> Result<()> {
|
||||
use k_quants::BlockQ3K;
|
||||
fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q3K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ3K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.03);
|
||||
|
||||
// Test some specific values
|
||||
@ -356,20 +435,30 @@ fn quantize_q3k() -> Result<()> {
|
||||
[-0.493, -0.37, -0.243, -0.0, 0.292, 0.492]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ3K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ3K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q4k() -> Result<()> {
|
||||
use k_quants::BlockQ4K;
|
||||
fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q4K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ4K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.017);
|
||||
|
||||
// Test some specific values
|
||||
@ -383,21 +472,31 @@ fn quantize_q4k() -> Result<()> {
|
||||
[-0.5, -0.373, -0.25, 0.0, 0.288, 0.498]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ4K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ4K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q5k() -> Result<()> {
|
||||
use k_quants::BlockQ5K;
|
||||
fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q5K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ5K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.009);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
@ -410,21 +509,30 @@ fn quantize_q5k() -> Result<()> {
|
||||
[-0.5, -0.373, -0.25, 0.0, 0.279, 0.499]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ5K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ5K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q6k() -> Result<()> {
|
||||
use k_quants::BlockQ6K;
|
||||
fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q6K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ6K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||
|
||||
// Test some specific values
|
||||
@ -438,22 +546,31 @@ fn quantize_q6k() -> Result<()> {
|
||||
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ6K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ6K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q8k() -> Result<()> {
|
||||
use k_quants::BlockQ8K;
|
||||
fn quantize_q8k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q8K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ8K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.003);
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
@ -466,15 +583,79 @@ fn quantize_q8k() -> Result<()> {
|
||||
[-0.5, -0.375, -0.25, -0.0, 0.281, 0.499]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ8K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ8K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(
|
||||
quantize_q4_0,
|
||||
quantize_q4_0_cpu,
|
||||
quantize_q4_0_cuda,
|
||||
quantize_q4_0_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q4_1,
|
||||
quantize_q4_1_cpu,
|
||||
quantize_q4_1_cuda,
|
||||
quantize_q4_1_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q5_0,
|
||||
quantize_q5_0_cpu,
|
||||
quantize_q5_0_cuda,
|
||||
quantize_q5_0_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q5_1,
|
||||
quantize_q5_1_cpu,
|
||||
quantize_q5_1_cuda,
|
||||
quantize_q5_1_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q2k,
|
||||
quantize_q2k_cpu,
|
||||
quantize_q2k_cuda,
|
||||
quantize_q2k_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q3k,
|
||||
quantize_q3k_cpu,
|
||||
quantize_q3k_cuda,
|
||||
quantize_q3k_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q4k,
|
||||
quantize_q4k_cpu,
|
||||
quantize_q4k_cuda,
|
||||
quantize_q4k_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q5k,
|
||||
quantize_q5k_cpu,
|
||||
quantize_q5k_cuda,
|
||||
quantize_q5k_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q6k,
|
||||
quantize_q6k_cpu,
|
||||
quantize_q6k_cuda,
|
||||
quantize_q6k_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q8k,
|
||||
quantize_q8k_cpu,
|
||||
quantize_q8k_cuda,
|
||||
quantize_q8k_metal
|
||||
);
|
||||
|
||||
/// Very simple dot product implementation
|
||||
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b).map(|(a, b)| a * b).sum()
|
||||
@ -591,6 +772,112 @@ fn get_random_tensors(
|
||||
Ok((lhs, rhs, mm))
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! quantized_matmul {
|
||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
|
||||
fn $fn_name(device: &Device) -> Result<()> {
|
||||
if device.is_cuda() {
|
||||
// TODO Enable Cuda GGML sometime maybe.
|
||||
return Ok(());
|
||||
}
|
||||
test_matmul(device, (1, 3, 4, 256), $dtype)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!($fn_name, $fn_name_cpu, $fn_name_cuda, $fn_name_metal);
|
||||
};
|
||||
}
|
||||
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q4_0_bis,
|
||||
quantized_matmul_q4_0_cpu,
|
||||
quantized_matmul_q4_0_cuda,
|
||||
quantized_matmul_q4_0_metal,
|
||||
GgmlDType::Q4_0
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q4_1_bis,
|
||||
quantized_matmul_q4_1_cpu,
|
||||
quantized_matmul_q4_1_cuda,
|
||||
quantized_matmul_q4_1_metal,
|
||||
GgmlDType::Q4_1
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q5_0_bis,
|
||||
quantized_matmul_q5_0_cpu,
|
||||
quantized_matmul_q5_0_cuda,
|
||||
quantized_matmul_q5_0_metal,
|
||||
GgmlDType::Q5_0
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q5_1_bis,
|
||||
quantized_matmul_q5_1_cpu,
|
||||
quantized_matmul_q5_1_cuda,
|
||||
quantized_matmul_q5_1_metal,
|
||||
GgmlDType::Q5_1
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q8_0_bis,
|
||||
quantized_matmul_q8_0_cpu,
|
||||
quantized_matmul_q8_0_cuda,
|
||||
quantized_matmul_q8_0_metal,
|
||||
GgmlDType::Q8_0
|
||||
);
|
||||
// Not implemented in Ggml
|
||||
// quantized_matmul!(
|
||||
// quantized_matmul_q8_1_bis,
|
||||
// quantized_matmul_q8_1_cpu,
|
||||
// quantized_matmul_q8_1_cuda,
|
||||
// quantized_matmul_q8_1_metal,
|
||||
// GgmlDType::Q8_1
|
||||
// );
|
||||
// TODO This is bugged (also bugged in GGML
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q2k_bis,
|
||||
quantized_matmul_q2k_cpu,
|
||||
quantized_matmul_q2k_cuda,
|
||||
quantized_matmul_q2k_metal,
|
||||
GgmlDType::Q2K
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q3k_bis,
|
||||
quantized_matmul_q3k_cpu,
|
||||
quantized_matmul_q3k_cuda,
|
||||
quantized_matmul_q3k_metal,
|
||||
GgmlDType::Q3K
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q4k_bis,
|
||||
quantized_matmul_q4k_cpu,
|
||||
quantized_matmul_q4k_cuda,
|
||||
quantized_matmul_q4k_metal,
|
||||
GgmlDType::Q4K
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q5k_bis,
|
||||
quantized_matmul_q5k_cpu,
|
||||
quantized_matmul_q5k_cuda,
|
||||
quantized_matmul_q5k_metal,
|
||||
GgmlDType::Q5K
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q6k_bis,
|
||||
quantized_matmul_q6k_cpu,
|
||||
quantized_matmul_q6k_cuda,
|
||||
quantized_matmul_q6k_metal,
|
||||
GgmlDType::Q6K
|
||||
);
|
||||
// Not implemented on metal
|
||||
// quantized_matmul!(
|
||||
// quantized_matmul_q8k_bis,
|
||||
// quantized_matmul_q8k_cpu,
|
||||
// quantized_matmul_q8k_cuda,
|
||||
// quantized_matmul_q8k_metal,
|
||||
// GgmlDType::Q8K
|
||||
// );
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q2k() -> Result<()> {
|
||||
use k_quants::BlockQ2K;
|
||||
@ -603,7 +890,7 @@ fn quantized_matmul_q2k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
@ -629,7 +916,7 @@ fn quantized_matmul_q3k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q3K)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
@ -655,7 +942,7 @@ fn quantized_matmul_q4k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q4K)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
@ -681,7 +968,7 @@ fn quantized_matmul_q5k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q5K)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
@ -708,7 +995,7 @@ fn quantized_matmul_q6k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q6K)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
@ -733,7 +1020,7 @@ fn quantized_matmul_q8k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q8K)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
|
@ -106,17 +106,17 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let config = blip::Config::image_captioning_large();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (image_embeds, device, mut model) = if args.quantized {
|
||||
let device = Device::Cpu;
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
|
||||
let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?;
|
||||
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||
(image_embeds, device, Model::Q(model))
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
|
@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
.extension()
|
||||
.map_or(false, |v| v == "safetensors");
|
||||
let (model, config) = if is_gguf {
|
||||
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
|
||||
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
|
||||
let (_vocab_size, dim) = vb
|
||||
.get_no_shape("model.embed_tokens.weight")?
|
||||
.shape()
|
||||
@ -279,13 +279,13 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
(config.seq_len, config.head_size() / 2),
|
||||
"rot.freq_cis_real",
|
||||
)?
|
||||
.dequantize(&candle::Device::Cpu)?;
|
||||
.dequantize(&device)?;
|
||||
let freq_cis_imag = vb
|
||||
.get(
|
||||
(config.seq_len, config.head_size() / 2),
|
||||
"rot.freq_cis_imag",
|
||||
)?
|
||||
.dequantize(&candle::Device::Cpu)?;
|
||||
.dequantize(&device)?;
|
||||
|
||||
let fake_vb = candle_nn::VarBuilder::from_tensors(
|
||||
[
|
||||
@ -295,7 +295,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
.into_iter()
|
||||
.collect(),
|
||||
candle::DType::F32,
|
||||
&candle::Device::Cpu,
|
||||
&device,
|
||||
);
|
||||
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||
|
@ -244,13 +244,14 @@ fn main() -> Result<()> {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||
let model = QMistral::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
(Model::Quantized(model), device)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
|
@ -307,18 +307,21 @@ fn main() -> Result<()> {
|
||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||
};
|
||||
let (model, device) = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = if args.quantized {
|
||||
let config = config();
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&filenames[0],
|
||||
&device,
|
||||
)?;
|
||||
let model = match args.model {
|
||||
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
|
||||
_ => QMixFormer::new(&config, vb)?,
|
||||
};
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
Model::Quantized(model)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = match args.model {
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
@ -334,8 +337,7 @@ fn main() -> Result<()> {
|
||||
let config = config();
|
||||
Model::MixFormer(MixFormer::new(&config, vb)?)
|
||||
}
|
||||
};
|
||||
(model, device)
|
||||
}
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
@ -132,7 +132,8 @@ impl T5ModelBuilder {
|
||||
}
|
||||
|
||||
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
|
||||
let device = Device::Cpu;
|
||||
let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &device)?;
|
||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||
}
|
||||
|
||||
|
@ -9,7 +9,7 @@ use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{Device, Tensor};
|
||||
use candle::Tensor;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -361,6 +361,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(false)?;
|
||||
|
||||
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
|
||||
Some("gguf") => {
|
||||
@ -369,7 +370,7 @@ fn main() -> anyhow::Result<()> {
|
||||
for (_, tensor) in model.tensor_infos.iter() {
|
||||
let elem_count = tensor.shape.elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
@ -377,15 +378,16 @@ fn main() -> anyhow::Result<()> {
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
ModelWeights::from_gguf(model, &mut file)?
|
||||
ModelWeights::from_gguf(model, &mut file, &device)?
|
||||
}
|
||||
Some("ggml" | "bin") | Some(_) | None => {
|
||||
let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
|
||||
let model = ggml_file::Content::read(&mut file, &device)
|
||||
.map_err(|e| e.with_path(model_path))?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensors.iter() {
|
||||
let elem_count = tensor.shape().elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
|
||||
elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
@ -486,7 +488,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
@ -507,7 +509,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
|
@ -236,16 +236,15 @@ fn main() -> Result<()> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let config = Config::replit_code_v1_5_3b();
|
||||
let (model, device) = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
||||
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
|
||||
(model, Device::Cpu)
|
||||
let model = if args.quantized {
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
|
||||
Model::Q(Q::new(&config, vb.pp("transformer"))?)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||
let model = Model::M(M::new(&config, vb.pp("transformer"))?);
|
||||
(model, device)
|
||||
Model::M(M::new(&config, vb.pp("transformer"))?)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
@ -234,13 +234,14 @@ fn main() -> Result<()> {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||
let model = QStableLM::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
|
@ -557,8 +557,10 @@ fn main() -> Result<()> {
|
||||
println!("loaded mel: {:?}", mel.dims());
|
||||
|
||||
let mut model = if args.quantized {
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&weights_filename,
|
||||
&device,
|
||||
)?;
|
||||
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
|
||||
} else {
|
||||
let vb =
|
||||
|
@ -15,6 +15,7 @@ const CAST: &str = include_str!("cast.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||
@ -62,6 +63,8 @@ macro_rules! primitive {
|
||||
};
|
||||
}
|
||||
primitive!(usize);
|
||||
primitive!(i64);
|
||||
primitive!(i32);
|
||||
primitive!(u32);
|
||||
primitive!(f32);
|
||||
|
||||
@ -117,6 +120,7 @@ pub enum Source {
|
||||
Reduce,
|
||||
Mfa,
|
||||
Conv,
|
||||
Quantized,
|
||||
}
|
||||
|
||||
macro_rules! ops{
|
||||
@ -215,17 +219,15 @@ type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipeline
|
||||
pub struct Kernels {
|
||||
libraries: RwLock<Libraries>,
|
||||
pipelines: RwLock<Pipelines>,
|
||||
fence: metal::Fence,
|
||||
}
|
||||
|
||||
impl Kernels {
|
||||
pub fn new(fence: metal::Fence) -> Self {
|
||||
pub fn new() -> Self {
|
||||
let libraries = RwLock::new(Libraries::new());
|
||||
let pipelines = RwLock::new(Pipelines::new());
|
||||
Self {
|
||||
libraries,
|
||||
pipelines,
|
||||
fence,
|
||||
}
|
||||
}
|
||||
|
||||
@ -239,6 +241,7 @@ impl Kernels {
|
||||
Source::Cast => CAST,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Conv => CONV,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
@ -345,7 +348,6 @@ pub fn call_unary_contiguous(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, output));
|
||||
@ -354,7 +356,6 @@ pub fn call_unary_contiguous(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -376,7 +377,6 @@ pub fn call_unary_strided(
|
||||
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
@ -398,7 +398,6 @@ pub fn call_unary_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -417,7 +416,6 @@ pub fn call_binary_contiguous(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, left, right, output));
|
||||
@ -428,7 +426,6 @@ pub fn call_binary_contiguous(
|
||||
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -453,7 +450,6 @@ pub fn call_binary_strided(
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let width: usize = shape.iter().product();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
@ -478,7 +474,6 @@ pub fn call_binary_strided(
|
||||
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -497,7 +492,6 @@ pub fn call_cast_contiguous(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, (input, input_offset), output));
|
||||
@ -506,7 +500,6 @@ pub fn call_cast_contiguous(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -526,7 +519,6 @@ pub fn call_cast_strided(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
@ -548,7 +540,6 @@ pub fn call_cast_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -568,7 +559,6 @@ pub fn call_reduce_contiguous(
|
||||
let elements_to_sum = length / out_length;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -597,7 +587,6 @@ pub fn call_reduce_contiguous(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -619,7 +608,6 @@ pub fn call_reduce_strided(
|
||||
let elements_to_sum = length / out_length;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -655,7 +643,6 @@ pub fn call_reduce_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -674,7 +661,6 @@ pub fn call_last_softmax(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -705,7 +691,6 @@ pub fn call_last_softmax(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -725,7 +710,6 @@ pub fn call_affine(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, add, input, output));
|
||||
@ -734,7 +718,6 @@ pub fn call_affine(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -757,7 +740,6 @@ pub fn call_affine_strided(
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -778,7 +760,6 @@ pub fn call_affine_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -797,7 +778,6 @@ pub fn call_powf(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, input, output));
|
||||
@ -806,7 +786,6 @@ pub fn call_powf(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -828,7 +807,6 @@ pub fn call_powf_strided(
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -848,7 +826,6 @@ pub fn call_powf_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -867,7 +844,6 @@ pub fn call_elu(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, input, output));
|
||||
@ -876,7 +852,6 @@ pub fn call_elu(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -898,7 +873,6 @@ pub fn call_elu_strided(
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -918,7 +892,6 @@ pub fn call_elu_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -940,7 +913,6 @@ pub fn call_where_cond_strided(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let size: usize = shape.iter().product();
|
||||
@ -969,7 +941,6 @@ pub fn call_where_cond_strided(
|
||||
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -996,7 +967,6 @@ pub fn call_index_select(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -1019,7 +989,6 @@ pub fn call_index_select(
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -1048,7 +1017,6 @@ pub fn call_gather(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -1071,7 +1039,6 @@ pub fn call_gather(
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -1100,7 +1067,6 @@ pub fn call_scatter_add(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -1123,7 +1089,6 @@ pub fn call_scatter_add(
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -1153,7 +1118,6 @@ pub fn call_index_add(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -1177,7 +1141,6 @@ pub fn call_index_add(
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -1381,7 +1344,6 @@ pub fn call_gemm(
|
||||
let block_bytes = block_elements * bytes;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||
@ -1421,12 +1383,10 @@ pub fn call_gemm(
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
// println!("grid size {grid_size:?} group size {group_size:?}");
|
||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
@ -1451,7 +1411,6 @@ pub fn call_im2col1d_strided(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
@ -1471,7 +1430,6 @@ pub fn call_im2col1d_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
@ -1501,7 +1459,6 @@ pub fn call_im2col_strided(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
@ -1523,7 +1480,6 @@ pub fn call_im2col_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
@ -1549,7 +1505,6 @@ pub fn call_upsample_nearest_2d(
|
||||
let scale_h = shape[3] as f32 / out_h as f32;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
@ -1567,7 +1522,176 @@ pub fn call_upsample_nearest_2d(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum GgmlDType {
|
||||
Q4_0,
|
||||
Q4_1,
|
||||
Q5_0,
|
||||
Q5_1,
|
||||
Q8_0,
|
||||
Q8_1,
|
||||
Q2K,
|
||||
Q3K,
|
||||
Q4K,
|
||||
Q5K,
|
||||
Q6K,
|
||||
Q8K,
|
||||
F16,
|
||||
F32,
|
||||
}
|
||||
|
||||
pub fn call_quantized_matmul_t(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
dtype: GgmlDType,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs: &Buffer,
|
||||
lhs_offset: usize,
|
||||
rhs: &Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// Everything is in reverse
|
||||
let ne00 = k as i64;
|
||||
let ne01 = n as i64;
|
||||
let ne02 = b as i64;
|
||||
let ne03 = 1 as i64;
|
||||
|
||||
let nb00 = 0i64;
|
||||
let nb01 = 0 as i64;
|
||||
let nb02 = 0 as i64;
|
||||
|
||||
let ne10 = k as i64;
|
||||
let ne11 = m as i64;
|
||||
let ne12 = b as i64;
|
||||
let ne13 = 1 as i64;
|
||||
|
||||
let nb10 = 0i64;
|
||||
let nb11 = 0i64;
|
||||
let nb12 = 0i64;
|
||||
|
||||
let ne0 = n as i64;
|
||||
let ne1 = m as i64;
|
||||
let r2: u32 = (ne12 / ne02) as u32;
|
||||
let r3: u32 = (ne13 / ne03) as u32;
|
||||
|
||||
let (nth0, nth1, align) = match dtype {
|
||||
GgmlDType::Q4_0
|
||||
| GgmlDType::Q4_1
|
||||
| GgmlDType::Q5_0
|
||||
| GgmlDType::Q5_1
|
||||
| GgmlDType::Q8_0
|
||||
| GgmlDType::Q8_1 => {
|
||||
let nth0 = 8;
|
||||
let nth1 = 8;
|
||||
let align = 8;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
// Fixing a bug in Metal for GGML
|
||||
let nth0 = 4;
|
||||
let nth1 = 8;
|
||||
let align = 4;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let nth0 = 4;
|
||||
let nth1 = 8;
|
||||
let align = 4;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::Q3K | GgmlDType::Q5K => {
|
||||
let nth0 = 2;
|
||||
let nth1 = 32;
|
||||
let align = 4;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let nth0 = 2;
|
||||
let nth1 = 32;
|
||||
let align = 2;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::F16 | GgmlDType::Q8K => {
|
||||
// Original implem uses rows
|
||||
let nth0 = 32;
|
||||
let nth1 = 1;
|
||||
let align = 8;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::F32 => {
|
||||
let nth0 = 32;
|
||||
let nth1 = 1;
|
||||
let align = 8;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
};
|
||||
let thread_groups_count = MTLSize {
|
||||
width: divide(ne01 as usize, align),
|
||||
height: ne11 as u64,
|
||||
depth: (ne12 * ne13) as u64,
|
||||
};
|
||||
let threads_per_threadgroup = MTLSize {
|
||||
width: nth0,
|
||||
height: nth1,
|
||||
depth: 1,
|
||||
};
|
||||
let name = match dtype {
|
||||
GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32",
|
||||
GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32",
|
||||
GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32",
|
||||
GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32",
|
||||
GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32",
|
||||
GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32",
|
||||
GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32",
|
||||
GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32",
|
||||
GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32",
|
||||
GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32",
|
||||
GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32",
|
||||
GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32",
|
||||
GgmlDType::F16 => "kernel_mul_mv_f16_f32",
|
||||
GgmlDType::F32 => "kernel_mul_mv_f32_f32",
|
||||
};
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
rhs,
|
||||
(lhs, lhs_offset),
|
||||
output,
|
||||
ne00,
|
||||
ne01,
|
||||
ne02,
|
||||
nb00,
|
||||
nb01,
|
||||
nb02,
|
||||
ne10,
|
||||
ne11,
|
||||
ne12,
|
||||
nb10,
|
||||
nb11,
|
||||
nb12,
|
||||
ne0,
|
||||
ne1,
|
||||
r2,
|
||||
r3
|
||||
)
|
||||
);
|
||||
encoder.set_threadgroup_memory_length(0, 8192);
|
||||
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
|
5107
candle-metal-kernels/src/quantized.metal
Normal file
5107
candle-metal-kernels/src/quantized.metal
Normal file
File diff suppressed because it is too large
Load Diff
@ -37,8 +37,7 @@ fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
|
||||
|
||||
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
@ -60,8 +59,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
|
||||
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
@ -96,8 +94,7 @@ fn run_strided<T: Clone>(
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
call_unary_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -278,8 +275,7 @@ fn binary_ops_bf16() {
|
||||
|
||||
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
@ -409,8 +405,7 @@ fn it_cast_f16_bf16() {
|
||||
|
||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
@ -445,8 +440,7 @@ fn run_affine_strided<T: Clone>(
|
||||
add: f64,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
@ -595,8 +589,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
let dst_el = ids.len() * left_size * right_size;
|
||||
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
call_index_select(
|
||||
&device,
|
||||
&command_buffer,
|
||||
@ -631,8 +624,7 @@ fn cos_f16() {
|
||||
|
||||
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
@ -662,8 +654,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
||||
|
||||
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
@ -782,8 +773,7 @@ fn run_where_cond<I: Clone, T: Clone>(
|
||||
name: &'static str,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
@ -859,8 +849,7 @@ fn run_gemm<T: Clone>(
|
||||
rhs_offset: usize,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
@ -117,7 +117,6 @@ UNARY_OP(erf)
|
||||
UNARY_OP(tanh)
|
||||
UNARY_OP(recip)
|
||||
UNARY_OP(relu)
|
||||
|
||||
UNARY(id, float, copy_f32, copy_f32_strided)
|
||||
UNARY(id, half, copy_f16, copy_f16_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
@ -136,6 +135,7 @@ BFLOAT_UNARY_OP(neg)
|
||||
BFLOAT_UNARY_OP(exp)
|
||||
BFLOAT_UNARY_OP(log)
|
||||
BFLOAT_UNARY_OP(gelu)
|
||||
BFLOAT_UNARY_OP(abs)
|
||||
BFLOAT_UNARY_OP(ceil)
|
||||
BFLOAT_UNARY_OP(floor)
|
||||
BFLOAT_UNARY_OP(round)
|
||||
|
@ -222,7 +222,10 @@ impl Benchmark for QMatMul {
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
|
||||
let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
|
||||
let mm = candle::quantized::QTensor::new(
|
||||
candle::quantized::QStorage::Cpu(Box::new(zeros)),
|
||||
(4096, 11008),
|
||||
)?;
|
||||
let mm = candle::quantized::QMatMul::from_qtensor(mm)?;
|
||||
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
|
||||
Ok((mm, arg))
|
||||
|
@ -33,7 +33,9 @@ def has_mkl() -> bool:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
|
||||
def load_ggml(
|
||||
path: Union[str, PathLike], device: Optional[Device] = None
|
||||
) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
|
||||
"""
|
||||
Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
|
||||
a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
|
||||
@ -41,7 +43,9 @@ def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str,
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def load_gguf(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
|
||||
def load_gguf(
|
||||
path: Union[str, PathLike], device: Optional[Device] = None
|
||||
) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
|
||||
"""
|
||||
Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
|
||||
and the second maps metadata keys to metadata values.
|
||||
|
@ -1074,20 +1074,20 @@ impl PyTensor {
|
||||
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
|
||||
use ::candle::quantized;
|
||||
let res = match quantized_dtype.to_lowercase().as_str() {
|
||||
"q2k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ2K>(self),
|
||||
"q3k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ3K>(self),
|
||||
"q4_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_0>(self),
|
||||
"q4_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_1>(self),
|
||||
"q4k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4K>(self),
|
||||
"q5_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_0>(self),
|
||||
"q5_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_1>(self),
|
||||
"q5k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5K>(self),
|
||||
"q6k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ6K>(self),
|
||||
"q8_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_0>(self),
|
||||
"q8_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_1>(self),
|
||||
"q8k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8K>(self),
|
||||
"f16" => quantized::QTensor::quantize::<f16>(self),
|
||||
"f32" => quantized::QTensor::quantize::<f32>(self),
|
||||
"q2k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q2K),
|
||||
"q3k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q3K),
|
||||
"q4_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_0),
|
||||
"q4_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_1),
|
||||
"q4k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4K),
|
||||
"q5_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_0),
|
||||
"q5_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_1),
|
||||
"q5k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5K),
|
||||
"q6k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q6K),
|
||||
"q8_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_0),
|
||||
"q8_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_1),
|
||||
"q8k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8K),
|
||||
"f16" => quantized::QTensor::quantize(self, quantized::GgmlDType::F16),
|
||||
"f32" => quantized::QTensor::quantize(self, quantized::GgmlDType::F32),
|
||||
dt => {
|
||||
return Err(PyErr::new::<PyValueError, _>(format!(
|
||||
"unknown quantized-dtype {dt}"
|
||||
@ -1278,13 +1278,19 @@ fn save_safetensors(
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
|
||||
#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")]
|
||||
/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
|
||||
/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
|
||||
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]]
|
||||
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
|
||||
fn load_ggml(
|
||||
path: &str,
|
||||
device: Option<PyDevice>,
|
||||
py: Python<'_>,
|
||||
) -> PyResult<(PyObject, PyObject, PyObject)> {
|
||||
let mut file = std::fs::File::open(path)?;
|
||||
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let ggml =
|
||||
::candle::quantized::ggml_file::Content::read(&mut file, &device).map_err(wrap_err)?;
|
||||
let tensors = ggml
|
||||
.tensors
|
||||
.into_iter()
|
||||
@ -1313,11 +1319,16 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
|
||||
#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")]
|
||||
/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
|
||||
/// and the second maps metadata keys to metadata values.
|
||||
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]]
|
||||
fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
|
||||
fn load_gguf(
|
||||
path: &str,
|
||||
device: Option<PyDevice>,
|
||||
py: Python<'_>,
|
||||
) -> PyResult<(PyObject, PyObject)> {
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
use ::candle::quantized::gguf_file;
|
||||
fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
|
||||
let v: PyObject = match v {
|
||||
@ -1349,7 +1360,7 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
|
||||
.tensor_infos
|
||||
.keys()
|
||||
.map(|key| {
|
||||
let qtensor = gguf.tensor(&mut file, key)?;
|
||||
let qtensor = gguf.tensor(&mut file, key, &device)?;
|
||||
Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py)))
|
||||
})
|
||||
.collect::<::candle::Result<Vec<_>>>()
|
||||
|
@ -356,6 +356,7 @@ impl ModelWeights {
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let cpu = &Device::Cpu;
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
@ -383,21 +384,28 @@ impl ModelWeights {
|
||||
.unwrap_or(10000f32);
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
|
||||
let output = ct.tensor(reader, "output.weight")?;
|
||||
let norm = RmsNorm::new(
|
||||
ct.tensor(reader, "output_norm.weight", device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
let output = ct.tensor(reader, "output.weight", device)?;
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
|
||||
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
|
||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
||||
let attention_wo =
|
||||
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
||||
let mlp_or_moe = if n_expert <= 1 {
|
||||
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
|
||||
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
|
||||
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
|
||||
let feed_forward_w1 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||
let feed_forward_w2 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||
let feed_forward_w3 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||
MlpOrMoe::Mlp(Mlp {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
@ -405,15 +413,15 @@ impl ModelWeights {
|
||||
})
|
||||
} else {
|
||||
let feed_forward_gate_inp =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?;
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?;
|
||||
let mut experts = Vec::with_capacity(n_expert);
|
||||
for i in 0..n_expert {
|
||||
let feed_forward_w1 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?;
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?;
|
||||
let feed_forward_w2 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?;
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?;
|
||||
let feed_forward_w3 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?;
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?;
|
||||
experts.push(Mlp {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
@ -426,8 +434,9 @@ impl ModelWeights {
|
||||
experts,
|
||||
}
|
||||
};
|
||||
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
|
||||
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
|
||||
let attention_norm =
|
||||
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
|
||||
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
|
@ -311,7 +311,7 @@ impl MixFormerSequentialForCausalLM {
|
||||
let mut blocks = Vec::new();
|
||||
for i in 0..cfg.n_layer {
|
||||
let block = ParallelBlock::new(cfg, vb.pp(i + 1))?;
|
||||
blocks.push(block)
|
||||
blocks.push(block);
|
||||
}
|
||||
let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?;
|
||||
Ok(Self {
|
||||
@ -332,7 +332,7 @@ impl MixFormerSequentialForCausalLM {
|
||||
Some(get_mask(seq_len, xs.device())?)
|
||||
};
|
||||
for block in self.blocks.iter_mut() {
|
||||
xs = block.forward(&xs, mask.as_ref())?
|
||||
xs = block.forward(&xs, mask.as_ref())?;
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
|
||||
}
|
||||
|
@ -10,33 +10,33 @@ pub struct VarBuilder {
|
||||
}
|
||||
|
||||
impl VarBuilder {
|
||||
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P, device: &Device) -> Result<Self> {
|
||||
let mut file = std::fs::File::open(p)?;
|
||||
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
|
||||
let mut data = std::collections::HashMap::new();
|
||||
for tensor_name in content.tensor_infos.keys() {
|
||||
let tensor = content.tensor(&mut file, tensor_name)?;
|
||||
let tensor = content.tensor(&mut file, tensor_name, device)?;
|
||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||
}
|
||||
Ok(Self {
|
||||
data: Arc::new(data),
|
||||
path: Vec::new(),
|
||||
device: Device::Cpu,
|
||||
device: device.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
|
||||
pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result<Self> {
|
||||
let mut cursor = std::io::Cursor::new(buffer);
|
||||
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
|
||||
let mut data = std::collections::HashMap::new();
|
||||
for tensor_name in content.tensor_infos.keys() {
|
||||
let tensor = content.tensor(&mut cursor, tensor_name)?;
|
||||
let tensor = content.tensor(&mut cursor, tensor_name, device)?;
|
||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||
}
|
||||
Ok(Self {
|
||||
data: Arc::new(data),
|
||||
path: Vec::new(),
|
||||
device: Device::Cpu,
|
||||
device: device.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -61,7 +61,7 @@ impl Model {
|
||||
|
||||
let start = Date::now();
|
||||
let model: SelectedModel = if quantized {
|
||||
let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights)?;
|
||||
let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights, &device)?;
|
||||
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
SelectedModel::Q(model)
|
||||
} else {
|
||||
|
@ -41,6 +41,7 @@ impl Model {
|
||||
) -> Result<Model, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let device = Device::Cpu;
|
||||
let name: ModelName = serde_json::from_slice(&config)?;
|
||||
let config: Config = serde_json::from_slice(&config)?;
|
||||
|
||||
@ -50,8 +51,9 @@ impl Model {
|
||||
let start = Date::now();
|
||||
console_log!("weights len: {:?}", weights.len());
|
||||
let model = if quantized {
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?;
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
|
||||
&weights, &device,
|
||||
)?;
|
||||
console_log!("weights loaded");
|
||||
if name._name_or_path == "microsoft/phi-2" {
|
||||
let model = QMixFormer::new_v2(&config, vb)?;
|
||||
|
@ -7,6 +7,7 @@ pub use candle_transformers::models::quantized_t5::{
|
||||
use candle_wasm_example_t5::console_log;
|
||||
use tokenizers::Tokenizer;
|
||||
use wasm_bindgen::prelude::*;
|
||||
const DEVICE: Device = Device::Cpu;
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct ModelEncoder {
|
||||
@ -31,7 +32,7 @@ impl ModelConditionalGeneration {
|
||||
) -> Result<ModelConditionalGeneration, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let vb = VarBuilder::from_gguf_buffer(&weights)?;
|
||||
let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?;
|
||||
let mut config: Config = serde_json::from_slice(&config)?;
|
||||
let tokenizer =
|
||||
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
@ -46,7 +47,7 @@ impl ModelConditionalGeneration {
|
||||
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
|
||||
let input: ConditionalGenerationParams =
|
||||
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
let device = &Device::Cpu;
|
||||
let device = &DEVICE;
|
||||
self.model.clear_kv_cache();
|
||||
let mut output_token_ids = [self.config.pad_token_id as u32].to_vec();
|
||||
let prompt = input.prompt;
|
||||
@ -128,7 +129,7 @@ impl ModelEncoder {
|
||||
) -> Result<ModelEncoder, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let vb = VarBuilder::from_gguf_buffer(&weights)?;
|
||||
let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?;
|
||||
let mut config: Config = serde_json::from_slice(&config)?;
|
||||
config.use_cache = false;
|
||||
let tokenizer =
|
||||
@ -138,7 +139,7 @@ impl ModelEncoder {
|
||||
}
|
||||
|
||||
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
|
||||
let device = &Device::Cpu;
|
||||
let device = &DEVICE;
|
||||
let input: DecoderParams =
|
||||
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
|
||||
|
@ -315,6 +315,7 @@ impl Decoder {
|
||||
let model = if md.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
|
||||
&md.weights,
|
||||
&device,
|
||||
)?;
|
||||
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
|
||||
} else {
|
||||
|
@ -40,7 +40,7 @@ fn quantized_matmul_neg() -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let qtensor = quantized::QTensor::new(quantized::QStorage::Cpu(Box::new(rhs_t)), (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
|
Reference in New Issue
Block a user