mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Tensor -> QTensor conversion (#496)
* Sketch some qmatmul test. * Add the quantization function. * More testing. * Make the test smaller and faster. * Add some shape checking.
This commit is contained in:
@ -90,7 +90,7 @@ impl Benchmark for QMatMul {
|
|||||||
type RunResult = Tensor;
|
type RunResult = Tensor;
|
||||||
fn preprocess() -> Result<Self::PreProcessData> {
|
fn preprocess() -> Result<Self::PreProcessData> {
|
||||||
let zeros = vec![candle_core::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
|
let zeros = vec![candle_core::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
|
||||||
let mm = candle_core::quantized::QTensor::new(zeros, (4096, 11008));
|
let mm = candle_core::quantized::QTensor::new(zeros, (4096, 11008))?;
|
||||||
let mm = candle_core::quantized::QMatMul::from_qtensor(mm);
|
let mm = candle_core::quantized::QMatMul::from_qtensor(mm);
|
||||||
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
|
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
|
||||||
Ok((mm, arg))
|
Ok((mm, arg))
|
||||||
|
@ -125,7 +125,7 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
|||||||
let raw_data_ptr = raw_data.as_ptr();
|
let raw_data_ptr = raw_data.as_ptr();
|
||||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||||
Ok(super::QTensor::new(data.to_vec(), dims))
|
super::QTensor::new(data.to_vec(), dims)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a [Tensor] from a raw GGML tensor.
|
/// Creates a [Tensor] from a raw GGML tensor.
|
||||||
|
@ -117,15 +117,52 @@ impl std::fmt::Debug for QTensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
|
||||||
|
let dims = shape.dims();
|
||||||
|
if dims.is_empty() {
|
||||||
|
crate::bail!("scalar tensor cannot be quantized {shape:?}")
|
||||||
|
}
|
||||||
|
if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
|
||||||
|
crate::bail!(
|
||||||
|
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
|
||||||
|
T::BLCK_SIZE
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
impl QTensor {
|
impl QTensor {
|
||||||
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||||
data: Vec<T>,
|
data: Vec<T>,
|
||||||
shape: S,
|
shape: S,
|
||||||
) -> Self {
|
) -> Result<Self> {
|
||||||
Self {
|
let shape = shape.into();
|
||||||
|
check_shape::<T>(&shape)?;
|
||||||
|
Ok(Self {
|
||||||
data: Box::new(data),
|
data: Box::new(data),
|
||||||
shape: shape.into(),
|
shape,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
|
||||||
|
let shape = src.shape();
|
||||||
|
check_shape::<T>(shape)?;
|
||||||
|
let src = src
|
||||||
|
.to_dtype(crate::DType::F32)?
|
||||||
|
.flatten_all()?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
if src.len() % T::BLCK_SIZE != 0 {
|
||||||
|
crate::bail!(
|
||||||
|
"tensor size ({shape:?}) is not divisible by block size {}",
|
||||||
|
T::BLCK_SIZE
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
||||||
|
T::from_float(&src, &mut data)?;
|
||||||
|
Ok(Self {
|
||||||
|
data: Box::new(data),
|
||||||
|
shape: shape.clone(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
pub fn dtype(&self) -> GgmlDType {
|
||||||
|
@ -32,7 +32,7 @@ fn quantized_matmul() -> Result<()> {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64));
|
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||||
let res = matmul.forward(&tensor_lhs)?;
|
let res = matmul.forward(&tensor_lhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -80,7 +80,7 @@ fn quantized_matmul_neg() -> Result<()> {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64));
|
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||||
let res = matmul.forward(&tensor_lhs)?;
|
let res = matmul.forward(&tensor_lhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -171,3 +171,46 @@ fn quantize_q6k() -> Result<()> {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn quantized_matmul_q6k() -> Result<()> {
|
||||||
|
use k_quants::BlockQ6K;
|
||||||
|
use rand::prelude::*;
|
||||||
|
|
||||||
|
let mut rng = StdRng::seed_from_u64(314159265358979);
|
||||||
|
|
||||||
|
let cpu = &Device::Cpu;
|
||||||
|
let (m, k, n) = (11, 512, 21);
|
||||||
|
let lhs = (0..m * k)
|
||||||
|
.map(|_| rng.gen::<f32>() - 0.5)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let rhs = (0..n * k)
|
||||||
|
.map(|_| rng.gen::<f32>() - 0.5)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let lhs = Tensor::from_vec(lhs, (m, k), cpu)?;
|
||||||
|
let rhs = Tensor::from_vec(rhs, (n, k), cpu)?;
|
||||||
|
|
||||||
|
let mm = lhs.matmul(&rhs.t()?)?;
|
||||||
|
assert_eq!(mm.dims(), [m, n]);
|
||||||
|
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||||
|
let dst = [dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]
|
||||||
|
.iter()
|
||||||
|
.map(|x| (1000. * x).round() / 1000.)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
|
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
||||||
|
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||||
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
|
assert_eq!(mm.dims(), [m, n]);
|
||||||
|
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||||
|
let dst = [dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]
|
||||||
|
.iter()
|
||||||
|
.map(|x| (1000. * x).round() / 1000.)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
assert_eq!(dst, [1.324, 1.49, -0.164, 1.741]);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user