mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Simplify the KvCache api. (#2207)
This commit is contained in:
@ -203,7 +203,6 @@ fn precomput_freqs_cis(
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
batch_size: usize,
|
||||
use_flash_attn: bool,
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
@ -252,12 +251,7 @@ impl ModelWeights {
|
||||
)?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let kv_cache = KvCache::new(
|
||||
2,
|
||||
(batch_size, head_count_kv, max_seq_len, head_dim),
|
||||
DType::F32,
|
||||
device,
|
||||
)?;
|
||||
let kv_cache = KvCache::new(2, max_seq_len);
|
||||
layers.push(LayerWeights {
|
||||
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
|
||||
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
|
||||
|
Reference in New Issue
Block a user