Add a flag.

This commit is contained in:
laurent
2023-06-29 22:12:15 +01:00
parent 23389b1bd7
commit ae3f202f3b

View File

@ -24,7 +24,6 @@ mod var_store;
mod weights; mod weights;
const MAX_SEQ_LEN: usize = 4096; const MAX_SEQ_LEN: usize = 4096;
const USE_KV_CACHE: bool = true;
const START_PROMPT: &str = r" const START_PROMPT: &str = r"
EDWARD: EDWARD:
I wonder how our princely father 'scaped, I wonder how our princely father 'scaped,
@ -219,15 +218,17 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
#[derive(Clone)] #[derive(Clone)]
struct Cache { struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>, masks: Arc<Mutex<HashMap<usize, Tensor>>>,
use_kv_cache: bool,
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>, kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
device: Device, device: Device,
} }
impl Cache { impl Cache {
fn new(config: &Config, device: &Device) -> Self { fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Self {
Self { Self {
masks: Arc::new(Mutex::new(HashMap::new())), masks: Arc::new(Mutex::new(HashMap::new())),
use_kv_cache,
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])), kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
device: device.clone(), device: device.clone(),
} }
@ -309,7 +310,7 @@ impl CausalSelfAttention {
let q = self.apply_rotary_emb(&q, freqs_cis)?; let q = self.apply_rotary_emb(&q, freqs_cis)?;
let mut k = self.apply_rotary_emb(&k, freqs_cis)?; let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
if USE_KV_CACHE { if self.cache.use_kv_cache {
let mut cache = self.cache.kvs.lock().unwrap(); let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] { if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
@ -446,6 +447,10 @@ struct Args {
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)] #[arg(long, default_value_t = 100)]
sample_len: usize, sample_len: usize,
/// Enable the key-value cache.
#[arg(long, default_value_t = true)]
use_kv_cache: bool,
} }
#[tokio::main] #[tokio::main]
@ -459,7 +464,7 @@ async fn main() -> Result<()> {
Device::new_cuda(0)? Device::new_cuda(0)?
}; };
let config = Config::config_7b(); let config = Config::config_7b();
let cache = Cache::new(&config, &device); let cache = Cache::new(args.use_kv_cache, &config, &device);
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let (llama, tokenizer_filename) = match args.npy { let (llama, tokenizer_filename) = match args.npy {
Some(npy) => { Some(npy) => {
@ -506,14 +511,14 @@ async fn main() -> Result<()> {
let mut index_pos = 0; let mut index_pos = 0;
for index in 0..args.sample_len { for index in 0..args.sample_len {
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
let context_size = if USE_KV_CACHE && index > 0 { let context_size = if cache.use_kv_cache && index > 0 {
1 1
} else { } else {
tokens.len() tokens.len()
}; };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?; let input = Tensor::new(ctxt, &device)?;
let freqs_cis = if USE_KV_CACHE { let freqs_cis = if cache.use_kv_cache {
freqs_cis.narrow(1, index_pos, ctxt.len())? freqs_cis.narrow(1, index_pos, ctxt.len())?
} else { } else {
freqs_cis.clone() freqs_cis.clone()