Remove the unused pragma for marian. (#1236)

This commit is contained in:
Laurent Mazare
2023-11-01 21:04:52 +01:00
committed by GitHub
parent 1704f1b3ae
commit 6c990a33ea

View File

@ -1,6 +1,5 @@
#![allow(unused)]
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
use candle::{Module, Result, Tensor};
use super::with_tracing::{linear, Embedding, Linear};
use candle::{Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
#[derive(Debug, Clone)]
@ -170,7 +169,6 @@ impl Attention {
kv_states: Option<&Tensor>,
attn_mask: Option<&Tensor>,
) -> Result<Tensor> {
let is_cross_attn = kv_states.is_some();
let (b_sz, tgt_len, _) = xs.dims3()?;
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
let (key_states, value_states) = match kv_states {
@ -259,6 +257,10 @@ impl EncoderLayer {
.apply(&self.fc2)?;
(xs + residual)?.apply(&self.final_layer_norm)
}
fn reset_kv_cache(&mut self) {
self.self_attn.reset_kv_cache()
}
}
#[derive(Debug, Clone)]
@ -320,6 +322,11 @@ impl DecoderLayer {
let xs = (xs + residual)?.apply(&self.final_layer_norm)?;
Ok(xs)
}
fn reset_kv_cache(&mut self) {
self.self_attn.reset_kv_cache();
self.encoder_attn.reset_kv_cache()
}
}
#[derive(Debug, Clone)]
@ -368,6 +375,12 @@ impl Encoder {
}
Ok(xs)
}
pub fn reset_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.reset_kv_cache()
}
}
}
#[derive(Debug, Clone)]
@ -422,6 +435,12 @@ impl Decoder {
}
Ok(xs)
}
pub fn reset_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.reset_kv_cache()
}
}
}
#[derive(Debug, Clone)]
@ -442,6 +461,11 @@ impl Model {
decoder,
})
}
fn reset_kv_cache(&mut self) {
self.encoder.reset_kv_cache();
self.decoder.reset_kv_cache();
}
}
#[derive(Debug, Clone)]
@ -489,4 +513,8 @@ impl MTModel {
.apply(&self.lm_head)?
.broadcast_add(&self.final_logits_bias)
}
pub fn reset_kv_cache(&mut self) {
self.model.reset_kv_cache();
}
}