mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Custom op for RmsNorm (#1890)
* Trying out a custom RmsNorm cuda kernel. * CPU implementation for rms-norm. * Cuda wrappers. * Add some validation. * Add some testing. * More testing.
This commit is contained in:
@ -4,11 +4,9 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{test_utils::to_vec3_round, Device, Result, Tensor};
|
||||
use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};
|
||||
|
||||
#[test]
|
||||
fn softmax() -> Result<()> {
|
||||
let device = &Device::Cpu;
|
||||
fn softmax(device: &Device) -> Result<()> {
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let t0 = candle_nn::ops::softmax(&tensor.log()?, 0)?;
|
||||
@ -54,6 +52,31 @@ fn softmax() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn rms_norm(device: &Device) -> Result<()> {
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let alpha = Tensor::new(&[1f32, 2f32, 3f32], device)?;
|
||||
let t = candle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?;
|
||||
assert_eq!(
|
||||
to_vec3_round(&t, 4)?,
|
||||
&[
|
||||
[[1.019, 0.6794, 4.0762], [0.1674, 1.6744, 4.521]],
|
||||
[[0.4714, 0.4714, 4.9497], [1.206, 0.603, 3.6181]]
|
||||
]
|
||||
);
|
||||
let t2 = candle_nn::ops::rms_norm_slow(&tensor, &alpha, 1e-5)?;
|
||||
assert_eq!(
|
||||
to_vec3_round(&t2, 4)?,
|
||||
&[
|
||||
[[1.019, 0.6794, 4.0762], [0.1674, 1.6744, 4.521]],
|
||||
[[0.4714, 0.4714, 4.9497], [1.206, 0.603, 3.6181]]
|
||||
]
|
||||
);
|
||||
let diff = (t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert!(diff < 1e-5);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softmax_numerical_stability() -> Result<()> {
|
||||
let dev = &Device::Cpu;
|
||||
@ -62,3 +85,6 @@ fn softmax_numerical_stability() -> Result<()> {
|
||||
assert_eq!(softmax.to_vec1::<f32>()?, &[1f32, 0.]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
|
||||
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
|
||||
|
Reference in New Issue
Block a user