Optimize the gelu f16 opt. (#2008)

* Optimize the gelu f16 opt.

* And add a test.
This commit is contained in:
Laurent Mazare
2024-04-04 16:28:23 +02:00
committed by GitHub
parent f48c07e242
commit 30b145150f
2 changed files with 19 additions and 8 deletions

View File

@ -106,6 +106,14 @@ fn unary_op(device: &Device) -> Result<()> {
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
]
);
let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
assert_eq!(
test_utils::to_vec2_round(&t_f16, 2)?,
[
[-0.0, 0.84, 4.0, -0.05, 0.35],
[2.69, -0.07, -0.11, 1.73, 2.79]
],
);
assert_eq!(
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
[