mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Simd support (#448)
* Import the simd intrinsics in candle-core. * simd version of reduce-sum. * Bugfix. * Fix some clippy lints.
This commit is contained in:
148
candle-core/src/cpu/avx.rs
Normal file
148
candle-core/src/cpu/avx.rs
Normal file
@ -0,0 +1,148 @@
|
||||
use super::{Cpu, CpuF16};
|
||||
#[cfg(target_arch = "x86")]
|
||||
use core::arch::x86::*;
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
use core::arch::x86_64::*;
|
||||
|
||||
use half::f16;
|
||||
|
||||
pub struct CurrentCpu {}
|
||||
|
||||
const STEP: usize = 32;
|
||||
const EPR: usize = 8;
|
||||
const ARR: usize = STEP / EPR;
|
||||
|
||||
impl Cpu<ARR> for CurrentCpu {
|
||||
type Unit = __m256;
|
||||
type Array = [__m256; ARR];
|
||||
|
||||
const STEP: usize = STEP;
|
||||
const EPR: usize = EPR;
|
||||
|
||||
fn n() -> usize {
|
||||
ARR
|
||||
}
|
||||
|
||||
unsafe fn zero() -> Self::Unit {
|
||||
_mm256_setzero_ps()
|
||||
}
|
||||
|
||||
unsafe fn zero_array() -> Self::Array {
|
||||
[Self::zero(); ARR]
|
||||
}
|
||||
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit {
|
||||
_mm256_set1_ps(v)
|
||||
}
|
||||
|
||||
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
|
||||
_mm256_loadu_ps(mem_addr)
|
||||
}
|
||||
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
||||
_mm256_add_ps(a, b)
|
||||
}
|
||||
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
|
||||
_mm256_add_ps(_mm256_mul_ps(b, c), a)
|
||||
}
|
||||
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
|
||||
_mm256_storeu_ps(mem_addr, a);
|
||||
}
|
||||
|
||||
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
|
||||
for i in 0..ARR / 2 {
|
||||
x[2 * i] = _mm256_add_ps(x[2 * i], x[2 * i + 1]);
|
||||
}
|
||||
for i in 0..ARR / 4 {
|
||||
x[4 * i] = _mm256_add_ps(x[4 * i], x[4 * i + 2]);
|
||||
}
|
||||
#[allow(clippy::reversed_empty_ranges)]
|
||||
for i in 0..ARR / 8 {
|
||||
x[8 * i] = _mm256_add_ps(x[8 * i], x[8 * i + 4]);
|
||||
}
|
||||
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
|
||||
let t1 = _mm_hadd_ps(t0, t0);
|
||||
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CurrentCpuF16 {}
|
||||
impl CpuF16<ARR> for CurrentCpuF16 {
|
||||
type Unit = __m256;
|
||||
type Array = [__m256; ARR];
|
||||
|
||||
const STEP: usize = STEP;
|
||||
const EPR: usize = EPR;
|
||||
|
||||
fn n() -> usize {
|
||||
ARR
|
||||
}
|
||||
|
||||
unsafe fn zero() -> Self::Unit {
|
||||
_mm256_setzero_ps()
|
||||
}
|
||||
|
||||
unsafe fn zero_array() -> Self::Array {
|
||||
[Self::zero(); ARR]
|
||||
}
|
||||
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit {
|
||||
_mm256_set1_ps(v)
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "f16c")]
|
||||
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
|
||||
_mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "f16c"))]
|
||||
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
|
||||
let mut tmp = [0.0f32; 8];
|
||||
for i in 0..8 {
|
||||
tmp[i] = (*mem_addr.add(i)).to_f32();
|
||||
}
|
||||
_mm_loadu_ps(tmp.as_ptr())
|
||||
}
|
||||
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
||||
_mm256_add_ps(a, b)
|
||||
}
|
||||
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
|
||||
_mm256_add_ps(_mm256_mul_ps(b, c), a)
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "f16c")]
|
||||
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
|
||||
_mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "f16c"))]
|
||||
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
|
||||
let mut tmp = [0.0f32; 8];
|
||||
_mm256_storeu_ps(tmp.as_mut_ptr(), a);
|
||||
for i in 0..8 {
|
||||
*mem_addr.add(i) = f16::from_f32(tmp[i]);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
|
||||
let mut offset = ARR >> 1;
|
||||
for i in 0..offset {
|
||||
x[i] = _mm256_add_ps(x[i], x[offset + i]);
|
||||
}
|
||||
offset >>= 1;
|
||||
for i in 0..offset {
|
||||
x[i] = _mm256_add_ps(x[i], x[offset + i]);
|
||||
}
|
||||
offset >>= 1;
|
||||
for i in 0..offset {
|
||||
x[i] = _mm256_add_ps(x[i], x[offset + i]);
|
||||
}
|
||||
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
|
||||
let t1 = _mm_hadd_ps(t0, t0);
|
||||
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
|
||||
}
|
||||
}
|
89
candle-core/src/cpu/kernels.rs
Normal file
89
candle-core/src/cpu/kernels.rs
Normal file
@ -0,0 +1,89 @@
|
||||
pub trait VecDot: num_traits::NumAssign + Copy {
|
||||
/// Dot-product of two vectors.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The length of `lhs` and `rhs` have to be at least `len`. `res` has to point to a valid
|
||||
/// element.
|
||||
#[inline(always)]
|
||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = Self::zero();
|
||||
for i in 0..len {
|
||||
*res += *lhs.add(i) * *rhs.add(i)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sum of all elements in a vector.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The length of `xs` must be at least `len`. `res` has to point to a valid
|
||||
/// element.
|
||||
#[inline(always)]
|
||||
unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
|
||||
*res = Self::zero();
|
||||
for i in 0..len {
|
||||
*res += *xs.add(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VecDot for f32 {
|
||||
#[inline(always)]
|
||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||
super::vec_dot_f32(lhs, rhs, res, len)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
|
||||
super::vec_sum(xs, res, len)
|
||||
}
|
||||
}
|
||||
|
||||
impl VecDot for half::f16 {
|
||||
#[inline(always)]
|
||||
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
|
||||
let mut res_f32 = 0f32;
|
||||
super::vec_dot_f16(lhs, rhs, &mut res_f32, len);
|
||||
*res = half::f16::from_f32(res_f32);
|
||||
}
|
||||
}
|
||||
|
||||
impl VecDot for f64 {}
|
||||
impl VecDot for half::bf16 {}
|
||||
impl VecDot for u8 {}
|
||||
impl VecDot for u32 {}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
|
||||
if n_threads == 1 {
|
||||
func(0)
|
||||
} else {
|
||||
rayon::scope(|s| {
|
||||
for thread_idx in 0..n_threads {
|
||||
let func = &func;
|
||||
s.spawn(move |_| func(thread_idx));
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn par_range(lo: usize, up: usize, n_threads: usize, func: impl Fn(usize) + Send + Sync) {
|
||||
if n_threads == 1 {
|
||||
for i in lo..up {
|
||||
func(i)
|
||||
}
|
||||
} else {
|
||||
rayon::scope(|s| {
|
||||
for thread_idx in 0..n_threads {
|
||||
let func = &func;
|
||||
s.spawn(move |_| {
|
||||
for i in (thread_idx..up).step_by(n_threads) {
|
||||
func(i)
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
179
candle-core/src/cpu/mod.rs
Normal file
179
candle-core/src/cpu/mod.rs
Normal file
@ -0,0 +1,179 @@
|
||||
pub mod kernels;
|
||||
|
||||
trait Cpu<const ARR: usize> {
|
||||
type Unit;
|
||||
type Array;
|
||||
const STEP: usize;
|
||||
const EPR: usize;
|
||||
|
||||
fn n() -> usize;
|
||||
unsafe fn zero() -> Self::Unit;
|
||||
unsafe fn zero_array() -> Self::Array;
|
||||
unsafe fn load(mem_addr: *const f32) -> Self::Unit;
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
|
||||
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit;
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
|
||||
}
|
||||
|
||||
trait CpuF16<const ARR: usize> {
|
||||
type Unit;
|
||||
type Array;
|
||||
const STEP: usize;
|
||||
const EPR: usize;
|
||||
|
||||
fn n() -> usize;
|
||||
unsafe fn zero() -> Self::Unit;
|
||||
unsafe fn zero_array() -> Self::Array;
|
||||
unsafe fn load(mem_addr: *const f16) -> Self::Unit;
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
|
||||
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit;
|
||||
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit);
|
||||
}
|
||||
use half::f16;
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub use avx::{CurrentCpu, CurrentCpuF16};
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub mod simd128;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub use simd128::CurrentCpu;
|
||||
|
||||
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub use neon::CurrentCpu;
|
||||
|
||||
#[cfg(any(
|
||||
target_feature = "neon",
|
||||
target_feature = "avx",
|
||||
target_feature = "simd128"
|
||||
))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
|
||||
let np = k & !(CurrentCpu::STEP - 1);
|
||||
|
||||
let mut sum = CurrentCpu::zero_array();
|
||||
let mut ax = CurrentCpu::zero_array();
|
||||
let mut ay = CurrentCpu::zero_array();
|
||||
|
||||
for i in (0..np).step_by(CurrentCpu::STEP) {
|
||||
for j in 0..CurrentCpu::n() {
|
||||
ax[j] = CurrentCpu::load(a_row.add(i + j * CurrentCpu::EPR));
|
||||
ay[j] = CurrentCpu::load(b_row.add(i + j * CurrentCpu::EPR));
|
||||
|
||||
sum[j] = CurrentCpu::vec_fma(sum[j], ax[j], ay[j]);
|
||||
}
|
||||
}
|
||||
|
||||
CurrentCpu::vec_reduce(sum, c);
|
||||
|
||||
// leftovers
|
||||
for i in np..k {
|
||||
*c += *a_row.add(i) * (*b_row.add(i));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(
|
||||
target_feature = "neon",
|
||||
target_feature = "avx",
|
||||
target_feature = "simd128"
|
||||
)))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
|
||||
// leftovers
|
||||
for i in 0..k {
|
||||
*c += *a_row.add(i) * (*b_row.add(i));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(
|
||||
target_feature = "neon",
|
||||
target_feature = "avx",
|
||||
target_feature = "simd128"
|
||||
))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
|
||||
let np = k & !(CurrentCpu::STEP - 1);
|
||||
|
||||
let mut sum = CurrentCpu::zero_array();
|
||||
let mut x = CurrentCpu::zero_array();
|
||||
|
||||
for i in (0..np).step_by(CurrentCpu::STEP) {
|
||||
for j in 0..CurrentCpu::n() {
|
||||
x[j] = CurrentCpu::load(row.add(i + j * CurrentCpu::EPR));
|
||||
sum[j] = CurrentCpu::vec_add(sum[j], x[j]);
|
||||
}
|
||||
}
|
||||
|
||||
CurrentCpu::vec_reduce(sum, b);
|
||||
|
||||
// leftovers
|
||||
for i in np..k {
|
||||
*b += *row.add(i)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(
|
||||
target_feature = "neon",
|
||||
target_feature = "avx",
|
||||
target_feature = "simd128"
|
||||
)))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
|
||||
*b = 0f32;
|
||||
for i in 0..k {
|
||||
*b += *row.add(i)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
|
||||
let mut sumf = 0.0f32;
|
||||
let np = k & !(CurrentCpuF16::STEP - 1);
|
||||
|
||||
let mut sum = CurrentCpuF16::zero_array();
|
||||
let mut ax = CurrentCpuF16::zero_array();
|
||||
let mut ay = CurrentCpuF16::zero_array();
|
||||
|
||||
for i in (0..np).step_by(CurrentCpuF16::STEP) {
|
||||
for j in 0..CurrentCpuF16::n() {
|
||||
ax[j] = CurrentCpuF16::load(a_row.add(i + j * CurrentCpuF16::EPR));
|
||||
ay[j] = CurrentCpuF16::load(b_row.add(i + j * CurrentCpuF16::EPR));
|
||||
|
||||
sum[j] = CurrentCpuF16::vec_fma(sum[j], ax[j], ay[j]);
|
||||
}
|
||||
}
|
||||
|
||||
CurrentCpuF16::vec_reduce(sum, &mut sumf);
|
||||
|
||||
// leftovers
|
||||
for i in np..k {
|
||||
sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
|
||||
}
|
||||
*c = sumf;
|
||||
}
|
||||
|
||||
#[cfg(not(target_feature = "avx"))]
|
||||
#[inline(always)]
|
||||
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
|
||||
// leftovers
|
||||
let mut sum = 0.0;
|
||||
for i in 0..k {
|
||||
sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
|
||||
}
|
||||
*c = sum;
|
||||
}
|
77
candle-core/src/cpu/neon.rs
Normal file
77
candle-core/src/cpu/neon.rs
Normal file
@ -0,0 +1,77 @@
|
||||
use super::Cpu;
|
||||
#[cfg(target_arch = "arm")]
|
||||
use core::arch::arm::*;
|
||||
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
use core::arch::aarch64::*;
|
||||
|
||||
pub struct CurrentCpu {}
|
||||
|
||||
const STEP: usize = 16;
|
||||
const EPR: usize = 4;
|
||||
const ARR: usize = STEP / EPR;
|
||||
|
||||
impl CurrentCpu {
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
unsafe fn reduce_one(x: float32x4_t) -> f32 {
|
||||
vaddvq_f32(x)
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "arm")]
|
||||
unsafe fn reduce_one(x: float32x4_t) -> f32 {
|
||||
vgetq_lane_f32(x, 0) + vgetq_lane_f32(x, 1) + vgetq_lane_f32(x, 2) + vgetq_lane_f32(x, 3)
|
||||
}
|
||||
}
|
||||
|
||||
impl Cpu<ARR> for CurrentCpu {
|
||||
type Unit = float32x4_t;
|
||||
type Array = [float32x4_t; ARR];
|
||||
|
||||
const STEP: usize = STEP;
|
||||
const EPR: usize = EPR;
|
||||
|
||||
fn n() -> usize {
|
||||
ARR
|
||||
}
|
||||
|
||||
unsafe fn zero() -> Self::Unit {
|
||||
vdupq_n_f32(0.0)
|
||||
}
|
||||
|
||||
unsafe fn from_f32(x: f32) -> Self::Unit {
|
||||
vdupq_n_f32(x)
|
||||
}
|
||||
|
||||
unsafe fn zero_array() -> Self::Array {
|
||||
[Self::zero(); ARR]
|
||||
}
|
||||
|
||||
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
|
||||
vld1q_f32(mem_addr)
|
||||
}
|
||||
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
||||
vaddq_f32(a, b)
|
||||
}
|
||||
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
|
||||
vfmaq_f32(a, b, c)
|
||||
}
|
||||
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
|
||||
vst1q_f32(mem_addr, a);
|
||||
}
|
||||
|
||||
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
|
||||
for i in 0..ARR / 2 {
|
||||
x[2 * i] = vaddq_f32(x[2 * i], x[2 * i + 1]);
|
||||
}
|
||||
for i in 0..ARR / 4 {
|
||||
x[4 * i] = vaddq_f32(x[4 * i], x[4 * i + 2]);
|
||||
}
|
||||
for i in 0..ARR / 8 {
|
||||
x[8 * i] = vaddq_f32(x[8 * i], x[8 * i + 4]);
|
||||
}
|
||||
*y = Self::reduce_one(x[0]);
|
||||
}
|
||||
}
|
64
candle-core/src/cpu/simd128.rs
Normal file
64
candle-core/src/cpu/simd128.rs
Normal file
@ -0,0 +1,64 @@
|
||||
use super::Cpu;
|
||||
use core::arch::wasm32::*;
|
||||
|
||||
pub struct CurrentCpu {}
|
||||
|
||||
const STEP: usize = 16;
|
||||
const EPR: usize = 4;
|
||||
const ARR: usize = STEP / EPR;
|
||||
|
||||
impl Cpu<ARR> for CurrentCpu {
|
||||
type Unit = v128;
|
||||
type Array = [v128; ARR];
|
||||
|
||||
const STEP: usize = STEP;
|
||||
const EPR: usize = EPR;
|
||||
|
||||
fn n() -> usize {
|
||||
ARR
|
||||
}
|
||||
|
||||
unsafe fn zero() -> Self::Unit {
|
||||
f32x4_splat(0.0)
|
||||
}
|
||||
|
||||
unsafe fn zero_array() -> Self::Array {
|
||||
[Self::zero(); ARR]
|
||||
}
|
||||
|
||||
unsafe fn from_f32(v: f32) -> Self::Unit {
|
||||
f32x4_splat(v)
|
||||
}
|
||||
|
||||
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
|
||||
v128_load(mem_addr as *mut v128)
|
||||
}
|
||||
|
||||
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
|
||||
f32x4_add(a, b)
|
||||
}
|
||||
|
||||
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
|
||||
f32x4_add(f32x4_mul(b, c), a)
|
||||
}
|
||||
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
|
||||
v128_store(mem_addr as *mut v128, a);
|
||||
}
|
||||
|
||||
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
|
||||
for i in 0..ARR / 2 {
|
||||
x[2 * i] = f32x4_add(x[2 * i], x[2 * i + 1]);
|
||||
}
|
||||
for i in 0..ARR / 4 {
|
||||
x[4 * i] = f32x4_add(x[4 * i], x[4 * i + 2]);
|
||||
}
|
||||
for i in 0..ARR / 8 {
|
||||
x[8 * i] = f32x4_add(x[8 * i], x[8 * i + 4]);
|
||||
}
|
||||
*y = f32x4_extract_lane::<0>(x[0])
|
||||
+ f32x4_extract_lane::<1>(x[0])
|
||||
+ f32x4_extract_lane::<2>(x[0])
|
||||
+ f32x4_extract_lane::<3>(x[0]);
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user