mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add a RotatingKVCache. (#2493)
* Add a RotatingKVCache. * Add some KvCache tests. * Test the reset too. * More kv-cache testing. * More tests for the rotating kv-cache. * Improve the api for the rotating cache so that the whole src tensor gets returned when it's overlarge. * Handle contiguity + bugfix + use in mimi. * Add a way to test the mimi streaming mode. * Mimi streaming fixes. * More rotating kv-cache. * Fix the attn mask generation. * Handle the abs case. * Add some tests for the generated mask.
This commit is contained in:
@ -101,21 +101,6 @@ impl Module for LayerScale {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_mask(
|
||||
size1: usize,
|
||||
size2: usize,
|
||||
context: usize,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size1)
|
||||
.flat_map(|i| {
|
||||
(0..size2)
|
||||
.map(move |j| u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i))
|
||||
})
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size1, size2), device)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamingMultiheadAttention {
|
||||
q_proj: Linear,
|
||||
@ -127,7 +112,7 @@ pub struct StreamingMultiheadAttention {
|
||||
context: usize,
|
||||
neg_inf: Tensor,
|
||||
rope: Option<Arc<RotaryEmbedding>>,
|
||||
kv_cache: candle_nn::kv_cache::KvCache,
|
||||
kv_cache: candle_nn::kv_cache::RotatingKvCache,
|
||||
pos: usize,
|
||||
use_flash_attn: bool,
|
||||
span: tracing::Span,
|
||||
@ -153,7 +138,7 @@ impl StreamingMultiheadAttention {
|
||||
num_heads: cfg.num_heads,
|
||||
context: cfg.context,
|
||||
neg_inf,
|
||||
kv_cache: candle_nn::kv_cache::KvCache::new(2, cfg.max_seq_len),
|
||||
kv_cache: candle_nn::kv_cache::RotatingKvCache::new(2, cfg.context),
|
||||
pos: 0,
|
||||
use_flash_attn: false,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mha"),
|
||||
@ -236,7 +221,7 @@ impl StreamingMultiheadAttention {
|
||||
self.kv_cache.reset()
|
||||
}
|
||||
|
||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
|
||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
|
||||
self.kv_cache = kv_cache
|
||||
}
|
||||
}
|
||||
@ -582,7 +567,7 @@ impl StreamingTransformerLayer {
|
||||
self.self_attn.reset_kv_cache()
|
||||
}
|
||||
|
||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
|
||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
|
||||
self.self_attn.set_kv_cache(kv_cache)
|
||||
}
|
||||
}
|
||||
@ -590,7 +575,6 @@ impl StreamingTransformerLayer {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamingTransformer {
|
||||
layers: Vec<StreamingTransformerLayer>,
|
||||
context: usize,
|
||||
positional_embedding: PositionalEmbedding,
|
||||
max_period: usize,
|
||||
}
|
||||
@ -617,7 +601,6 @@ impl StreamingTransformer {
|
||||
}
|
||||
Ok(Self {
|
||||
layers,
|
||||
context: cfg.context,
|
||||
positional_embedding: cfg.positional_embedding,
|
||||
max_period: cfg.max_period,
|
||||
})
|
||||
@ -629,19 +612,11 @@ impl StreamingTransformer {
|
||||
|
||||
pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
|
||||
let (_b, t, c) = xs.dims3()?;
|
||||
// We will extract at most "context" from the kv_cache.
|
||||
// Note that the mask will discard the values that are before context.
|
||||
let pos = self.layers[0]
|
||||
let pos = self.layers[0].self_attn.kv_cache.current_seq_len();
|
||||
let mask = self.layers[0]
|
||||
.self_attn
|
||||
.kv_cache
|
||||
.k_cache()
|
||||
.current_seq_len()
|
||||
.min(self.context);
|
||||
let mask = if t == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(get_mask(t, pos + t, self.context, xs.device())?)
|
||||
};
|
||||
.attn_mask(t, xs.device())?;
|
||||
let mut xs = match self.positional_embedding {
|
||||
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
|
||||
PositionalEmbedding::Sin => {
|
||||
|
Reference in New Issue
Block a user