#[cfg(feature = "mkl")] extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; use anyhow::Error as E; use clap::{Parser, ValueEnum}; use candle::{DType, Tensor}; use candle_nn::VarBuilder; use candle_transformers::models::marian; use tokenizers::Tokenizer; #[derive(Clone, Debug, Copy, ValueEnum)] enum Which { Base, Big, } // TODO: Maybe add support for the conditional prompt. #[derive(Parser)] struct Args { #[arg(long)] model: Option, #[arg(long)] tokenizer: Option, #[arg(long)] tokenizer_dec: Option, /// Choose the variant of the model to run. #[arg(long, default_value = "big")] which: Which, /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, /// Use the quantized version of the model. #[arg(long)] quantized: bool, /// Text to be translated #[arg(long)] text: String, } pub fn main() -> anyhow::Result<()> { use hf_hub::api::sync::Api; let args = Args::parse(); let config = match args.which { Which::Base => marian::Config::opus_mt_fr_en(), Which::Big => marian::Config::opus_mt_tc_big_fr_en(), }; let tokenizer = { let tokenizer = match args.tokenizer { Some(tokenizer) => std::path::PathBuf::from(tokenizer), None => { let name = match args.which { Which::Base => "tokenizer-marian-base-fr.json", Which::Big => "tokenizer-marian-fr.json", }; Api::new()? .model("lmz/candle-marian".to_string()) .get(name)? } }; Tokenizer::from_file(&tokenizer).map_err(E::msg)? }; let tokenizer_dec = { let tokenizer = match args.tokenizer_dec { Some(tokenizer) => std::path::PathBuf::from(tokenizer), None => { let name = match args.which { Which::Base => "tokenizer-marian-base-en.json", Which::Big => "tokenizer-marian-en.json", }; Api::new()? .model("lmz/candle-marian".to_string()) .get(name)? } }; Tokenizer::from_file(&tokenizer).map_err(E::msg)? }; let device = candle_examples::device(args.cpu)?; let vb = { let model = match args.model { Some(model) => std::path::PathBuf::from(model), None => match args.which { Which::Base => Api::new()? .repo(hf_hub::Repo::with_revision( "Helsinki-NLP/opus-mt-fr-en".to_string(), hf_hub::RepoType::Model, "refs/pr/4".to_string(), )) .get("model.safetensors")?, Which::Big => Api::new()? .model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string()) .get("model.safetensors")?, }, }; unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? } }; let model = marian::MTModel::new(&config, vb)?; let mut logits_processor = candle_transformers::generation::LogitsProcessor::new(1337, None, None); let encoder_xs = { let mut tokens = tokenizer .encode(args.text, true) .map_err(E::msg)? .get_ids() .to_vec(); tokens.push(config.eos_token_id); let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; model.encoder().forward(&tokens, 0)? }; let mut token_ids = vec![config.decoder_start_token_id]; for index in 0..1000 { // TODO: Add a kv cache. let context_size = if index >= 1000 { 1 } else { token_ids.len() }; let start_pos = token_ids.len().saturating_sub(context_size); let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?; let logits = model.decode(&input_ids, &encoder_xs)?; let logits = logits.squeeze(0)?; let logits = logits.get(logits.dim(0)? - 1)?; let token = logits_processor.sample(&logits)?; token_ids.push(token); println!("{token}"); if token == config.eos_token_id || token == config.forced_eos_token_id { break; } } println!( "{}", tokenizer_dec.decode(&token_ids, true).map_err(E::msg)? ); Ok(()) }