mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add a kv-cache to the quantized llama example. (#466)
* Add a kv-cache to the quantized llama example. * Also print the prompt. * Bugfix in q6k dequantizing. * Another bugfix.
This commit is contained in:
@ -70,7 +70,7 @@ const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
|
|||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct BlockQ8_0 {
|
pub struct BlockQ8_0 {
|
||||||
d: f16,
|
d: f16,
|
||||||
qs: [u8; QK8_0],
|
qs: [i8; QK8_0],
|
||||||
}
|
}
|
||||||
const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
|
const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
|
||||||
|
|
||||||
@ -476,14 +476,14 @@ impl GgmlType for BlockQ6K {
|
|||||||
if k % QK_K != 0 {
|
if k % QK_K != 0 {
|
||||||
crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}")
|
crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}")
|
||||||
}
|
}
|
||||||
for x in xs.iter() {
|
for (idx_x, x) in xs.iter().enumerate() {
|
||||||
let d = x.d.to_f32();
|
let d = x.d.to_f32();
|
||||||
let ql = &x.ql;
|
let ql = &x.ql;
|
||||||
let qh = &x.qh;
|
let qh = &x.qh;
|
||||||
let sc = &x.scales;
|
let sc = &x.scales;
|
||||||
for n in (0..QK_K).step_by(128) {
|
for n in (0..QK_K).step_by(128) {
|
||||||
let idx = n / 128;
|
let idx = n / 128;
|
||||||
let ys = &mut ys[n..];
|
let ys = &mut ys[idx_x * QK_K + n..];
|
||||||
let sc = &sc[8 * idx..];
|
let sc = &sc[8 * idx..];
|
||||||
let ql = &ql[64 * idx..];
|
let ql = &ql[64 * idx..];
|
||||||
let qh = &qh[32 * idx..];
|
let qh = &qh[32 * idx..];
|
||||||
@ -663,7 +663,7 @@ impl GgmlType for BlockQ8_0 {
|
|||||||
let id = if d != 0f32 { 1. / d } else { 0. };
|
let id = if d != 0f32 { 1. / d } else { 0. };
|
||||||
ys.d = f16::from_f32(d);
|
ys.d = f16::from_f32(d);
|
||||||
for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) {
|
for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) {
|
||||||
*y = f32::round(x * id) as u8
|
*y = f32::round(x * id) as i8
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -52,6 +52,7 @@ struct LayerWeights {
|
|||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||||
@ -75,7 +76,7 @@ impl LayerWeights {
|
|||||||
Ok(rope)
|
Ok(rope)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward_attn(&self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
|
fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||||
let q = self.attention_wq.forward(x)?;
|
let q = self.attention_wq.forward(x)?;
|
||||||
let k = self.attention_wk.forward(x)?;
|
let k = self.attention_wk.forward(x)?;
|
||||||
@ -94,7 +95,15 @@ impl LayerWeights {
|
|||||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||||
let k = self.apply_rotary_emb(&k, index_pos)?;
|
let k = self.apply_rotary_emb(&k, index_pos)?;
|
||||||
|
|
||||||
// TODO: KV cache.
|
let (k, v) = match &self.kv_cache {
|
||||||
|
None => (k, v),
|
||||||
|
Some((k_cache, v_cache)) => {
|
||||||
|
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
|
||||||
|
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
|
||||||
|
(k, v)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.kv_cache = Some((k.clone(), v.clone()));
|
||||||
|
|
||||||
// If we start supporting MQA, we need to repeat the k and v tensors here.
|
// If we start supporting MQA, we need to repeat the k and v tensors here.
|
||||||
|
|
||||||
@ -181,6 +190,7 @@ impl ModelWeights {
|
|||||||
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
|
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
|
||||||
cos: cos.clone(),
|
cos: cos.clone(),
|
||||||
sin: sin.clone(),
|
sin: sin.clone(),
|
||||||
|
kv_cache: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -209,7 +219,7 @@ impl ModelWeights {
|
|||||||
let (_b_sz, seq_len) = x.dims2()?;
|
let (_b_sz, seq_len) = x.dims2()?;
|
||||||
let mask = self.mask(seq_len)?;
|
let mask = self.mask(seq_len)?;
|
||||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||||
for (_layer_idx, layer) in self.layers.iter().enumerate() {
|
for layer in self.layers.iter_mut() {
|
||||||
let x = layer_in;
|
let x = layer_in;
|
||||||
let residual = &x;
|
let residual = &x;
|
||||||
let x = layer.attention_norm.forward(&x)?;
|
let x = layer.attention_norm.forward(&x)?;
|
||||||
@ -302,8 +312,9 @@ fn main() -> anyhow::Result<()> {
|
|||||||
.to_vec();
|
.to_vec();
|
||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||||
for _index in 0..args.sample_len {
|
print!("{prompt}");
|
||||||
let context_size = tokens.len();
|
for index in 0..args.sample_len {
|
||||||
|
let context_size = if index == 0 { tokens.len() } else { 1 };
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &Device::Cpu)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &Device::Cpu)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, index_pos)?;
|
let logits = model.forward(&input, index_pos)?;
|
||||||
|
Reference in New Issue
Block a user