mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Mimi streaming fixes.
This commit is contained in:
@ -126,7 +126,10 @@ fn main() -> Result<()> {
|
|||||||
for chunk_start in (0..seq_len).step_by(chunk_size) {
|
for chunk_start in (0..seq_len).step_by(chunk_size) {
|
||||||
let chunk_len = usize::min(chunk_size, seq_len - chunk_start);
|
let chunk_len = usize::min(chunk_size, seq_len - chunk_start);
|
||||||
let codes = codes.narrow(candle::D::Minus1, chunk_start, chunk_len)?;
|
let codes = codes.narrow(candle::D::Minus1, chunk_start, chunk_len)?;
|
||||||
pcm_chunks.push(model.decode(&codes)?)
|
let pcm = model.decode_step(&codes.into())?;
|
||||||
|
if let Some(pcm) = pcm.as_option() {
|
||||||
|
pcm_chunks.push(pcm.clone())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Tensor::cat(&pcm_chunks, candle::D::Minus1)?
|
Tensor::cat(&pcm_chunks, candle::D::Minus1)?
|
||||||
}
|
}
|
||||||
|
@ -216,6 +216,16 @@ impl StreamingMultiheadAttention {
|
|||||||
let pre_ws = match mask {
|
let pre_ws = match mask {
|
||||||
None => pre_ws,
|
None => pre_ws,
|
||||||
Some(mask) => {
|
Some(mask) => {
|
||||||
|
// This is a bit cumbersome and slightly incorrect: when providing a new slice
|
||||||
|
// the kv cache will have a slice offset rather than offset + t. In the mimi
|
||||||
|
// context of an offset of 250, this would not make much difference though.
|
||||||
|
let mask_len = mask.dim(D::Minus1)?;
|
||||||
|
let pre_ws_len = pre_ws.dim(D::Minus1)?;
|
||||||
|
let mask = if pre_ws_len < mask_len {
|
||||||
|
mask.narrow(D::Minus1, mask_len - pre_ws_len, pre_ws_len)?
|
||||||
|
} else {
|
||||||
|
mask.clone()
|
||||||
|
};
|
||||||
let mask = mask.broadcast_left((b, self.num_heads))?;
|
let mask = mask.broadcast_left((b, self.num_heads))?;
|
||||||
let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
|
let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
|
||||||
mask.where_cond(&neg_inf, &pre_ws)?
|
mask.where_cond(&neg_inf, &pre_ws)?
|
||||||
|
Reference in New Issue
Block a user