From 8ad47907f3a3d2eaff3850ebc3bf9f4c0cdbae31 Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 30 Jun 2023 10:26:56 +0100 Subject: [PATCH] Add the kernels. --- candle-core/src/cpu_backend.rs | 18 ++++++++++++- candle-core/src/cuda_backend.rs | 44 ++++++++++++++++++++++++++++++++ candle-core/src/display.rs | 7 +++++ candle-core/src/dtype.rs | 4 +++ candle-core/src/npy.rs | 14 ++++++++-- candle-core/src/op.rs | 14 ++++++++++ candle-kernels/src/affine.cu | 1 + candle-kernels/src/binary.cu | 4 +++ candle-kernels/src/cast.cu | 12 +++++++++ candle-kernels/src/embeddings.cu | 1 + candle-kernels/src/ternary.cu | 1 + 11 files changed, 117 insertions(+), 3 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 1425d92f..4105d0de 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -7,6 +7,7 @@ use half::{bf16, f16}; // intercept the oom errors to avoid panicking and provide a proper error. #[derive(Debug, Clone)] pub enum CpuStorage { + U8(Vec), U32(Vec), BF16(Vec), F16(Vec), @@ -19,6 +20,7 @@ trait Map1 { fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result { match vs { + CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)), CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)), CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)), CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)), @@ -41,6 +43,7 @@ trait Map2 { l2: &Layout, ) -> Result { match (v1, v2) { + (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)), (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)), (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)), @@ -302,6 +305,7 @@ fn divide_by_sum_over_dim(s: &mut [T], shape: &Shape, dim: usize) impl CpuStorage { pub fn dtype(&self) -> DType { match self { + Self::U8(_) => DType::U8, Self::U32(_) => DType::U32, Self::BF16(_) => DType::BF16, Self::F16(_) => DType::F16, @@ -417,6 +421,7 @@ impl CpuStorage { let data = unary_map(storage, layout, |v| v); Ok(Self::F64(data)) } + _ => todo!("implement cast for {:?} {dtype:?}", self.dtype()), } } @@ -449,7 +454,7 @@ impl CpuStorage { Self::F16(s) => divide_by_sum_over_dim(s, shape, dim), Self::F32(s) => divide_by_sum_over_dim(s, shape, dim), Self::F64(s) => divide_by_sum_over_dim(s, shape, dim), - Self::U32(_) => Ok(()), + Self::U8(_) | Self::U32(_) => Ok(()), } } @@ -475,6 +480,10 @@ impl CpuStorage { let data = unary_map(storage, layout, B::f64); Ok(Self::F64(data)) } + Self::U8(storage) => { + let data = unary_map(storage, layout, B::u8); + Ok(Self::U8(data)) + } Self::U32(storage) => { let data = unary_map(storage, layout, B::u32); Ok(Self::U32(data)) @@ -509,6 +518,10 @@ impl CpuStorage { let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u32); Ok(Self::U32(data)) } + (Self::U8(lhs), Self::U8(rhs)) => { + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u8); + Ok(Self::U8(data)) + } _ => { // This should be covered by the dtype check above. Err(Error::DTypeMismatchBinaryOp { @@ -527,6 +540,7 @@ impl CpuStorage { src_l: &Layout, ) -> Result<()> { match (self, dst) { + (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), @@ -582,6 +596,7 @@ impl CpuStorage { pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self { let elem_count = shape.elem_count(); match dtype { + DType::U8 => Self::U8(vec![1u8; elem_count]), DType::U32 => Self::U32(vec![1u32; elem_count]), DType::BF16 => Self::BF16(vec![bf16::ONE; elem_count]), DType::F16 => Self::F16(vec![f16::ONE; elem_count]), @@ -593,6 +608,7 @@ impl CpuStorage { pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self { let elem_count = shape.elem_count(); match dtype { + DType::U8 => Self::U8(vec![0u8; elem_count]), DType::U32 => Self::U32(vec![0u32; elem_count]), DType::BF16 => Self::BF16(vec![bf16::ZERO; elem_count]), DType::F16 => Self::F16(vec![f16::ZERO; elem_count]), diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 40b7e67f..641efd7f 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -105,6 +105,10 @@ impl CudaDevice { pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { let elem_count = shape.elem_count(); let slice = match dtype { + DType::U8 => { + let data = self.alloc_zeros::(elem_count)?; + CudaStorageSlice::U8(data) + } DType::U32 => { let data = self.alloc_zeros::(elem_count)?; CudaStorageSlice::U32(data) @@ -136,6 +140,14 @@ impl CudaDevice { let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); let slice = match dtype { + DType::U8 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }?; + let func = self.get_or_load_func("fill_u8", kernels::FILL)?; + let params = (&data, v as u8, elem_count); + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::U8(data) + } DType::U32 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }?; @@ -189,6 +201,10 @@ impl CudaDevice { pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result { let slice = match storage { + CpuStorage::U8(storage) => { + let data = self.htod_sync_copy(storage)?; + CudaStorageSlice::U8(data) + } CpuStorage::U32(storage) => { let data = self.htod_sync_copy(storage)?; CudaStorageSlice::U32(data) @@ -238,6 +254,7 @@ impl CudaDevice { #[derive(Debug)] enum CudaStorageSlice { + U8(CudaSlice), U32(CudaSlice), BF16(CudaSlice), F16(CudaSlice), @@ -256,6 +273,7 @@ trait Map1 { fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { let out = match s { + S::U8(s) => S::U8(self.f(s, d, l)?), S::U32(s) => S::U32(self.f(s, d, l)?), S::BF16(s) => S::BF16(self.f(s, d, l)?), S::F16(s) => S::F16(self.f(s, d, l)?), @@ -278,6 +296,7 @@ trait Map2 { fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { let out = match (s1, s2) { + (S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?), (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?), (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?), (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), @@ -596,6 +615,7 @@ impl CudaStorage { pub fn dtype(&self) -> DType { match self.slice { + CudaStorageSlice::U8(_) => DType::U8, CudaStorageSlice::U32(_) => DType::U32, CudaStorageSlice::BF16(_) => DType::BF16, CudaStorageSlice::F16(_) => DType::F16, @@ -621,6 +641,7 @@ impl CudaStorage { // lifetime issue and is safe as long as self.slice does not go out of scope before inp // is used. let inp = match &self.slice { + CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(), @@ -632,6 +653,12 @@ impl CudaStorage { let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str()); let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?; let slice = match dtype { + DType::U8 => { + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::U8(out) + } DType::U32 => { let out = unsafe { dev.alloc::(el) }?; let params = (el, dims.len(), &ds, *inp, &out); @@ -706,6 +733,11 @@ impl CudaStorage { pub(crate) fn to_cpu_storage(&self) -> Result { match &self.slice { + CudaStorageSlice::U8(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice)?; + Ok(CpuStorage::U8(cpu_storage)) + } CudaStorageSlice::U32(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice)?; @@ -857,6 +889,18 @@ impl CudaStorage { unsafe { func.launch(cfg, params) }? } } + (CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst)? + } else { + let func = dev.get_or_load_func("ucopy_u8", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }? + } + } (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 81ca3c98..60907bb3 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -43,6 +43,7 @@ impl Tensor { impl std::fmt::Debug for Tensor { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self.dtype() { + DType::U8 => self.fmt_dt::(f), DType::U32 => self.fmt_dt::(f), DType::BF16 => self.fmt_dt::(f), DType::F16 => self.fmt_dt::(f), @@ -415,6 +416,12 @@ impl std::fmt::Display for Tensor { self.clone() }; match self.dtype() { + DType::U8 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } DType::U32 => { let tf: IntFormatter = IntFormatter::new(); let max_w = tf.max_width(&to_display); diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 9a51635d..e6785491 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -2,6 +2,7 @@ use crate::{CpuStorage, Error, Result}; #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum DType { + U8, U32, BF16, F16, @@ -12,6 +13,7 @@ pub enum DType { impl DType { pub fn as_str(&self) -> &'static str { match self { + Self::U8 => "u8", Self::U32 => "u32", Self::BF16 => "bf16", Self::F16 => "f16", @@ -22,6 +24,7 @@ impl DType { pub fn size_in_bytes(&self) -> usize { match self { + Self::U8 => 4, Self::U32 => 4, Self::BF16 => 2, Self::F16 => 2, @@ -89,6 +92,7 @@ macro_rules! with_dtype { } use half::{bf16, f16}; +with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64); with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64); with_dtype!(f16, F16, f16::from_f64, f16::to_f64); with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index 7e157c8f..c0608519 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -86,6 +86,7 @@ impl Header { DType::F32 => "f4", DType::F64 => "f8", DType::U32 => "u4", + DType::U8 => "u1", }; if !shape.is_empty() { shape.push(',') @@ -162,9 +163,9 @@ impl Header { // "q" | "i8" => DType::S64, // "h" | "i2" => DType::S16, // "b" | "i1" => DType::S8, - // "B" | "u1" => DType::U8, + "B" | "u1" => DType::U8, "I" | "u4" => DType::U32, - // "?" | "b1" => DType::Pred, + "?" | "b1" => DType::U8, // "F" | "F4" => DType::C64, // "D" | "F8" => DType::C128, descr => return Err(Error::Npy(format!("unrecognized descr {descr}"))), @@ -218,6 +219,11 @@ impl Tensor { reader.read_f64_into::(&mut data_t)?; Tensor::from_vec(data_t, shape, &Device::Cpu) } + DType::U8 => { + let mut data_t = vec![0u8; elem_count]; + reader.read_exact(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } DType::U32 => { let mut data_t = vec![0u32; elem_count]; reader.read_u32_into::(&mut data_t)?; @@ -331,6 +337,10 @@ impl Tensor { f.write_u32::(v)? } } + DType::U8 => { + let data = self.reshape(elem_count)?.to_vec1::()?; + f.write_all(&data)?; + } } Ok(()) } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index db6ef87f..860be0b3 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -49,6 +49,7 @@ pub(crate) trait UnaryOp { fn f16(v1: f16) -> f16; fn f32(v1: f32) -> f32; fn f64(v1: f64) -> f64; + fn u8(v1: u8) -> u8; fn u32(v1: u32) -> u32; } @@ -60,6 +61,7 @@ pub(crate) trait BinaryOp { fn f16(v1: f16, v2: f16) -> f16; fn f32(v1: f32, v2: f32) -> f32; fn f64(v1: f64, v2: f64) -> f64; + fn u8(v1: u8, v2: u8) -> u8; fn u32(v1: u32, v2: u32) -> u32; } @@ -96,6 +98,9 @@ macro_rules! bin_op { fn f64(v1: f64, v2: f64) -> f64 { $e(v1, v2) } + fn u8(v1: u8, v2: u8) -> u8 { + $e(v1, v2) + } fn u32(v1: u32, v2: u32) -> u32 { $e(v1, v2) } @@ -126,6 +131,9 @@ macro_rules! unary_op { fn f64($a: f64) -> f64 { $e } + fn u8(_: u8) -> u8 { + todo!("no unary function for u8") + } fn u32(_: u32) -> u32 { todo!("no unary function for u32") } @@ -177,6 +185,9 @@ impl UnaryOp for Gelu { * (1.0 + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) } + fn u8(_: u8) -> u8 { + 0 + } fn u32(_: u32) -> u32 { 0 } @@ -199,6 +210,9 @@ impl UnaryOp for Relu { fn f64(v: f64) -> f64 { v.max(0f64) } + fn u8(v: u8) -> u8 { + v + } fn u32(v: u32) -> u32 { v } diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu index a52dd639..a02ce7a6 100644 --- a/candle-kernels/src/affine.cu +++ b/candle-kernels/src/affine.cu @@ -38,4 +38,5 @@ AFFINE_OP(__half, affine_f16) AFFINE_OP(float, affine_f32) AFFINE_OP(double, affine_f64) +AFFINE_OP(uint8_t, affine_u8) AFFINE_OP(uint32_t, affine_u32) diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu index 65f24db1..c99d96fd 100644 --- a/candle-kernels/src/binary.cu +++ b/candle-kernels/src/binary.cu @@ -17,13 +17,17 @@ BINARY_OP(__half, bsub_f16, x - y) BINARY_OP(float, badd_f32, x + y) BINARY_OP(double, badd_f64, x + y); +BINARY_OP(uint8_t, badd_u8, x + y); BINARY_OP(uint32_t, badd_u32, x + y); BINARY_OP(float, bdiv_f32, x / y) BINARY_OP(double, bdiv_f64, x / y); +BINARY_OP(uint8_t, bdiv_u8, x / y); BINARY_OP(uint32_t, bdiv_u32, x / y); BINARY_OP(float, bmul_f32, x * y) BINARY_OP(double, bmul_f64, x * y); +BINARY_OP(uint8_t, bmul_u8, x * y); BINARY_OP(uint32_t, bmul_u32, x * y); BINARY_OP(float, bsub_f32, x - y) BINARY_OP(double, bsub_f64, x - y); +BINARY_OP(uint8_t, bsub_u8, x - y); BINARY_OP(uint32_t, bsub_u32, x - y); diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index a3b10da2..42d04e80 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -27,10 +27,12 @@ extern "C" __global__ void FN_NAME( \ #if __CUDA_ARCH__ >= 800 CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16) +CAST_OP(__nv_bfloat16, uint8_t, cast_bf16_u8) CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32) // CAST_OP(__nv_bfloat16, __half, cast_bf16_f16) CAST_OP(__nv_bfloat16, float, cast_bf16_f32) CAST_OP(__nv_bfloat16, double, cast_bf16_f64) +CAST_OP(uint8_t, __nv_bfloat16, cast_u8_bf16) CAST_OP(uint32_t, __nv_bfloat16, cast_u32_bf16) // CAST_OP(__half, __nv_bfloat16, cast_f16_bf16) CAST_OP(float, __nv_bfloat16, cast_f32_bf16) @@ -40,22 +42,32 @@ CAST_OP(double, __nv_bfloat16, cast_f64_bf16) #if __CUDA_ARCH__ >= 530 CAST_OP(__half, __half, cast_f16_f16) +// CAST_OP(__half, uint8_t, cast_f16_u8 ) CAST_OP(__half, uint32_t, cast_f16_u32) CAST_OP(__half, float, cast_f16_f32) CAST_OP(__half, double, cast_f16_f64) +CAST_OP(uint8_t, __half, cast_u8_f16 ) CAST_OP(uint32_t, __half, cast_u32_f16) CAST_OP(float, __half, cast_f32_f16) CAST_OP(double, __half, cast_f64_f16) #endif CAST_OP(uint32_t, uint32_t, cast_u32_u32) +CAST_OP(uint32_t, uint8_t, cast_u32_u8 ) CAST_OP(uint32_t, float, cast_u32_f32) CAST_OP(uint32_t, double, cast_u32_f64) +CAST_OP(uint8_t, uint32_t, cast_u8_u32) +CAST_OP(uint8_t, uint8_t, cast_u8_u8 ) +CAST_OP(uint8_t, float, cast_u8_f32) +CAST_OP(uint8_t, double, cast_u8_f64) + +CAST_OP(float, uint8_t, cast_f32_u8 ) CAST_OP(float, uint32_t, cast_f32_u32) CAST_OP(float, float, cast_f32_f32) CAST_OP(float, double, cast_f32_f64) +CAST_OP(double, uint8_t, cast_f64_u8 ) CAST_OP(double, uint32_t, cast_f64_u32) CAST_OP(double, float, cast_f64_f32) CAST_OP(double, double, cast_f64_f64) diff --git a/candle-kernels/src/embeddings.cu b/candle-kernels/src/embeddings.cu index 18fe5dfb..8425b16b 100644 --- a/candle-kernels/src/embeddings.cu +++ b/candle-kernels/src/embeddings.cu @@ -39,4 +39,5 @@ EMB_OP(__half, emb_f16) EMB_OP(float, emb_f32) EMB_OP(double, emb_f64) +EMB_OP(uint8_t, emb_u8) EMB_OP(uint32_t, emb_u32) diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index d08f9e10..c064f6e5 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -42,4 +42,5 @@ WHERE_OP(__half, where_f16) WHERE_OP(float, where_f32) WHERE_OP(double, where_f64) +WHERE_OP(uint8_t, where_u8) WHERE_OP(uint32_t, where_u32)