mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Speed up bert with approx gelu (#1410)
This commit is contained in:
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
Bert is a general large language model. In this example it can be used for two
|
Bert is a general large language model. In this example it can be used for two
|
||||||
different tasks:
|
different tasks:
|
||||||
|
|
||||||
- Compute sentence embeddings for a prompt.
|
- Compute sentence embeddings for a prompt.
|
||||||
- Compute similarities between a set of sentences.
|
- Compute similarities between a set of sentences.
|
||||||
|
|
||||||
|
|
||||||
## Sentence embeddings
|
## Sentence embeddings
|
||||||
|
|
||||||
Bert is used to compute the sentence embeddings for a prompt. The model weights
|
Bert is used to compute the sentence embeddings for a prompt. The model weights
|
||||||
@ -24,6 +24,48 @@ cargo run --example bert --release -- --prompt "Here is a test sentence"
|
|||||||
> Tensor[[1, 7, 384], f32]
|
> Tensor[[1, 7, 384], f32]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Custom models
|
||||||
|
|
||||||
|
You can specify different models, such as BGE, with the `--model-id` flag:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example bert --release -- \
|
||||||
|
--model-id BAAI/bge-large-zh-v1.5 \
|
||||||
|
--prompt "Here is a test sentence"
|
||||||
|
Loaded and encoded 435.70775ms
|
||||||
|
[[[ 3.0944e-1, -7.8455e-5, -1.2768e0, ..., 1.3755e-2, -3.2371e-1, 2.3819e-1],
|
||||||
|
[-2.8506e-1, 1.9953e-1, -1.3076e0, ..., 6.9819e-2, 1.0833e-2, -1.1512e0],
|
||||||
|
[ 3.9892e-1, 2.0000e-1, -9.3178e-1, ..., -4.1393e-1, -4.9644e-2, -3.3786e-1],
|
||||||
|
...
|
||||||
|
[ 6.0345e-1, 3.5744e-1, -1.2672e0, ..., -6.9165e-1, -3.4973e-3, -8.4214e-1],
|
||||||
|
[ 3.9218e-1, -3.2735e-1, -1.3123e0, ..., -4.9318e-1, -5.1334e-1, -3.6391e-1],
|
||||||
|
[ 3.0978e-1, 2.5662e-4, -1.2773e0, ..., 1.3357e-2, -3.2390e-1, 2.3858e-1]]]
|
||||||
|
Tensor[[1, 9, 1024], f32]
|
||||||
|
Took 176.744667ms
|
||||||
|
```
|
||||||
|
|
||||||
|
### Gelu approximation
|
||||||
|
|
||||||
|
You can get a speedup by using an approximation of the gelu activation, with a
|
||||||
|
small loss of precision, by passing the `--approximate-gelu` flag:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example bert --release -- \
|
||||||
|
--model-id BAAI/bge-large-zh-v1.5 \
|
||||||
|
--prompt "Here is a test sentence" \
|
||||||
|
--approximate-gelu
|
||||||
|
Loaded and encoded 244.388042ms
|
||||||
|
[[[ 3.1048e-1, -6.0339e-4, -1.2758e0, ..., 1.3718e-2, -3.2362e-1, 2.3775e-1],
|
||||||
|
[-2.8354e-1, 1.9984e-1, -1.3077e0, ..., 6.9390e-2, 9.9681e-3, -1.1531e0],
|
||||||
|
[ 3.9947e-1, 1.9917e-1, -9.3178e-1, ..., -4.1301e-1, -5.0719e-2, -3.3955e-1],
|
||||||
|
...
|
||||||
|
[ 6.0499e-1, 3.5664e-1, -1.2642e0, ..., -6.9134e-1, -3.4581e-3, -8.4471e-1],
|
||||||
|
[ 3.9311e-1, -3.2812e-1, -1.3105e0, ..., -4.9291e-1, -5.1270e-1, -3.6543e-1],
|
||||||
|
[ 3.1082e-1, -2.6737e-4, -1.2762e0, ..., 1.3319e-2, -3.2381e-1, 2.3815e-1]]]
|
||||||
|
Tensor[[1, 9, 1024], f32]
|
||||||
|
Took 116.840791ms
|
||||||
|
```
|
||||||
|
|
||||||
## Similarities
|
## Similarities
|
||||||
|
|
||||||
In this example, Bert is used to compute the sentence embeddings for a set of
|
In this example, Bert is used to compute the sentence embeddings for a set of
|
||||||
|
@ -3,7 +3,7 @@ extern crate intel_mkl_src;
|
|||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::Tensor;
|
use candle::Tensor;
|
||||||
@ -45,6 +45,10 @@ struct Args {
|
|||||||
/// L2 normalization for embeddings.
|
/// L2 normalization for embeddings.
|
||||||
#[arg(long, default_value = "true")]
|
#[arg(long, default_value = "true")]
|
||||||
normalize_embeddings: bool,
|
normalize_embeddings: bool,
|
||||||
|
|
||||||
|
/// Use tanh based approximation for Gelu instead of erf implementation.
|
||||||
|
#[arg(long, default_value = "false")]
|
||||||
|
approximate_gelu: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
@ -73,7 +77,7 @@ impl Args {
|
|||||||
(config, tokenizer, weights)
|
(config, tokenizer, weights)
|
||||||
};
|
};
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
let config: Config = serde_json::from_str(&config)?;
|
let mut config: Config = serde_json::from_str(&config)?;
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let vb = if self.use_pth {
|
let vb = if self.use_pth {
|
||||||
@ -81,6 +85,9 @@ impl Args {
|
|||||||
} else {
|
} else {
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||||
};
|
};
|
||||||
|
if self.approximate_gelu {
|
||||||
|
config.hidden_act = HiddenAct::GeluApproximate;
|
||||||
|
}
|
||||||
let model = BertModel::load(vb, &config)?;
|
let model = BertModel::load(vb, &config)?;
|
||||||
Ok((model, tokenizer))
|
Ok((model, tokenizer))
|
||||||
}
|
}
|
||||||
|
@ -7,8 +7,9 @@ pub const DTYPE: DType = DType::F32;
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
enum HiddenAct {
|
pub enum HiddenAct {
|
||||||
Gelu,
|
Gelu,
|
||||||
|
GeluApproximate,
|
||||||
Relu,
|
Relu,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -28,6 +29,7 @@ impl HiddenActLayer {
|
|||||||
match self.act {
|
match self.act {
|
||||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
|
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
|
||||||
HiddenAct::Gelu => xs.gelu_erf(),
|
HiddenAct::Gelu => xs.gelu_erf(),
|
||||||
|
HiddenAct::GeluApproximate => xs.gelu(),
|
||||||
HiddenAct::Relu => xs.relu(),
|
HiddenAct::Relu => xs.relu(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -48,7 +50,7 @@ pub struct Config {
|
|||||||
num_hidden_layers: usize,
|
num_hidden_layers: usize,
|
||||||
num_attention_heads: usize,
|
num_attention_heads: usize,
|
||||||
intermediate_size: usize,
|
intermediate_size: usize,
|
||||||
hidden_act: HiddenAct,
|
pub hidden_act: HiddenAct,
|
||||||
hidden_dropout_prob: f64,
|
hidden_dropout_prob: f64,
|
||||||
max_position_embeddings: usize,
|
max_position_embeddings: usize,
|
||||||
type_vocab_size: usize,
|
type_vocab_size: usize,
|
||||||
|
Reference in New Issue
Block a user