From 4c967b9184834cd1e166dfdd6d88450d16bad8f2 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 30 Oct 2023 18:29:36 +0100 Subject: [PATCH] Use the hub files for the marian example. (#1220) * Use the hub files for the marian example. * Use the secondary decoder. * Add a readme. * More readme. --- README.md | 4 ++ candle-examples/examples/marian-mt/README.md | 19 +++++++ candle-examples/examples/marian-mt/main.rs | 58 ++++++++++++++------ candle-transformers/src/models/marian.rs | 39 +++++++++---- 4 files changed, 93 insertions(+), 27 deletions(-) create mode 100644 candle-examples/examples/marian-mt/README.md diff --git a/README.md b/README.md index 8c076ec7..81cbfe62 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,8 @@ We also provide a some command line based examples using state of the art models evaluation, segmentation). - [BLIP](./candle-examples/examples/blip/): image to text model, can be used to generate captions for an image. +- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation + model, generates the translated text from the input text. Run them using commands like: ``` @@ -174,6 +176,8 @@ If you have an addition to this list, please submit a pull request. - Wurstchen v2. - Image to text. - BLIP. + - Text to text. + - Marian MT (Machine Translation). - Computer Vision Models. - DINOv2, ConvMixer, EfficientNet, ResNet, ViT. - yolo-v3, yolo-v8. diff --git a/candle-examples/examples/marian-mt/README.md b/candle-examples/examples/marian-mt/README.md new file mode 100644 index 00000000..e41bf007 --- /dev/null +++ b/candle-examples/examples/marian-mt/README.md @@ -0,0 +1,19 @@ +# candle-marian-mt + +`marian-mt` is a neural machine translation model. In this example it is used to +translate text from French to English. See the associated [model +card](https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en) for details on +the model itself. + +## Running an example + +```bash +cargo run --example marian-mt --release -- \ + --text "Demain, dès l'aube, à l'heure où blanchit la campagne, Je partirai. Vois-tu, je sais que tu m'attends. J'irai par la forêt, j'irai par la montagne. Je ne puis demeurer loin de toi plus longtemps." +``` + +``` + Tomorrow, at dawn, at the time when the country is whitening, I will go. See, +I know you are waiting for me. I will go through the forest, I will go through the +mountain. I cannot stay far from you any longer. +``` diff --git a/candle-examples/examples/marian-mt/main.rs b/candle-examples/examples/marian-mt/main.rs index bf33743c..c198777c 100644 --- a/candle-examples/examples/marian-mt/main.rs +++ b/candle-examples/examples/marian-mt/main.rs @@ -8,7 +8,6 @@ use anyhow::Error as E; use clap::Parser; use candle::{DType, Tensor}; -use candle_examples::token_output_stream::TokenOutputStream; use candle_nn::VarBuilder; use candle_transformers::models::marian; @@ -18,10 +17,13 @@ use tokenizers::Tokenizer; #[derive(Parser)] struct Args { #[arg(long)] - model: String, + model: Option, #[arg(long)] - tokenizer: String, + tokenizer: Option, + + #[arg(long)] + tokenizer_dec: Option, /// Run on CPU rather than on GPU. #[arg(long)] @@ -37,25 +39,52 @@ struct Args { } pub fn main() -> anyhow::Result<()> { + use hf_hub::api::sync::Api; let args = Args::parse(); let config = marian::Config::opus_mt_tc_big_fr_en(); + let tokenizer = { + let tokenizer = match args.tokenizer { + Some(tokenizer) => std::path::PathBuf::from(tokenizer), + None => Api::new()? + .model("lmz/candle-marian".to_string()) + .get("tokenizer-marian-fr.json")?, + }; + Tokenizer::from_file(&tokenizer).map_err(E::msg)? + }; + + let tokenizer_dec = { + let tokenizer = match args.tokenizer_dec { + Some(tokenizer) => std::path::PathBuf::from(tokenizer), + None => Api::new()? + .model("lmz/candle-marian".to_string()) + .get("tokenizer-marian-en.json")?, + }; + Tokenizer::from_file(&tokenizer).map_err(E::msg)? + }; let device = candle_examples::device(args.cpu)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[&args.model], DType::F32, &device)? }; + let vb = { + let model = match args.model { + Some(model) => std::path::PathBuf::from(model), + None => Api::new()? + .model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string()) + .get("model.safetensors")?, + }; + unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? } + }; let model = marian::MTModel::new(&config, vb)?; - let tokenizer = Tokenizer::from_file(&args.tokenizer).map_err(E::msg)?; - let mut tokenizer_dec = TokenOutputStream::new(tokenizer.clone()); let mut logits_processor = candle_transformers::generation::LogitsProcessor::new(1337, None, None); let encoder_xs = { - let tokens = tokenizer + let mut tokens = tokenizer .encode(args.text, true) .map_err(E::msg)? .get_ids() .to_vec(); + tokens.push(config.eos_token_id); let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; model.encoder().forward(&tokens, 0)? }; @@ -70,20 +99,15 @@ pub fn main() -> anyhow::Result<()> { let logits = logits.squeeze(0)?; let logits = logits.get(logits.dim(0)? - 1)?; let token = logits_processor.sample(&logits)?; + token_ids.push(token); println!("{token}"); if token == config.eos_token_id || token == config.forced_eos_token_id { break; } - token_ids.push(token); - if let Some(t) = tokenizer_dec.next_token(token)? { - use std::io::Write; - print!("{t}"); - std::io::stdout().flush()?; - } } - if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? { - print!("{rest}"); - } - + println!( + "{}", + tokenizer_dec.decode(&token_ids, true).map_err(E::msg)? + ); Ok(()) } diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs index 71f17720..2bcfd2f7 100644 --- a/candle-transformers/src/models/marian.rs +++ b/candle-transformers/src/models/marian.rs @@ -135,7 +135,12 @@ impl Attention { .contiguous() } - fn forward(&self, xs: &Tensor, kv_states: Option<&Tensor>) -> Result { + fn forward( + &self, + xs: &Tensor, + kv_states: Option<&Tensor>, + attn_mask: Option<&Tensor>, + ) -> Result { let is_cross_attn = kv_states.is_some(); let (b_sz, tgt_len, _) = xs.dims3()?; let query_states = (xs.apply(&self.q_proj)? * self.scaling)?; @@ -156,7 +161,10 @@ impl Attention { let key_states = key_states.reshape(proj_shape)?; let value_states = value_states.reshape(proj_shape)?; let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; - // todo: attn_mask + let attn_weights = match attn_mask { + None => attn_weights, + Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?, + }; let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?; let attn_output = attn_probs.matmul(&value_states)?; attn_output @@ -196,8 +204,8 @@ impl EncoderLayer { fn forward(&self, xs: &Tensor) -> Result { let residual = xs; - let xs = - (self.self_attn.forward(xs, None)? + residual)?.apply(&self.self_attn_layer_norm)?; + let xs = (self.self_attn.forward(xs, None, None)? + residual)? + .apply(&self.self_attn_layer_norm)?; let residual = &xs; let xs = xs .apply(&self.fc1)? @@ -241,15 +249,20 @@ impl DecoderLayer { }) } - fn forward(&self, xs: &Tensor, encoder_xs: Option<&Tensor>) -> Result { + fn forward( + &self, + xs: &Tensor, + encoder_xs: Option<&Tensor>, + attn_mask: &Tensor, + ) -> Result { let residual = xs; - let xs = - (self.self_attn.forward(xs, None)? + residual)?.apply(&self.self_attn_layer_norm)?; + let xs = (self.self_attn.forward(xs, None, Some(attn_mask))? + residual)? + .apply(&self.self_attn_layer_norm)?; let xs = match encoder_xs { None => xs, Some(encoder_xs) => { let residual = &xs; - let xs = self.encoder_attn.forward(&xs, Some(encoder_xs))?; + let xs = self.encoder_attn.forward(&xs, Some(encoder_xs), None)?; (residual + xs)?.apply(&self.encoder_attn_layer_norm)? } }; @@ -346,6 +359,7 @@ impl Decoder { xs: &Tensor, encoder_xs: Option<&Tensor>, past_kv_len: usize, + attn_mask: &Tensor, ) -> Result { let xs = xs.apply(&self.embed_tokens)?; let xs = match self.embed_scale { @@ -358,7 +372,7 @@ impl Decoder { .unsqueeze(0)?; let mut xs = xs.broadcast_add(&embed_pos)?; for layer in self.layers.iter() { - xs = layer.forward(&xs, encoder_xs)?; + xs = layer.forward(&xs, encoder_xs, attn_mask)?; } Ok(xs) } @@ -413,9 +427,14 @@ impl MTModel { } pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result { + let seq_len = xs.dim(1)?; + let mask: Vec<_> = (0..seq_len) + .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) + .collect(); + let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?; self.model .decoder - .forward(xs, Some(encoder_xs), 0)? + .forward(xs, Some(encoder_xs), 0, &mask)? .apply(&self.lm_head)? .broadcast_add(&self.final_logits_bias) }