Fixed Quantized Qwen3 Model (#2951)

* optimize KV cache to reduce GPU memory usage

* revert to using candle_nn::kv_cache::KvCache with initial capacity of 512
This commit is contained in:
Snake
2025-05-13 11:53:42 +08:00
committed by GitHub
parent 36508a2c93
commit 485ddf2996

View File

@ -160,12 +160,9 @@ impl AttentionWeights {
let q_norm = gg.rms_norm(&format!("{prefix}.attn_q_norm.weight"), rms_norm_eps)?;
let k_norm = gg.rms_norm(&format!("{prefix}.attn_k_norm.weight"), rms_norm_eps)?;
let max_position_embeddings = gg
.metadata()
.get("qwen3.context_length")
.and_then(|v| v.to_u32().ok())
.unwrap_or(4096) as usize;
let kv_cache = KvCache::new(2, max_position_embeddings);
// Initialize KV cache with 512 tokens capacity to reduce initial memory allocation.
// The cache will grow in chunks of 512 tokens when needed.
let kv_cache = KvCache::new(2, 512);
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");