Add a function to clear the KV cache in falcon. (#2066)

* Add a function to clear the KV cache in falcon.

* Clippy.
This commit is contained in:
Laurent Mazare
2024-04-15 09:29:25 +02:00
committed by GitHub
parent e198bb0816
commit 8ad822a983
2 changed files with 15 additions and 0 deletions

View File

@ -217,6 +217,7 @@ fn mul_mat_vec_via_q8_1(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
#[allow(clippy::too_many_arguments)]
fn mul_mat_via_q8_1(
data: &CudaSlice<u8>,
y: &CudaView<f32>,

View File

@ -315,6 +315,10 @@ impl FalconAttention {
let attn_output = self.dense.forward(&attn_output)?;
Ok(attn_output)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug)]
@ -402,6 +406,10 @@ impl FalconDecoderLayer {
let output = (mlp_output + residual)?;
Ok(output)
}
pub fn clear_kv_cache(&mut self) {
self.self_attention.clear_kv_cache()
}
}
#[derive(Debug)]
@ -477,4 +485,10 @@ impl Falcon {
let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?;
Ok(logits)
}
pub fn clear_kv_cache(&mut self) {
for block in self.blocks.iter_mut() {
block.clear_kv_cache()
}
}
}