Add some caching to the causal mask. (#103)

This commit is contained in:
Laurent Mazare
2023-07-07 12:56:44 +01:00
committed by GitHub
parent 65937612d0
commit 05ff1cff66

View File

@ -282,6 +282,7 @@ fn rotate_half(x: &Tensor) -> Result<Tensor> {
#[derive(Debug)]
struct FalconRotaryEmbedding {
inv_freq: Tensor,
cache: Option<(usize, Tensor, Tensor)>,
}
impl FalconRotaryEmbedding {
@ -292,7 +293,8 @@ impl FalconRotaryEmbedding {
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
.collect();
let inv_freq = Tensor::new(inv_freq.as_slice(), &vb.device)?;
Ok(Self { inv_freq })
let cache = None;
Ok(Self { inv_freq, cache })
}
fn cos_sin(
@ -301,7 +303,12 @@ impl FalconRotaryEmbedding {
device: &Device,
dtype: DType,
) -> Result<(Tensor, Tensor)> {
// TODO: Add the cache.
match &self.cache {
Some((s, cos, sin)) if *s == seq_len => {
return Ok((cos.clone(), sin.clone()));
}
_ => {}
}
let t: Vec<_> = (0..seq_len).map(|c| c as u32).collect();
let t = Tensor::new(t.as_slice(), device)?.to_dtype(dtype)?;
let inv_freq = self.inv_freq.to_dtype(dtype)?;
@ -309,6 +316,7 @@ impl FalconRotaryEmbedding {
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
let cos = emb.cos()?;
let sin = emb.sin()?;
self.cache = Some((seq_len, cos.clone(), sin.clone()));
Ok((cos, sin))
}