mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Fix the matmul layout for accelerate & mkl. (#2011)
* Fix the matmul layout for accelerate & mkl. * Reduce the required precision for pow (because of accelerate). * And a fix the gelu f16 test.
This commit is contained in:
@ -1330,7 +1330,7 @@ impl Map2 for MatMul {
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
|
||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||
(n as i32, b'N')
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
(k as i32, b'T')
|
||||
@ -1338,7 +1338,7 @@ impl Map2 for MatMul {
|
||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||
};
|
||||
// The b tensor has dims batching, m, k (lhs)
|
||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||
(k as i32, b'N')
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
(m as i32, b'T')
|
||||
@ -1421,7 +1421,7 @@ impl Map2 for MatMul {
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
|
||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||
(n as i32, b'N')
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
(k as i32, b'T')
|
||||
@ -1429,7 +1429,7 @@ impl Map2 for MatMul {
|
||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||
};
|
||||
// The b tensor has dims batching, m, k (lhs)
|
||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||
(k as i32, b'N')
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
(m as i32, b'T')
|
||||
|
Reference in New Issue
Block a user