Add a RotatingKVCache. (#2493)

* Add a RotatingKVCache.

* Add some KvCache tests.

* Test the reset too.

* More kv-cache testing.

* More tests for the rotating kv-cache.

* Improve the api for the rotating cache so that the whole src tensor gets returned when it's overlarge.

* Handle contiguity + bugfix + use in mimi.

* Add a way to test the mimi streaming mode.

* Mimi streaming fixes.

* More rotating kv-cache.

* Fix the attn mask generation.

* Handle the abs case.

* Add some tests for the generated mask.
This commit is contained in:
Laurent Mazare
2024-09-23 13:14:32 +02:00
committed by GitHub
parent 8097559c1a
commit d01207dbf3
4 changed files with 379 additions and 38 deletions

View File

@ -101,21 +101,6 @@ impl Module for LayerScale {
}
}
pub(crate) fn get_mask(
size1: usize,
size2: usize,
context: usize,
device: &Device,
) -> Result<Tensor> {
let mask: Vec<_> = (0..size1)
.flat_map(|i| {
(0..size2)
.map(move |j| u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i))
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}
#[derive(Debug, Clone)]
pub struct StreamingMultiheadAttention {
q_proj: Linear,
@ -127,7 +112,7 @@ pub struct StreamingMultiheadAttention {
context: usize,
neg_inf: Tensor,
rope: Option<Arc<RotaryEmbedding>>,
kv_cache: candle_nn::kv_cache::KvCache,
kv_cache: candle_nn::kv_cache::RotatingKvCache,
pos: usize,
use_flash_attn: bool,
span: tracing::Span,
@ -153,7 +138,7 @@ impl StreamingMultiheadAttention {
num_heads: cfg.num_heads,
context: cfg.context,
neg_inf,
kv_cache: candle_nn::kv_cache::KvCache::new(2, cfg.max_seq_len),
kv_cache: candle_nn::kv_cache::RotatingKvCache::new(2, cfg.context),
pos: 0,
use_flash_attn: false,
span: tracing::span!(tracing::Level::TRACE, "mha"),
@ -236,7 +221,7 @@ impl StreamingMultiheadAttention {
self.kv_cache.reset()
}
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
self.kv_cache = kv_cache
}
}
@ -582,7 +567,7 @@ impl StreamingTransformerLayer {
self.self_attn.reset_kv_cache()
}
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
self.self_attn.set_kv_cache(kv_cache)
}
}
@ -590,7 +575,6 @@ impl StreamingTransformerLayer {
#[derive(Debug, Clone)]
pub struct StreamingTransformer {
layers: Vec<StreamingTransformerLayer>,
context: usize,
positional_embedding: PositionalEmbedding,
max_period: usize,
}
@ -617,7 +601,6 @@ impl StreamingTransformer {
}
Ok(Self {
layers,
context: cfg.context,
positional_embedding: cfg.positional_embedding,
max_period: cfg.max_period,
})
@ -629,19 +612,11 @@ impl StreamingTransformer {
pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
let (_b, t, c) = xs.dims3()?;
// We will extract at most "context" from the kv_cache.
// Note that the mask will discard the values that are before context.
let pos = self.layers[0]
let pos = self.layers[0].self_attn.kv_cache.current_seq_len();
let mask = self.layers[0]
.self_attn
.kv_cache
.k_cache()
.current_seq_len()
.min(self.context);
let mask = if t == 1 {
None
} else {
Some(get_mask(t, pos + t, self.context, xs.device())?)
};
.attn_mask(t, xs.device())?;
let mut xs = match self.positional_embedding {
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
PositionalEmbedding::Sin => {