mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Use the faster rms-norm kernel for llama. (#2107)
* Use the faster rms-norm kernel for llama. * Use the fast variant by default.
This commit is contained in:
@ -28,7 +28,7 @@
|
||||
//! ```
|
||||
//!
|
||||
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle::{DType, Module, Result, Tensor, D};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct LayerNormConfig {
|
||||
@ -105,7 +105,7 @@ impl LayerNorm {
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for LayerNorm {
|
||||
impl Module for LayerNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
@ -162,13 +162,22 @@ impl RmsNorm {
|
||||
pub fn into_inner(self) -> LayerNorm {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// Faster variant of the forward kernel, this can only be used on contiguous tensors though.
|
||||
pub fn forward_diff(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.0.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for RmsNorm {
|
||||
impl Module for RmsNorm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
if xs.is_contiguous() {
|
||||
crate::ops::rms_norm(xs, &self.0.weight, self.0.eps as f32)
|
||||
} else {
|
||||
self.0.forward(xs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<RmsNorm> {
|
||||
let config = LayerNormConfig {
|
||||
|
@ -180,6 +180,11 @@ impl RmsNorm {
|
||||
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
pub fn forward_diff(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward_diff(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RmsNorm {
|
||||
|
Reference in New Issue
Block a user