Speed up bert with approx gelu (#1410)

This commit is contained in:
Juarez Bochi
2023-12-06 11:46:37 -05:00
committed by GitHub
parent 236b820e28
commit 9bd94c1ffa
3 changed files with 56 additions and 5 deletions

View File

@ -7,8 +7,9 @@ pub const DTYPE: DType = DType::F32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")]
enum HiddenAct {
pub enum HiddenAct {
Gelu,
GeluApproximate,
Relu,
}
@ -28,6 +29,7 @@ impl HiddenActLayer {
match self.act {
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
HiddenAct::Gelu => xs.gelu_erf(),
HiddenAct::GeluApproximate => xs.gelu(),
HiddenAct::Relu => xs.relu(),
}
}
@ -48,7 +50,7 @@ pub struct Config {
num_hidden_layers: usize,
num_attention_heads: usize,
intermediate_size: usize,
hidden_act: HiddenAct,
pub hidden_act: HiddenAct,
hidden_dropout_prob: f64,
max_position_embeddings: usize,
type_vocab_size: usize,