mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Handle contiguity + bugfix + use in mimi.
This commit is contained in:
@ -127,7 +127,7 @@ pub struct StreamingMultiheadAttention {
|
||||
context: usize,
|
||||
neg_inf: Tensor,
|
||||
rope: Option<Arc<RotaryEmbedding>>,
|
||||
kv_cache: candle_nn::kv_cache::KvCache,
|
||||
kv_cache: candle_nn::kv_cache::RotatingKvCache,
|
||||
pos: usize,
|
||||
use_flash_attn: bool,
|
||||
span: tracing::Span,
|
||||
@ -153,7 +153,7 @@ impl StreamingMultiheadAttention {
|
||||
num_heads: cfg.num_heads,
|
||||
context: cfg.context,
|
||||
neg_inf,
|
||||
kv_cache: candle_nn::kv_cache::KvCache::new(2, cfg.max_seq_len),
|
||||
kv_cache: candle_nn::kv_cache::RotatingKvCache::new(2, cfg.context),
|
||||
pos: 0,
|
||||
use_flash_attn: false,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mha"),
|
||||
@ -236,7 +236,7 @@ impl StreamingMultiheadAttention {
|
||||
self.kv_cache.reset()
|
||||
}
|
||||
|
||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
|
||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
|
||||
self.kv_cache = kv_cache
|
||||
}
|
||||
}
|
||||
@ -582,7 +582,7 @@ impl StreamingTransformerLayer {
|
||||
self.self_attn.reset_kv_cache()
|
||||
}
|
||||
|
||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
|
||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
|
||||
self.self_attn.set_kv_cache(kv_cache)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user