mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add a flag.
This commit is contained in:
@ -24,7 +24,6 @@ mod var_store;
|
||||
mod weights;
|
||||
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const USE_KV_CACHE: bool = true;
|
||||
const START_PROMPT: &str = r"
|
||||
EDWARD:
|
||||
I wonder how our princely father 'scaped,
|
||||
@ -219,15 +218,17 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
#[derive(Clone)]
|
||||
struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
use_kv_cache: bool,
|
||||
#[allow(clippy::type_complexity)]
|
||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
fn new(config: &Config, device: &Device) -> Self {
|
||||
fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Self {
|
||||
Self {
|
||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||
use_kv_cache,
|
||||
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
||||
device: device.clone(),
|
||||
}
|
||||
@ -309,7 +310,7 @@ impl CausalSelfAttention {
|
||||
let q = self.apply_rotary_emb(&q, freqs_cis)?;
|
||||
let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
|
||||
|
||||
if USE_KV_CACHE {
|
||||
if self.cache.use_kv_cache {
|
||||
let mut cache = self.cache.kvs.lock().unwrap();
|
||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
||||
@ -446,6 +447,10 @@ struct Args {
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
/// Enable the key-value cache.
|
||||
#[arg(long, default_value_t = true)]
|
||||
use_kv_cache: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@ -459,7 +464,7 @@ async fn main() -> Result<()> {
|
||||
Device::new_cuda(0)?
|
||||
};
|
||||
let config = Config::config_7b();
|
||||
let cache = Cache::new(&config, &device);
|
||||
let cache = Cache::new(args.use_kv_cache, &config, &device);
|
||||
let start = std::time::Instant::now();
|
||||
let (llama, tokenizer_filename) = match args.npy {
|
||||
Some(npy) => {
|
||||
@ -506,14 +511,14 @@ async fn main() -> Result<()> {
|
||||
let mut index_pos = 0;
|
||||
for index in 0..args.sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let context_size = if USE_KV_CACHE && index > 0 {
|
||||
let context_size = if cache.use_kv_cache && index > 0 {
|
||||
1
|
||||
} else {
|
||||
tokens.len()
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?;
|
||||
let freqs_cis = if USE_KV_CACHE {
|
||||
let freqs_cis = if cache.use_kv_cache {
|
||||
freqs_cis.narrow(1, index_pos, ctxt.len())?
|
||||
} else {
|
||||
freqs_cis.clone()
|
||||
|
Reference in New Issue
Block a user