mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Basic qmatmul
parallelization (#492)
* Basic `par_iter` parallelization * Pass errors up * Disable `avx` for x86 macs
This commit is contained in:
@ -1,8 +1,8 @@
|
|||||||
[target.x86_64-unknown-linux-gnu]
|
[build]
|
||||||
rustflags = ["-C", "target-cpu=native"]
|
|
||||||
|
|
||||||
[target.aarch64-apple-darwin]
|
|
||||||
rustflags = ["-C", "target-cpu=native"]
|
rustflags = ["-C", "target-cpu=native"]
|
||||||
|
|
||||||
[target.wasm32-unknown-unknown]
|
[target.wasm32-unknown-unknown]
|
||||||
rustflags = ["-C", "target-feature=+simd128"]
|
rustflags = ["-C", "target-feature=+simd128"]
|
||||||
|
|
||||||
|
[target.x86_64-apple-darwin]
|
||||||
|
rustflags = ["-C", "target-feature=-avx,-avx2"]
|
@ -1,6 +1,7 @@
|
|||||||
use super::GgmlDType;
|
use super::GgmlDType;
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use half::f16;
|
use half::f16;
|
||||||
|
use rayon::prelude::*;
|
||||||
|
|
||||||
// Default to QK_K 256 rather than 64.
|
// Default to QK_K 256 rather than 64.
|
||||||
pub const QK_K: usize = 256;
|
pub const QK_K: usize = 256;
|
||||||
@ -13,7 +14,7 @@ pub const QK5_1: usize = 32;
|
|||||||
pub const QK8_0: usize = 32;
|
pub const QK8_0: usize = 32;
|
||||||
pub const QK8_1: usize = 32;
|
pub const QK8_1: usize = 32;
|
||||||
|
|
||||||
pub trait GgmlType: Sized + Clone {
|
pub trait GgmlType: Sized + Clone + Send + Sync {
|
||||||
const DTYPE: GgmlDType;
|
const DTYPE: GgmlDType;
|
||||||
const BLCK_SIZE: usize;
|
const BLCK_SIZE: usize;
|
||||||
type VecDotType: GgmlType;
|
type VecDotType: GgmlType;
|
||||||
@ -1030,10 +1031,19 @@ pub fn matmul<T: GgmlType>(
|
|||||||
for row_idx in 0..m {
|
for row_idx in 0..m {
|
||||||
let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
|
let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
|
||||||
let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n];
|
let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n];
|
||||||
for (col_idx, dst) in dst_row.iter_mut().enumerate() {
|
|
||||||
|
let result: Result<Vec<_>> = dst_row
|
||||||
|
.into_par_iter()
|
||||||
|
.enumerate()
|
||||||
|
.with_min_len(128)
|
||||||
|
.with_max_len(512)
|
||||||
|
.map(|(col_idx, dst)| {
|
||||||
let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks];
|
let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks];
|
||||||
*dst = T::vec_dot(k, rhs_col, lhs_row)?;
|
T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = value)
|
||||||
}
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
result?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user