Simplify the KvCache api. (#2207)

This commit is contained in:
Laurent Mazare
2024-05-23 17:07:21 +02:00
committed by GitHub
parent 31cf64147b
commit 45e235a747
3 changed files with 54 additions and 44 deletions

View File

@ -203,7 +203,6 @@ fn precomput_freqs_cis(
impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
batch_size: usize,
use_flash_attn: bool,
ct: gguf_file::Content,
reader: &mut R,
@ -252,12 +251,7 @@ impl ModelWeights {
)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let kv_cache = KvCache::new(
2,
(batch_size, head_count_kv, max_seq_len, head_dim),
DType::F32,
device,
)?;
let kv_cache = KvCache::new(2, max_seq_len);
layers.push(LayerWeights {
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,