diff --git a/candle-examples/examples/marian-mt/main.rs b/candle-examples/examples/marian-mt/main.rs index ed044627..bf33743c 100644 --- a/candle-examples/examples/marian-mt/main.rs +++ b/candle-examples/examples/marian-mt/main.rs @@ -36,8 +36,6 @@ struct Args { text: String, } -const SEP_TOKEN_ID: u32 = 102; - pub fn main() -> anyhow::Result<()> { let args = Args::parse(); @@ -62,7 +60,7 @@ pub fn main() -> anyhow::Result<()> { model.encoder().forward(&tokens, 0)? }; - let mut token_ids = vec![30522u32]; + let mut token_ids = vec![config.decoder_start_token_id]; for index in 0..1000 { // TODO: Add a kv cache. let context_size = if index >= 1000 { 1 } else { token_ids.len() }; @@ -72,7 +70,8 @@ pub fn main() -> anyhow::Result<()> { let logits = logits.squeeze(0)?; let logits = logits.get(logits.dim(0)? - 1)?; let token = logits_processor.sample(&logits)?; - if token == SEP_TOKEN_ID { + println!("{token}"); + if token == config.eos_token_id || token == config.forced_eos_token_id { break; } token_ids.push(token); diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs index d48ce38b..71f17720 100644 --- a/candle-transformers/src/models/marian.rs +++ b/candle-transformers/src/models/marian.rs @@ -18,11 +18,11 @@ pub struct Config { pub is_encoder_decoder: bool, pub activation_function: candle_nn::Activation, pub d_model: usize, - pub decoder_start_token_id: usize, + pub decoder_start_token_id: u32, pub scale_embedding: bool, - pub pad_token_id: usize, - pub eos_token_id: usize, - pub forced_eos_token_id: usize, + pub pad_token_id: u32, + pub eos_token_id: u32, + pub forced_eos_token_id: u32, pub share_encoder_decoder_embeddings: bool, } @@ -224,7 +224,8 @@ impl DecoderLayer { let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?; let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?; let encoder_attn = Attention::new(cfg, true, vb.pp("encoder_attn"))?; - let encoder_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?; + let encoder_attn_layer_norm = + layer_norm(cfg.d_model, 1e-5, vb.pp("encoder_attn_layer_norm"))?; let fc1 = linear(cfg.d_model, cfg.decoder_ffn_dim, vb.pp("fc1"))?; let fc2 = linear(cfg.decoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?; let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?; @@ -249,7 +250,7 @@ impl DecoderLayer { Some(encoder_xs) => { let residual = &xs; let xs = self.encoder_attn.forward(&xs, Some(encoder_xs))?; - (residual + xs)?.apply(&self.self_attn_layer_norm)? + (residual + xs)?.apply(&self.encoder_attn_layer_norm)? } }; let residual = &xs; @@ -257,7 +258,8 @@ impl DecoderLayer { .apply(&self.fc1)? .apply(&self.activation_fn)? .apply(&self.fc2)?; - (xs + residual)?.apply(&self.final_layer_norm) + let xs = (xs + residual)?.apply(&self.final_layer_norm)?; + Ok(xs) } } @@ -356,7 +358,7 @@ impl Decoder { .unsqueeze(0)?; let mut xs = xs.broadcast_add(&embed_pos)?; for layer in self.layers.iter() { - xs = layer.forward(&xs, encoder_xs)? + xs = layer.forward(&xs, encoder_xs)?; } Ok(xs) } @@ -385,6 +387,7 @@ impl Model { #[derive(Debug, Clone)] pub struct MTModel { model: Model, + lm_head: Linear, final_logits_bias: Tensor, } @@ -393,8 +396,10 @@ impl MTModel { let target_vocab_size = cfg.decoder_vocab_size.unwrap_or(cfg.vocab_size); let final_logits_bias = vb.get((1, target_vocab_size), "final_logits_bias")?; let model = Model::new(cfg, vb.pp("model"))?; + let lm_head = Linear::from_weights(model.shared.embeddings().clone(), None); Ok(Self { model, + lm_head, final_logits_bias, }) } @@ -408,6 +413,10 @@ impl MTModel { } pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result { - self.model.decoder.forward(xs, Some(encoder_xs), 0) + self.model + .decoder + .forward(xs, Some(encoder_xs), 0)? + .apply(&self.lm_head)? + .broadcast_add(&self.final_logits_bias) } }