mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
98 lines
2.8 KiB
Rust
98 lines
2.8 KiB
Rust
#[cfg(feature = "mkl")]
|
|
extern crate intel_mkl_src;
|
|
|
|
#[cfg(feature = "accelerate")]
|
|
extern crate accelerate_src;
|
|
|
|
use anyhow::Error as E;
|
|
use clap::Parser;
|
|
|
|
use candle::{DType, Tensor};
|
|
use candle_examples::token_output_stream::TokenOutputStream;
|
|
use candle_nn::VarBuilder;
|
|
use candle_transformers::models::marian;
|
|
|
|
use tokenizers::Tokenizer;
|
|
|
|
// TODO: Maybe add support for the conditional prompt.
|
|
#[derive(Parser)]
|
|
struct Args {
|
|
#[arg(long)]
|
|
model: String,
|
|
|
|
#[arg(long)]
|
|
tokenizer: String,
|
|
|
|
/// Run on CPU rather than on GPU.
|
|
#[arg(long)]
|
|
cpu: bool,
|
|
|
|
/// Use the quantized version of the model.
|
|
#[arg(long)]
|
|
quantized: bool,
|
|
|
|
/// Text to be translated
|
|
#[arg(long)]
|
|
text: String,
|
|
}
|
|
|
|
const SEP_TOKEN_ID: u32 = 102;
|
|
|
|
pub fn main() -> anyhow::Result<()> {
|
|
let args = Args::parse();
|
|
|
|
let config = marian::Config::opus_mt_tc_big_fr_en();
|
|
|
|
let device = candle_examples::device(args.cpu)?;
|
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[&args.model], DType::F32, &device)? };
|
|
let model = marian::MTModel::new(&config, vb)?;
|
|
|
|
let vocab = std::fs::read_to_string(args.tokenizer)?;
|
|
let vocab = serde_json::from_str(&vocab)?;
|
|
let tokenizer = tokenizers::models::wordpiece::WordPieceBuilder::new()
|
|
.unk_token("<unk>".to_string())
|
|
.vocab(vocab)
|
|
.build()
|
|
.map_err(E::msg)?;
|
|
let tokenizer = Tokenizer::new(tokenizer);
|
|
let mut tokenizer_dec = TokenOutputStream::new(tokenizer.clone());
|
|
let mut logits_processor =
|
|
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
|
|
|
let encoder_xs = {
|
|
let tokens = tokenizer
|
|
.encode(args.text, true)
|
|
.map_err(E::msg)?
|
|
.get_ids()
|
|
.to_vec();
|
|
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
|
model.encoder().forward(&tokens, 0)?
|
|
};
|
|
|
|
let mut token_ids = vec![30522u32];
|
|
for index in 0..1000 {
|
|
// TODO: Add a kv cache.
|
|
let context_size = if index >= 1000 { 1 } else { token_ids.len() };
|
|
let start_pos = token_ids.len().saturating_sub(context_size);
|
|
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
|
let logits = model.decode(&input_ids, &encoder_xs)?;
|
|
let logits = logits.squeeze(0)?;
|
|
let logits = logits.get(logits.dim(0)? - 1)?;
|
|
let token = logits_processor.sample(&logits)?;
|
|
if token == SEP_TOKEN_ID {
|
|
break;
|
|
}
|
|
token_ids.push(token);
|
|
if let Some(t) = tokenizer_dec.next_token(token)? {
|
|
use std::io::Write;
|
|
print!("{t}");
|
|
std::io::stdout().flush()?;
|
|
}
|
|
}
|
|
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
|
print!("{rest}");
|
|
}
|
|
|
|
Ok(())
|
|
}
|