Files
candle/candle-core/src/cpu/simd128.rs
Laurent Mazare 495e0b7580 Simd support (#448)
* Import the simd intrinsics in candle-core.

* simd version of reduce-sum.

* Bugfix.

* Fix some clippy lints.
2023-08-15 09:50:38 +01:00

65 lines
1.5 KiB
Rust

use super::Cpu;
use core::arch::wasm32::*;
pub struct CurrentCpu {}
const STEP: usize = 16;
const EPR: usize = 4;
const ARR: usize = STEP / EPR;
impl Cpu<ARR> for CurrentCpu {
type Unit = v128;
type Array = [v128; ARR];
const STEP: usize = STEP;
const EPR: usize = EPR;
fn n() -> usize {
ARR
}
unsafe fn zero() -> Self::Unit {
f32x4_splat(0.0)
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}
unsafe fn from_f32(v: f32) -> Self::Unit {
f32x4_splat(v)
}
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
v128_load(mem_addr as *mut v128)
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
f32x4_add(a, b)
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
f32x4_add(f32x4_mul(b, c), a)
}
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
v128_store(mem_addr as *mut v128, a);
}
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
for i in 0..ARR / 2 {
x[2 * i] = f32x4_add(x[2 * i], x[2 * i + 1]);
}
for i in 0..ARR / 4 {
x[4 * i] = f32x4_add(x[4 * i], x[4 * i + 2]);
}
for i in 0..ARR / 8 {
x[8 * i] = f32x4_add(x[8 * i], x[8 * i + 4]);
}
*y = f32x4_extract_lane::<0>(x[0])
+ f32x4_extract_lane::<1>(x[0])
+ f32x4_extract_lane::<2>(x[0])
+ f32x4_extract_lane::<3>(x[0]);
}
}