mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
149 lines
3.8 KiB
Rust
149 lines
3.8 KiB
Rust
use super::{Cpu, CpuF16};
|
|
#[cfg(target_arch = "x86")]
|
|
use core::arch::x86::*;
|
|
#[cfg(target_arch = "x86_64")]
|
|
use core::arch::x86_64::*;
|
|
|
|
use half::f16;
|
|
|
|
pub struct CurrentCpu {}
|
|
|
|
const STEP: usize = 32;
|
|
const EPR: usize = 8;
|
|
const ARR: usize = STEP / EPR;
|
|
|
|
impl Cpu<ARR> for CurrentCpu {
|
|
type Unit = __m256;
|
|
type Array = [__m256; ARR];
|
|
|
|
const STEP: usize = STEP;
|
|
const EPR: usize = EPR;
|
|
|
|
fn n() -> usize {
|
|
ARR
|
|
}
|
|
|
|
unsafe fn zero() -> Self::Unit {
|
|
_mm256_setzero_ps()
|
|
}
|
|
|
|
unsafe fn zero_array() -> Self::Array {
|
|
[Self::zero(); ARR]
|
|
}
|
|
|
|
unsafe fn from_f32(v: f32) -> Self::Unit {
|
|
_mm256_set1_ps(v)
|
|
}
|
|
|
|
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
|
|
_mm256_loadu_ps(mem_addr)
|
|
}
|
|
|
|
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
|
_mm256_add_ps(a, b)
|
|
}
|
|
|
|
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
|
|
_mm256_add_ps(_mm256_mul_ps(b, c), a)
|
|
}
|
|
|
|
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
|
|
_mm256_storeu_ps(mem_addr, a);
|
|
}
|
|
|
|
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
|
|
for i in 0..ARR / 2 {
|
|
x[2 * i] = _mm256_add_ps(x[2 * i], x[2 * i + 1]);
|
|
}
|
|
for i in 0..ARR / 4 {
|
|
x[4 * i] = _mm256_add_ps(x[4 * i], x[4 * i + 2]);
|
|
}
|
|
#[allow(clippy::reversed_empty_ranges)]
|
|
for i in 0..ARR / 8 {
|
|
x[8 * i] = _mm256_add_ps(x[8 * i], x[8 * i + 4]);
|
|
}
|
|
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
|
|
let t1 = _mm_hadd_ps(t0, t0);
|
|
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
|
|
}
|
|
}
|
|
|
|
pub struct CurrentCpuF16 {}
|
|
impl CpuF16<ARR> for CurrentCpuF16 {
|
|
type Unit = __m256;
|
|
type Array = [__m256; ARR];
|
|
|
|
const STEP: usize = STEP;
|
|
const EPR: usize = EPR;
|
|
|
|
fn n() -> usize {
|
|
ARR
|
|
}
|
|
|
|
unsafe fn zero() -> Self::Unit {
|
|
_mm256_setzero_ps()
|
|
}
|
|
|
|
unsafe fn zero_array() -> Self::Array {
|
|
[Self::zero(); ARR]
|
|
}
|
|
|
|
unsafe fn from_f32(v: f32) -> Self::Unit {
|
|
_mm256_set1_ps(v)
|
|
}
|
|
|
|
#[cfg(target_feature = "f16c")]
|
|
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
|
|
_mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
|
|
}
|
|
|
|
#[cfg(not(target_feature = "f16c"))]
|
|
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
|
|
let mut tmp = [0.0f32; 8];
|
|
for i in 0..8 {
|
|
tmp[i] = (*mem_addr.add(i)).to_f32();
|
|
}
|
|
_mm256_loadu_ps(tmp.as_ptr())
|
|
}
|
|
|
|
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
|
_mm256_add_ps(a, b)
|
|
}
|
|
|
|
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
|
|
_mm256_add_ps(_mm256_mul_ps(b, c), a)
|
|
}
|
|
|
|
#[cfg(target_feature = "f16c")]
|
|
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
|
|
_mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
|
|
}
|
|
|
|
#[cfg(not(target_feature = "f16c"))]
|
|
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
|
|
let mut tmp = [0.0f32; 8];
|
|
_mm256_storeu_ps(tmp.as_mut_ptr(), a);
|
|
for i in 0..8 {
|
|
*mem_addr.add(i) = f16::from_f32(tmp[i]);
|
|
}
|
|
}
|
|
|
|
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
|
|
let mut offset = ARR >> 1;
|
|
for i in 0..offset {
|
|
x[i] = _mm256_add_ps(x[i], x[offset + i]);
|
|
}
|
|
offset >>= 1;
|
|
for i in 0..offset {
|
|
x[i] = _mm256_add_ps(x[i], x[offset + i]);
|
|
}
|
|
offset >>= 1;
|
|
for i in 0..offset {
|
|
x[i] = _mm256_add_ps(x[i], x[offset + i]);
|
|
}
|
|
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
|
|
let t1 = _mm_hadd_ps(t0, t0);
|
|
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
|
|
}
|
|
}
|