mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fix the matmul layout for accelerate & mkl. (#2011)
* Fix the matmul layout for accelerate & mkl. * Reduce the required precision for pow (because of accelerate). * And a fix the gelu f16 test.
This commit is contained in:
@ -1330,7 +1330,7 @@ impl Map2 for MatMul {
|
|||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
|
|
||||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||||
(n as i32, b'N')
|
(n as i32, b'N')
|
||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
(k as i32, b'T')
|
(k as i32, b'T')
|
||||||
@ -1338,7 +1338,7 @@ impl Map2 for MatMul {
|
|||||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||||
};
|
};
|
||||||
// The b tensor has dims batching, m, k (lhs)
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||||
(k as i32, b'N')
|
(k as i32, b'N')
|
||||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
(m as i32, b'T')
|
(m as i32, b'T')
|
||||||
@ -1421,7 +1421,7 @@ impl Map2 for MatMul {
|
|||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
|
|
||||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||||
(n as i32, b'N')
|
(n as i32, b'N')
|
||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
(k as i32, b'T')
|
(k as i32, b'T')
|
||||||
@ -1429,7 +1429,7 @@ impl Map2 for MatMul {
|
|||||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||||
};
|
};
|
||||||
// The b tensor has dims batching, m, k (lhs)
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||||
(k as i32, b'N')
|
(k as i32, b'N')
|
||||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
(m as i32, b'T')
|
(m as i32, b'T')
|
||||||
|
@ -73,20 +73,7 @@ fn squeeze_mm(device: &Device) -> Result<()> {
|
|||||||
let seq_len = 8_usize;
|
let seq_len = 8_usize;
|
||||||
let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
|
let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
|
||||||
let x = a.i((.., seq_len - 1, ..))?;
|
let x = a.i((.., seq_len - 1, ..))?;
|
||||||
println!(
|
|
||||||
"x shape:{:?}, stride:{:?}, is_contiguous:{}",
|
|
||||||
x.shape(),
|
|
||||||
x.stride(),
|
|
||||||
x.is_contiguous()
|
|
||||||
);
|
|
||||||
|
|
||||||
let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
|
let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
|
||||||
println!(
|
|
||||||
"w shape:{:?}, stride:{:?}, is_contiguous:{}",
|
|
||||||
w.shape(),
|
|
||||||
w.stride(),
|
|
||||||
w.is_contiguous()
|
|
||||||
);
|
|
||||||
let x = x.matmul(&w)?;
|
let x = x.matmul(&w)?;
|
||||||
assert_eq!(x.dims(), &[1, 32]);
|
assert_eq!(x.dims(), &[1, 32]);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -107,13 +107,8 @@ fn unary_op(device: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
|
let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
|
||||||
assert_eq!(
|
let max_diff = (tensor.gelu()? - t_f16)?.flatten_all()?.max(0)?;
|
||||||
test_utils::to_vec2_round(&t_f16, 2)?,
|
assert!(max_diff.to_vec0::<f32>()? < 5e-3);
|
||||||
[
|
|
||||||
[-0.0, 0.84, 4.0, -0.05, 0.35],
|
|
||||||
[2.69, -0.07, -0.11, 1.73, 2.79]
|
|
||||||
],
|
|
||||||
);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
||||||
[
|
[
|
||||||
@ -1255,8 +1250,8 @@ fn pow() -> Result<()> {
|
|||||||
let rhs = (&lhs - 2.)?;
|
let rhs = (&lhs - 2.)?;
|
||||||
let res = lhs.pow(&rhs)?;
|
let res = lhs.pow(&rhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec2_round(&res, 4)?,
|
test_utils::to_vec2_round(&res, 3)?,
|
||||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
|
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0]]
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user