mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Simd support (#448)
* Import the simd intrinsics in candle-core. * simd version of reduce-sum. * Bugfix. * Fix some clippy lints.
This commit is contained in:
148
candle-core/src/cpu/avx.rs
Normal file
148
candle-core/src/cpu/avx.rs
Normal file
@ -0,0 +1,148 @@
|
||||
use super::{Cpu, CpuF16};
|
||||
#[cfg(target_arch = "x86")]
|
||||
use core::arch::x86::*;
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
use core::arch::x86_64::*;
|
||||
|
||||
use half::f16;
|
||||
|
||||
pub struct CurrentCpu {}
|
||||
|
||||
const STEP: usize = 32;
|
||||
const EPR: usize = 8;
|
||||
const ARR: usize = STEP / EPR;
|
||||
|
||||
impl Cpu<ARR> for CurrentCpu {
|
||||
type Unit = __m256;
|
||||
type Array = [__m256; ARR];
|
||||
|
||||
const STEP: usize = STEP;
|
||||
const EPR: usize = EPR;
|
||||
|
||||
fn n() -> usize {
|
||||
ARR
|
||||
}
|
||||
|
||||
unsafe fn zero() -> Self::Unit {
|
||||
_mm256_setzero_ps()
|
||||
}
|
||||
|
||||
unsafe fn zero_array() -> Self::Array {
|
||||
[Self::zero(); ARR]
|
||||
}
|
||||
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit {
|
||||
_mm256_set1_ps(v)
|
||||
}
|
||||
|
||||
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
|
||||
_mm256_loadu_ps(mem_addr)
|
||||
}
|
||||
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
||||
_mm256_add_ps(a, b)
|
||||
}
|
||||
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
|
||||
_mm256_add_ps(_mm256_mul_ps(b, c), a)
|
||||
}
|
||||
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
|
||||
_mm256_storeu_ps(mem_addr, a);
|
||||
}
|
||||
|
||||
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
|
||||
for i in 0..ARR / 2 {
|
||||
x[2 * i] = _mm256_add_ps(x[2 * i], x[2 * i + 1]);
|
||||
}
|
||||
for i in 0..ARR / 4 {
|
||||
x[4 * i] = _mm256_add_ps(x[4 * i], x[4 * i + 2]);
|
||||
}
|
||||
#[allow(clippy::reversed_empty_ranges)]
|
||||
for i in 0..ARR / 8 {
|
||||
x[8 * i] = _mm256_add_ps(x[8 * i], x[8 * i + 4]);
|
||||
}
|
||||
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
|
||||
let t1 = _mm_hadd_ps(t0, t0);
|
||||
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CurrentCpuF16 {}
|
||||
impl CpuF16<ARR> for CurrentCpuF16 {
|
||||
type Unit = __m256;
|
||||
type Array = [__m256; ARR];
|
||||
|
||||
const STEP: usize = STEP;
|
||||
const EPR: usize = EPR;
|
||||
|
||||
fn n() -> usize {
|
||||
ARR
|
||||
}
|
||||
|
||||
unsafe fn zero() -> Self::Unit {
|
||||
_mm256_setzero_ps()
|
||||
}
|
||||
|
||||
unsafe fn zero_array() -> Self::Array {
|
||||
[Self::zero(); ARR]
|
||||
}
|
||||
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit {
|
||||
_mm256_set1_ps(v)
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "f16c")]
|
||||
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
|
||||
_mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "f16c"))]
|
||||
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
|
||||
let mut tmp = [0.0f32; 8];
|
||||
for i in 0..8 {
|
||||
tmp[i] = (*mem_addr.add(i)).to_f32();
|
||||
}
|
||||
_mm_loadu_ps(tmp.as_ptr())
|
||||
}
|
||||
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
||||
_mm256_add_ps(a, b)
|
||||
}
|
||||
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
|
||||
_mm256_add_ps(_mm256_mul_ps(b, c), a)
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "f16c")]
|
||||
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
|
||||
_mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "f16c"))]
|
||||
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
|
||||
let mut tmp = [0.0f32; 8];
|
||||
_mm256_storeu_ps(tmp.as_mut_ptr(), a);
|
||||
for i in 0..8 {
|
||||
*mem_addr.add(i) = f16::from_f32(tmp[i]);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
|
||||
let mut offset = ARR >> 1;
|
||||
for i in 0..offset {
|
||||
x[i] = _mm256_add_ps(x[i], x[offset + i]);
|
||||
}
|
||||
offset >>= 1;
|
||||
for i in 0..offset {
|
||||
x[i] = _mm256_add_ps(x[i], x[offset + i]);
|
||||
}
|
||||
offset >>= 1;
|
||||
for i in 0..offset {
|
||||
x[i] = _mm256_add_ps(x[i], x[offset + i]);
|
||||
}
|
||||
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
|
||||
let t1 = _mm_hadd_ps(t0, t0);
|
||||
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user