Merge pull request #41 from LaurentMazare/kv-cache

Kv cache
This commit is contained in:
Laurent Mazare
2023-06-29 19:11:52 +01:00
committed by GitHub
4 changed files with 110 additions and 51 deletions

View File

@ -1,8 +1,6 @@
[workspace] [workspace]
members = [ members = [
"candle-core", "candle-core",
"candle-hub", "candle-hub",
"candle-kernels", "candle-kernels",
] ]

View File

@ -24,6 +24,7 @@ mod var_store;
mod weights; mod weights;
const CONTEXT_SIZE: usize = 512; const CONTEXT_SIZE: usize = 512;
const USE_KV_CACHE: bool = false;
const START_PROMPT: &str = r" const START_PROMPT: &str = r"
EDWARD: EDWARD:
I wonder how our princely father 'scaped, I wonder how our princely father 'scaped,
@ -218,13 +219,16 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
#[derive(Clone)] #[derive(Clone)]
struct Cache { struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>, masks: Arc<Mutex<HashMap<usize, Tensor>>>,
#[allow(clippy::type_complexity)]
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
device: Device, device: Device,
} }
impl Cache { impl Cache {
fn new(device: &Device) -> Self { fn new(config: &Config, device: &Device) -> Self {
Self { Self {
masks: Arc::new(Mutex::new(HashMap::new())), masks: Arc::new(Mutex::new(HashMap::new())),
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
device: device.clone(), device: device.clone(),
} }
} }
@ -249,7 +253,6 @@ struct CausalSelfAttention {
c_attn: Linear, c_attn: Linear,
c_proj: Linear, c_proj: Linear,
n_head: usize, n_head: usize,
// n_embd: usize,
cache: Cache, cache: Cache,
} }
@ -265,6 +268,11 @@ impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let mut dims = x.dims().to_vec(); let mut dims = x.dims().to_vec();
let freqs_cis = if dims[1] < CONTEXT_SIZE {
freqs_cis.narrow(1, CONTEXT_SIZE - dims[1], dims[1])?
} else {
freqs_cis.clone()
};
let v = dims.pop().unwrap(); let v = dims.pop().unwrap();
dims.push(v / 2); dims.push(v / 2);
dims.push(2); dims.push(2);
@ -285,7 +293,7 @@ impl CausalSelfAttention {
Ok(rope) Ok(rope)
} }
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
let (t, c) = x.shape().r2()?; let (t, c) = x.shape().r2()?;
let qkv = self.c_attn.forward(x)?; let qkv = self.c_attn.forward(x)?;
let qkv = qkv.to_dtype(DType::F32)?; let qkv = qkv.to_dtype(DType::F32)?;
@ -296,9 +304,31 @@ impl CausalSelfAttention {
let target_dim = [t, self.n_head, c / self.n_head]; let target_dim = [t, self.n_head, c / self.n_head];
let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?; let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?;
let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?; let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?;
let v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?; let mut v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
let q = self.apply_rotary_emb(&q, freqs_cis)?; let q = self.apply_rotary_emb(&q, freqs_cis)?;
let k = self.apply_rotary_emb(&k, freqs_cis)?; let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
if USE_KV_CACHE {
let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 1)?;
v = Tensor::cat(&[cache_v, &v], 1)?;
let k_seq_len = k.dims()[1];
if k_seq_len > CONTEXT_SIZE {
k = k
.narrow(1, k_seq_len - CONTEXT_SIZE, CONTEXT_SIZE)?
.contiguous()?
}
let v_seq_len = v.dims()[1];
if v_seq_len > CONTEXT_SIZE {
v = v
.narrow(1, v_seq_len - CONTEXT_SIZE, CONTEXT_SIZE)?
.contiguous()?
}
}
cache[block_idx] = Some((k.clone(), v.clone()))
}
let k_shape = k.shape(); let k_shape = k.shape();
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?; let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?; let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
@ -330,8 +360,11 @@ impl Block {
} }
} }
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
let x = (self.attn.forward(&self.rms_1.forward(x)?, freqs_cis)? + x)?; let x = (self
.attn
.forward(&self.rms_1.forward(x)?, freqs_cis, block_idx)?
+ x)?;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?; let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?;
Ok(x) Ok(x)
} }
@ -358,8 +391,8 @@ impl Llama {
// TODO: Support for mini-batches? (i.e. r2) // TODO: Support for mini-batches? (i.e. r2)
let t = x.shape().r1()?; let t = x.shape().r1()?;
let mut x = self.wte.forward(x)?; let mut x = self.wte.forward(x)?;
for block in self.blocks.iter() { for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, freqs_cis)?; x = block.forward(&x, freqs_cis, block_idx)?;
} }
let x = self.ln_f.forward(&x)?; let x = self.ln_f.forward(&x)?;
let x = x.narrow(0, t - 1, 1)?; let x = x.narrow(0, t - 1, 1)?;
@ -400,7 +433,7 @@ struct Args {
/// Use npy instead of safetensors /// Use npy instead of safetensors
#[arg(long)] #[arg(long)]
npy: bool, npy: Option<String>,
/// The temperature used to generate samples. /// The temperature used to generate samples.
#[arg(long)] #[arg(long)]
@ -426,15 +459,16 @@ async fn main() -> Result<()> {
Device::new_cuda(0)? Device::new_cuda(0)?
}; };
let config = Config::config_7b(); let config = Config::config_7b();
let cache = Cache::new(&device); let cache = Cache::new(&config, &device);
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let (llama, tokenizer_filename) = if args.npy { let (llama, tokenizer_filename) = match args.npy {
Some(npy) => {
println!("building the model (NPY)"); println!("building the model (NPY)");
( let weights = Llama::load_npy(&device, &npy, &cache, &config)?;
Llama::load_npy(&device, "/data/llama.npz", &cache, &config)?, let token_path = std::path::Path::new("llama-tokenizer.json").to_path_buf();
std::path::Path::new("llama-tokenizer.json").to_path_buf(), (weights, token_path)
) }
} else { None => {
let api = Api::new()?; let api = Api::new()?;
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
println!("building the model"); println!("building the model");
@ -453,6 +487,7 @@ async fn main() -> Result<()> {
Llama::load(&device, &filenames, &cache, &config)?, Llama::load(&device, &filenames, &cache, &config)?,
tokenizer_filename, tokenizer_filename,
) )
}
}; };
println!("Loaded in {:?}", start.elapsed()); println!("Loaded in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
@ -470,7 +505,12 @@ async fn main() -> Result<()> {
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
for index in 0..args.sample_len { for index in 0..args.sample_len {
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; let context_size = if USE_KV_CACHE && index > 0 {
1
} else {
CONTEXT_SIZE
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?; let input = Tensor::new(ctxt, &device)?;
let logits = llama.forward(&input, &freqs_cis)?; let logits = llama.forward(&input, &freqs_cis)?;

View File

@ -10,6 +10,14 @@ pub enum Error {
got: DType, got: DType,
}, },
#[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
NarrowInvalidArgs {
shape: Shape,
dim: usize,
start: usize,
len: usize,
},
#[error("{op} only supports contiguous tensors")] #[error("{op} only supports contiguous tensors")]
RequiresContiguous { op: &'static str }, RequiresContiguous { op: &'static str },

View File

@ -349,22 +349,35 @@ impl Tensor {
} }
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim` /// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
/// ranges from `start` to `start + length`. /// ranges from `start` to `start + len`.
pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> { pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
let dims = self.dims();
if dim >= dims.len() || start + len > dims[dim] {
Err(Error::NarrowInvalidArgs {
shape: self.shape().clone(),
dim,
start,
len,
})?
}
if start == 0 && dims[dim] == len {
Ok(self.clone())
} else {
let op = if self.track_op() { let op = if self.track_op() {
Some(Op::Narrow(self.clone(), dim, start, length)) Some(Op::Narrow(self.clone(), dim, start, len))
} else { } else {
None None
}; };
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
storage: self.storage.clone(), storage: self.storage.clone(),
layout: self.layout().narrow(dim, start, length)?, layout: self.layout().narrow(dim, start, len)?,
op, op,
is_variable: false, is_variable: false,
}; };
Ok(Tensor(Arc::new(tensor_))) Ok(Tensor(Arc::new(tensor_)))
} }
}
pub fn softmax(&self, dim: usize) -> Result<Self> { pub fn softmax(&self, dim: usize) -> Result<Self> {
// TODO: unify the two branches. // TODO: unify the two branches.