Optimizing decode matmul (Phi at 28tok/s on M3).

Adding some benchmark in order to help checking out matmul performance.
This commit is contained in:
Nicolas Patry
2023-12-20 09:54:19 +01:00
parent 03641293ee
commit 9b5e4843a6
4 changed files with 66 additions and 5 deletions

View File

@ -1297,11 +1297,21 @@ pub fn call_gemm(
let batched = b > 1;
let fused_activation = false;
let fused_bias = false;
let m_simd = 16;
let n_simd = 16;
let k_simd = 16;
let m_splits = 2;
let n_splits = 2;
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
let m_simd = 16;
let n_simd = 8;
let k_simd = 64;
let m_splits = 1;
let n_splits = 1;
(m_simd, n_simd, k_simd, m_splits, n_splits)
} else {
let m_simd = 40;
let n_simd = 40;
let k_simd = 8;
let m_splits = 1;
let n_splits = 1;
(m_simd, n_simd, k_simd, m_splits, n_splits)
};
let constants = Some(ConstantValues::new(vec![
(0, Value::USize(m)),
(1, Value::USize(n)),