mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Use the gelu-erf activation. (#969)
This commit is contained in:
@ -17,7 +17,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
|
|||||||
let expected_blocks = xs.len() / block_size;
|
let expected_blocks = xs.len() / block_size;
|
||||||
let actual_blocks = ys.len();
|
let actual_blocks = ys.len();
|
||||||
|
|
||||||
//validate that the input is the right size
|
// Validate that the input is the right size
|
||||||
if expected_blocks != actual_blocks {
|
if expected_blocks != actual_blocks {
|
||||||
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
|
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
|
||||||
}
|
}
|
||||||
@ -37,12 +37,12 @@ pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
|
|||||||
|
|
||||||
let actual_output_len = ys.len();
|
let actual_output_len = ys.len();
|
||||||
let expected_output_len = xs.len() * block_size;
|
let expected_output_len = xs.len() * block_size;
|
||||||
//validate that the output is the right size
|
// Validate that the output is the right size
|
||||||
if expected_output_len != actual_output_len {
|
if expected_output_len != actual_output_len {
|
||||||
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
|
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
|
||||||
}
|
}
|
||||||
|
|
||||||
//zip the blocks and outputs together
|
// Zip the blocks and outputs together
|
||||||
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
|
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,9 +16,7 @@ pub enum Activation {
|
|||||||
impl super::Module for Activation {
|
impl super::Module for Activation {
|
||||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::Gelu => xs.gelu(),
|
Self::Gelu => xs.gelu_erf(),
|
||||||
// TODO: This is "gelu_new", not the original "gelu".
|
|
||||||
// There's some small numerical difference:
|
|
||||||
// https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
|
// https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
|
||||||
Self::NewGelu => xs.gelu(),
|
Self::NewGelu => xs.gelu(),
|
||||||
Self::Relu => xs.relu(),
|
Self::Relu => xs.relu(),
|
||||||
|
@ -25,10 +25,8 @@ impl HiddenActLayer {
|
|||||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
match self.act {
|
match self.act {
|
||||||
// TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
|
|
||||||
// small numerical difference.
|
|
||||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
|
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
|
||||||
HiddenAct::Gelu => xs.gelu(),
|
HiddenAct::Gelu => xs.gelu_erf(),
|
||||||
HiddenAct::Relu => xs.relu(),
|
HiddenAct::Relu => xs.relu(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user