mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Marian MT model (#1210)
* Skeleton files for the marian MT model. * Marian initialization. * Implement the attention forward method. * Forward pass for the encoder side. * Expose the encoder and decoder. * Start plugging the decoder. * Forward pass for the decoder layer. * Set up the marian example. * Add some missing backtraces. * Bugfix.
This commit is contained in:
90
candle-examples/examples/marian-mt/main.rs
Normal file
90
candle-examples/examples/marian-mt/main.rs
Normal file
@ -0,0 +1,90 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::marian;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
// TODO: Maybe add support for the conditional prompt.
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use the quantized version of the model.
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Text to be translated
|
||||
#[arg(long)]
|
||||
text: String,
|
||||
}
|
||||
|
||||
const SEP_TOKEN_ID: u32 = 102;
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let config = marian::Config::opus_mt_tc_big_fr_en();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[&args.model], DType::F32, &device)? };
|
||||
let model = marian::MTModel::new(&config, vb)?;
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&args.tokenizer).map_err(E::msg)?;
|
||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer.clone());
|
||||
let mut logits_processor =
|
||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||
|
||||
let encoder_xs = {
|
||||
let tokens = tokenizer
|
||||
.encode(args.text, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
model.encoder().forward(&tokens, 0)?
|
||||
};
|
||||
|
||||
let mut token_ids = vec![30522u32];
|
||||
for index in 0..1000 {
|
||||
// TODO: Add a kv cache.
|
||||
let context_size = if index >= 1000 { 1 } else { token_ids.len() };
|
||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||
let logits = model.decode(&input_ids, &encoder_xs)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
if token == SEP_TOKEN_ID {
|
||||
break;
|
||||
}
|
||||
token_ids.push(token);
|
||||
if let Some(t) = tokenizer_dec.next_token(token)? {
|
||||
use std::io::Write;
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user