From 109e95b189fc6a587ef7d2f901d194646f59b0c4 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Fri, 18 Aug 2023 10:45:37 +0200 Subject: [PATCH] Basic `qmatmul` parallelization (#492) * Basic `par_iter` parallelization * Pass errors up * Disable `avx` for x86 macs --- .cargo/config.toml | 8 ++++---- candle-core/src/quantized/k_quants.rs | 20 +++++++++++++++----- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index 8ff190a4..ca9d853b 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,8 +1,8 @@ -[target.x86_64-unknown-linux-gnu] -rustflags = ["-C", "target-cpu=native"] - -[target.aarch64-apple-darwin] +[build] rustflags = ["-C", "target-cpu=native"] [target.wasm32-unknown-unknown] rustflags = ["-C", "target-feature=+simd128"] + +[target.x86_64-apple-darwin] +rustflags = ["-C", "target-feature=-avx,-avx2"] \ No newline at end of file diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index ee1a08a8..21001cfc 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1,6 +1,7 @@ use super::GgmlDType; use crate::Result; use half::f16; +use rayon::prelude::*; // Default to QK_K 256 rather than 64. pub const QK_K: usize = 256; @@ -13,7 +14,7 @@ pub const QK5_1: usize = 32; pub const QK8_0: usize = 32; pub const QK8_1: usize = 32; -pub trait GgmlType: Sized + Clone { +pub trait GgmlType: Sized + Clone + Send + Sync { const DTYPE: GgmlDType; const BLCK_SIZE: usize; type VecDotType: GgmlType; @@ -1030,10 +1031,19 @@ pub fn matmul( for row_idx in 0..m { let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n]; - for (col_idx, dst) in dst_row.iter_mut().enumerate() { - let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; - *dst = T::vec_dot(k, rhs_col, lhs_row)?; - } + + let result: Result> = dst_row + .into_par_iter() + .enumerate() + .with_min_len(128) + .with_max_len(512) + .map(|(col_idx, dst)| { + let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; + T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = value) + }) + .collect(); + + result?; } Ok(()) }