mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Interactive mode for the quantized model. (#690)
This commit is contained in:
@ -5,7 +5,7 @@ use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, Module};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
struct RmsNorm {
|
||||
inner: candle_nn::LayerNorm,
|
||||
@ -126,9 +126,13 @@ impl LayerWeights {
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
Some((k_cache, v_cache)) => {
|
||||
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
|
||||
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
|
||||
(k, v)
|
||||
if index_pos == 0 {
|
||||
(k, v)
|
||||
} else {
|
||||
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
|
||||
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
|
Reference in New Issue
Block a user