mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 09:57:10 +00:00
Basic qmatmul
parallelization (#492)
* Basic `par_iter` parallelization * Pass errors up * Disable `avx` for x86 macs
This commit is contained in:
@ -1,8 +1,8 @@
|
||||
[target.x86_64-unknown-linux-gnu]
|
||||
rustflags = ["-C", "target-cpu=native"]
|
||||
|
||||
[target.aarch64-apple-darwin]
|
||||
[build]
|
||||
rustflags = ["-C", "target-cpu=native"]
|
||||
|
||||
[target.wasm32-unknown-unknown]
|
||||
rustflags = ["-C", "target-feature=+simd128"]
|
||||
|
||||
[target.x86_64-apple-darwin]
|
||||
rustflags = ["-C", "target-feature=-avx,-avx2"]
|
@ -1,6 +1,7 @@
|
||||
use super::GgmlDType;
|
||||
use crate::Result;
|
||||
use half::f16;
|
||||
use rayon::prelude::*;
|
||||
|
||||
// Default to QK_K 256 rather than 64.
|
||||
pub const QK_K: usize = 256;
|
||||
@ -13,7 +14,7 @@ pub const QK5_1: usize = 32;
|
||||
pub const QK8_0: usize = 32;
|
||||
pub const QK8_1: usize = 32;
|
||||
|
||||
pub trait GgmlType: Sized + Clone {
|
||||
pub trait GgmlType: Sized + Clone + Send + Sync {
|
||||
const DTYPE: GgmlDType;
|
||||
const BLCK_SIZE: usize;
|
||||
type VecDotType: GgmlType;
|
||||
@ -1030,10 +1031,19 @@ pub fn matmul<T: GgmlType>(
|
||||
for row_idx in 0..m {
|
||||
let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
|
||||
let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n];
|
||||
for (col_idx, dst) in dst_row.iter_mut().enumerate() {
|
||||
|
||||
let result: Result<Vec<_>> = dst_row
|
||||
.into_par_iter()
|
||||
.enumerate()
|
||||
.with_min_len(128)
|
||||
.with_max_len(512)
|
||||
.map(|(col_idx, dst)| {
|
||||
let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks];
|
||||
*dst = T::vec_dot(k, rhs_col, lhs_row)?;
|
||||
}
|
||||
T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = value)
|
||||
})
|
||||
.collect();
|
||||
|
||||
result?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user