Add the layernorm specialized op. (#2212)

* Add the layernorm cuda kernels.

* Dedicated layer norm op.

* Add the slower variant.

* Plug the cuda implementation.

* Add the metal variant.

* Add a dedicated test.

* Bugfix.
This commit is contained in:
Laurent Mazare
2024-05-24 15:58:01 +02:00
committed by GitHub
parent 6f0b807ffd
commit 1df2bddccf
7 changed files with 547 additions and 6 deletions

View File

@ -77,6 +77,32 @@ fn rms_norm(device: &Device) -> Result<()> {
Ok(())
}
fn layer_norm(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
let alpha = Tensor::new(&[1f32, 2f32, 3f32], device)?;
let beta = Tensor::new(&[0.5f32, 0f32, -0.2f32], device)?;
let t = candle_nn::ops::layer_norm(&tensor, &alpha, &beta, 1e-5)?;
assert_eq!(
to_vec3_round(&t, 4)?,
&[
[[0.7673, -2.6726, 3.0071], [-0.7247, 0.0, 3.4742]],
[[-0.008, -1.778, 3.991], [1.2071, -2.8284, 1.9213]]
]
);
let t2 = candle_nn::ops::layer_norm_slow(&tensor, &alpha, &beta, 1e-5)?;
assert_eq!(
to_vec3_round(&t2, 4)?,
&[
[[0.7673, -2.6726, 3.0071], [-0.7247, 0.0, 3.4742]],
[[-0.008, -1.778, 3.991], [1.2071, -2.8284, 1.9213]]
]
);
let diff = (t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert!(diff < 1e-5);
Ok(())
}
#[test]
fn softmax_numerical_stability() -> Result<()> {
let dev = &Device::Cpu;
@ -185,4 +211,5 @@ test_device!(rope, rope_cpu, rope_gpu, rope_metal);
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal);
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);