mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
More rotating kv-cache.
This commit is contained in:
@ -216,16 +216,6 @@ impl StreamingMultiheadAttention {
|
||||
let pre_ws = match mask {
|
||||
None => pre_ws,
|
||||
Some(mask) => {
|
||||
// This is a bit cumbersome and slightly incorrect: when providing a new slice
|
||||
// the kv cache will have a slice offset rather than offset + t. In the mimi
|
||||
// context of an offset of 250, this would not make much difference though.
|
||||
let mask_len = mask.dim(D::Minus1)?;
|
||||
let pre_ws_len = pre_ws.dim(D::Minus1)?;
|
||||
let mask = if pre_ws_len < mask_len {
|
||||
mask.narrow(D::Minus1, mask_len - pre_ws_len, pre_ws_len)?
|
||||
} else {
|
||||
mask.clone()
|
||||
};
|
||||
let mask = mask.broadcast_left((b, self.num_heads))?;
|
||||
let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
|
||||
mask.where_cond(&neg_inf, &pre_ws)?
|
||||
@ -639,18 +629,22 @@ impl StreamingTransformer {
|
||||
|
||||
pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
|
||||
let (_b, t, c) = xs.dims3()?;
|
||||
// We will extract at most "context" from the kv_cache.
|
||||
// Note that the mask will discard the values that are before context.
|
||||
let pos = self.layers[0]
|
||||
.self_attn
|
||||
.kv_cache
|
||||
.k_cache()
|
||||
.current_seq_len()
|
||||
.min(self.context);
|
||||
.current_seq_len();
|
||||
let mask = if t == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(get_mask(t, pos + t, self.context, xs.device())?)
|
||||
let cache_out_len = if t < self.context {
|
||||
(pos + t).min(self.context)
|
||||
} else {
|
||||
t
|
||||
};
|
||||
// TODO: this is wrong, the mask depends on the kv-cache offset because of its rotating
|
||||
// nature.
|
||||
Some(get_mask(t, cache_out_len, self.context, xs.device())?)
|
||||
};
|
||||
let mut xs = match self.positional_embedding {
|
||||
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
|
||||
|
Reference in New Issue
Block a user