mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Fix for the RWKV models. (#1955)
* Fix for the RWKV models. * More general fix + revert the rwkv hack. * Remove the old hack.
This commit is contained in:
@ -1454,9 +1454,9 @@ pub fn call_gemm(
|
||||
// lhs has shape b, m, k
|
||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||
// there is a single element.
|
||||
let a_trans = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) {
|
||||
let a_trans = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||
false
|
||||
} else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 {
|
||||
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
|
||||
true
|
||||
} else {
|
||||
return Err(MetalKernelError::MatMulNonContiguous {
|
||||
@ -1466,9 +1466,9 @@ pub fn call_gemm(
|
||||
})?;
|
||||
};
|
||||
// rhs has shape b, k, n
|
||||
let b_trans = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) {
|
||||
let b_trans = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||
false
|
||||
} else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 {
|
||||
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
|
||||
true
|
||||
} else {
|
||||
return Err(MetalKernelError::MatMulNonContiguous {
|
||||
|
Reference in New Issue
Block a user