mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Remove the unused pragma for marian. (#1236)
This commit is contained in:
@ -1,6 +1,5 @@
|
||||
#![allow(unused)]
|
||||
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
|
||||
use candle::{Module, Result, Tensor};
|
||||
use super::with_tracing::{linear, Embedding, Linear};
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -170,7 +169,6 @@ impl Attention {
|
||||
kv_states: Option<&Tensor>,
|
||||
attn_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let is_cross_attn = kv_states.is_some();
|
||||
let (b_sz, tgt_len, _) = xs.dims3()?;
|
||||
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
|
||||
let (key_states, value_states) = match kv_states {
|
||||
@ -259,6 +257,10 @@ impl EncoderLayer {
|
||||
.apply(&self.fc2)?;
|
||||
(xs + residual)?.apply(&self.final_layer_norm)
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.self_attn.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -320,6 +322,11 @@ impl DecoderLayer {
|
||||
let xs = (xs + residual)?.apply(&self.final_layer_norm)?;
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.self_attn.reset_kv_cache();
|
||||
self.encoder_attn.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -368,6 +375,12 @@ impl Encoder {
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -422,6 +435,12 @@ impl Decoder {
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -442,6 +461,11 @@ impl Model {
|
||||
decoder,
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.encoder.reset_kv_cache();
|
||||
self.decoder.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -489,4 +513,8 @@ impl MTModel {
|
||||
.apply(&self.lm_head)?
|
||||
.broadcast_add(&self.final_logits_bias)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.model.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user