Enable the KV cache after fixing the caching length and the rope bits.

This commit is contained in:
laurent
2023-06-29 22:00:57 +01:00
parent e87a99d16e
commit 23389b1bd7

View File

@ -23,8 +23,8 @@ use std::sync::{Arc, Mutex};
mod var_store;
mod weights;
const CONTEXT_SIZE: usize = 512;
const USE_KV_CACHE: bool = false;
const MAX_SEQ_LEN: usize = 4096;
const USE_KV_CACHE: bool = true;
const START_PROMPT: &str = r"
EDWARD:
I wonder how our princely father 'scaped,
@ -268,8 +268,9 @@ impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let mut dims = x.dims().to_vec();
let freqs_cis = if dims[1] < CONTEXT_SIZE {
freqs_cis.narrow(1, CONTEXT_SIZE - dims[1], dims[1])?
let fcis_dims = freqs_cis.dims();
let freqs_cis = if dims[1] < fcis_dims[1] {
freqs_cis.narrow(1, 0, dims[1])?
} else {
freqs_cis.clone()
};
@ -311,18 +312,18 @@ impl CausalSelfAttention {
if USE_KV_CACHE {
let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 1)?;
v = Tensor::cat(&[cache_v, &v], 1)?;
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
let k_seq_len = k.dims()[1];
if k_seq_len > CONTEXT_SIZE {
if k_seq_len > MAX_SEQ_LEN {
k = k
.narrow(1, k_seq_len - CONTEXT_SIZE, CONTEXT_SIZE)?
.narrow(1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
.contiguous()?
}
let v_seq_len = v.dims()[1];
if v_seq_len > CONTEXT_SIZE {
if v_seq_len > 2 * MAX_SEQ_LEN {
v = v
.narrow(1, v_seq_len - CONTEXT_SIZE, CONTEXT_SIZE)?
.narrow(1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
.contiguous()?
}
}
@ -405,19 +406,18 @@ impl Llama {
}
fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
let seq_len = CONTEXT_SIZE;
let n_elem = config.n_embd / config.n_head;
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
.collect();
let arange: Vec<_> = (0..seq_len).map(|c| c as f32).collect();
let arange: Vec<_> = (0..MAX_SEQ_LEN).map(|c| c as f32).collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let arange = Tensor::new(arange.as_slice(), device)?;
let idx_theta = arange
.reshape((arange.elem_count(), 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let shape = [1, seq_len, n_elem / 2, 1];
let shape = [1, MAX_SEQ_LEN, n_elem / 2, 1];
let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?;
let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?;
let last_dim = idx_theta_cos.rank() - 1;
@ -503,16 +503,23 @@ async fn main() -> Result<()> {
let mut new_tokens = vec![];
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed);
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
for index in 0..args.sample_len {
let start_gen = std::time::Instant::now();
let context_size = if USE_KV_CACHE && index > 0 {
1
} else {
CONTEXT_SIZE
tokens.len()
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?;
let freqs_cis = if USE_KV_CACHE {
freqs_cis.narrow(1, index_pos, ctxt.len())?
} else {
freqs_cis.clone()
};
let logits = llama.forward(&input, &freqs_cis)?;
index_pos += ctxt.len();
let next_token = if let Some(temperature) = args.temperature {
println!("Sampling with temperature {temperature:?}");