mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
More rotating kv-cache.
This commit is contained in:
@ -216,16 +216,6 @@ impl StreamingMultiheadAttention {
|
|||||||
let pre_ws = match mask {
|
let pre_ws = match mask {
|
||||||
None => pre_ws,
|
None => pre_ws,
|
||||||
Some(mask) => {
|
Some(mask) => {
|
||||||
// This is a bit cumbersome and slightly incorrect: when providing a new slice
|
|
||||||
// the kv cache will have a slice offset rather than offset + t. In the mimi
|
|
||||||
// context of an offset of 250, this would not make much difference though.
|
|
||||||
let mask_len = mask.dim(D::Minus1)?;
|
|
||||||
let pre_ws_len = pre_ws.dim(D::Minus1)?;
|
|
||||||
let mask = if pre_ws_len < mask_len {
|
|
||||||
mask.narrow(D::Minus1, mask_len - pre_ws_len, pre_ws_len)?
|
|
||||||
} else {
|
|
||||||
mask.clone()
|
|
||||||
};
|
|
||||||
let mask = mask.broadcast_left((b, self.num_heads))?;
|
let mask = mask.broadcast_left((b, self.num_heads))?;
|
||||||
let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
|
let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
|
||||||
mask.where_cond(&neg_inf, &pre_ws)?
|
mask.where_cond(&neg_inf, &pre_ws)?
|
||||||
@ -639,18 +629,22 @@ impl StreamingTransformer {
|
|||||||
|
|
||||||
pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
|
pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
|
||||||
let (_b, t, c) = xs.dims3()?;
|
let (_b, t, c) = xs.dims3()?;
|
||||||
// We will extract at most "context" from the kv_cache.
|
|
||||||
// Note that the mask will discard the values that are before context.
|
|
||||||
let pos = self.layers[0]
|
let pos = self.layers[0]
|
||||||
.self_attn
|
.self_attn
|
||||||
.kv_cache
|
.kv_cache
|
||||||
.k_cache()
|
.k_cache()
|
||||||
.current_seq_len()
|
.current_seq_len();
|
||||||
.min(self.context);
|
|
||||||
let mask = if t == 1 {
|
let mask = if t == 1 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(get_mask(t, pos + t, self.context, xs.device())?)
|
let cache_out_len = if t < self.context {
|
||||||
|
(pos + t).min(self.context)
|
||||||
|
} else {
|
||||||
|
t
|
||||||
|
};
|
||||||
|
// TODO: this is wrong, the mask depends on the kv-cache offset because of its rotating
|
||||||
|
// nature.
|
||||||
|
Some(get_mask(t, cache_out_len, self.context, xs.device())?)
|
||||||
};
|
};
|
||||||
let mut xs = match self.positional_embedding {
|
let mut xs = match self.positional_embedding {
|
||||||
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
|
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
|
||||||
|
Reference in New Issue
Block a user