diff --git a/candle-examples/examples/bert/README.md b/candle-examples/examples/bert/README.md index 82ca5f40..5a75b516 100644 --- a/candle-examples/examples/bert/README.md +++ b/candle-examples/examples/bert/README.md @@ -2,10 +2,10 @@ Bert is a general large language model. In this example it can be used for two different tasks: + - Compute sentence embeddings for a prompt. - Compute similarities between a set of sentences. - ## Sentence embeddings Bert is used to compute the sentence embeddings for a prompt. The model weights @@ -24,6 +24,48 @@ cargo run --example bert --release -- --prompt "Here is a test sentence" > Tensor[[1, 7, 384], f32] ``` +### Custom models + +You can specify different models, such as BGE, with the `--model-id` flag: + +```bash +cargo run --example bert --release -- \ +--model-id BAAI/bge-large-zh-v1.5 \ +--prompt "Here is a test sentence" +Loaded and encoded 435.70775ms +[[[ 3.0944e-1, -7.8455e-5, -1.2768e0, ..., 1.3755e-2, -3.2371e-1, 2.3819e-1], + [-2.8506e-1, 1.9953e-1, -1.3076e0, ..., 6.9819e-2, 1.0833e-2, -1.1512e0], + [ 3.9892e-1, 2.0000e-1, -9.3178e-1, ..., -4.1393e-1, -4.9644e-2, -3.3786e-1], + ... + [ 6.0345e-1, 3.5744e-1, -1.2672e0, ..., -6.9165e-1, -3.4973e-3, -8.4214e-1], + [ 3.9218e-1, -3.2735e-1, -1.3123e0, ..., -4.9318e-1, -5.1334e-1, -3.6391e-1], + [ 3.0978e-1, 2.5662e-4, -1.2773e0, ..., 1.3357e-2, -3.2390e-1, 2.3858e-1]]] +Tensor[[1, 9, 1024], f32] +Took 176.744667ms +``` + +### Gelu approximation + +You can get a speedup by using an approximation of the gelu activation, with a +small loss of precision, by passing the `--approximate-gelu` flag: + +```bash +$ cargo run --example bert --release -- \ +--model-id BAAI/bge-large-zh-v1.5 \ +--prompt "Here is a test sentence" \ +--approximate-gelu +Loaded and encoded 244.388042ms +[[[ 3.1048e-1, -6.0339e-4, -1.2758e0, ..., 1.3718e-2, -3.2362e-1, 2.3775e-1], + [-2.8354e-1, 1.9984e-1, -1.3077e0, ..., 6.9390e-2, 9.9681e-3, -1.1531e0], + [ 3.9947e-1, 1.9917e-1, -9.3178e-1, ..., -4.1301e-1, -5.0719e-2, -3.3955e-1], + ... + [ 6.0499e-1, 3.5664e-1, -1.2642e0, ..., -6.9134e-1, -3.4581e-3, -8.4471e-1], + [ 3.9311e-1, -3.2812e-1, -1.3105e0, ..., -4.9291e-1, -5.1270e-1, -3.6543e-1], + [ 3.1082e-1, -2.6737e-4, -1.2762e0, ..., 1.3319e-2, -3.2381e-1, 2.3815e-1]]] +Tensor[[1, 9, 1024], f32] +Took 116.840791ms +``` + ## Similarities In this example, Bert is used to compute the sentence embeddings for a set of diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index fcd2eab9..88e29718 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -3,7 +3,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle_transformers::models::bert::{BertModel, Config, DTYPE}; +use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE}; use anyhow::{Error as E, Result}; use candle::Tensor; @@ -45,6 +45,10 @@ struct Args { /// L2 normalization for embeddings. #[arg(long, default_value = "true")] normalize_embeddings: bool, + + /// Use tanh based approximation for Gelu instead of erf implementation. + #[arg(long, default_value = "false")] + approximate_gelu: bool, } impl Args { @@ -73,7 +77,7 @@ impl Args { (config, tokenizer, weights) }; let config = std::fs::read_to_string(config_filename)?; - let config: Config = serde_json::from_str(&config)?; + let mut config: Config = serde_json::from_str(&config)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let vb = if self.use_pth { @@ -81,6 +85,9 @@ impl Args { } else { unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } }; + if self.approximate_gelu { + config.hidden_act = HiddenAct::GeluApproximate; + } let model = BertModel::load(vb, &config)?; Ok((model, tokenizer)) } diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index d6826a16..51c524f5 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -7,8 +7,9 @@ pub const DTYPE: DType = DType::F32; #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] #[serde(rename_all = "lowercase")] -enum HiddenAct { +pub enum HiddenAct { Gelu, + GeluApproximate, Relu, } @@ -28,6 +29,7 @@ impl HiddenActLayer { match self.act { // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 HiddenAct::Gelu => xs.gelu_erf(), + HiddenAct::GeluApproximate => xs.gelu(), HiddenAct::Relu => xs.relu(), } } @@ -48,7 +50,7 @@ pub struct Config { num_hidden_layers: usize, num_attention_heads: usize, intermediate_size: usize, - hidden_act: HiddenAct, + pub hidden_act: HiddenAct, hidden_dropout_prob: f64, max_position_embeddings: usize, type_vocab_size: usize,