mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add a function to clear the KV cache in falcon. (#2066)
* Add a function to clear the KV cache in falcon. * Clippy.
This commit is contained in:
@ -217,6 +217,7 @@ fn mul_mat_vec_via_q8_1(
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn mul_mat_via_q8_1(
|
||||
data: &CudaSlice<u8>,
|
||||
y: &CudaView<f32>,
|
||||
|
@ -315,6 +315,10 @@ impl FalconAttention {
|
||||
let attn_output = self.dense.forward(&attn_output)?;
|
||||
Ok(attn_output)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -402,6 +406,10 @@ impl FalconDecoderLayer {
|
||||
let output = (mlp_output + residual)?;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.self_attention.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -477,4 +485,10 @@ impl Falcon {
|
||||
let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?;
|
||||
Ok(logits)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for block in self.blocks.iter_mut() {
|
||||
block.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user