From 485ddf2996169f49e3da3af22f3a56188678cf43 Mon Sep 17 00:00:00 2001 From: Snake <47769817+nosnakeob@users.noreply.github.com> Date: Tue, 13 May 2025 11:53:42 +0800 Subject: [PATCH] Fixed Quantized Qwen3 Model (#2951) * optimize KV cache to reduce GPU memory usage * revert to using candle_nn::kv_cache::KvCache with initial capacity of 512 --- candle-transformers/src/models/quantized_qwen3.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/candle-transformers/src/models/quantized_qwen3.rs b/candle-transformers/src/models/quantized_qwen3.rs index 34dba8cd..00f7c03d 100644 --- a/candle-transformers/src/models/quantized_qwen3.rs +++ b/candle-transformers/src/models/quantized_qwen3.rs @@ -160,12 +160,9 @@ impl AttentionWeights { let q_norm = gg.rms_norm(&format!("{prefix}.attn_q_norm.weight"), rms_norm_eps)?; let k_norm = gg.rms_norm(&format!("{prefix}.attn_k_norm.weight"), rms_norm_eps)?; - let max_position_embeddings = gg - .metadata() - .get("qwen3.context_length") - .and_then(|v| v.to_u32().ok()) - .unwrap_or(4096) as usize; - let kv_cache = KvCache::new(2, max_position_embeddings); + // Initialize KV cache with 512 tokens capacity to reduce initial memory allocation. + // The cache will grow in chunks of 512 tokens when needed. + let kv_cache = KvCache::new(2, 512); let span_attn = tracing::span!(tracing::Level::TRACE, "attn");