mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Merge pull request #42 from LaurentMazare/kv-cache-enable
Enable the KV cache
This commit is contained in:
@ -23,8 +23,7 @@ use std::sync::{Arc, Mutex};
|
|||||||
mod var_store;
|
mod var_store;
|
||||||
mod weights;
|
mod weights;
|
||||||
|
|
||||||
const CONTEXT_SIZE: usize = 512;
|
const MAX_SEQ_LEN: usize = 4096;
|
||||||
const USE_KV_CACHE: bool = false;
|
|
||||||
const START_PROMPT: &str = r"
|
const START_PROMPT: &str = r"
|
||||||
EDWARD:
|
EDWARD:
|
||||||
I wonder how our princely father 'scaped,
|
I wonder how our princely father 'scaped,
|
||||||
@ -219,15 +218,17 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct Cache {
|
struct Cache {
|
||||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||||
|
use_kv_cache: bool,
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||||
device: Device,
|
device: Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Cache {
|
impl Cache {
|
||||||
fn new(config: &Config, device: &Device) -> Self {
|
fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Self {
|
||||||
Self {
|
Self {
|
||||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
use_kv_cache,
|
||||||
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
}
|
}
|
||||||
@ -268,8 +269,9 @@ impl CausalSelfAttention {
|
|||||||
|
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
let mut dims = x.dims().to_vec();
|
let mut dims = x.dims().to_vec();
|
||||||
let freqs_cis = if dims[1] < CONTEXT_SIZE {
|
let fcis_dims = freqs_cis.dims();
|
||||||
freqs_cis.narrow(1, CONTEXT_SIZE - dims[1], dims[1])?
|
let freqs_cis = if dims[1] < fcis_dims[1] {
|
||||||
|
freqs_cis.narrow(1, 0, dims[1])?
|
||||||
} else {
|
} else {
|
||||||
freqs_cis.clone()
|
freqs_cis.clone()
|
||||||
};
|
};
|
||||||
@ -308,21 +310,21 @@ impl CausalSelfAttention {
|
|||||||
let q = self.apply_rotary_emb(&q, freqs_cis)?;
|
let q = self.apply_rotary_emb(&q, freqs_cis)?;
|
||||||
let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
|
let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
|
||||||
|
|
||||||
if USE_KV_CACHE {
|
if self.cache.use_kv_cache {
|
||||||
let mut cache = self.cache.kvs.lock().unwrap();
|
let mut cache = self.cache.kvs.lock().unwrap();
|
||||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||||
k = Tensor::cat(&[cache_k, &k], 1)?;
|
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
||||||
v = Tensor::cat(&[cache_v, &v], 1)?;
|
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
|
||||||
let k_seq_len = k.dims()[1];
|
let k_seq_len = k.dims()[1];
|
||||||
if k_seq_len > CONTEXT_SIZE {
|
if k_seq_len > MAX_SEQ_LEN {
|
||||||
k = k
|
k = k
|
||||||
.narrow(1, k_seq_len - CONTEXT_SIZE, CONTEXT_SIZE)?
|
.narrow(1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||||
.contiguous()?
|
.contiguous()?
|
||||||
}
|
}
|
||||||
let v_seq_len = v.dims()[1];
|
let v_seq_len = v.dims()[1];
|
||||||
if v_seq_len > CONTEXT_SIZE {
|
if v_seq_len > 2 * MAX_SEQ_LEN {
|
||||||
v = v
|
v = v
|
||||||
.narrow(1, v_seq_len - CONTEXT_SIZE, CONTEXT_SIZE)?
|
.narrow(1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||||
.contiguous()?
|
.contiguous()?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -405,19 +407,18 @@ impl Llama {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
|
fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
|
||||||
let seq_len = CONTEXT_SIZE;
|
|
||||||
let n_elem = config.n_embd / config.n_head;
|
let n_elem = config.n_embd / config.n_head;
|
||||||
let theta: Vec<_> = (0..n_elem)
|
let theta: Vec<_> = (0..n_elem)
|
||||||
.step_by(2)
|
.step_by(2)
|
||||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
||||||
.collect();
|
.collect();
|
||||||
let arange: Vec<_> = (0..seq_len).map(|c| c as f32).collect();
|
let arange: Vec<_> = (0..MAX_SEQ_LEN).map(|c| c as f32).collect();
|
||||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||||
let arange = Tensor::new(arange.as_slice(), device)?;
|
let arange = Tensor::new(arange.as_slice(), device)?;
|
||||||
let idx_theta = arange
|
let idx_theta = arange
|
||||||
.reshape((arange.elem_count(), 1))?
|
.reshape((arange.elem_count(), 1))?
|
||||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
let shape = [1, seq_len, n_elem / 2, 1];
|
let shape = [1, MAX_SEQ_LEN, n_elem / 2, 1];
|
||||||
let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?;
|
let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?;
|
||||||
let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?;
|
let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?;
|
||||||
let last_dim = idx_theta_cos.rank() - 1;
|
let last_dim = idx_theta_cos.rank() - 1;
|
||||||
@ -446,6 +447,10 @@ struct Args {
|
|||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, default_value_t = 100)]
|
#[arg(long, default_value_t = 100)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
|
/// Disable the key-value cache.
|
||||||
|
#[arg(long)]
|
||||||
|
no_kv_cache: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@ -459,7 +464,7 @@ async fn main() -> Result<()> {
|
|||||||
Device::new_cuda(0)?
|
Device::new_cuda(0)?
|
||||||
};
|
};
|
||||||
let config = Config::config_7b();
|
let config = Config::config_7b();
|
||||||
let cache = Cache::new(&config, &device);
|
let cache = Cache::new(!args.no_kv_cache, &config, &device);
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let (llama, tokenizer_filename) = match args.npy {
|
let (llama, tokenizer_filename) = match args.npy {
|
||||||
Some(npy) => {
|
Some(npy) => {
|
||||||
@ -503,16 +508,23 @@ async fn main() -> Result<()> {
|
|||||||
let mut new_tokens = vec![];
|
let mut new_tokens = vec![];
|
||||||
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed);
|
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed);
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
|
let mut index_pos = 0;
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
let context_size = if USE_KV_CACHE && index > 0 {
|
let context_size = if cache.use_kv_cache && index > 0 {
|
||||||
1
|
1
|
||||||
} else {
|
} else {
|
||||||
CONTEXT_SIZE
|
tokens.len()
|
||||||
};
|
};
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?;
|
let input = Tensor::new(ctxt, &device)?;
|
||||||
|
let freqs_cis = if cache.use_kv_cache {
|
||||||
|
freqs_cis.narrow(1, index_pos, ctxt.len())?
|
||||||
|
} else {
|
||||||
|
freqs_cis.clone()
|
||||||
|
};
|
||||||
let logits = llama.forward(&input, &freqs_cis)?;
|
let logits = llama.forward(&input, &freqs_cis)?;
|
||||||
|
index_pos += ctxt.len();
|
||||||
|
|
||||||
let next_token = if let Some(temperature) = args.temperature {
|
let next_token = if let Some(temperature) = args.temperature {
|
||||||
println!("Sampling with temperature {temperature:?}");
|
println!("Sampling with temperature {temperature:?}");
|
||||||
|
Reference in New Issue
Block a user