mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add some caching to the causal mask. (#103)
This commit is contained in:
@ -282,6 +282,7 @@ fn rotate_half(x: &Tensor) -> Result<Tensor> {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct FalconRotaryEmbedding {
|
struct FalconRotaryEmbedding {
|
||||||
inv_freq: Tensor,
|
inv_freq: Tensor,
|
||||||
|
cache: Option<(usize, Tensor, Tensor)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FalconRotaryEmbedding {
|
impl FalconRotaryEmbedding {
|
||||||
@ -292,7 +293,8 @@ impl FalconRotaryEmbedding {
|
|||||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
|
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
|
||||||
.collect();
|
.collect();
|
||||||
let inv_freq = Tensor::new(inv_freq.as_slice(), &vb.device)?;
|
let inv_freq = Tensor::new(inv_freq.as_slice(), &vb.device)?;
|
||||||
Ok(Self { inv_freq })
|
let cache = None;
|
||||||
|
Ok(Self { inv_freq, cache })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cos_sin(
|
fn cos_sin(
|
||||||
@ -301,7 +303,12 @@ impl FalconRotaryEmbedding {
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
) -> Result<(Tensor, Tensor)> {
|
) -> Result<(Tensor, Tensor)> {
|
||||||
// TODO: Add the cache.
|
match &self.cache {
|
||||||
|
Some((s, cos, sin)) if *s == seq_len => {
|
||||||
|
return Ok((cos.clone(), sin.clone()));
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
let t: Vec<_> = (0..seq_len).map(|c| c as u32).collect();
|
let t: Vec<_> = (0..seq_len).map(|c| c as u32).collect();
|
||||||
let t = Tensor::new(t.as_slice(), device)?.to_dtype(dtype)?;
|
let t = Tensor::new(t.as_slice(), device)?.to_dtype(dtype)?;
|
||||||
let inv_freq = self.inv_freq.to_dtype(dtype)?;
|
let inv_freq = self.inv_freq.to_dtype(dtype)?;
|
||||||
@ -309,6 +316,7 @@ impl FalconRotaryEmbedding {
|
|||||||
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||||
let cos = emb.cos()?;
|
let cos = emb.cos()?;
|
||||||
let sin = emb.sin()?;
|
let sin = emb.sin()?;
|
||||||
|
self.cache = Some((seq_len, cos.clone(), sin.clone()));
|
||||||
Ok((cos, sin))
|
Ok((cos, sin))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user