mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Handle contiguity + bugfix + use in mimi.
This commit is contained in:
@ -224,23 +224,27 @@ impl RotatingCache {
|
||||
|
||||
self.current_seq_len += seq_len;
|
||||
if seq_len >= self.max_seq_len {
|
||||
let src = src.narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?;
|
||||
ad.slice_set(&src, self.dim, 0)?;
|
||||
let to_copy = src
|
||||
.narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?
|
||||
.contiguous()?;
|
||||
ad.slice_set(&to_copy, self.dim, 0)?;
|
||||
self.offset = 0;
|
||||
// Here we return `src` rather than `ad` so that all the past can be used.
|
||||
Ok(src)
|
||||
Ok(src.clone())
|
||||
} else {
|
||||
let rem_len = self.max_seq_len - self.offset;
|
||||
if seq_len <= rem_len {
|
||||
ad.slice_set(src, self.dim, self.offset)?;
|
||||
ad.slice_set(&src.contiguous()?, self.dim, self.offset)?;
|
||||
self.offset = (self.offset + seq_len) % self.max_seq_len;
|
||||
} else {
|
||||
// We have to make two copies here as we go over the boundary of the cache.
|
||||
if rem_len > 0 {
|
||||
let src1 = src.narrow(self.dim, 0, rem_len)?;
|
||||
let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?;
|
||||
ad.slice_set(&src1, self.dim, self.offset)?;
|
||||
}
|
||||
let src2 = src.narrow(self.dim, rem_len, seq_len - rem_len)?;
|
||||
let src2 = src
|
||||
.narrow(self.dim, rem_len, seq_len - rem_len)?
|
||||
.contiguous()?;
|
||||
ad.slice_set(&src2, self.dim, 0)?;
|
||||
self.offset = seq_len - rem_len;
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ fn rotating_kv_cache() -> Result<()> {
|
||||
|
||||
let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?;
|
||||
let data = cache.append(&t)?;
|
||||
assert_eq!(data.to_vec1::<f64>()?, [3., 4., 5., 6., 7., 8.]);
|
||||
assert_eq!(data.to_vec1::<f64>()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]);
|
||||
assert_eq!(cache.current_seq_len(), 22);
|
||||
assert_eq!(cache.offset(), 0);
|
||||
|
||||
|
@ -127,7 +127,7 @@ pub struct StreamingMultiheadAttention {
|
||||
context: usize,
|
||||
neg_inf: Tensor,
|
||||
rope: Option<Arc<RotaryEmbedding>>,
|
||||
kv_cache: candle_nn::kv_cache::KvCache,
|
||||
kv_cache: candle_nn::kv_cache::RotatingKvCache,
|
||||
pos: usize,
|
||||
use_flash_attn: bool,
|
||||
span: tracing::Span,
|
||||
@ -153,7 +153,7 @@ impl StreamingMultiheadAttention {
|
||||
num_heads: cfg.num_heads,
|
||||
context: cfg.context,
|
||||
neg_inf,
|
||||
kv_cache: candle_nn::kv_cache::KvCache::new(2, cfg.max_seq_len),
|
||||
kv_cache: candle_nn::kv_cache::RotatingKvCache::new(2, cfg.context),
|
||||
pos: 0,
|
||||
use_flash_attn: false,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mha"),
|
||||
@ -236,7 +236,7 @@ impl StreamingMultiheadAttention {
|
||||
self.kv_cache.reset()
|
||||
}
|
||||
|
||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
|
||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
|
||||
self.kv_cache = kv_cache
|
||||
}
|
||||
}
|
||||
@ -582,7 +582,7 @@ impl StreamingTransformerLayer {
|
||||
self.self_attn.reset_kv_cache()
|
||||
}
|
||||
|
||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
|
||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
|
||||
self.self_attn.set_kv_cache(kv_cache)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user