mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Simd support (#448)
* Import the simd intrinsics in candle-core. * simd version of reduce-sum. * Bugfix. * Fix some clippy lints.
This commit is contained in:
179
candle-core/src/cpu/mod.rs
Normal file
179
candle-core/src/cpu/mod.rs
Normal file
@ -0,0 +1,179 @@
|
||||
pub mod kernels;
|
||||
|
||||
trait Cpu<const ARR: usize> {
|
||||
type Unit;
|
||||
type Array;
|
||||
const STEP: usize;
|
||||
const EPR: usize;
|
||||
|
||||
fn n() -> usize;
|
||||
unsafe fn zero() -> Self::Unit;
|
||||
unsafe fn zero_array() -> Self::Array;
|
||||
unsafe fn load(mem_addr: *const f32) -> Self::Unit;
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
|
||||
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit;
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
|
||||
}
|
||||
|
||||
trait CpuF16<const ARR: usize> {
|
||||
type Unit;
|
||||
type Array;
|
||||
const STEP: usize;
|
||||
const EPR: usize;
|
||||
|
||||
fn n() -> usize;
|
||||
unsafe fn zero() -> Self::Unit;
|
||||
unsafe fn zero_array() -> Self::Array;
|
||||
unsafe fn load(mem_addr: *const f16) -> Self::Unit;
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
|
||||
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit;
|
||||
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit);
|
||||
}
|
||||
use half::f16;
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub use avx::{CurrentCpu, CurrentCpuF16};
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub mod simd128;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub use simd128::CurrentCpu;
|
||||
|
||||
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub use neon::CurrentCpu;
|
||||
|
||||
#[cfg(any(
|
||||
target_feature = "neon",
|
||||
target_feature = "avx",
|
||||
target_feature = "simd128"
|
||||
))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
|
||||
let np = k & !(CurrentCpu::STEP - 1);
|
||||
|
||||
let mut sum = CurrentCpu::zero_array();
|
||||
let mut ax = CurrentCpu::zero_array();
|
||||
let mut ay = CurrentCpu::zero_array();
|
||||
|
||||
for i in (0..np).step_by(CurrentCpu::STEP) {
|
||||
for j in 0..CurrentCpu::n() {
|
||||
ax[j] = CurrentCpu::load(a_row.add(i + j * CurrentCpu::EPR));
|
||||
ay[j] = CurrentCpu::load(b_row.add(i + j * CurrentCpu::EPR));
|
||||
|
||||
sum[j] = CurrentCpu::vec_fma(sum[j], ax[j], ay[j]);
|
||||
}
|
||||
}
|
||||
|
||||
CurrentCpu::vec_reduce(sum, c);
|
||||
|
||||
// leftovers
|
||||
for i in np..k {
|
||||
*c += *a_row.add(i) * (*b_row.add(i));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(
|
||||
target_feature = "neon",
|
||||
target_feature = "avx",
|
||||
target_feature = "simd128"
|
||||
)))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
|
||||
// leftovers
|
||||
for i in 0..k {
|
||||
*c += *a_row.add(i) * (*b_row.add(i));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(
|
||||
target_feature = "neon",
|
||||
target_feature = "avx",
|
||||
target_feature = "simd128"
|
||||
))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
|
||||
let np = k & !(CurrentCpu::STEP - 1);
|
||||
|
||||
let mut sum = CurrentCpu::zero_array();
|
||||
let mut x = CurrentCpu::zero_array();
|
||||
|
||||
for i in (0..np).step_by(CurrentCpu::STEP) {
|
||||
for j in 0..CurrentCpu::n() {
|
||||
x[j] = CurrentCpu::load(row.add(i + j * CurrentCpu::EPR));
|
||||
sum[j] = CurrentCpu::vec_add(sum[j], x[j]);
|
||||
}
|
||||
}
|
||||
|
||||
CurrentCpu::vec_reduce(sum, b);
|
||||
|
||||
// leftovers
|
||||
for i in np..k {
|
||||
*b += *row.add(i)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(
|
||||
target_feature = "neon",
|
||||
target_feature = "avx",
|
||||
target_feature = "simd128"
|
||||
)))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
|
||||
*b = 0f32;
|
||||
for i in 0..k {
|
||||
*b += *row.add(i)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
|
||||
let mut sumf = 0.0f32;
|
||||
let np = k & !(CurrentCpuF16::STEP - 1);
|
||||
|
||||
let mut sum = CurrentCpuF16::zero_array();
|
||||
let mut ax = CurrentCpuF16::zero_array();
|
||||
let mut ay = CurrentCpuF16::zero_array();
|
||||
|
||||
for i in (0..np).step_by(CurrentCpuF16::STEP) {
|
||||
for j in 0..CurrentCpuF16::n() {
|
||||
ax[j] = CurrentCpuF16::load(a_row.add(i + j * CurrentCpuF16::EPR));
|
||||
ay[j] = CurrentCpuF16::load(b_row.add(i + j * CurrentCpuF16::EPR));
|
||||
|
||||
sum[j] = CurrentCpuF16::vec_fma(sum[j], ax[j], ay[j]);
|
||||
}
|
||||
}
|
||||
|
||||
CurrentCpuF16::vec_reduce(sum, &mut sumf);
|
||||
|
||||
// leftovers
|
||||
for i in np..k {
|
||||
sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
|
||||
}
|
||||
*c = sumf;
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "avx"))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
|
||||
// leftovers
|
||||
let mut sum = 0.0;
|
||||
for i in 0..k {
|
||||
sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
|
||||
}
|
||||
*c = sum;
|
||||
}
|
Reference in New Issue
Block a user