mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Use the gelu-erf activation. (#969)
This commit is contained in:
@ -17,7 +17,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
|
||||
let expected_blocks = xs.len() / block_size;
|
||||
let actual_blocks = ys.len();
|
||||
|
||||
//validate that the input is the right size
|
||||
// Validate that the input is the right size
|
||||
if expected_blocks != actual_blocks {
|
||||
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
|
||||
}
|
||||
@ -37,12 +37,12 @@ pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
|
||||
|
||||
let actual_output_len = ys.len();
|
||||
let expected_output_len = xs.len() * block_size;
|
||||
//validate that the output is the right size
|
||||
// Validate that the output is the right size
|
||||
if expected_output_len != actual_output_len {
|
||||
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
|
||||
}
|
||||
|
||||
//zip the blocks and outputs together
|
||||
// Zip the blocks and outputs together
|
||||
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
|
||||
}
|
||||
|
||||
|
@ -16,9 +16,7 @@ pub enum Activation {
|
||||
impl super::Module for Activation {
|
||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Gelu => xs.gelu(),
|
||||
// TODO: This is "gelu_new", not the original "gelu".
|
||||
// There's some small numerical difference:
|
||||
Self::Gelu => xs.gelu_erf(),
|
||||
// https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
|
||||
Self::NewGelu => xs.gelu(),
|
||||
Self::Relu => xs.relu(),
|
||||
|
@ -25,10 +25,8 @@ impl HiddenActLayer {
|
||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
match self.act {
|
||||
// TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
|
||||
// small numerical difference.
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
|
||||
HiddenAct::Gelu => xs.gelu(),
|
||||
HiddenAct::Gelu => xs.gelu_erf(),
|
||||
HiddenAct::Relu => xs.relu(),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user