From d8f75ceeaa4702b641a9f71ec348fc54a32f4cd7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 5 Jul 2023 07:41:14 +0000 Subject: [PATCH] Some polish. --- candle-examples/examples/bert/main.rs | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 9c9dc206..4de0aeac 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -621,18 +621,26 @@ struct Args { #[arg(long)] offline: bool, + /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending #[arg(long)] model_id: Option, #[arg(long)] revision: Option, + + /// The number of times to run the prompt. + #[arg(long, default_value = "This is an example sentence")] + prompt: String, + + /// The number of times to run the prompt. + #[arg(long, default_value = "1")] + n: usize, } #[tokio::main] async fn main() -> Result<()> { use tokenizers::Tokenizer; let start = std::time::Instant::now(); - println!("Building {:?}", start.elapsed()); let args = Args::parse(); let device = if args.cpu { @@ -672,29 +680,25 @@ async fn main() -> Result<()> { api.get(&repo, "model.safetensors").await?, ) }; - println!("Building {:?}", start.elapsed()); let config = std::fs::read_to_string(config_filename)?; let config: Config = serde_json::from_str(&config)?; - println!("Config loaded {:?}", start.elapsed()); let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = tokenizer.with_padding(None).with_truncation(None); - println!("Tokenizer loaded {:?}", start.elapsed()); let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone()); let model = BertModel::load(&vb, &config)?; - println!("Loaded {:?}", start.elapsed()); let tokens = tokenizer - .encode("This is an example sentence", true) + .encode(args.prompt, true) .map_err(E::msg)? .get_ids() .to_vec(); let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?; let token_type_ids = token_ids.zeros_like()?; println!("Loaded and encoded {:?}", start.elapsed()); - for _ in 0..100 { + for _ in 0..args.n { let start = std::time::Instant::now(); let _ys = model.forward(&token_ids, &token_type_ids)?; println!("Took {:?}", start.elapsed());