From 43a007cba4c4f18cffb63ebeff7eefcb9e846922 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 4 Jul 2023 14:12:14 +0000 Subject: [PATCH 1/3] Upgrading bert example to work with `bert-base-uncased`. - Always take weights from the hub - Optional `model_id` + `revision` to use safetensors version potentially - Optional loading for `bert-base-uncased` (`weight` vs `gamma`). - Take the config from the hub. --- candle-examples/Cargo.toml | 2 + candle-examples/examples/bert/main.rs | 77 +++++++++++++++++++++------ 2 files changed, 63 insertions(+), 16 deletions(-) diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index a71ca17b..53a1a150 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -12,6 +12,8 @@ readme = "README.md" [dependencies] candle = { path = "../candle-core", default-features=false } +serde = { version = "1.0.166", features = ["derive"] } +serde_json = "1.0.99" [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index e5801314..968ecd0d 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -1,10 +1,9 @@ #![allow(dead_code)] -// The tokenizer.json and weights should be retrieved from: -// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 - use anyhow::{Error as E, Result}; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; +use candle_hub::{api::Api, Repo, RepoType}; use clap::Parser; +use serde::Deserialize; use std::collections::HashMap; const DTYPE: DType = DType::F32; @@ -66,7 +65,8 @@ impl<'a> VarBuilder<'a> { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "lowercase")] enum HiddenAct { Gelu, Relu, @@ -84,13 +84,14 @@ impl HiddenAct { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "lowercase")] enum PositionEmbeddingType { Absolute, } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1 -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Deserialize)] struct Config { vocab_size: usize, hidden_size: usize, @@ -235,8 +236,22 @@ impl LayerNorm { } fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result { - let weight = vb.get(size, &format!("{p}.weight"))?; - let bias = vb.get(size, &format!("{p}.bias"))?; + let (weight, bias) = match ( + vb.get(size, &format!("{p}.weight")), + vb.get(size, &format!("{p}.bias")), + ) { + (Ok(weight), Ok(bias)) => (weight, bias), + (Err(err), _) | (_, Err(err)) => { + if let (Ok(weight), Ok(bias)) = ( + vb.get(size, &format!("{p}.gamma")), + vb.get(size, &format!("{p}.beta")), + ) { + (weight, bias) + } else { + return Err(err.into()); + } + } + }; Ok(Self { weight, bias, eps }) } @@ -567,8 +582,21 @@ struct BertModel { impl BertModel { fn load(vb: &VarBuilder, config: &Config) -> Result { - let embeddings = BertEmbeddings::load("embeddings", vb, config)?; - let encoder = BertEncoder::load("encoder", vb, config)?; + let (embeddings, encoder) = match ( + BertEmbeddings::load("embeddings", vb, config), + BertEncoder::load("encoder", vb, config), + ) { + (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), + (Err(err), _) | (_, Err(err)) => { + match ( + BertEmbeddings::load("bert.embeddings", vb, config), + BertEncoder::load("bert.encoder", vb, config), + ) { + (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), + _ => return Err(err), + } + } + }; Ok(Self { embeddings, encoder, @@ -590,13 +618,14 @@ struct Args { cpu: bool, #[arg(long)] - tokenizer_config: String, + model_id: Option, #[arg(long)] - weights: String, + revision: Option, } -fn main() -> Result<()> { +#[tokio::main] +async fn main() -> Result<()> { use tokenizers::Tokenizer; let args = Args::parse(); @@ -606,13 +635,29 @@ fn main() -> Result<()> { Device::new_cuda(0)? }; - let mut tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?; + let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string(); + let default_revision = "refs/pr/21".to_string(); + let (model_id, revision) = match (args.model_id, args.revision) { + (Some(model_id), Some(revision)) => (model_id, revision), + (Some(model_id), None) => (model_id, "main".to_string()), + (None, Some(revision)) => (default_model, revision), + (None, None) => (default_model, default_revision), + }; + + let api = Api::new()?; + let repo = Repo::with_revision(model_id, RepoType::Model, revision); + println!("building the model"); + let config_filename = api.get(&repo, "config.json").await?; + let config = std::fs::read_to_string(config_filename)?; + let config: Config = serde_json::from_str(&config)?; + let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; + let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = tokenizer.with_padding(None).with_truncation(None); - let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? }; + let weights = api.get(&repo, "model.safetensors").await?; + let weights = unsafe { candle::safetensors::MmapedFile::new(weights)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone()); - let config = Config::all_mini_lm_l6_v2(); let model = BertModel::load(&vb, &config)?; let tokens = tokenizer From 963c75cb89c165fe5f1c8e24c481c58defcfb5a8 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 5 Jul 2023 07:19:57 +0000 Subject: [PATCH 2/3] Adding offline mode. --- candle-examples/examples/bert/main.rs | 52 +++++++++++++++++++++------ 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 968ecd0d..9c9dc206 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -1,7 +1,7 @@ #![allow(dead_code)] -use anyhow::{Error as E, Result}; +use anyhow::{anyhow, Error as E, Result}; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; -use candle_hub::{api::Api, Repo, RepoType}; +use candle_hub::{api::Api, Cache, Repo, RepoType}; use clap::Parser; use serde::Deserialize; use std::collections::HashMap; @@ -617,6 +617,10 @@ struct Args { #[arg(long)] cpu: bool, + /// Run offline (you must have the files already cached) + #[arg(long)] + offline: bool, + #[arg(long)] model_id: Option, @@ -627,6 +631,8 @@ struct Args { #[tokio::main] async fn main() -> Result<()> { use tokenizers::Tokenizer; + let start = std::time::Instant::now(); + println!("Building {:?}", start.elapsed()); let args = Args::parse(); let device = if args.cpu { @@ -644,21 +650,41 @@ async fn main() -> Result<()> { (None, None) => (default_model, default_revision), }; - let api = Api::new()?; let repo = Repo::with_revision(model_id, RepoType::Model, revision); - println!("building the model"); - let config_filename = api.get(&repo, "config.json").await?; + let (config_filename, tokenizer_filename, weights_filename) = if args.offline { + let cache = Cache::default(); + ( + cache + .get(&repo, "config.json") + .ok_or(anyhow!("Missing config file in cache"))?, + cache + .get(&repo, "tokenizer.json") + .ok_or(anyhow!("Missing tokenizer file in cache"))?, + cache + .get(&repo, "model.safetensors") + .ok_or(anyhow!("Missing weights file in cache"))?, + ) + } else { + let api = Api::new()?; + ( + api.get(&repo, "config.json").await?, + api.get(&repo, "tokenizer.json").await?, + api.get(&repo, "model.safetensors").await?, + ) + }; + println!("Building {:?}", start.elapsed()); let config = std::fs::read_to_string(config_filename)?; let config: Config = serde_json::from_str(&config)?; - let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; + println!("Config loaded {:?}", start.elapsed()); let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = tokenizer.with_padding(None).with_truncation(None); + println!("Tokenizer loaded {:?}", start.elapsed()); - let weights = api.get(&repo, "model.safetensors").await?; - let weights = unsafe { candle::safetensors::MmapedFile::new(weights)? }; + let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone()); let model = BertModel::load(&vb, &config)?; + println!("Loaded {:?}", start.elapsed()); let tokens = tokenizer .encode("This is an example sentence", true) @@ -666,9 +692,13 @@ async fn main() -> Result<()> { .get_ids() .to_vec(); let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?; - println!("{token_ids}"); let token_type_ids = token_ids.zeros_like()?; - let ys = model.forward(&token_ids, &token_type_ids)?; - println!("{ys}"); + println!("Loaded and encoded {:?}", start.elapsed()); + for _ in 0..100 { + let start = std::time::Instant::now(); + let _ys = model.forward(&token_ids, &token_type_ids)?; + println!("Took {:?}", start.elapsed()); + // println!("Ys {:?}", ys.shape()); + } Ok(()) } From d8f75ceeaa4702b641a9f71ec348fc54a32f4cd7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 5 Jul 2023 07:41:14 +0000 Subject: [PATCH 3/3] Some polish. --- candle-examples/examples/bert/main.rs | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 9c9dc206..4de0aeac 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -621,18 +621,26 @@ struct Args { #[arg(long)] offline: bool, + /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending #[arg(long)] model_id: Option, #[arg(long)] revision: Option, + + /// The number of times to run the prompt. + #[arg(long, default_value = "This is an example sentence")] + prompt: String, + + /// The number of times to run the prompt. + #[arg(long, default_value = "1")] + n: usize, } #[tokio::main] async fn main() -> Result<()> { use tokenizers::Tokenizer; let start = std::time::Instant::now(); - println!("Building {:?}", start.elapsed()); let args = Args::parse(); let device = if args.cpu { @@ -672,29 +680,25 @@ async fn main() -> Result<()> { api.get(&repo, "model.safetensors").await?, ) }; - println!("Building {:?}", start.elapsed()); let config = std::fs::read_to_string(config_filename)?; let config: Config = serde_json::from_str(&config)?; - println!("Config loaded {:?}", start.elapsed()); let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = tokenizer.with_padding(None).with_truncation(None); - println!("Tokenizer loaded {:?}", start.elapsed()); let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone()); let model = BertModel::load(&vb, &config)?; - println!("Loaded {:?}", start.elapsed()); let tokens = tokenizer - .encode("This is an example sentence", true) + .encode(args.prompt, true) .map_err(E::msg)? .get_ids() .to_vec(); let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?; let token_type_ids = token_ids.zeros_like()?; println!("Loaded and encoded {:?}", start.elapsed()); - for _ in 0..100 { + for _ in 0..args.n { let start = std::time::Instant::now(); let _ys = model.forward(&token_ids, &token_type_ids)?; println!("Took {:?}", start.elapsed());