mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Enable the new layer-norm. (#2213)
* Enable the new layer-norm. * Shape fixes.
This commit is contained in:
@ -11,8 +11,8 @@
|
||||
//! use candle_nn::{LayerNorm, Module};
|
||||
//! # fn main() -> candle::Result<()> {
|
||||
//!
|
||||
//! let w = Tensor::new(1f32, &Cpu)?;
|
||||
//! let b = Tensor::new(0f32, &Cpu)?;
|
||||
//! let w = Tensor::new(&[1f32, 1f32, 1f32], &Cpu)?;
|
||||
//! let b = Tensor::new(&[0f32, 0f32, 0f32], &Cpu)?;
|
||||
//! let layer = LayerNorm::new(w, b, 1e-5);
|
||||
//!
|
||||
//! let xs = Tensor::new(
|
||||
@ -107,6 +107,11 @@ impl LayerNorm {
|
||||
|
||||
impl Module for LayerNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
if x.is_contiguous() && self.remove_mean {
|
||||
if let Some(bias) = self.bias.as_ref() {
|
||||
return crate::ops::layer_norm(x, &self.weight, bias, self.eps as f32);
|
||||
}
|
||||
}
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
|
@ -13,6 +13,12 @@ fn layer_norm() -> Result<()> {
|
||||
let device = &Device::Cpu;
|
||||
let w = Tensor::new(&[3f32], device)?;
|
||||
let b = Tensor::new(&[0.5f32], device)?;
|
||||
let ln2 = LayerNorm::new(Tensor::cat(&[&w, &w], 0)?, Tensor::cat(&[&b, &b], 0)?, 1e-8);
|
||||
let ln3 = LayerNorm::new(
|
||||
Tensor::cat(&[&w, &w, &w], 0)?,
|
||||
Tensor::cat(&[&b, &b, &b], 0)?,
|
||||
1e-8,
|
||||
);
|
||||
let ln = LayerNorm::new(w, b, 1e-8);
|
||||
|
||||
let two = Tensor::new(&[[[2f32]]], device)?;
|
||||
@ -20,11 +26,11 @@ fn layer_norm() -> Result<()> {
|
||||
assert_eq!(res.to_vec1::<f32>()?, [0.5f32]);
|
||||
|
||||
let inp = Tensor::new(&[[[4f32, 0f32]]], device)?;
|
||||
let res = ln.forward(&inp)?;
|
||||
let res = ln2.forward(&inp)?;
|
||||
assert_eq!(res.to_vec3::<f32>()?, [[[3.5f32, -2.5]]]);
|
||||
|
||||
let inp = Tensor::new(&[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]], device)?;
|
||||
let res = ln.forward(&inp)?;
|
||||
let res = ln3.forward(&inp)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res, 4)?,
|
||||
[[
|
||||
@ -35,7 +41,10 @@ fn layer_norm() -> Result<()> {
|
||||
);
|
||||
let mean = (res.sum_keepdim(2)? / 3.0)?;
|
||||
// The average value should be `b`.
|
||||
assert_eq!(mean.to_vec3::<f32>()?, [[[0.5], [0.5], [0.5]]]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&mean, 4)?,
|
||||
[[[0.5], [0.5], [0.5]]]
|
||||
);
|
||||
let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(2)?.sqrt()? / 3.0)?;
|
||||
// The standard deviation should be sqrt(`w`).
|
||||
assert_eq!(
|
||||
|
@ -56,24 +56,20 @@ impl RotaryEmbedding {
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((cfg.max_position_embeddings, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
dim,
|
||||
sin: emb.sin()?,
|
||||
cos: emb.cos()?,
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
|
||||
let xs_rot = xs.i((.., .., .., ..self.dim))?;
|
||||
let xs_rot = xs.i((.., .., .., ..self.dim))?.contiguous()?;
|
||||
let xs_pass = xs.i((.., .., .., self.dim..))?;
|
||||
let xs12 = xs_rot.chunk(2, D::Minus1)?;
|
||||
let (xs1, xs2) = (&xs12[0], &xs12[1]);
|
||||
let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let rotate_half = Tensor::cat(&[&xs2.neg()?, xs1], D::Minus1)?;
|
||||
let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
|
||||
let xs_rot = candle_nn::rotary_emb::rope(&xs_rot, &c, &s)?;
|
||||
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user