mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Get the ggml based llama to generate some text. (#464)
* Add more stats to the ggml example. * Build a quantized model from the file content. * Move the tensor retrieval in the main crate. * Start adding the forward pass. * Add more to the forward pass of the quantized llama. * Apply the attention layers. * Add the sampling loop. * Get the sampling loop to work. * Minor tweak. * Add a quantize/dequantize test. * Bugfix. * Add a comment + swap the order. * Bugfixes.
This commit is contained in:
@ -30,9 +30,9 @@ fn quantized_matmul() -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (64, 4));
|
||||
let op = quantized::QMatMul::new(std::sync::Arc::new(qtensor));
|
||||
let res = tensor_lhs.custom_op1(op)?;
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64));
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
res.to_vec2::<f32>()?,
|
||||
&[
|
||||
@ -44,3 +44,32 @@ fn quantized_matmul() -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q4_0() -> Result<()> {
|
||||
use k_quants::BlockQ4_0;
|
||||
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ4_0::zeros(); 4];
|
||||
BlockQ4_0::from_float(&src, &mut quant)?;
|
||||
BlockQ4_0::to_float(&quant, dst.as_mut_slice())?;
|
||||
assert_eq!(
|
||||
dst,
|
||||
&[
|
||||
-0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625,
|
||||
11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25,
|
||||
23.25, 27.125, 27.125, 27.125, 27.125, 31.0, 31.0, 31.5, 31.5, 31.5, 31.5, 39.375,
|
||||
39.375, 39.375, 39.375, 39.375, 39.375, 39.375, 39.375, 47.25, 47.25, 47.25, 47.25,
|
||||
47.25, 47.25, 47.25, 47.25, 55.125, 55.125, 55.125, 55.125, 55.125, 55.125, 55.125,
|
||||
55.125, 63.0, 63.0, 63.0, 63.0, 59.375, 59.375, 71.25, 71.25, 71.25, 71.25, 71.25,
|
||||
71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 83.125, 83.125, 83.125, 83.125,
|
||||
83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 95.0, 95.0, 95.0, 95.0,
|
||||
95.0, 95.0, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 111.125, 111.125,
|
||||
111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125,
|
||||
111.125, 111.125, 111.125, 111.125, 111.125, 127.0, 127.0, 127.0, 127.0, 127.0, 127.0,
|
||||
127.0, 127.0
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user