mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Fix the matmul layout for accelerate & mkl. (#2011)
* Fix the matmul layout for accelerate & mkl. * Reduce the required precision for pow (because of accelerate). * And a fix the gelu f16 test.
This commit is contained in:
@ -107,13 +107,8 @@ fn unary_op(device: &Device) -> Result<()> {
|
||||
]
|
||||
);
|
||||
let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&t_f16, 2)?,
|
||||
[
|
||||
[-0.0, 0.84, 4.0, -0.05, 0.35],
|
||||
[2.69, -0.07, -0.11, 1.73, 2.79]
|
||||
],
|
||||
);
|
||||
let max_diff = (tensor.gelu()? - t_f16)?.flatten_all()?.max(0)?;
|
||||
assert!(max_diff.to_vec0::<f32>()? < 5e-3);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
||||
[
|
||||
@ -1255,8 +1250,8 @@ fn pow() -> Result<()> {
|
||||
let rhs = (&lhs - 2.)?;
|
||||
let res = lhs.pow(&rhs)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res, 4)?,
|
||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
|
||||
test_utils::to_vec2_round(&res, 3)?,
|
||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user