//! Traits and methods for CPU-backed Tensors pub mod erf; pub mod kernels; #[allow(unused)] trait Cpu { type Unit; type Array; const STEP: usize; const EPR: usize; fn n() -> usize; unsafe fn zero() -> Self::Unit; unsafe fn zero_array() -> Self::Array; unsafe fn load(mem_addr: *const f32) -> Self::Unit; unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit; unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit; unsafe fn vec_reduce(x: Self::Array, y: *mut f32); unsafe fn from_f32(v: f32) -> Self::Unit; unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit); } #[allow(unused)] trait CpuF16 { type Unit; type Array; const STEP: usize; const EPR: usize; fn n() -> usize; unsafe fn zero() -> Self::Unit; unsafe fn zero_array() -> Self::Array; unsafe fn load(mem_addr: *const f16) -> Self::Unit; unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit; unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit; unsafe fn vec_reduce(x: Self::Array, y: *mut f32); unsafe fn from_f32(v: f32) -> Self::Unit; unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit); } use half::f16; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[cfg(target_feature = "avx")] pub mod avx; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[cfg(target_feature = "avx")] pub use avx::{CurrentCpu, CurrentCpuF16}; #[cfg(target_arch = "wasm32")] #[cfg(target_feature = "simd128")] pub mod simd128; #[cfg(target_arch = "wasm32")] #[cfg(target_feature = "simd128")] pub use simd128::CurrentCpu; #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] #[cfg(target_feature = "neon")] pub mod neon; #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] #[cfg(target_feature = "neon")] pub use neon::CurrentCpu; #[cfg(any( target_feature = "neon", target_feature = "avx", target_feature = "simd128" ))] #[inline(always)] pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) { let np = k & !(CurrentCpu::STEP - 1); let mut sum = CurrentCpu::zero_array(); let mut ax = CurrentCpu::zero_array(); let mut ay = CurrentCpu::zero_array(); for i in (0..np).step_by(CurrentCpu::STEP) { for j in 0..CurrentCpu::n() { ax[j] = CurrentCpu::load(a_row.add(i + j * CurrentCpu::EPR)); ay[j] = CurrentCpu::load(b_row.add(i + j * CurrentCpu::EPR)); sum[j] = CurrentCpu::vec_fma(sum[j], ax[j], ay[j]); } } CurrentCpu::vec_reduce(sum, c); // leftovers for i in np..k { *c += *a_row.add(i) * (*b_row.add(i)); } } #[cfg(not(any( target_feature = "neon", target_feature = "avx", target_feature = "simd128" )))] #[inline(always)] pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) { // leftovers for i in 0..k { *c += *a_row.add(i) * (*b_row.add(i)); } } #[cfg(any( target_feature = "neon", target_feature = "avx", target_feature = "simd128" ))] #[inline(always)] pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) { let np = k & !(CurrentCpu::STEP - 1); let mut sum = CurrentCpu::zero_array(); let mut x = CurrentCpu::zero_array(); for i in (0..np).step_by(CurrentCpu::STEP) { for j in 0..CurrentCpu::n() { x[j] = CurrentCpu::load(row.add(i + j * CurrentCpu::EPR)); sum[j] = CurrentCpu::vec_add(sum[j], x[j]); } } CurrentCpu::vec_reduce(sum, b); // leftovers for i in np..k { *b += *row.add(i) } } #[cfg(not(any( target_feature = "neon", target_feature = "avx", target_feature = "simd128" )))] #[inline(always)] pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) { *b = 0f32; for i in 0..k { *b += *row.add(i) } } #[cfg(target_feature = "avx")] #[inline(always)] pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) { let mut sumf = 0.0f32; let np = k & !(CurrentCpuF16::STEP - 1); let mut sum = CurrentCpuF16::zero_array(); let mut ax = CurrentCpuF16::zero_array(); let mut ay = CurrentCpuF16::zero_array(); for i in (0..np).step_by(CurrentCpuF16::STEP) { for j in 0..CurrentCpuF16::n() { ax[j] = CurrentCpuF16::load(a_row.add(i + j * CurrentCpuF16::EPR)); ay[j] = CurrentCpuF16::load(b_row.add(i + j * CurrentCpuF16::EPR)); sum[j] = CurrentCpuF16::vec_fma(sum[j], ax[j], ay[j]); } } CurrentCpuF16::vec_reduce(sum, &mut sumf); // leftovers for i in np..k { sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32(); } *c = sumf; } #[cfg(not(target_feature = "avx"))] #[inline(always)] pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) { // leftovers let mut sum = 0.0; for i in 0..k { sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32(); } *c = sum; }