mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
2144 lines
81 KiB
Rust
2144 lines
81 KiB
Rust
use crate::backend::{BackendDevice, BackendStorage};
|
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
|
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
|
|
use half::{bf16, f16};
|
|
|
|
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
|
// intercept the oom errors to avoid panicking and provide a proper error.
|
|
#[derive(Debug, Clone)]
|
|
pub enum CpuStorage {
|
|
U8(Vec<u8>),
|
|
U32(Vec<u32>),
|
|
BF16(Vec<bf16>),
|
|
F16(Vec<f16>),
|
|
F32(Vec<f32>),
|
|
F64(Vec<f64>),
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct CpuDevice;
|
|
|
|
pub trait Map1 {
|
|
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
|
|
|
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
|
match vs {
|
|
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
|
|
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
|
|
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
|
|
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
|
|
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
|
|
CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub trait Map1Any {
|
|
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
|
|
&self,
|
|
vs: &[T],
|
|
layout: &Layout,
|
|
wrap: W,
|
|
) -> Result<CpuStorage>;
|
|
|
|
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
|
match vs {
|
|
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
|
|
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
|
|
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
|
|
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
|
|
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
|
|
CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
|
|
}
|
|
}
|
|
}
|
|
|
|
type C = CpuStorage;
|
|
pub trait Map2 {
|
|
const OP: &'static str;
|
|
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
|
|
|
fn map(
|
|
&self,
|
|
v1: &CpuStorage,
|
|
l1: &Layout,
|
|
v2: &CpuStorage,
|
|
l2: &Layout,
|
|
) -> Result<CpuStorage> {
|
|
match (v1, v2) {
|
|
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
|
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
|
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
|
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
|
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
|
lhs: v1.dtype(),
|
|
rhs: v2.dtype(),
|
|
op: Self::OP,
|
|
}
|
|
.bt()),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub trait Map2U8 {
|
|
const OP: &'static str;
|
|
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
|
|
|
fn map(
|
|
&self,
|
|
v1: &CpuStorage,
|
|
l1: &Layout,
|
|
v2: &CpuStorage,
|
|
l2: &Layout,
|
|
) -> Result<CpuStorage> {
|
|
match (v1, v2) {
|
|
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
|
lhs: v1.dtype(),
|
|
rhs: v2.dtype(),
|
|
op: Self::OP,
|
|
}
|
|
.bt()),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct Cmp(CmpOp);
|
|
impl Map2U8 for Cmp {
|
|
const OP: &'static str = "cmp";
|
|
#[inline(always)]
|
|
fn f<T: WithDType>(
|
|
&self,
|
|
lhs: &[T],
|
|
lhs_l: &Layout,
|
|
rhs: &[T],
|
|
rhs_l: &Layout,
|
|
) -> Result<Vec<u8>> {
|
|
let dst = match self.0 {
|
|
CmpOp::Eq => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x == y)),
|
|
CmpOp::Ne => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x != y)),
|
|
CmpOp::Lt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x < y)),
|
|
CmpOp::Le => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x <= y)),
|
|
CmpOp::Gt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x > y)),
|
|
CmpOp::Ge => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x >= y)),
|
|
};
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
|
|
|
|
impl<'a, I: IntDType> Map2 for WCond<'a, I> {
|
|
const OP: &'static str = "where";
|
|
#[inline(always)]
|
|
fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
|
|
let vs = match (
|
|
self.1.contiguous_offsets(),
|
|
t_l.contiguous_offsets(),
|
|
f_l.contiguous_offsets(),
|
|
) {
|
|
(Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
|
|
let pred = &self.0[o1..o2];
|
|
let t = &t[o_t1..o_t2];
|
|
let f = &f[o_f1..o_f2];
|
|
pred.iter()
|
|
.zip(t.iter().zip(f.iter()))
|
|
.map(|(p, (&t, &f))| if p.is_true() { t } else { f })
|
|
.collect::<Vec<_>>()
|
|
}
|
|
_ => self
|
|
.1
|
|
.strided_index()
|
|
.zip(t_l.strided_index().zip(f_l.strided_index()))
|
|
.map(|(i_p, (i_t, i_f))| {
|
|
if self.0[i_p].is_true() {
|
|
t[i_t]
|
|
} else {
|
|
f[i_f]
|
|
}
|
|
})
|
|
.collect::<Vec<_>>(),
|
|
};
|
|
Ok(vs)
|
|
}
|
|
}
|
|
|
|
struct ReduceIndex {
|
|
reduce_dim_index: usize,
|
|
use_min: bool,
|
|
return_index: bool,
|
|
}
|
|
|
|
impl ReduceIndex {
|
|
// The value gets replaced if f(s[current_acc], s[i]) returns true.
|
|
#[inline(always)]
|
|
fn fold_impl<T, U, F, G>(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result<Vec<U>>
|
|
where
|
|
T: Clone + Copy,
|
|
U: Clone + Copy,
|
|
F: Fn(T, T) -> bool,
|
|
G: Fn(T, usize) -> U,
|
|
{
|
|
let reduce_dim_size = src_l.dims()[self.reduce_dim_index];
|
|
let reduce_dim_stride = src_l.stride()[self.reduce_dim_index];
|
|
let dst_len = src_l.shape().elem_count() / reduce_dim_size;
|
|
let mut dst: Vec<U> = Vec::with_capacity(dst_len);
|
|
let dst_to_set = dst.spare_capacity_mut();
|
|
let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) };
|
|
match src_l.contiguous_offsets() {
|
|
Some((o1, o2)) => {
|
|
let src = &src[o1..o2];
|
|
if reduce_dim_stride == 1 {
|
|
for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
|
|
let start_src_i = start_src_i * reduce_dim_size;
|
|
let src = &src[start_src_i..start_src_i + reduce_dim_size];
|
|
let mut acc = 0;
|
|
let mut val = src[0];
|
|
for (src_i, &s) in src.iter().enumerate() {
|
|
if f(val, s) {
|
|
acc = src_i;
|
|
val = s
|
|
}
|
|
}
|
|
*dst_v = g(val, acc)
|
|
}
|
|
} else {
|
|
for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
|
|
let (p, q) = (
|
|
start_src_i / reduce_dim_stride,
|
|
start_src_i % reduce_dim_stride,
|
|
);
|
|
// start_src_i = p * reduce_dim_stride + q
|
|
let start_src_i = p * reduce_dim_stride * reduce_dim_size + q;
|
|
let src = &src[start_src_i..];
|
|
let mut acc = 0;
|
|
let mut val = src[0];
|
|
for src_i in 0..reduce_dim_size {
|
|
let s = src[src_i * reduce_dim_stride];
|
|
if f(val, s) {
|
|
acc = src_i;
|
|
val = s
|
|
}
|
|
}
|
|
*dst_v = g(val, acc)
|
|
}
|
|
}
|
|
}
|
|
None => {
|
|
let l = src_l.narrow(self.reduce_dim_index, 0, 1)?;
|
|
for (unstr_index, src_index) in l.strided_index().enumerate() {
|
|
let src = &src[src_index..];
|
|
let mut acc = 0;
|
|
let mut val = src[0];
|
|
for src_i in 0..reduce_dim_size {
|
|
let s = src[src_i * reduce_dim_stride];
|
|
if f(val, s) {
|
|
acc = src_i;
|
|
val = s
|
|
}
|
|
}
|
|
dst_to_set[unstr_index] = g(val, acc)
|
|
}
|
|
}
|
|
}
|
|
unsafe { dst.set_len(dst_len) };
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
impl Map1Any for ReduceIndex {
|
|
#[inline(always)]
|
|
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
|
|
&self,
|
|
src: &[T],
|
|
src_l: &Layout,
|
|
wrap: W,
|
|
) -> Result<CpuStorage> {
|
|
if src_l.shape().elem_count() == 0 {
|
|
Err(Error::EmptyTensor { op: "reduce" }.bt())?
|
|
}
|
|
let dst = match (self.return_index, self.use_min) {
|
|
(false, true) => wrap(self.fold_impl(src, src_l, |x, y| x > y, |v, _i| v)?),
|
|
(false, false) => wrap(self.fold_impl(src, src_l, |x, y| x < y, |v, _i| v)?),
|
|
(true, true) => {
|
|
CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x > y, |_v, i| i as u32)?)
|
|
}
|
|
(true, false) => {
|
|
CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x < y, |_v, i| i as u32)?)
|
|
}
|
|
};
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
struct Reduce<'a> {
|
|
dst_shape: &'a Shape,
|
|
reduce_dims: &'a [usize],
|
|
reduce_dims_and_stride: Vec<(usize, usize)>,
|
|
}
|
|
|
|
impl<'a> Reduce<'a> {
|
|
#[inline(always)]
|
|
fn fold_impl<T, F>(&self, src: &[T], src_l: &Layout, start_elt: T, f: F) -> Result<Vec<T>>
|
|
where
|
|
T: Clone + Copy,
|
|
F: Fn(T, T) -> T,
|
|
{
|
|
let mut dst = vec![start_elt; self.dst_shape.elem_count()];
|
|
match src_l.contiguous_offsets() {
|
|
Some((o1, o2)) => {
|
|
let src = &src[o1..o2];
|
|
// Handle the case where we reduce over the last dimensions separately as it is
|
|
// fairly common and easy to optimize. This rely on the layout being contiguous!
|
|
// reduce_dims is sorted, check if it is ranging from a to n-1.
|
|
let reduce_over_last_dims = self
|
|
.reduce_dims
|
|
.iter()
|
|
.rev()
|
|
.enumerate()
|
|
.all(|(i, &v)| v == src_l.shape().rank() - 1 - i);
|
|
if reduce_over_last_dims {
|
|
let reduce_sz = self
|
|
.reduce_dims_and_stride
|
|
.iter()
|
|
.map(|(u, _)| u)
|
|
.product::<usize>();
|
|
let mut src_i = 0;
|
|
for dst_v in dst.iter_mut() {
|
|
for &s in src[src_i..src_i + reduce_sz].iter() {
|
|
*dst_v = f(*dst_v, s)
|
|
}
|
|
src_i += reduce_sz
|
|
}
|
|
return Ok(dst);
|
|
};
|
|
for (unstr_index, &src) in src.iter().enumerate() {
|
|
let mut dst_index = unstr_index;
|
|
// Set the reduce_dims indexes to 0.
|
|
for &(dim, stride) in self.reduce_dims_and_stride.iter() {
|
|
// The compiler is able to optimize the following in a single divmod op.
|
|
let (pre, post) = (dst_index / stride, dst_index % stride);
|
|
dst_index = (pre / dim) * stride + post;
|
|
}
|
|
dst[dst_index] = f(dst[dst_index], src);
|
|
}
|
|
}
|
|
None => {
|
|
for (unstr_index, src_index) in src_l.strided_index().enumerate() {
|
|
let mut dst_index = unstr_index;
|
|
// Set the reduce_dims indexes to 0.
|
|
for &(dim, stride) in self.reduce_dims_and_stride.iter() {
|
|
// The compiler is able to optimize the following in a single divmod op.
|
|
let (pre, post) = (dst_index / stride, dst_index % stride);
|
|
dst_index = (pre / dim) * stride + post;
|
|
}
|
|
dst[dst_index] = f(dst[dst_index], src[src_index]);
|
|
}
|
|
}
|
|
}
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
impl<'a> Map1 for Reduce<'a> {
|
|
#[inline(always)]
|
|
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
|
self.fold_impl(src, src_l, T::zero(), |x, y| x + y)
|
|
}
|
|
}
|
|
|
|
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
|
vs: &[T],
|
|
layout: &Layout,
|
|
mut f: F,
|
|
) -> Vec<U> {
|
|
match layout.strided_blocks() {
|
|
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
|
[start_offset..start_offset + len]
|
|
.iter()
|
|
.map(|&v| f(v))
|
|
.collect(),
|
|
crate::StridedBlocks::MultipleBlocks {
|
|
block_start_index,
|
|
block_len,
|
|
} => {
|
|
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
|
// Specialize the case where block_len is one to avoid the second loop.
|
|
if block_len == 1 {
|
|
for index in block_start_index {
|
|
let v = unsafe { vs.get_unchecked(index) };
|
|
result.push(f(*v))
|
|
}
|
|
} else {
|
|
for index in block_start_index {
|
|
for offset in 0..block_len {
|
|
let v = unsafe { vs.get_unchecked(index + offset) };
|
|
result.push(f(*v))
|
|
}
|
|
}
|
|
}
|
|
result
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
|
vs: &[T],
|
|
layout: &Layout,
|
|
mut f: F,
|
|
mut f_vec: FV,
|
|
) -> Vec<U> {
|
|
match layout.strided_blocks() {
|
|
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
|
let mut ys: Vec<U> = Vec::with_capacity(len);
|
|
let ys_to_set = ys.spare_capacity_mut();
|
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
|
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
|
// SAFETY: values are all set by f_vec.
|
|
unsafe { ys.set_len(len) };
|
|
ys
|
|
}
|
|
crate::StridedBlocks::MultipleBlocks {
|
|
block_start_index,
|
|
block_len,
|
|
} => {
|
|
let el_count = layout.shape().elem_count();
|
|
// Specialize the case where block_len is one to avoid the second loop.
|
|
if block_len == 1 {
|
|
let mut result = Vec::with_capacity(el_count);
|
|
for index in block_start_index {
|
|
let v = unsafe { vs.get_unchecked(index) };
|
|
result.push(f(*v))
|
|
}
|
|
result
|
|
} else {
|
|
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
|
let ys_to_set = ys.spare_capacity_mut();
|
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
|
let mut dst_index = 0;
|
|
for src_index in block_start_index {
|
|
let vs = &vs[src_index..src_index + block_len];
|
|
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
|
f_vec(vs, ys);
|
|
dst_index += block_len;
|
|
}
|
|
// SAFETY: values are all set by f_vec.
|
|
unsafe { ys.set_len(el_count) };
|
|
ys
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// This function maps over two strided index sequences.
|
|
fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
|
lhs_l: &Layout,
|
|
rhs_l: &Layout,
|
|
lhs: &[T],
|
|
rhs: &[T],
|
|
mut f: F,
|
|
) -> Vec<U> {
|
|
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
|
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
|
.iter()
|
|
.zip(rhs[o_r1..o_r2].iter())
|
|
.map(|(&l, &r)| f(l, r))
|
|
.collect(),
|
|
(Some((o_l1, o_l2)), None) => {
|
|
// TODO: Maybe we want to avoid going through the layout twice.
|
|
match rhs_l.offsets_b() {
|
|
Some(ob) => {
|
|
let mut i_in_block = 0;
|
|
let mut i_right_broadcast = 0;
|
|
lhs[o_l1..o_l2]
|
|
.iter()
|
|
.map(|&l| {
|
|
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
|
i_right_broadcast += 1;
|
|
if i_right_broadcast >= ob.right_broadcast {
|
|
i_in_block += 1;
|
|
i_right_broadcast = 0;
|
|
}
|
|
if i_in_block >= ob.len {
|
|
i_in_block = 0
|
|
}
|
|
f(l, *r)
|
|
})
|
|
.collect()
|
|
}
|
|
None => lhs_l
|
|
.strided_index()
|
|
.zip(rhs_l.strided_index())
|
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
.collect(),
|
|
}
|
|
}
|
|
(None, Some((o_r1, o_r2))) => {
|
|
// TODO: Maybe we want to avoid going through the layout twice.
|
|
match lhs_l.offsets_b() {
|
|
Some(ob) => {
|
|
let mut i_in_block = 0;
|
|
let mut i_right_broadcast = 0;
|
|
rhs[o_r1..o_r2]
|
|
.iter()
|
|
.map(|&r| {
|
|
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
|
i_right_broadcast += 1;
|
|
if i_right_broadcast >= ob.right_broadcast {
|
|
i_in_block += 1;
|
|
i_right_broadcast = 0;
|
|
}
|
|
if i_in_block >= ob.len {
|
|
i_in_block = 0
|
|
}
|
|
f(*l, r)
|
|
})
|
|
.collect()
|
|
}
|
|
None => lhs_l
|
|
.strided_index()
|
|
.zip(rhs_l.strided_index())
|
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
.collect(),
|
|
}
|
|
}
|
|
_ => lhs_l
|
|
.strided_index()
|
|
.zip(rhs_l.strided_index())
|
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
.collect(),
|
|
}
|
|
}
|
|
|
|
// Similar to binary_map but with vectorized variants.
|
|
fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
|
lhs_l: &Layout,
|
|
rhs_l: &Layout,
|
|
lhs: &[T],
|
|
rhs: &[T],
|
|
mut f: F,
|
|
mut f_vec: FV,
|
|
) -> Vec<T> {
|
|
let el_count = lhs_l.shape().elem_count();
|
|
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
|
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
let ys_to_set = ys.spare_capacity_mut();
|
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
|
// SAFETY: values are all set by f_vec.
|
|
unsafe { ys.set_len(el_count) };
|
|
ys
|
|
}
|
|
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
|
Some(ob) if ob.right_broadcast == 1 => {
|
|
let rhs = &rhs[ob.start..ob.start + ob.len];
|
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
let ys_to_set = ys.spare_capacity_mut();
|
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
let mut dst_i = 0;
|
|
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
|
f_vec(
|
|
&lhs[src_i..src_i + ob.len],
|
|
rhs,
|
|
&mut ys_to_set[dst_i..dst_i + ob.len],
|
|
);
|
|
dst_i += ob.len;
|
|
}
|
|
// SAFETY: values are all set by f_vec.
|
|
unsafe { ys.set_len(el_count) };
|
|
ys
|
|
}
|
|
Some(ob) => {
|
|
let rhs = &rhs[ob.start..ob.start + ob.len];
|
|
let mut ys = lhs[o_l1..o_l2].to_vec();
|
|
for idx_l in 0..ob.left_broadcast {
|
|
let start = idx_l * ob.len * ob.right_broadcast;
|
|
for (i, &r) in rhs.iter().enumerate() {
|
|
let start = start + i * ob.right_broadcast;
|
|
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
|
*v = f(*v, r)
|
|
}
|
|
}
|
|
}
|
|
ys
|
|
}
|
|
None => lhs_l
|
|
.strided_index()
|
|
.zip(rhs_l.strided_index())
|
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
.collect(),
|
|
},
|
|
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
|
Some(ob) if ob.right_broadcast == 1 => {
|
|
let lhs = &lhs[ob.start..ob.start + ob.len];
|
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
let ys_to_set = ys.spare_capacity_mut();
|
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
let mut dst_i = 0;
|
|
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
|
f_vec(
|
|
lhs,
|
|
&rhs[src_i..src_i + ob.len],
|
|
&mut ys_to_set[dst_i..dst_i + ob.len],
|
|
);
|
|
dst_i += ob.len;
|
|
}
|
|
// SAFETY: values are all set by f_vec.
|
|
unsafe { ys.set_len(el_count) };
|
|
ys
|
|
}
|
|
Some(ob) => {
|
|
let lhs = &lhs[ob.start..ob.start + ob.len];
|
|
let mut ys = rhs[o_r1..o_r2].to_vec();
|
|
for idx_l in 0..ob.left_broadcast {
|
|
let start = idx_l * ob.len * ob.right_broadcast;
|
|
for (i, &l) in lhs.iter().enumerate() {
|
|
let start = start + i * ob.right_broadcast;
|
|
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
|
*v = f(l, *v)
|
|
}
|
|
}
|
|
}
|
|
ys
|
|
}
|
|
None => lhs_l
|
|
.strided_index()
|
|
.zip(rhs_l.strided_index())
|
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
.collect(),
|
|
},
|
|
_ => lhs_l
|
|
.strided_index()
|
|
.zip(rhs_l.strided_index())
|
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
.collect(),
|
|
}
|
|
}
|
|
|
|
struct Affine(f64, f64);
|
|
|
|
impl Map1 for Affine {
|
|
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
|
|
let mul = T::from_f64(self.0);
|
|
let add = T::from_f64(self.1);
|
|
Ok(unary_map(vs, layout, |v| v * mul + add))
|
|
}
|
|
}
|
|
|
|
struct AvgPool2D((usize, usize), (usize, usize));
|
|
|
|
impl Map1 for AvgPool2D {
|
|
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
|
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
|
|
let (k_h, k_w) = self.0;
|
|
let (s_h, s_w) = self.1;
|
|
let (b_sz, c, h, w) = layout.shape().dims4()?;
|
|
let stride = layout.stride();
|
|
let (stride_h, stride_w) = (stride[2], stride[3]);
|
|
let h_out = (h - k_h) / s_h + 1;
|
|
let w_out = (w - k_w) / s_w + 1;
|
|
let src_index = layout.start_offset();
|
|
let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
|
|
let scale = 1f64 / (k_h * k_w) as f64;
|
|
let scale = T::from_f64(scale);
|
|
for b_idx in 0..b_sz {
|
|
let dst = &mut dst[b_idx * c * h_out * w_out..];
|
|
let src_index = src_index + b_idx * stride[0];
|
|
for c_idx in 0..c {
|
|
let dst = &mut dst[c_idx * h_out * w_out..];
|
|
let src_index = src_index + c_idx * stride[1];
|
|
for h_idx in 0..h_out {
|
|
for w_idx in 0..w_out {
|
|
let mut sum = T::zero();
|
|
for m in 0..k_h {
|
|
for n in 0..k_w {
|
|
let m = s_h * h_idx + m;
|
|
let n = s_w * w_idx + n;
|
|
sum += src[src_index + m * stride_h + n * stride_w]
|
|
}
|
|
}
|
|
dst[h_idx * w_out + w_idx] = sum * scale;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
struct MaxPool2D((usize, usize), (usize, usize));
|
|
|
|
impl Map1 for MaxPool2D {
|
|
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
|
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
|
|
let (k_h, k_w) = self.0;
|
|
let (s_h, s_w) = self.1;
|
|
let (b_sz, c, h, w) = layout.shape().dims4()?;
|
|
let stride = layout.stride();
|
|
let (stride_h, stride_w) = (stride[2], stride[3]);
|
|
let h_out = (h - k_h) / s_h + 1;
|
|
let w_out = (w - k_w) / s_w + 1;
|
|
let src_index = layout.start_offset();
|
|
let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
|
|
for b_idx in 0..b_sz {
|
|
let dst = &mut dst[b_idx * c * h_out * w_out..];
|
|
let src_index = src_index + b_idx * stride[0];
|
|
for c_idx in 0..c {
|
|
let dst = &mut dst[c_idx * h_out * w_out..];
|
|
let src_index = src_index + c_idx * stride[1];
|
|
for h_idx in 0..h_out {
|
|
for w_idx in 0..w_out {
|
|
let mut largest =
|
|
src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
|
|
for m in 0..k_h {
|
|
for n in 0..k_w {
|
|
let m = s_h * h_idx + m;
|
|
let n = s_w * w_idx + n;
|
|
if largest < src[src_index + m * stride_h + n * stride_w] {
|
|
largest = src[src_index + m * stride_h + n * stride_w]
|
|
}
|
|
}
|
|
}
|
|
dst[h_idx * w_out + w_idx] = largest;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
struct UpsampleNearest2D(usize, usize);
|
|
|
|
impl Map1 for UpsampleNearest2D {
|
|
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
|
// TODO: Specialized implementation for the case 2*h, 2*w?
|
|
let (dst_h, dst_w) = (self.0, self.1);
|
|
let (b_sz, c, src_h, src_w) = layout.shape().dims4()?;
|
|
let stride = layout.stride();
|
|
let (stride_h, stride_w) = (stride[2], stride[3]);
|
|
let src_index = layout.start_offset();
|
|
let scale_h = src_h as f64 / dst_h as f64;
|
|
let scale_w = src_w as f64 / dst_w as f64;
|
|
let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
|
|
let src_h_idxs = (0..src_h)
|
|
.map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
|
|
.collect::<Vec<_>>();
|
|
let src_w_idxs = (0..src_w)
|
|
.map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
|
|
.collect::<Vec<_>>();
|
|
for b_idx in 0..b_sz {
|
|
let dst = &mut dst[b_idx * c * dst_h * dst_w..];
|
|
let src_index = src_index + b_idx * stride[0];
|
|
for c_idx in 0..c {
|
|
let dst = &mut dst[c_idx * dst_h * dst_w..];
|
|
let src_index = src_index + c_idx * stride[1];
|
|
for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() {
|
|
for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() {
|
|
let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w;
|
|
dst[h_idx * dst_w + w_idx] = src[src_index]
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
struct Gather<'a, I: IntDType> {
|
|
ids: &'a [I],
|
|
ids_l: &'a Layout,
|
|
dim: usize,
|
|
}
|
|
|
|
impl<'a, I: IntDType> Map1 for Gather<'a, I> {
|
|
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
|
let ids = match self.ids_l.contiguous_offsets() {
|
|
Some((a, b)) => &self.ids[a..b],
|
|
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
|
};
|
|
let src = match src_l.contiguous_offsets() {
|
|
Some((a, b)) => &src[a..b],
|
|
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
|
};
|
|
let dim = self.dim;
|
|
let ids_dims = self.ids_l.dims();
|
|
let src_dims = src_l.dims();
|
|
let dst_len: usize = ids_dims.iter().product();
|
|
let dst_left_len: usize = ids_dims[..dim].iter().product();
|
|
let dst_dim_len = ids_dims[dim];
|
|
let dst_right_len: usize = ids_dims[dim + 1..].iter().product();
|
|
|
|
let src_dim_len = src_dims[dim];
|
|
let src_right_len: usize = src_dims[dim + 1..].iter().product();
|
|
|
|
let mut dst = vec![T::zero(); dst_len];
|
|
for left_i in 0..dst_left_len {
|
|
let start_src_idx = left_i * src_right_len * src_dim_len;
|
|
let start_dst_idx = left_i * dst_right_len * dst_dim_len;
|
|
for i in 0..dst_dim_len {
|
|
let start_dst_idx = start_dst_idx + i * dst_right_len;
|
|
for right_i in 0..dst_right_len {
|
|
let dst_idx = start_dst_idx + right_i;
|
|
let index = ids[dst_idx].as_usize();
|
|
if index >= src_dim_len {
|
|
Err(Error::InvalidIndex {
|
|
index,
|
|
size: src_dim_len,
|
|
op: "gather",
|
|
}
|
|
.bt())?
|
|
}
|
|
let src_idx = start_src_idx + index * src_right_len + right_i;
|
|
dst[dst_idx] = src[src_idx]
|
|
}
|
|
}
|
|
}
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
struct IndexSelect<'a, T: IntDType> {
|
|
ids: &'a [T],
|
|
ids_l: &'a Layout,
|
|
dim: usize,
|
|
}
|
|
|
|
impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
|
|
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
|
let src = match layout.contiguous_offsets() {
|
|
Some((a, b)) => &src[a..b],
|
|
None => Err(Error::RequiresContiguous { op: "index-select" })?,
|
|
};
|
|
let dim = self.dim;
|
|
let n_ids = match self.ids_l.dims() {
|
|
[n_ids] => *n_ids,
|
|
d => Err(Error::UnexpectedNumberOfDims {
|
|
expected: 1,
|
|
got: d.len(),
|
|
shape: self.ids_l.shape().clone(),
|
|
}
|
|
.bt())?,
|
|
};
|
|
let stride_ids = self.ids_l.stride()[0];
|
|
let mut dst_dims = layout.dims().to_vec();
|
|
let src_dim = dst_dims[dim];
|
|
dst_dims[dim] = n_ids;
|
|
let dst_len: usize = dst_dims.iter().product();
|
|
let left_len: usize = dst_dims[..dim].iter().product();
|
|
let right_len: usize = dst_dims[dim + 1..].iter().product();
|
|
let mut dst = vec![T::zero(); dst_len];
|
|
for left_i in 0..left_len {
|
|
let start_src_idx = left_i * right_len * src_dim;
|
|
let start_dst_idx = left_i * right_len * n_ids;
|
|
for i in 0..n_ids {
|
|
let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
|
|
if index >= src_dim {
|
|
Err(Error::InvalidIndex {
|
|
index,
|
|
size: src_dim,
|
|
op: "index-select",
|
|
}
|
|
.bt())?
|
|
}
|
|
let start_src_idx = start_src_idx + index * right_len;
|
|
let start_dst_idx = start_dst_idx + i * right_len;
|
|
dst[start_dst_idx..start_dst_idx + right_len]
|
|
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
|
|
}
|
|
}
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
struct ScatterAdd<'a, I: IntDType> {
|
|
ids: &'a [I],
|
|
ids_l: &'a Layout,
|
|
dim: usize,
|
|
}
|
|
|
|
impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
|
const OP: &'static str = "scatter-add";
|
|
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
|
let dst_len = l1.shape().elem_count();
|
|
let mut dst = vec![T::zero(); dst_len];
|
|
copy_strided_src_(v1, &mut dst, 0, l1);
|
|
let src = match src_l.contiguous_offsets() {
|
|
None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
|
|
Some((o1, o2)) => &src[o1..o2],
|
|
};
|
|
|
|
let dim = self.dim;
|
|
let ids_dims = self.ids_l.dims();
|
|
let dst_dims = l1.dims();
|
|
let dst_dim_len = dst_dims[dim];
|
|
let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
|
|
|
|
let ids_left_len: usize = ids_dims[..dim].iter().product();
|
|
let ids_dim_len = ids_dims[dim];
|
|
let ids_right_len: usize = ids_dims[dim + 1..].iter().product();
|
|
|
|
let ids = match self.ids_l.contiguous_offsets() {
|
|
Some((a, b)) => &self.ids[a..b],
|
|
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
|
};
|
|
for left_i in 0..ids_left_len {
|
|
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
|
|
let start_dst_idx = left_i * dst_right_len * dst_dim_len;
|
|
for i in 0..ids_dim_len {
|
|
let start_ids_idx = start_ids_idx + i * ids_right_len;
|
|
for right_i in 0..dst_right_len {
|
|
let ids_idx = start_ids_idx + right_i;
|
|
let index = ids[ids_idx].as_usize();
|
|
if index >= dst_dim_len {
|
|
Err(Error::InvalidIndex {
|
|
index,
|
|
size: dst_dim_len,
|
|
op: "gather",
|
|
}
|
|
.bt())?
|
|
}
|
|
let dst_idx = start_dst_idx + index * dst_right_len + right_i;
|
|
dst[dst_idx] += src[ids_idx]
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
struct IndexAdd<'a, I: IntDType> {
|
|
ids: &'a [I],
|
|
dim: usize,
|
|
}
|
|
|
|
impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
|
const OP: &'static str = "index-add";
|
|
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
|
|
// v1, l1 -> self
|
|
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
|
let dst_len = l1.shape().elem_count();
|
|
let mut dst = vec![T::zero(); dst_len];
|
|
copy_strided_src_(v1, &mut dst, 0, l1);
|
|
let src = match src_l.contiguous_offsets() {
|
|
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
|
Some((o1, o2)) => &src[o1..o2],
|
|
};
|
|
let dim = self.dim;
|
|
let max_idx = l1.dims()[dim];
|
|
let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
|
|
let src_dim_sz = src_l.dims()[dim];
|
|
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
|
|
if dim == 0 {
|
|
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
|
let dst_idx = dst_idx.as_usize();
|
|
if dst_idx >= max_idx {
|
|
Err(Error::InvalidIndex {
|
|
index: dst_idx,
|
|
op: "index-add",
|
|
size: max_idx,
|
|
})?
|
|
}
|
|
let src_idx = src_idx * post_dim;
|
|
let dst_idx = dst_idx * post_dim;
|
|
let src = &src[src_idx..src_idx + post_dim];
|
|
let dst = &mut dst[dst_idx..dst_idx + post_dim];
|
|
for (d, &s) in dst.iter_mut().zip(src.iter()) {
|
|
*d += s
|
|
}
|
|
}
|
|
} else {
|
|
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
|
let dst_idx = dst_idx.as_usize();
|
|
if dst_idx >= max_idx {
|
|
Err(Error::InvalidIndex {
|
|
index: dst_idx,
|
|
op: "index-add",
|
|
size: max_idx,
|
|
})?
|
|
}
|
|
for pre_i in 0..pre_dim {
|
|
let pre_src_i = (pre_i * src_dim_sz + src_idx) * post_dim;
|
|
let pre_dst_i = (pre_i * max_idx + dst_idx) * post_dim;
|
|
let src = &src[pre_src_i..pre_src_i + post_dim];
|
|
let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim];
|
|
for (d, &s) in dst.iter_mut().zip(src.iter()) {
|
|
*d += s
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
|
|
match src_l.strided_blocks() {
|
|
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
|
let to_copy = (dst.len() - dst_offset).min(len);
|
|
dst[dst_offset..dst_offset + to_copy]
|
|
.copy_from_slice(&src[start_offset..start_offset + to_copy])
|
|
}
|
|
crate::StridedBlocks::MultipleBlocks {
|
|
block_start_index,
|
|
block_len: 1,
|
|
} => {
|
|
for (dst_index, src_index) in block_start_index.enumerate() {
|
|
let dst_index = dst_index + dst_offset;
|
|
if dst_index >= dst.len() {
|
|
break;
|
|
}
|
|
dst[dst_index] = src[src_index]
|
|
}
|
|
}
|
|
crate::StridedBlocks::MultipleBlocks {
|
|
block_start_index,
|
|
block_len,
|
|
} => {
|
|
let mut dst_index = dst_offset;
|
|
for src_index in block_start_index {
|
|
let next_dst_index = dst_index + block_len;
|
|
if dst_index >= dst.len() {
|
|
break;
|
|
}
|
|
let to_copy = usize::min(block_len, dst.len() - dst_index);
|
|
dst[dst_index..dst_index + to_copy]
|
|
.copy_from_slice(&src[src_index..src_index + to_copy]);
|
|
dst_index = next_dst_index
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
|
|
|
|
impl<'a> Map2 for Conv1D<'a> {
|
|
const OP: &'static str = "conv1d";
|
|
fn f<T: 'static + num_traits::NumAssign + Copy>(
|
|
&self,
|
|
inp: &[T],
|
|
inp_l: &Layout,
|
|
k: &[T],
|
|
k_l: &Layout,
|
|
) -> Result<Vec<T>> {
|
|
// TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc).
|
|
let p = self.0;
|
|
let inp = &inp[inp_l.start_offset()..];
|
|
let k = &k[k_l.start_offset()..];
|
|
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
|
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
|
let l_out = p.l_out();
|
|
let dst_elems = p.c_out * l_out * p.b_size;
|
|
let mut dst = vec![T::zero(); dst_elems];
|
|
// The output shape is [b_size, c_out, l_out]
|
|
for b_idx in 0..p.b_size {
|
|
let inp_idx = b_idx * inp_s0;
|
|
let dst_idx = b_idx * p.c_out * l_out;
|
|
for dst_c_idx in 0..p.c_out {
|
|
let dst_idx = dst_idx + dst_c_idx * l_out;
|
|
for dst_l in 0..l_out {
|
|
let dst_idx = dst_idx + dst_l;
|
|
let mut d = T::zero();
|
|
for offset in 0..p.k_size {
|
|
let src_l = (p.stride * dst_l + offset)
|
|
.saturating_sub(p.padding)
|
|
.min(p.l_in - 1);
|
|
for src_c_idx in 0..p.c_in {
|
|
let inp_idx = inp_idx + src_c_idx * inp_s1 + src_l * inp_s2;
|
|
let k_idx = dst_c_idx * k_s0 + src_c_idx * k_s1 + offset * k_s2;
|
|
d += inp[inp_idx] * k[k_idx]
|
|
}
|
|
}
|
|
dst[dst_idx] = d
|
|
}
|
|
}
|
|
}
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
|
|
|
impl<'a> Map2 for Conv2D<'a> {
|
|
const OP: &'static str = "conv2d";
|
|
fn f<T: 'static + num_traits::NumAssign + Copy + std::fmt::Display>(
|
|
&self,
|
|
inp: &[T],
|
|
inp_l: &Layout,
|
|
k: &[T],
|
|
k_l: &Layout,
|
|
) -> Result<Vec<T>> {
|
|
let p = self.0;
|
|
let inp = &inp[inp_l.start_offset()..];
|
|
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
|
|
let k = &k[k_l.start_offset()..];
|
|
let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
|
|
let (out_h, out_w) = (p.out_h(), p.out_w());
|
|
|
|
let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
|
|
for b_idx in 0..p.b_size {
|
|
let inp_idx = b_idx * inp_s0;
|
|
let dst_idx = b_idx * p.c_out * out_h * out_w;
|
|
for dst_c_idx in 0..p.c_out {
|
|
let dst_idx = dst_idx + dst_c_idx * out_h * out_w;
|
|
for dst_h in 0..out_h {
|
|
let dst_idx = dst_idx + dst_h * out_w;
|
|
for dst_w in 0..out_w {
|
|
let dst_idx = dst_idx + dst_w;
|
|
let mut d = T::zero();
|
|
for offset_h in 0..p.k_h {
|
|
let src_h = (p.stride * dst_h + offset_h)
|
|
.saturating_sub(p.padding)
|
|
.min(p.i_h - 1);
|
|
for offset_w in 0..p.k_w {
|
|
let src_w = (p.stride * dst_w + offset_w)
|
|
.saturating_sub(p.padding)
|
|
.min(p.i_w - 1);
|
|
for src_c_idx in 0..p.c_in {
|
|
let inp_idx = inp_idx
|
|
+ src_c_idx * inp_s1
|
|
+ src_h * inp_s2
|
|
+ src_w * inp_s3;
|
|
let k_idx = dst_c_idx * k_s0
|
|
+ src_c_idx * k_s1
|
|
+ offset_h * k_s2
|
|
+ offset_w * k_s3;
|
|
d += inp[inp_idx] * k[k_idx]
|
|
}
|
|
}
|
|
}
|
|
dst[dst_idx] = d
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
struct MatMul((usize, usize, usize, usize));
|
|
|
|
impl MatMul {
|
|
fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error {
|
|
Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding {
|
|
lhs_l: lhs_l.clone(),
|
|
rhs_l: rhs_l.clone(),
|
|
bmnk: self.0,
|
|
msg,
|
|
}))
|
|
.bt()
|
|
}
|
|
}
|
|
|
|
impl Map2 for MatMul {
|
|
const OP: &'static str = "mat_mul";
|
|
|
|
#[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
|
|
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
|
|
&self,
|
|
lhs: &[T],
|
|
lhs_l: &Layout,
|
|
rhs: &[T],
|
|
rhs_l: &Layout,
|
|
) -> Result<Vec<T>> {
|
|
use gemm::{gemm, Parallelism};
|
|
let (b, m, n, k) = self.0;
|
|
let lhs = &lhs[lhs_l.start_offset()..];
|
|
let rhs = &rhs[rhs_l.start_offset()..];
|
|
|
|
let lhs_stride = lhs_l.stride();
|
|
let rhs_stride = rhs_l.stride();
|
|
let rank = lhs_stride.len();
|
|
let lhs_cs = lhs_stride[rank - 1];
|
|
let lhs_rs = lhs_stride[rank - 2];
|
|
|
|
let rhs_cs = rhs_stride[rank - 1];
|
|
let rhs_rs = rhs_stride[rank - 2];
|
|
|
|
let a_skip: usize = match lhs_stride[..rank - 2] {
|
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
|
[stride] => stride,
|
|
[] => m * k,
|
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
|
};
|
|
let b_skip: usize = match rhs_stride[..rank - 2] {
|
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
|
[stride] => stride,
|
|
[] => n * k,
|
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
|
};
|
|
let c_skip: usize = m * n;
|
|
|
|
let dst_shape: Shape = (m, n).into();
|
|
let dst_strides = dst_shape.stride_contiguous();
|
|
let dst_rs = dst_strides[0];
|
|
let dst_cs = dst_strides[1];
|
|
|
|
let mut dst = vec![T::zero(); b * m * n];
|
|
let num_threads = crate::utils::get_num_threads();
|
|
let parallelism = if num_threads > 1 {
|
|
Parallelism::Rayon(num_threads)
|
|
} else {
|
|
Parallelism::None
|
|
};
|
|
for step in 0..b {
|
|
let lhs_p = &lhs[step * a_skip..];
|
|
let rhs_p = &rhs[step * b_skip..];
|
|
let dst_p = &mut dst[step * c_skip..];
|
|
unsafe {
|
|
gemm(
|
|
/* m: usize = */ m,
|
|
/* n: usize = */ n,
|
|
/* k: usize = */ k,
|
|
/* dst: *mut T = */ dst_p.as_mut_ptr(),
|
|
/* dst_cs: isize = */ dst_cs as isize,
|
|
/* dst_rs: isize = */ dst_rs as isize,
|
|
/* read_dst: bool = */ false,
|
|
/* lhs: *const T = */ lhs_p.as_ptr(),
|
|
/* lhs_cs: isize = */ lhs_cs as isize,
|
|
/* lhs_rs: isize = */ lhs_rs as isize,
|
|
/* rhs: *const T = */ rhs_p.as_ptr(),
|
|
/* rhs_cs: isize = */ rhs_cs as isize,
|
|
/* rhs_rs: isize = */ rhs_rs as isize,
|
|
/* alpha: T = */ T::zero(),
|
|
/* beta: T = */ T::one(),
|
|
/* conj_dst: bool = */ false,
|
|
/* conj_lhs: bool = */ false,
|
|
/* conj_rhs: bool = */ false,
|
|
parallelism,
|
|
)
|
|
}
|
|
}
|
|
Ok(dst)
|
|
}
|
|
|
|
#[cfg(feature = "accelerate")]
|
|
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
|
|
&self,
|
|
lhs: &[T],
|
|
lhs_l: &Layout,
|
|
rhs: &[T],
|
|
rhs_l: &Layout,
|
|
) -> Result<Vec<T>> {
|
|
let (b, m, n, k) = self.0;
|
|
let lhs = &lhs[lhs_l.start_offset()..];
|
|
let rhs = &rhs[rhs_l.start_offset()..];
|
|
|
|
let lhs_stride = lhs_l.stride();
|
|
let rhs_stride = rhs_l.stride();
|
|
let rank = lhs_stride.len();
|
|
|
|
let a_skip: usize = match lhs_stride[..rank - 2] {
|
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
|
[stride] => stride,
|
|
[] => m * k,
|
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
|
};
|
|
let b_skip: usize = match rhs_stride[..rank - 2] {
|
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
|
[stride] => stride,
|
|
[] => n * k,
|
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
|
};
|
|
let c_skip: usize = m * n;
|
|
|
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
|
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
|
|
|
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
|
(n as i32, b'N')
|
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
|
(k as i32, b'T')
|
|
} else {
|
|
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
|
};
|
|
// The b tensor has dims batching, m, k (lhs)
|
|
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
|
(k as i32, b'N')
|
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
|
(m as i32, b'T')
|
|
} else {
|
|
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
|
|
};
|
|
|
|
let mut dst = vec![T::zero(); b * m * n];
|
|
match T::DTYPE {
|
|
DType::F16 => {
|
|
crate::bail!("the accelerate backend does not support f16 matmul")
|
|
}
|
|
DType::F32 => {
|
|
for step in 0..b {
|
|
let lhs_p = &lhs[step * a_skip..];
|
|
let rhs_p = &rhs[step * b_skip..];
|
|
let dst_p = &mut dst[step * c_skip..];
|
|
unsafe {
|
|
let a = rhs_p.as_ptr() as *const f32;
|
|
let b = lhs_p.as_ptr() as *const f32;
|
|
let c = dst_p.as_mut_ptr() as *mut f32;
|
|
let a = std::slice::from_raw_parts(a, a_skip);
|
|
let b = std::slice::from_raw_parts(b, b_skip);
|
|
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
|
crate::accelerate::sgemm(
|
|
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
|
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
|
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
|
|
/* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
|
|
)
|
|
}
|
|
}
|
|
}
|
|
DType::F64 => {
|
|
for step in 0..b {
|
|
let lhs_p = &lhs[step * a_skip..];
|
|
let rhs_p = &rhs[step * b_skip..];
|
|
let dst_p = &mut dst[step * c_skip..];
|
|
unsafe {
|
|
let a = rhs_p.as_ptr() as *const f64;
|
|
let b = lhs_p.as_ptr() as *const f64;
|
|
let c = dst_p.as_mut_ptr() as *mut f64;
|
|
let a = std::slice::from_raw_parts(a, a_skip);
|
|
let b = std::slice::from_raw_parts(b, b_skip);
|
|
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
|
crate::accelerate::dgemm(
|
|
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
|
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
|
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
|
|
/* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
|
|
)
|
|
}
|
|
}
|
|
}
|
|
dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
|
|
}
|
|
Ok(dst)
|
|
}
|
|
|
|
#[cfg(feature = "mkl")]
|
|
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
|
|
&self,
|
|
lhs: &[T],
|
|
lhs_l: &Layout,
|
|
rhs: &[T],
|
|
rhs_l: &Layout,
|
|
) -> Result<Vec<T>> {
|
|
let (b, m, n, k) = self.0;
|
|
let lhs = &lhs[lhs_l.start_offset()..];
|
|
let rhs = &rhs[rhs_l.start_offset()..];
|
|
|
|
let lhs_stride = lhs_l.stride();
|
|
let rhs_stride = rhs_l.stride();
|
|
let rank = lhs_stride.len();
|
|
|
|
let a_skip: usize = match lhs_stride[..rank - 2] {
|
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
|
[stride] => stride,
|
|
[] => m * k,
|
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
|
};
|
|
let b_skip: usize = match rhs_stride[..rank - 2] {
|
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
|
[stride] => stride,
|
|
[] => n * k,
|
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
|
};
|
|
let c_skip: usize = m * n;
|
|
|
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
|
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
|
|
|
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
|
(n as i32, b'N')
|
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
|
(k as i32, b'T')
|
|
} else {
|
|
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
|
};
|
|
// The b tensor has dims batching, m, k (lhs)
|
|
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
|
(k as i32, b'N')
|
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
|
(m as i32, b'T')
|
|
} else {
|
|
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
|
|
};
|
|
|
|
let mut dst = vec![T::zero(); b * m * n];
|
|
match T::DTYPE {
|
|
DType::F16 => {
|
|
for step in 0..b {
|
|
let lhs_p = &lhs[step * a_skip..];
|
|
let rhs_p = &rhs[step * b_skip..];
|
|
let dst_p = &mut dst[step * c_skip..];
|
|
unsafe {
|
|
let a = rhs_p.as_ptr() as *const f16;
|
|
let b = lhs_p.as_ptr() as *const f16;
|
|
let c = dst_p.as_mut_ptr() as *mut f16;
|
|
let a = std::slice::from_raw_parts(a, a_skip);
|
|
let b = std::slice::from_raw_parts(b, b_skip);
|
|
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
|
crate::mkl::hgemm(
|
|
transa,
|
|
transb,
|
|
/* m= */ n as i32,
|
|
/* n= */ m as i32,
|
|
/* k= */ k as i32,
|
|
/* alpha= */ f16::ONE,
|
|
/* a= */ a,
|
|
/* lda= */ lda,
|
|
/* b= */ b,
|
|
/* ldb= */ ldb,
|
|
/* beta= */ f16::ZERO,
|
|
/* c= */ c,
|
|
/* ldc= */ n as i32,
|
|
)
|
|
}
|
|
}
|
|
}
|
|
DType::F32 => {
|
|
for step in 0..b {
|
|
let lhs_p = &lhs[step * a_skip..];
|
|
let rhs_p = &rhs[step * b_skip..];
|
|
let dst_p = &mut dst[step * c_skip..];
|
|
unsafe {
|
|
let a = rhs_p.as_ptr() as *const f32;
|
|
let b = lhs_p.as_ptr() as *const f32;
|
|
let c = dst_p.as_mut_ptr() as *mut f32;
|
|
let a = std::slice::from_raw_parts(a, a_skip);
|
|
let b = std::slice::from_raw_parts(b, b_skip);
|
|
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
|
crate::mkl::sgemm(
|
|
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
|
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
|
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
|
|
/* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
|
|
)
|
|
}
|
|
}
|
|
}
|
|
DType::F64 => {
|
|
for step in 0..b {
|
|
let lhs_p = &lhs[step * a_skip..];
|
|
let rhs_p = &rhs[step * b_skip..];
|
|
let dst_p = &mut dst[step * c_skip..];
|
|
unsafe {
|
|
let a = rhs_p.as_ptr() as *const f64;
|
|
let b = lhs_p.as_ptr() as *const f64;
|
|
let c = dst_p.as_mut_ptr() as *mut f64;
|
|
let a = std::slice::from_raw_parts(a, a_skip);
|
|
let b = std::slice::from_raw_parts(b, b_skip);
|
|
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
|
crate::mkl::dgemm(
|
|
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
|
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
|
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
|
|
/* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
|
|
)
|
|
}
|
|
}
|
|
}
|
|
dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
|
|
}
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
|
|
if v.is_sign_positive() {
|
|
v
|
|
} else {
|
|
(v.exp() - T::one()) * alpha
|
|
}
|
|
}
|
|
|
|
impl CpuStorage {
|
|
pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
|
|
D::cpu_storage_as_slice(self)
|
|
}
|
|
}
|
|
|
|
impl BackendStorage for CpuStorage {
|
|
type Device = CpuDevice;
|
|
|
|
fn dtype(&self) -> DType {
|
|
match self {
|
|
Self::U8(_) => DType::U8,
|
|
Self::U32(_) => DType::U32,
|
|
Self::BF16(_) => DType::BF16,
|
|
Self::F16(_) => DType::F16,
|
|
Self::F32(_) => DType::F32,
|
|
Self::F64(_) => DType::F64,
|
|
}
|
|
}
|
|
|
|
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
|
// TODO: find a way around the quadratic number of cases below.
|
|
match (self, dtype) {
|
|
(Self::U8(storage), DType::BF16) => {
|
|
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
|
|
Ok(Self::BF16(data))
|
|
}
|
|
(Self::U32(storage), DType::BF16) => {
|
|
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
|
|
Ok(Self::BF16(data))
|
|
}
|
|
(Self::BF16(storage), DType::BF16) => {
|
|
let data = unary_map(storage, layout, |v| v);
|
|
Ok(Self::BF16(data))
|
|
}
|
|
(Self::F16(storage), DType::BF16) => {
|
|
let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
|
|
Ok(Self::BF16(data))
|
|
}
|
|
(Self::F32(storage), DType::BF16) => {
|
|
let data = unary_map(storage, layout, bf16::from_f32);
|
|
Ok(Self::BF16(data))
|
|
}
|
|
(Self::F64(storage), DType::BF16) => {
|
|
let data = unary_map(storage, layout, bf16::from_f64);
|
|
Ok(Self::BF16(data))
|
|
}
|
|
(Self::U8(storage), DType::F16) => {
|
|
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
|
|
Ok(Self::F16(data))
|
|
}
|
|
(Self::U32(storage), DType::F16) => {
|
|
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
|
|
Ok(Self::F16(data))
|
|
}
|
|
(Self::BF16(storage), DType::F16) => {
|
|
let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
|
|
Ok(Self::F16(data))
|
|
}
|
|
(Self::F16(storage), DType::F16) => {
|
|
let data = unary_map(storage, layout, |v| v);
|
|
Ok(Self::F16(data))
|
|
}
|
|
(Self::F32(storage), DType::F16) => {
|
|
let data = unary_map(storage, layout, f16::from_f32);
|
|
Ok(Self::F16(data))
|
|
}
|
|
(Self::F64(storage), DType::F16) => {
|
|
let data = unary_map(storage, layout, f16::from_f64);
|
|
Ok(Self::F16(data))
|
|
}
|
|
(Self::U8(storage), DType::F32) => {
|
|
let data = unary_map(storage, layout, |v| v as f32);
|
|
Ok(Self::F32(data))
|
|
}
|
|
(Self::U32(storage), DType::F32) => {
|
|
let data = unary_map(storage, layout, |v| v as f32);
|
|
Ok(Self::F32(data))
|
|
}
|
|
(Self::BF16(storage), DType::F32) => {
|
|
let data = unary_map(storage, layout, |v| v.to_f32());
|
|
Ok(Self::F32(data))
|
|
}
|
|
(Self::F16(storage), DType::F32) => {
|
|
let data = unary_map(storage, layout, |v| v.to_f32());
|
|
Ok(Self::F32(data))
|
|
}
|
|
(Self::F32(storage), DType::F32) => {
|
|
let data = unary_map(storage, layout, |v| v);
|
|
Ok(Self::F32(data))
|
|
}
|
|
(Self::F64(storage), DType::F32) => {
|
|
let data = unary_map(storage, layout, |v| v as f32);
|
|
Ok(Self::F32(data))
|
|
}
|
|
(Self::U8(storage), DType::U8) => {
|
|
let data = unary_map(storage, layout, |v| v);
|
|
Ok(Self::U8(data))
|
|
}
|
|
(Self::BF16(storage), DType::U8) => {
|
|
let data = unary_map(storage, layout, |v| v.to_f32() as u8);
|
|
Ok(Self::U8(data))
|
|
}
|
|
(Self::F16(storage), DType::U8) => {
|
|
let data = unary_map(storage, layout, |v| v.to_f32() as u8);
|
|
Ok(Self::U8(data))
|
|
}
|
|
(Self::F32(storage), DType::U8) => {
|
|
let data = unary_map(storage, layout, |v| v as u8);
|
|
Ok(Self::U8(data))
|
|
}
|
|
(Self::F64(storage), DType::U8) => {
|
|
let data = unary_map(storage, layout, |v| v as u8);
|
|
Ok(Self::U8(data))
|
|
}
|
|
(Self::U8(storage), DType::U32) => {
|
|
let data = unary_map(storage, layout, |v| v as u32);
|
|
Ok(Self::U32(data))
|
|
}
|
|
(Self::U32(storage), DType::U8) => {
|
|
let data = unary_map(storage, layout, |v| v as u8);
|
|
Ok(Self::U8(data))
|
|
}
|
|
(Self::U32(storage), DType::U32) => {
|
|
let data = unary_map(storage, layout, |v| v);
|
|
Ok(Self::U32(data))
|
|
}
|
|
(Self::BF16(storage), DType::U32) => {
|
|
let data = unary_map(storage, layout, |v| v.to_f32() as u32);
|
|
Ok(Self::U32(data))
|
|
}
|
|
(Self::F16(storage), DType::U32) => {
|
|
let data = unary_map(storage, layout, |v| v.to_f32() as u32);
|
|
Ok(Self::U32(data))
|
|
}
|
|
(Self::F32(storage), DType::U32) => {
|
|
let data = unary_map(storage, layout, |v| v as u32);
|
|
Ok(Self::U32(data))
|
|
}
|
|
(Self::F64(storage), DType::U32) => {
|
|
let data = unary_map(storage, layout, |v| v as u32);
|
|
Ok(Self::U32(data))
|
|
}
|
|
(Self::U8(storage), DType::F64) => {
|
|
let data = unary_map(storage, layout, |v| v as f64);
|
|
Ok(Self::F64(data))
|
|
}
|
|
(Self::U32(storage), DType::F64) => {
|
|
let data = unary_map(storage, layout, |v| v as f64);
|
|
Ok(Self::F64(data))
|
|
}
|
|
(Self::BF16(storage), DType::F64) => {
|
|
let data = unary_map(storage, layout, |v| v.to_f64());
|
|
Ok(Self::F64(data))
|
|
}
|
|
(Self::F16(storage), DType::F64) => {
|
|
let data = unary_map(storage, layout, |v| v.to_f64());
|
|
Ok(Self::F64(data))
|
|
}
|
|
(Self::F32(storage), DType::F64) => {
|
|
let data = unary_map(storage, layout, |v| v as f64);
|
|
Ok(Self::F64(data))
|
|
}
|
|
(Self::F64(storage), DType::F64) => {
|
|
let data = unary_map(storage, layout, |v| v);
|
|
Ok(Self::F64(data))
|
|
}
|
|
}
|
|
}
|
|
|
|
fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result<Self> {
|
|
match op {
|
|
ReduceOp::Sum => {
|
|
let src_dims = layout.dims();
|
|
let mut dst_dims = src_dims.to_vec();
|
|
for &dim in reduce_dims.iter() {
|
|
dst_dims[dim] = 1;
|
|
}
|
|
let dst_shape = Shape::from(dst_dims);
|
|
let mut reduce_dims = reduce_dims.to_vec();
|
|
// Sort the reduce_dims as they have to be processed from left to right when converting the
|
|
// indexes.
|
|
reduce_dims.sort();
|
|
let reduce_dims_and_stride: Vec<_> = reduce_dims
|
|
.iter()
|
|
.map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
|
|
.collect();
|
|
Reduce {
|
|
dst_shape: &dst_shape,
|
|
reduce_dims: &reduce_dims,
|
|
reduce_dims_and_stride,
|
|
}
|
|
.map(self, layout)
|
|
}
|
|
ReduceOp::Min | ReduceOp::ArgMin | ReduceOp::Max | ReduceOp::ArgMax => {
|
|
let reduce_dim_index = match reduce_dims {
|
|
[reduce_dim_index] => *reduce_dim_index,
|
|
_ => {
|
|
let op = match op {
|
|
ReduceOp::Min => "min",
|
|
ReduceOp::ArgMin => "argmin",
|
|
ReduceOp::Max => "max",
|
|
ReduceOp::ArgMax => "argmax",
|
|
_ => unreachable!(),
|
|
};
|
|
let dims = reduce_dims.to_vec();
|
|
Err(Error::OnlySingleDimension { op, dims })?
|
|
}
|
|
};
|
|
let (use_min, return_index) = match op {
|
|
ReduceOp::Min => (true, false),
|
|
ReduceOp::ArgMin => (true, true),
|
|
ReduceOp::Max => (false, false),
|
|
ReduceOp::ArgMax => (false, true),
|
|
_ => unreachable!(),
|
|
};
|
|
ReduceIndex {
|
|
reduce_dim_index,
|
|
use_min,
|
|
return_index,
|
|
}
|
|
.map(self, layout)
|
|
}
|
|
}
|
|
}
|
|
|
|
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
|
Cmp(op).map(self, lhs_l, rhs, rhs_l)
|
|
}
|
|
|
|
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
|
Affine(mul, add).map(self, layout)
|
|
}
|
|
|
|
fn avg_pool2d(
|
|
&self,
|
|
layout: &Layout,
|
|
kernel_size: (usize, usize),
|
|
stride: (usize, usize),
|
|
) -> Result<Self> {
|
|
AvgPool2D(kernel_size, stride).map(self, layout)
|
|
}
|
|
|
|
fn max_pool2d(
|
|
&self,
|
|
layout: &Layout,
|
|
kernel_size: (usize, usize),
|
|
stride: (usize, usize),
|
|
) -> Result<Self> {
|
|
MaxPool2D(kernel_size, stride).map(self, layout)
|
|
}
|
|
|
|
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
|
|
UpsampleNearest2D(h, w).map(self, layout)
|
|
}
|
|
|
|
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
|
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
|
|
match self {
|
|
Self::BF16(storage) => {
|
|
let data = unary_map(storage, layout, |v| elu(v, bf16::from_f64(alpha)));
|
|
Ok(Self::BF16(data))
|
|
}
|
|
Self::F16(storage) => {
|
|
let data = unary_map(storage, layout, |v| elu(v, f16::from_f64(alpha)));
|
|
Ok(Self::F16(data))
|
|
}
|
|
Self::F32(storage) => {
|
|
let data = unary_map(storage, layout, |v| elu(v, f32::from_f64(alpha)));
|
|
Ok(Self::F32(data))
|
|
}
|
|
Self::F64(storage) => {
|
|
let data = unary_map(storage, layout, |v| elu(v, alpha));
|
|
Ok(Self::F64(data))
|
|
}
|
|
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
|
|
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
|
|
}
|
|
}
|
|
|
|
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
|
match self {
|
|
Self::BF16(storage) => {
|
|
if B::BF16_VEC {
|
|
let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec);
|
|
Ok(Self::BF16(data))
|
|
} else {
|
|
let data = unary_map(storage, layout, B::bf16);
|
|
Ok(Self::BF16(data))
|
|
}
|
|
}
|
|
Self::F16(storage) => {
|
|
if B::F16_VEC {
|
|
let data = unary_map_vec(storage, layout, B::f16, B::f16_vec);
|
|
Ok(Self::F16(data))
|
|
} else {
|
|
let data = unary_map(storage, layout, B::f16);
|
|
Ok(Self::F16(data))
|
|
}
|
|
}
|
|
Self::F32(storage) => {
|
|
if B::F32_VEC {
|
|
let data = unary_map_vec(storage, layout, B::f32, B::f32_vec);
|
|
Ok(Self::F32(data))
|
|
} else {
|
|
let data = unary_map(storage, layout, B::f32);
|
|
Ok(Self::F32(data))
|
|
}
|
|
}
|
|
Self::F64(storage) => {
|
|
if B::F64_VEC {
|
|
let data = unary_map_vec(storage, layout, B::f64, B::f64_vec);
|
|
Ok(Self::F64(data))
|
|
} else {
|
|
let data = unary_map(storage, layout, B::f64);
|
|
Ok(Self::F64(data))
|
|
}
|
|
}
|
|
Self::U8(storage) => {
|
|
let data = unary_map(storage, layout, B::u8);
|
|
Ok(Self::U8(data))
|
|
}
|
|
Self::U32(storage) => {
|
|
let data = unary_map(storage, layout, B::u32);
|
|
Ok(Self::U32(data))
|
|
}
|
|
}
|
|
}
|
|
|
|
fn binary_impl<B: BinaryOpT>(
|
|
&self,
|
|
rhs: &Self,
|
|
lhs_l: &Layout,
|
|
rhs_l: &Layout,
|
|
) -> Result<Self> {
|
|
match (self, rhs) {
|
|
(Self::BF16(lhs), Self::BF16(rhs)) => {
|
|
let data = if B::BF16_VEC {
|
|
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::bf16, B::bf16_vec)
|
|
} else {
|
|
binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16)
|
|
};
|
|
Ok(Self::BF16(data))
|
|
}
|
|
(Self::F16(lhs), Self::F16(rhs)) => {
|
|
let data = if B::F16_VEC {
|
|
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f16, B::f16_vec)
|
|
} else {
|
|
binary_map(lhs_l, rhs_l, lhs, rhs, B::f16)
|
|
};
|
|
Ok(Self::F16(data))
|
|
}
|
|
(Self::F32(lhs), Self::F32(rhs)) => {
|
|
let data = if B::F32_VEC {
|
|
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f32, B::f32_vec)
|
|
} else {
|
|
binary_map(lhs_l, rhs_l, lhs, rhs, B::f32)
|
|
};
|
|
Ok(Self::F32(data))
|
|
}
|
|
(Self::F64(lhs), Self::F64(rhs)) => {
|
|
let data = if B::F64_VEC {
|
|
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f64, B::f64_vec)
|
|
} else {
|
|
binary_map(lhs_l, rhs_l, lhs, rhs, B::f64)
|
|
};
|
|
Ok(Self::F64(data))
|
|
}
|
|
(Self::U32(lhs), Self::U32(rhs)) => {
|
|
let data = if B::U32_VEC {
|
|
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u32, B::u32_vec)
|
|
} else {
|
|
binary_map(lhs_l, rhs_l, lhs, rhs, B::u32)
|
|
};
|
|
Ok(Self::U32(data))
|
|
}
|
|
(Self::U8(lhs), Self::U8(rhs)) => {
|
|
let data = if B::U8_VEC {
|
|
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec)
|
|
} else {
|
|
binary_map(lhs_l, rhs_l, lhs, rhs, B::u8)
|
|
};
|
|
Ok(Self::U8(data))
|
|
}
|
|
_ => {
|
|
// This should be covered by the dtype check above.
|
|
Err(Error::DTypeMismatchBinaryOp {
|
|
lhs: self.dtype(),
|
|
rhs: rhs.dtype(),
|
|
op: B::NAME,
|
|
}
|
|
.bt())
|
|
}
|
|
}
|
|
}
|
|
|
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
|
match (self, dst) {
|
|
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
|
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
|
(Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
|
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
|
(Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
|
(Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
|
(_, dst) => {
|
|
// This should be covered by the dtype check above.
|
|
return Err(Error::DTypeMismatchBinaryOp {
|
|
lhs: self.dtype(),
|
|
rhs: dst.dtype(),
|
|
op: "copy_strided",
|
|
}
|
|
.bt());
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn where_cond(
|
|
&self,
|
|
layout: &Layout,
|
|
t: &Self,
|
|
t_l: &Layout,
|
|
f: &Self,
|
|
f_l: &Layout,
|
|
) -> Result<Self> {
|
|
match self {
|
|
Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
|
|
Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
|
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
|
|
}
|
|
}
|
|
|
|
fn conv1d(
|
|
&self,
|
|
l: &Layout,
|
|
kernel: &Self,
|
|
kernel_l: &Layout,
|
|
params: &crate::conv::ParamsConv1D,
|
|
) -> Result<Self> {
|
|
Conv1D(params).map(self, l, kernel, kernel_l)
|
|
}
|
|
|
|
fn conv2d(
|
|
&self,
|
|
l: &Layout,
|
|
kernel: &Self,
|
|
kernel_l: &Layout,
|
|
params: &crate::conv::ParamsConv2D,
|
|
) -> Result<Self> {
|
|
Conv2D(params).map(self, l, kernel, kernel_l)
|
|
}
|
|
|
|
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
|
match ids {
|
|
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
|
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
|
|
}
|
|
}
|
|
|
|
fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
|
match ids {
|
|
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
|
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
|
|
}
|
|
}
|
|
|
|
fn scatter_add(
|
|
&self,
|
|
l: &Layout,
|
|
ids: &Self,
|
|
ids_l: &Layout,
|
|
src: &Self,
|
|
src_l: &Layout,
|
|
dim: usize,
|
|
) -> Result<Self> {
|
|
match ids {
|
|
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
|
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
|
|
}
|
|
}
|
|
|
|
fn index_add(
|
|
&self,
|
|
l: &Layout,
|
|
ids: &Self,
|
|
ids_l: &Layout,
|
|
src: &Self,
|
|
src_l: &Layout,
|
|
dim: usize,
|
|
) -> Result<Self> {
|
|
match ids {
|
|
Self::U8(ids) => {
|
|
let ids = match ids_l.contiguous_offsets() {
|
|
Some((a, b)) => &ids[a..b],
|
|
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
|
};
|
|
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
|
}
|
|
Self::U32(ids) => {
|
|
let ids = match ids_l.contiguous_offsets() {
|
|
Some((a, b)) => &ids[a..b],
|
|
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
|
};
|
|
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
|
}
|
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
|
|
}
|
|
}
|
|
|
|
fn matmul(
|
|
&self,
|
|
rhs: &Self,
|
|
bmnk: (usize, usize, usize, usize),
|
|
lhs_l: &Layout,
|
|
rhs_l: &Layout,
|
|
) -> Result<Self> {
|
|
MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
|
|
}
|
|
|
|
fn device(&self) -> &Self::Device {
|
|
&CpuDevice
|
|
}
|
|
|
|
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
|
Ok(self.clone())
|
|
}
|
|
|
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
|
Ok(self.clone())
|
|
}
|
|
}
|
|
|
|
impl BackendDevice for CpuDevice {
|
|
type Storage = CpuStorage;
|
|
|
|
fn location(&self) -> crate::DeviceLocation {
|
|
crate::DeviceLocation::Cpu
|
|
}
|
|
|
|
fn same_device(&self, _: &Self) -> bool {
|
|
true
|
|
}
|
|
|
|
fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
|
|
Ok(s.clone())
|
|
}
|
|
|
|
fn new(_: usize) -> Result<Self> {
|
|
Ok(Self)
|
|
}
|
|
|
|
fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
|
|
use rand::prelude::*;
|
|
|
|
let elem_count = shape.elem_count();
|
|
let mut rng = rand::thread_rng();
|
|
match dtype {
|
|
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()),
|
|
DType::BF16 => {
|
|
let mut data = Vec::with_capacity(elem_count);
|
|
let uniform =
|
|
rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max));
|
|
for _i in 0..elem_count {
|
|
data.push(rng.sample::<bf16, _>(uniform))
|
|
}
|
|
Ok(CpuStorage::BF16(data))
|
|
}
|
|
DType::F16 => {
|
|
let mut data = Vec::with_capacity(elem_count);
|
|
let uniform =
|
|
rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max));
|
|
for _i in 0..elem_count {
|
|
data.push(rng.sample::<f16, _>(uniform))
|
|
}
|
|
Ok(CpuStorage::F16(data))
|
|
}
|
|
DType::F32 => {
|
|
let mut data = Vec::with_capacity(elem_count);
|
|
let uniform = rand::distributions::Uniform::new(min as f32, max as f32);
|
|
for _i in 0..elem_count {
|
|
data.push(rng.sample::<f32, _>(uniform))
|
|
}
|
|
Ok(CpuStorage::F32(data))
|
|
}
|
|
DType::F64 => {
|
|
let mut data = Vec::with_capacity(elem_count);
|
|
let uniform = rand::distributions::Uniform::new(min, max);
|
|
for _i in 0..elem_count {
|
|
data.push(rng.sample::<f64, _>(uniform))
|
|
}
|
|
Ok(CpuStorage::F64(data))
|
|
}
|
|
}
|
|
}
|
|
|
|
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CpuStorage> {
|
|
use rand::prelude::*;
|
|
|
|
let elem_count = shape.elem_count();
|
|
let mut rng = rand::thread_rng();
|
|
match dtype {
|
|
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
|
|
DType::BF16 => {
|
|
let mut data = Vec::with_capacity(elem_count);
|
|
let std = bf16::from_f64(std);
|
|
let mean = bf16::from_f64(mean);
|
|
for _i in 0..elem_count {
|
|
data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean)
|
|
}
|
|
Ok(CpuStorage::BF16(data))
|
|
}
|
|
DType::F16 => {
|
|
let mut data = Vec::with_capacity(elem_count);
|
|
let std = f16::from_f64(std);
|
|
let mean = f16::from_f64(mean);
|
|
for _i in 0..elem_count {
|
|
data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean)
|
|
}
|
|
Ok(CpuStorage::F16(data))
|
|
}
|
|
DType::F32 => {
|
|
let mut data = Vec::with_capacity(elem_count);
|
|
let std = std as f32;
|
|
let mean = mean as f32;
|
|
for _i in 0..elem_count {
|
|
data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean)
|
|
}
|
|
Ok(CpuStorage::F32(data))
|
|
}
|
|
DType::F64 => {
|
|
let mut data = Vec::with_capacity(elem_count);
|
|
for _i in 0..elem_count {
|
|
data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
|
|
}
|
|
Ok(CpuStorage::F64(data))
|
|
}
|
|
}
|
|
}
|
|
|
|
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
|
let elem_count = shape.elem_count();
|
|
let storage = match dtype {
|
|
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
|
|
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
|
|
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
|
|
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
|
|
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
|
|
DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
|
|
};
|
|
Ok(storage)
|
|
}
|
|
|
|
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
|
let elem_count = shape.elem_count();
|
|
let storage = match dtype {
|
|
DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
|
|
DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
|
|
DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
|
|
DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
|
|
DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
|
|
DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
|
|
};
|
|
Ok(storage)
|
|
}
|
|
}
|
|
|
|
#[macro_export]
|
|
macro_rules! map_dtype {
|
|
($name:expr, $storage:ident, $fn:expr, ($($dtypes:ident),+)) => {
|
|
match $storage {
|
|
$(CpuStorage::$dtypes(__e) => CpuStorage::$dtypes($fn(__e)),)*
|
|
s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())?,
|
|
}
|
|
};
|
|
}
|