Add the kernels.

This commit is contained in:
laurent
2023-06-30 10:26:56 +01:00
parent a7b16cbb98
commit 8ad47907f3
11 changed files with 117 additions and 3 deletions

View File

@ -7,6 +7,7 @@ use half::{bf16, f16};
// intercept the oom errors to avoid panicking and provide a proper error.
#[derive(Debug, Clone)]
pub enum CpuStorage {
U8(Vec<u8>),
U32(Vec<u32>),
BF16(Vec<bf16>),
F16(Vec<f16>),
@ -19,6 +20,7 @@ trait Map1 {
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
match vs {
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
@ -41,6 +43,7 @@ trait Map2 {
l2: &Layout,
) -> Result<CpuStorage> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
@ -302,6 +305,7 @@ fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize)
impl CpuStorage {
pub fn dtype(&self) -> DType {
match self {
Self::U8(_) => DType::U8,
Self::U32(_) => DType::U32,
Self::BF16(_) => DType::BF16,
Self::F16(_) => DType::F16,
@ -417,6 +421,7 @@ impl CpuStorage {
let data = unary_map(storage, layout, |v| v);
Ok(Self::F64(data))
}
_ => todo!("implement cast for {:?} {dtype:?}", self.dtype()),
}
}
@ -449,7 +454,7 @@ impl CpuStorage {
Self::F16(s) => divide_by_sum_over_dim(s, shape, dim),
Self::F32(s) => divide_by_sum_over_dim(s, shape, dim),
Self::F64(s) => divide_by_sum_over_dim(s, shape, dim),
Self::U32(_) => Ok(()),
Self::U8(_) | Self::U32(_) => Ok(()),
}
}
@ -475,6 +480,10 @@ impl CpuStorage {
let data = unary_map(storage, layout, B::f64);
Ok(Self::F64(data))
}
Self::U8(storage) => {
let data = unary_map(storage, layout, B::u8);
Ok(Self::U8(data))
}
Self::U32(storage) => {
let data = unary_map(storage, layout, B::u32);
Ok(Self::U32(data))
@ -509,6 +518,10 @@ impl CpuStorage {
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u32);
Ok(Self::U32(data))
}
(Self::U8(lhs), Self::U8(rhs)) => {
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u8);
Ok(Self::U8(data))
}
_ => {
// This should be covered by the dtype check above.
Err(Error::DTypeMismatchBinaryOp {
@ -527,6 +540,7 @@ impl CpuStorage {
src_l: &Layout,
) -> Result<()> {
match (self, dst) {
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
@ -582,6 +596,7 @@ impl CpuStorage {
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
let elem_count = shape.elem_count();
match dtype {
DType::U8 => Self::U8(vec![1u8; elem_count]),
DType::U32 => Self::U32(vec![1u32; elem_count]),
DType::BF16 => Self::BF16(vec![bf16::ONE; elem_count]),
DType::F16 => Self::F16(vec![f16::ONE; elem_count]),
@ -593,6 +608,7 @@ impl CpuStorage {
pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self {
let elem_count = shape.elem_count();
match dtype {
DType::U8 => Self::U8(vec![0u8; elem_count]),
DType::U32 => Self::U32(vec![0u32; elem_count]),
DType::BF16 => Self::BF16(vec![bf16::ZERO; elem_count]),
DType::F16 => Self::F16(vec![f16::ZERO; elem_count]),