From 6c990a33ea4635bf98b180f6e4c99e6795ccfbab Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 1 Nov 2023 21:04:52 +0100 Subject: [PATCH] Remove the unused pragma for marian. (#1236) --- candle-transformers/src/models/marian.rs | 36 +++++++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs index ebab3dbc..05804a1c 100644 --- a/candle-transformers/src/models/marian.rs +++ b/candle-transformers/src/models/marian.rs @@ -1,6 +1,5 @@ -#![allow(unused)] -use super::with_tracing::{linear, linear_no_bias, Embedding, Linear}; -use candle::{Module, Result, Tensor}; +use super::with_tracing::{linear, Embedding, Linear}; +use candle::{Result, Tensor}; use candle_nn::{layer_norm, LayerNorm, VarBuilder}; #[derive(Debug, Clone)] @@ -170,7 +169,6 @@ impl Attention { kv_states: Option<&Tensor>, attn_mask: Option<&Tensor>, ) -> Result { - let is_cross_attn = kv_states.is_some(); let (b_sz, tgt_len, _) = xs.dims3()?; let query_states = (xs.apply(&self.q_proj)? * self.scaling)?; let (key_states, value_states) = match kv_states { @@ -259,6 +257,10 @@ impl EncoderLayer { .apply(&self.fc2)?; (xs + residual)?.apply(&self.final_layer_norm) } + + fn reset_kv_cache(&mut self) { + self.self_attn.reset_kv_cache() + } } #[derive(Debug, Clone)] @@ -320,6 +322,11 @@ impl DecoderLayer { let xs = (xs + residual)?.apply(&self.final_layer_norm)?; Ok(xs) } + + fn reset_kv_cache(&mut self) { + self.self_attn.reset_kv_cache(); + self.encoder_attn.reset_kv_cache() + } } #[derive(Debug, Clone)] @@ -368,6 +375,12 @@ impl Encoder { } Ok(xs) } + + pub fn reset_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.reset_kv_cache() + } + } } #[derive(Debug, Clone)] @@ -422,6 +435,12 @@ impl Decoder { } Ok(xs) } + + pub fn reset_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.reset_kv_cache() + } + } } #[derive(Debug, Clone)] @@ -442,6 +461,11 @@ impl Model { decoder, }) } + + fn reset_kv_cache(&mut self) { + self.encoder.reset_kv_cache(); + self.decoder.reset_kv_cache(); + } } #[derive(Debug, Clone)] @@ -489,4 +513,8 @@ impl MTModel { .apply(&self.lm_head)? .broadcast_add(&self.final_logits_bias) } + + pub fn reset_kv_cache(&mut self) { + self.model.reset_kv_cache(); + } }