mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Tensor -> QTensor conversion (#496)
* Sketch some qmatmul test. * Add the quantization function. * More testing. * Make the test smaller and faster. * Add some shape checking.
This commit is contained in:
@ -32,7 +32,7 @@ fn quantized_matmul() -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64));
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
@ -80,7 +80,7 @@ fn quantized_matmul_neg() -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64));
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
@ -171,3 +171,46 @@ fn quantize_q6k() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q6k() -> Result<()> {
|
||||
use k_quants::BlockQ6K;
|
||||
use rand::prelude::*;
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(314159265358979);
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let lhs = (0..m * k)
|
||||
.map(|_| rng.gen::<f32>() - 0.5)
|
||||
.collect::<Vec<_>>();
|
||||
let rhs = (0..n * k)
|
||||
.map(|_| rng.gen::<f32>() - 0.5)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let lhs = Tensor::from_vec(lhs, (m, k), cpu)?;
|
||||
let rhs = Tensor::from_vec(rhs, (n, k), cpu)?;
|
||||
|
||||
let mm = lhs.matmul(&rhs.t()?)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = [dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]
|
||||
.iter()
|
||||
.map(|x| (1000. * x).round() / 1000.)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = [dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]
|
||||
.iter()
|
||||
.map(|x| (1000. * x).round() / 1000.)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(dst, [1.324, 1.49, -0.164, 1.741]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user