diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index a3e64a17..a39ee3a3 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -12,6 +12,8 @@ readme = "README.md" [dependencies] candle = { path = "../candle-core", default-features=false } +serde = { version = "1.0.166", features = ["derive"] } +serde_json = "1.0.99" num-traits = "0.2.15" [dev-dependencies] diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index e5801314..4de0aeac 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -1,10 +1,9 @@ #![allow(dead_code)] -// The tokenizer.json and weights should be retrieved from: -// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 - -use anyhow::{Error as E, Result}; +use anyhow::{anyhow, Error as E, Result}; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; +use candle_hub::{api::Api, Cache, Repo, RepoType}; use clap::Parser; +use serde::Deserialize; use std::collections::HashMap; const DTYPE: DType = DType::F32; @@ -66,7 +65,8 @@ impl<'a> VarBuilder<'a> { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "lowercase")] enum HiddenAct { Gelu, Relu, @@ -84,13 +84,14 @@ impl HiddenAct { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "lowercase")] enum PositionEmbeddingType { Absolute, } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1 -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Deserialize)] struct Config { vocab_size: usize, hidden_size: usize, @@ -235,8 +236,22 @@ impl LayerNorm { } fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result { - let weight = vb.get(size, &format!("{p}.weight"))?; - let bias = vb.get(size, &format!("{p}.bias"))?; + let (weight, bias) = match ( + vb.get(size, &format!("{p}.weight")), + vb.get(size, &format!("{p}.bias")), + ) { + (Ok(weight), Ok(bias)) => (weight, bias), + (Err(err), _) | (_, Err(err)) => { + if let (Ok(weight), Ok(bias)) = ( + vb.get(size, &format!("{p}.gamma")), + vb.get(size, &format!("{p}.beta")), + ) { + (weight, bias) + } else { + return Err(err.into()); + } + } + }; Ok(Self { weight, bias, eps }) } @@ -567,8 +582,21 @@ struct BertModel { impl BertModel { fn load(vb: &VarBuilder, config: &Config) -> Result { - let embeddings = BertEmbeddings::load("embeddings", vb, config)?; - let encoder = BertEncoder::load("encoder", vb, config)?; + let (embeddings, encoder) = match ( + BertEmbeddings::load("embeddings", vb, config), + BertEncoder::load("encoder", vb, config), + ) { + (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), + (Err(err), _) | (_, Err(err)) => { + match ( + BertEmbeddings::load("bert.embeddings", vb, config), + BertEncoder::load("bert.encoder", vb, config), + ) { + (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), + _ => return Err(err), + } + } + }; Ok(Self { embeddings, encoder, @@ -589,15 +617,30 @@ struct Args { #[arg(long)] cpu: bool, + /// Run offline (you must have the files already cached) #[arg(long)] - tokenizer_config: String, + offline: bool, + + /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending + #[arg(long)] + model_id: Option, #[arg(long)] - weights: String, + revision: Option, + + /// The number of times to run the prompt. + #[arg(long, default_value = "This is an example sentence")] + prompt: String, + + /// The number of times to run the prompt. + #[arg(long, default_value = "1")] + n: usize, } -fn main() -> Result<()> { +#[tokio::main] +async fn main() -> Result<()> { use tokenizers::Tokenizer; + let start = std::time::Instant::now(); let args = Args::parse(); let device = if args.cpu { @@ -606,24 +649,60 @@ fn main() -> Result<()> { Device::new_cuda(0)? }; - let mut tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?; + let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string(); + let default_revision = "refs/pr/21".to_string(); + let (model_id, revision) = match (args.model_id, args.revision) { + (Some(model_id), Some(revision)) => (model_id, revision), + (Some(model_id), None) => (model_id, "main".to_string()), + (None, Some(revision)) => (default_model, revision), + (None, None) => (default_model, default_revision), + }; + + let repo = Repo::with_revision(model_id, RepoType::Model, revision); + let (config_filename, tokenizer_filename, weights_filename) = if args.offline { + let cache = Cache::default(); + ( + cache + .get(&repo, "config.json") + .ok_or(anyhow!("Missing config file in cache"))?, + cache + .get(&repo, "tokenizer.json") + .ok_or(anyhow!("Missing tokenizer file in cache"))?, + cache + .get(&repo, "model.safetensors") + .ok_or(anyhow!("Missing weights file in cache"))?, + ) + } else { + let api = Api::new()?; + ( + api.get(&repo, "config.json").await?, + api.get(&repo, "tokenizer.json").await?, + api.get(&repo, "model.safetensors").await?, + ) + }; + let config = std::fs::read_to_string(config_filename)?; + let config: Config = serde_json::from_str(&config)?; + let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = tokenizer.with_padding(None).with_truncation(None); - let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? }; + let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone()); - let config = Config::all_mini_lm_l6_v2(); let model = BertModel::load(&vb, &config)?; let tokens = tokenizer - .encode("This is an example sentence", true) + .encode(args.prompt, true) .map_err(E::msg)? .get_ids() .to_vec(); let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?; - println!("{token_ids}"); let token_type_ids = token_ids.zeros_like()?; - let ys = model.forward(&token_ids, &token_type_ids)?; - println!("{ys}"); + println!("Loaded and encoded {:?}", start.elapsed()); + for _ in 0..args.n { + let start = std::time::Instant::now(); + let _ys = model.forward(&token_ids, &token_type_ids)?; + println!("Took {:?}", start.elapsed()); + // println!("Ys {:?}", ys.shape()); + } Ok(()) }