mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Use the faster rms-norm kernel for llama. (#2107)
* Use the faster rms-norm kernel for llama. * Use the fast variant by default.
This commit is contained in:
@ -28,7 +28,7 @@
|
||||
//! ```
|
||||
//!
|
||||
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle::{DType, Module, Result, Tensor, D};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct LayerNormConfig {
|
||||
@ -105,7 +105,7 @@ impl LayerNorm {
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for LayerNorm {
|
||||
impl Module for LayerNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
@ -162,11 +162,20 @@ impl RmsNorm {
|
||||
pub fn into_inner(self) -> LayerNorm {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// Faster variant of the forward kernel, this can only be used on contiguous tensors though.
|
||||
pub fn forward_diff(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.0.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for RmsNorm {
|
||||
impl Module for RmsNorm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.0.forward(xs)
|
||||
if xs.is_contiguous() {
|
||||
crate::ops::rms_norm(xs, &self.0.weight, self.0.eps as f32)
|
||||
} else {
|
||||
self.0.forward(xs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user