Marian MT model (#1210)

* Skeleton files for the marian MT model.

* Marian initialization.

* Implement the attention forward method.

* Forward pass for the encoder side.

* Expose the encoder and decoder.

* Start plugging the decoder.

* Forward pass for the decoder layer.

* Set up the marian example.

* Add some missing backtraces.

* Bugfix.
This commit is contained in:
Laurent Mazare
2023-10-29 16:12:22 +01:00
committed by GitHub
parent c3f2676d49
commit 7bbde55c61
5 changed files with 521 additions and 10 deletions

View File

@ -0,0 +1,90 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::marian;
use tokenizers::Tokenizer;
// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
#[arg(long)]
model: String,
#[arg(long)]
tokenizer: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Use the quantized version of the model.
#[arg(long)]
quantized: bool,
/// Text to be translated
#[arg(long)]
text: String,
}
const SEP_TOKEN_ID: u32 = 102;
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let config = marian::Config::opus_mt_tc_big_fr_en();
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[&args.model], DType::F32, &device)? };
let model = marian::MTModel::new(&config, vb)?;
let tokenizer = Tokenizer::from_file(&args.tokenizer).map_err(E::msg)?;
let mut tokenizer_dec = TokenOutputStream::new(tokenizer.clone());
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
let encoder_xs = {
let tokens = tokenizer
.encode(args.text, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
model.encoder().forward(&tokens, 0)?
};
let mut token_ids = vec![30522u32];
for index in 0..1000 {
// TODO: Add a kv cache.
let context_size = if index >= 1000 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.decode(&input_ids, &encoder_xs)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
if token == SEP_TOKEN_ID {
break;
}
token_ids.push(token);
if let Some(t) = tokenizer_dec.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
Ok(())
}