mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Use the hub files for the marian example. (#1220)
* Use the hub files for the marian example. * Use the secondary decoder. * Add a readme. * More readme.
This commit is contained in:
@ -135,7 +135,12 @@ impl Attention {
|
||||
.contiguous()
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, kv_states: Option<&Tensor>) -> Result<Tensor> {
|
||||
fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
kv_states: Option<&Tensor>,
|
||||
attn_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let is_cross_attn = kv_states.is_some();
|
||||
let (b_sz, tgt_len, _) = xs.dims3()?;
|
||||
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
|
||||
@ -156,7 +161,10 @@ impl Attention {
|
||||
let key_states = key_states.reshape(proj_shape)?;
|
||||
let value_states = value_states.reshape(proj_shape)?;
|
||||
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||
// todo: attn_mask
|
||||
let attn_weights = match attn_mask {
|
||||
None => attn_weights,
|
||||
Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?,
|
||||
};
|
||||
let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_probs.matmul(&value_states)?;
|
||||
attn_output
|
||||
@ -196,8 +204,8 @@ impl EncoderLayer {
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs =
|
||||
(self.self_attn.forward(xs, None)? + residual)?.apply(&self.self_attn_layer_norm)?;
|
||||
let xs = (self.self_attn.forward(xs, None, None)? + residual)?
|
||||
.apply(&self.self_attn_layer_norm)?;
|
||||
let residual = &xs;
|
||||
let xs = xs
|
||||
.apply(&self.fc1)?
|
||||
@ -241,15 +249,20 @@ impl DecoderLayer {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, encoder_xs: Option<&Tensor>) -> Result<Tensor> {
|
||||
fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
encoder_xs: Option<&Tensor>,
|
||||
attn_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs =
|
||||
(self.self_attn.forward(xs, None)? + residual)?.apply(&self.self_attn_layer_norm)?;
|
||||
let xs = (self.self_attn.forward(xs, None, Some(attn_mask))? + residual)?
|
||||
.apply(&self.self_attn_layer_norm)?;
|
||||
let xs = match encoder_xs {
|
||||
None => xs,
|
||||
Some(encoder_xs) => {
|
||||
let residual = &xs;
|
||||
let xs = self.encoder_attn.forward(&xs, Some(encoder_xs))?;
|
||||
let xs = self.encoder_attn.forward(&xs, Some(encoder_xs), None)?;
|
||||
(residual + xs)?.apply(&self.encoder_attn_layer_norm)?
|
||||
}
|
||||
};
|
||||
@ -346,6 +359,7 @@ impl Decoder {
|
||||
xs: &Tensor,
|
||||
encoder_xs: Option<&Tensor>,
|
||||
past_kv_len: usize,
|
||||
attn_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.embed_tokens)?;
|
||||
let xs = match self.embed_scale {
|
||||
@ -358,7 +372,7 @@ impl Decoder {
|
||||
.unsqueeze(0)?;
|
||||
let mut xs = xs.broadcast_add(&embed_pos)?;
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, encoder_xs)?;
|
||||
xs = layer.forward(&xs, encoder_xs, attn_mask)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
@ -413,9 +427,14 @@ impl MTModel {
|
||||
}
|
||||
|
||||
pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result<Tensor> {
|
||||
let seq_len = xs.dim(1)?;
|
||||
let mask: Vec<_> = (0..seq_len)
|
||||
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect();
|
||||
let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
|
||||
self.model
|
||||
.decoder
|
||||
.forward(xs, Some(encoder_xs), 0)?
|
||||
.forward(xs, Some(encoder_xs), 0, &mask)?
|
||||
.apply(&self.lm_head)?
|
||||
.broadcast_add(&self.final_logits_bias)
|
||||
}
|
||||
|
Reference in New Issue
Block a user