#[cfg(feature = "mkl")] extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; use candle_transformers::models::t5; use anyhow::{anyhow, Error as E, Result}; use candle::{DType, Tensor}; use candle_nn::VarBuilder; use clap::Parser; use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; use tokenizers::Tokenizer; const DTYPE: DType = DType::F32; const DEFAULT_PROMPT: &str = "Translate English to German: That is good."; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, /// Run offline (you must have the files already cached) #[arg(long)] offline: bool, /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, /// The model repository to use on the HuggingFace hub. #[arg(long)] model_id: Option, #[arg(long)] revision: Option, /// Compute embeddings for this prompt or use the DEFAULT_PROMPT. #[arg(long)] prompt: Option, /// The number of times to run the prompt. #[arg(long, default_value = "1")] n: usize, } impl Args { fn build_model_and_tokenizer(&self) -> Result<(t5::T5EncoderModel, Tokenizer)> { let device = candle_examples::device(self.cpu)?; let default_model = "t5-small".to_string(); let default_revision = "refs/pr/15".to_string(); let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { (Some(model_id), Some(revision)) => (model_id, revision), (Some(model_id), None) => (model_id, "main".to_string()), (None, Some(revision)) => (default_model, revision), (None, None) => (default_model, default_revision), }; let repo = Repo::with_revision(model_id, RepoType::Model, revision); let (config_filename, tokenizer_filename, weights_filename) = if self.offline { let cache = Cache::default().repo(repo); ( cache .get("config.json") .ok_or(anyhow!("Missing config file in cache"))?, cache .get("tokenizer.json") .ok_or(anyhow!("Missing tokenizer file in cache"))?, cache .get("model.safetensors") .ok_or(anyhow!("Missing weights file in cache"))?, ) } else { let api = Api::new()?; let api = api.repo(repo); ( api.get("config.json")?, api.get("tokenizer.json")?, api.get("model.safetensors")?, ) }; let config = std::fs::read_to_string(config_filename)?; let config: t5::Config = serde_json::from_str(&config)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); let model = t5::T5EncoderModel::load(vb, &config)?; Ok((model, tokenizer)) } } fn main() -> Result<()> { let args = Args::parse(); let start = std::time::Instant::now(); let (model, mut tokenizer) = args.build_model_and_tokenizer()?; let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string()); let tokenizer = tokenizer .with_padding(None) .with_truncation(None) .map_err(E::msg)?; let tokens = tokenizer .encode(prompt, true) .map_err(E::msg)? .get_ids() .to_vec(); let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?; println!("Loaded and encoded {:?}", start.elapsed()); for idx in 0..args.n { let start = std::time::Instant::now(); let ys = model.forward(&token_ids)?; if idx == 0 { println!("{ys}"); } println!("Took {:?}", start.elapsed()); } Ok(()) }