mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Add the kernels.
This commit is contained in:
@ -7,6 +7,7 @@ use half::{bf16, f16};
|
||||
// intercept the oom errors to avoid panicking and provide a proper error.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CpuStorage {
|
||||
U8(Vec<u8>),
|
||||
U32(Vec<u32>),
|
||||
BF16(Vec<bf16>),
|
||||
F16(Vec<f16>),
|
||||
@ -19,6 +20,7 @@ trait Map1 {
|
||||
|
||||
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
||||
match vs {
|
||||
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
|
||||
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
|
||||
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
|
||||
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
|
||||
@ -41,6 +43,7 @@ trait Map2 {
|
||||
l2: &Layout,
|
||||
) -> Result<CpuStorage> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
||||
@ -302,6 +305,7 @@ fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize)
|
||||
impl CpuStorage {
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self {
|
||||
Self::U8(_) => DType::U8,
|
||||
Self::U32(_) => DType::U32,
|
||||
Self::BF16(_) => DType::BF16,
|
||||
Self::F16(_) => DType::F16,
|
||||
@ -417,6 +421,7 @@ impl CpuStorage {
|
||||
let data = unary_map(storage, layout, |v| v);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
_ => todo!("implement cast for {:?} {dtype:?}", self.dtype()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -449,7 +454,7 @@ impl CpuStorage {
|
||||
Self::F16(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||
Self::F32(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||
Self::F64(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||
Self::U32(_) => Ok(()),
|
||||
Self::U8(_) | Self::U32(_) => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -475,6 +480,10 @@ impl CpuStorage {
|
||||
let data = unary_map(storage, layout, B::f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
Self::U8(storage) => {
|
||||
let data = unary_map(storage, layout, B::u8);
|
||||
Ok(Self::U8(data))
|
||||
}
|
||||
Self::U32(storage) => {
|
||||
let data = unary_map(storage, layout, B::u32);
|
||||
Ok(Self::U32(data))
|
||||
@ -509,6 +518,10 @@ impl CpuStorage {
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::U8(lhs), Self::U8(rhs)) => {
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u8);
|
||||
Ok(Self::U8(data))
|
||||
}
|
||||
_ => {
|
||||
// This should be covered by the dtype check above.
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
@ -527,6 +540,7 @@ impl CpuStorage {
|
||||
src_l: &Layout,
|
||||
) -> Result<()> {
|
||||
match (self, dst) {
|
||||
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
@ -582,6 +596,7 @@ impl CpuStorage {
|
||||
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
DType::U8 => Self::U8(vec![1u8; elem_count]),
|
||||
DType::U32 => Self::U32(vec![1u32; elem_count]),
|
||||
DType::BF16 => Self::BF16(vec![bf16::ONE; elem_count]),
|
||||
DType::F16 => Self::F16(vec![f16::ONE; elem_count]),
|
||||
@ -593,6 +608,7 @@ impl CpuStorage {
|
||||
pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
DType::U8 => Self::U8(vec![0u8; elem_count]),
|
||||
DType::U32 => Self::U32(vec![0u32; elem_count]),
|
||||
DType::BF16 => Self::BF16(vec![bf16::ZERO; elem_count]),
|
||||
DType::F16 => Self::F16(vec![f16::ZERO; elem_count]),
|
||||
|
Reference in New Issue
Block a user