Use the hub files for the marian example. (#1220)

* Use the hub files for the marian example.

* Use the secondary decoder.

* Add a readme.

* More readme.
This commit is contained in:
Laurent Mazare
2023-10-30 18:29:36 +01:00
committed by GitHub
parent c05c0a8213
commit 4c967b9184
4 changed files with 93 additions and 27 deletions

View File

@ -0,0 +1,19 @@
# candle-marian-mt
`marian-mt` is a neural machine translation model. In this example it is used to
translate text from French to English. See the associated [model
card](https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en) for details on
the model itself.
## Running an example
```bash
cargo run --example marian-mt --release -- \
--text "Demain, dès l'aube, à l'heure où blanchit la campagne, Je partirai. Vois-tu, je sais que tu m'attends. J'irai par la forêt, j'irai par la montagne. Je ne puis demeurer loin de toi plus longtemps."
```
```
<NIL> Tomorrow, at dawn, at the time when the country is whitening, I will go. See,
I know you are waiting for me. I will go through the forest, I will go through the
mountain. I cannot stay far from you any longer.</s>
```

View File

@ -8,7 +8,6 @@ use anyhow::Error as E;
use clap::Parser;
use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::marian;
@ -18,10 +17,13 @@ use tokenizers::Tokenizer;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: String,
model: Option<String>,
#[arg(long)]
tokenizer: String,
tokenizer: Option<String>,
#[arg(long)]
tokenizer_dec: Option<String>,
/// Run on CPU rather than on GPU.
#[arg(long)]
@ -37,25 +39,52 @@ struct Args {
}
pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();
let config = marian::Config::opus_mt_tc_big_fr_en();
let tokenizer = {
let tokenizer = match args.tokenizer {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => Api::new()?
.model("lmz/candle-marian".to_string())
.get("tokenizer-marian-fr.json")?,
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
let tokenizer_dec = {
let tokenizer = match args.tokenizer_dec {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => Api::new()?
.model("lmz/candle-marian".to_string())
.get("tokenizer-marian-en.json")?,
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[&args.model], DType::F32, &device)? };
let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
.get("model.safetensors")?,
};
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
};
let model = marian::MTModel::new(&config, vb)?;
let tokenizer = Tokenizer::from_file(&args.tokenizer).map_err(E::msg)?;
let mut tokenizer_dec = TokenOutputStream::new(tokenizer.clone());
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
let encoder_xs = {
let tokens = tokenizer
let mut tokens = tokenizer
.encode(args.text, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
tokens.push(config.eos_token_id);
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
model.encoder().forward(&tokens, 0)?
};
@ -70,20 +99,15 @@ pub fn main() -> anyhow::Result<()> {
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
token_ids.push(token);
println!("{token}");
if token == config.eos_token_id || token == config.forced_eos_token_id {
break;
}
token_ids.push(token);
if let Some(t) = tokenizer_dec.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
println!(
"{}",
tokenizer_dec.decode(&token_ids, true).map_err(E::msg)?
);
Ok(())
}