mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Add some caching to the causal mask. (#103)
This commit is contained in:
@ -282,6 +282,7 @@ fn rotate_half(x: &Tensor) -> Result<Tensor> {
|
||||
#[derive(Debug)]
|
||||
struct FalconRotaryEmbedding {
|
||||
inv_freq: Tensor,
|
||||
cache: Option<(usize, Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl FalconRotaryEmbedding {
|
||||
@ -292,7 +293,8 @@ impl FalconRotaryEmbedding {
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let inv_freq = Tensor::new(inv_freq.as_slice(), &vb.device)?;
|
||||
Ok(Self { inv_freq })
|
||||
let cache = None;
|
||||
Ok(Self { inv_freq, cache })
|
||||
}
|
||||
|
||||
fn cos_sin(
|
||||
@ -301,7 +303,12 @@ impl FalconRotaryEmbedding {
|
||||
device: &Device,
|
||||
dtype: DType,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
// TODO: Add the cache.
|
||||
match &self.cache {
|
||||
Some((s, cos, sin)) if *s == seq_len => {
|
||||
return Ok((cos.clone(), sin.clone()));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
let t: Vec<_> = (0..seq_len).map(|c| c as u32).collect();
|
||||
let t = Tensor::new(t.as_slice(), device)?.to_dtype(dtype)?;
|
||||
let inv_freq = self.inv_freq.to_dtype(dtype)?;
|
||||
@ -309,6 +316,7 @@ impl FalconRotaryEmbedding {
|
||||
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
let cos = emb.cos()?;
|
||||
let sin = emb.sin()?;
|
||||
self.cache = Some((seq_len, cos.clone(), sin.clone()));
|
||||
Ok((cos, sin))
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user