mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Speed up bert with approx gelu (#1410)
This commit is contained in:
@ -7,8 +7,9 @@ pub const DTYPE: DType = DType::F32;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum HiddenAct {
|
||||
pub enum HiddenAct {
|
||||
Gelu,
|
||||
GeluApproximate,
|
||||
Relu,
|
||||
}
|
||||
|
||||
@ -28,6 +29,7 @@ impl HiddenActLayer {
|
||||
match self.act {
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
|
||||
HiddenAct::Gelu => xs.gelu_erf(),
|
||||
HiddenAct::GeluApproximate => xs.gelu(),
|
||||
HiddenAct::Relu => xs.relu(),
|
||||
}
|
||||
}
|
||||
@ -48,7 +50,7 @@ pub struct Config {
|
||||
num_hidden_layers: usize,
|
||||
num_attention_heads: usize,
|
||||
intermediate_size: usize,
|
||||
hidden_act: HiddenAct,
|
||||
pub hidden_act: HiddenAct,
|
||||
hidden_dropout_prob: f64,
|
||||
max_position_embeddings: usize,
|
||||
type_vocab_size: usize,
|
||||
|
Reference in New Issue
Block a user