Remove the unused pragma for marian. (#1236)

This commit is contained in:
Laurent Mazare
2023-11-01 21:04:52 +01:00
committed by GitHub
parent 1704f1b3ae
commit 6c990a33ea

View File

@ -1,6 +1,5 @@
#![allow(unused)] use super::with_tracing::{linear, Embedding, Linear};
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear}; use candle::{Result, Tensor};
use candle::{Module, Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, VarBuilder}; use candle_nn::{layer_norm, LayerNorm, VarBuilder};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -170,7 +169,6 @@ impl Attention {
kv_states: Option<&Tensor>, kv_states: Option<&Tensor>,
attn_mask: Option<&Tensor>, attn_mask: Option<&Tensor>,
) -> Result<Tensor> { ) -> Result<Tensor> {
let is_cross_attn = kv_states.is_some();
let (b_sz, tgt_len, _) = xs.dims3()?; let (b_sz, tgt_len, _) = xs.dims3()?;
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?; let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
let (key_states, value_states) = match kv_states { let (key_states, value_states) = match kv_states {
@ -259,6 +257,10 @@ impl EncoderLayer {
.apply(&self.fc2)?; .apply(&self.fc2)?;
(xs + residual)?.apply(&self.final_layer_norm) (xs + residual)?.apply(&self.final_layer_norm)
} }
fn reset_kv_cache(&mut self) {
self.self_attn.reset_kv_cache()
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -320,6 +322,11 @@ impl DecoderLayer {
let xs = (xs + residual)?.apply(&self.final_layer_norm)?; let xs = (xs + residual)?.apply(&self.final_layer_norm)?;
Ok(xs) Ok(xs)
} }
fn reset_kv_cache(&mut self) {
self.self_attn.reset_kv_cache();
self.encoder_attn.reset_kv_cache()
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -368,6 +375,12 @@ impl Encoder {
} }
Ok(xs) Ok(xs)
} }
pub fn reset_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.reset_kv_cache()
}
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -422,6 +435,12 @@ impl Decoder {
} }
Ok(xs) Ok(xs)
} }
pub fn reset_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.reset_kv_cache()
}
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -442,6 +461,11 @@ impl Model {
decoder, decoder,
}) })
} }
fn reset_kv_cache(&mut self) {
self.encoder.reset_kv_cache();
self.decoder.reset_kv_cache();
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -489,4 +513,8 @@ impl MTModel {
.apply(&self.lm_head)? .apply(&self.lm_head)?
.broadcast_add(&self.final_logits_bias) .broadcast_add(&self.final_logits_bias)
} }
pub fn reset_kv_cache(&mut self) {
self.model.reset_kv_cache();
}
} }