Speed up bert with approx gelu (#1410)

This commit is contained in:
Juarez Bochi
2023-12-06 11:46:37 -05:00
committed by GitHub
parent 236b820e28
commit 9bd94c1ffa
3 changed files with 56 additions and 5 deletions

View File

@ -3,7 +3,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
use anyhow::{Error as E, Result};
use candle::Tensor;
@ -45,6 +45,10 @@ struct Args {
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
/// Use tanh based approximation for Gelu instead of erf implementation.
#[arg(long, default_value = "false")]
approximate_gelu: bool,
}
impl Args {
@ -73,7 +77,7 @@ impl Args {
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let mut config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if self.use_pth {
@ -81,6 +85,9 @@ impl Args {
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
if self.approximate_gelu {
config.hidden_act = HiddenAct::GeluApproximate;
}
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}