mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Use the faster rms-norm kernel for llama. (#2107)
* Use the faster rms-norm kernel for llama. * Use the fast variant by default.
This commit is contained in:
@ -28,7 +28,7 @@
|
|||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
|
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
|
||||||
use candle::{DType, Result, Tensor, D};
|
use candle::{DType, Module, Result, Tensor, D};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
pub struct LayerNormConfig {
|
pub struct LayerNormConfig {
|
||||||
@ -105,7 +105,7 @@ impl LayerNorm {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl crate::Module for LayerNorm {
|
impl Module for LayerNorm {
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x_dtype = x.dtype();
|
let x_dtype = x.dtype();
|
||||||
let internal_dtype = match x_dtype {
|
let internal_dtype = match x_dtype {
|
||||||
@ -162,11 +162,20 @@ impl RmsNorm {
|
|||||||
pub fn into_inner(self) -> LayerNorm {
|
pub fn into_inner(self) -> LayerNorm {
|
||||||
self.0
|
self.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Faster variant of the forward kernel, this can only be used on contiguous tensors though.
|
||||||
|
pub fn forward_diff(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
self.0.forward(xs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl crate::Module for RmsNorm {
|
impl Module for RmsNorm {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
self.0.forward(xs)
|
if xs.is_contiguous() {
|
||||||
|
crate::ops::rms_norm(xs, &self.0.weight, self.0.eps as f32)
|
||||||
|
} else {
|
||||||
|
self.0.forward(xs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,6 +180,11 @@ impl RmsNorm {
|
|||||||
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
||||||
Ok(Self { inner, span })
|
Ok(Self { inner, span })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn forward_diff(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward_diff(x)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for RmsNorm {
|
impl Module for RmsNorm {
|
||||||
|
Reference in New Issue
Block a user