#[cfg(feature = "mkl")] extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; use anyhow::Error as E; use clap::Parser; use candle::{DType, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; use candle_nn::VarBuilder; use candle_transformers::models::marian; use tokenizers::Tokenizer; // TODO: Maybe add support for the conditional prompt. #[derive(Parser)] struct Args { #[arg(long)] model: String, #[arg(long)] tokenizer: String, /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, /// Use the quantized version of the model. #[arg(long)] quantized: bool, /// Text to be translated #[arg(long)] text: String, } const SEP_TOKEN_ID: u32 = 102; pub fn main() -> anyhow::Result<()> { let args = Args::parse(); let config = marian::Config::opus_mt_tc_big_fr_en(); let device = candle_examples::device(args.cpu)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[&args.model], DType::F32, &device)? }; let model = marian::MTModel::new(&config, vb)?; let vocab = std::fs::read_to_string(args.tokenizer)?; let vocab = serde_json::from_str(&vocab)?; let tokenizer = tokenizers::models::wordpiece::WordPieceBuilder::new() .unk_token("".to_string()) .vocab(vocab) .build() .map_err(E::msg)?; let tokenizer = Tokenizer::new(tokenizer); let mut tokenizer_dec = TokenOutputStream::new(tokenizer.clone()); let mut logits_processor = candle_transformers::generation::LogitsProcessor::new(1337, None, None); let encoder_xs = { let tokens = tokenizer .encode(args.text, true) .map_err(E::msg)? .get_ids() .to_vec(); let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; model.encoder().forward(&tokens, 0)? }; let mut token_ids = vec![30522u32]; for index in 0..1000 { // TODO: Add a kv cache. let context_size = if index >= 1000 { 1 } else { token_ids.len() }; let start_pos = token_ids.len().saturating_sub(context_size); let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?; let logits = model.decode(&input_ids, &encoder_xs)?; let logits = logits.squeeze(0)?; let logits = logits.get(logits.dim(0)? - 1)?; let token = logits_processor.sample(&logits)?; if token == SEP_TOKEN_ID { break; } token_ids.push(token); if let Some(t) = tokenizer_dec.next_token(token)? { use std::io::Write; print!("{t}"); std::io::stdout().flush()?; } } if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? { print!("{rest}"); } Ok(()) }