Interactive mode for the quantized model. (#690)

This commit is contained in:
Laurent Mazare
2023-08-31 11:52:42 +02:00
committed by GitHub
parent 94aa234dfd
commit 7509c98970
2 changed files with 113 additions and 59 deletions

View File

@ -5,7 +5,7 @@ use candle::quantized::{ggml_file, gguf_file};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module};
const MAX_SEQ_LEN: usize = 4096;
pub const MAX_SEQ_LEN: usize = 4096;
struct RmsNorm {
inner: candle_nn::LayerNorm,
@ -126,9 +126,13 @@ impl LayerWeights {
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((k_cache, v_cache)) => {
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
(k, v)
if index_pos == 0 {
(k, v)
} else {
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
(k, v)
}
}
};
self.kv_cache = Some((k.clone(), v.clone()));