Add the kernels.

This commit is contained in:
laurent
2023-06-30 10:26:56 +01:00
parent a7b16cbb98
commit 8ad47907f3
11 changed files with 117 additions and 3 deletions

View File

@ -7,6 +7,7 @@ use half::{bf16, f16};
// intercept the oom errors to avoid panicking and provide a proper error. // intercept the oom errors to avoid panicking and provide a proper error.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum CpuStorage { pub enum CpuStorage {
U8(Vec<u8>),
U32(Vec<u32>), U32(Vec<u32>),
BF16(Vec<bf16>), BF16(Vec<bf16>),
F16(Vec<f16>), F16(Vec<f16>),
@ -19,6 +20,7 @@ trait Map1 {
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> { fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
match vs { match vs {
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)), CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)), CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)), CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
@ -41,6 +43,7 @@ trait Map2 {
l2: &Layout, l2: &Layout,
) -> Result<CpuStorage> { ) -> Result<CpuStorage> {
match (v1, v2) { match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)), (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)), (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)), (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
@ -302,6 +305,7 @@ fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize)
impl CpuStorage { impl CpuStorage {
pub fn dtype(&self) -> DType { pub fn dtype(&self) -> DType {
match self { match self {
Self::U8(_) => DType::U8,
Self::U32(_) => DType::U32, Self::U32(_) => DType::U32,
Self::BF16(_) => DType::BF16, Self::BF16(_) => DType::BF16,
Self::F16(_) => DType::F16, Self::F16(_) => DType::F16,
@ -417,6 +421,7 @@ impl CpuStorage {
let data = unary_map(storage, layout, |v| v); let data = unary_map(storage, layout, |v| v);
Ok(Self::F64(data)) Ok(Self::F64(data))
} }
_ => todo!("implement cast for {:?} {dtype:?}", self.dtype()),
} }
} }
@ -449,7 +454,7 @@ impl CpuStorage {
Self::F16(s) => divide_by_sum_over_dim(s, shape, dim), Self::F16(s) => divide_by_sum_over_dim(s, shape, dim),
Self::F32(s) => divide_by_sum_over_dim(s, shape, dim), Self::F32(s) => divide_by_sum_over_dim(s, shape, dim),
Self::F64(s) => divide_by_sum_over_dim(s, shape, dim), Self::F64(s) => divide_by_sum_over_dim(s, shape, dim),
Self::U32(_) => Ok(()), Self::U8(_) | Self::U32(_) => Ok(()),
} }
} }
@ -475,6 +480,10 @@ impl CpuStorage {
let data = unary_map(storage, layout, B::f64); let data = unary_map(storage, layout, B::f64);
Ok(Self::F64(data)) Ok(Self::F64(data))
} }
Self::U8(storage) => {
let data = unary_map(storage, layout, B::u8);
Ok(Self::U8(data))
}
Self::U32(storage) => { Self::U32(storage) => {
let data = unary_map(storage, layout, B::u32); let data = unary_map(storage, layout, B::u32);
Ok(Self::U32(data)) Ok(Self::U32(data))
@ -509,6 +518,10 @@ impl CpuStorage {
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u32); let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u32);
Ok(Self::U32(data)) Ok(Self::U32(data))
} }
(Self::U8(lhs), Self::U8(rhs)) => {
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u8);
Ok(Self::U8(data))
}
_ => { _ => {
// This should be covered by the dtype check above. // This should be covered by the dtype check above.
Err(Error::DTypeMismatchBinaryOp { Err(Error::DTypeMismatchBinaryOp {
@ -527,6 +540,7 @@ impl CpuStorage {
src_l: &Layout, src_l: &Layout,
) -> Result<()> { ) -> Result<()> {
match (self, dst) { match (self, dst) {
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
@ -582,6 +596,7 @@ impl CpuStorage {
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self { pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
let elem_count = shape.elem_count(); let elem_count = shape.elem_count();
match dtype { match dtype {
DType::U8 => Self::U8(vec![1u8; elem_count]),
DType::U32 => Self::U32(vec![1u32; elem_count]), DType::U32 => Self::U32(vec![1u32; elem_count]),
DType::BF16 => Self::BF16(vec![bf16::ONE; elem_count]), DType::BF16 => Self::BF16(vec![bf16::ONE; elem_count]),
DType::F16 => Self::F16(vec![f16::ONE; elem_count]), DType::F16 => Self::F16(vec![f16::ONE; elem_count]),
@ -593,6 +608,7 @@ impl CpuStorage {
pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self { pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self {
let elem_count = shape.elem_count(); let elem_count = shape.elem_count();
match dtype { match dtype {
DType::U8 => Self::U8(vec![0u8; elem_count]),
DType::U32 => Self::U32(vec![0u32; elem_count]), DType::U32 => Self::U32(vec![0u32; elem_count]),
DType::BF16 => Self::BF16(vec![bf16::ZERO; elem_count]), DType::BF16 => Self::BF16(vec![bf16::ZERO; elem_count]),
DType::F16 => Self::F16(vec![f16::ZERO; elem_count]), DType::F16 => Self::F16(vec![f16::ZERO; elem_count]),

View File

@ -105,6 +105,10 @@ impl CudaDevice {
pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> { pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count(); let elem_count = shape.elem_count();
let slice = match dtype { let slice = match dtype {
DType::U8 => {
let data = self.alloc_zeros::<u8>(elem_count)?;
CudaStorageSlice::U8(data)
}
DType::U32 => { DType::U32 => {
let data = self.alloc_zeros::<u32>(elem_count)?; let data = self.alloc_zeros::<u32>(elem_count)?;
CudaStorageSlice::U32(data) CudaStorageSlice::U32(data)
@ -136,6 +140,14 @@ impl CudaDevice {
let elem_count = shape.elem_count(); let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32); let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let slice = match dtype { let slice = match dtype {
DType::U8 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count) }?;
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
let params = (&data, v as u8, elem_count);
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U8(data)
}
DType::U32 => { DType::U32 => {
// SAFETY: Set later by running the fill kernel. // SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count) }?; let data = unsafe { self.alloc::<u32>(elem_count) }?;
@ -189,6 +201,10 @@ impl CudaDevice {
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> { pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage { let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.htod_sync_copy(storage)?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => { CpuStorage::U32(storage) => {
let data = self.htod_sync_copy(storage)?; let data = self.htod_sync_copy(storage)?;
CudaStorageSlice::U32(data) CudaStorageSlice::U32(data)
@ -238,6 +254,7 @@ impl CudaDevice {
#[derive(Debug)] #[derive(Debug)]
enum CudaStorageSlice { enum CudaStorageSlice {
U8(CudaSlice<u8>),
U32(CudaSlice<u32>), U32(CudaSlice<u32>),
BF16(CudaSlice<bf16>), BF16(CudaSlice<bf16>),
F16(CudaSlice<f16>), F16(CudaSlice<f16>),
@ -256,6 +273,7 @@ trait Map1 {
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> { fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s { let out = match s {
S::U8(s) => S::U8(self.f(s, d, l)?),
S::U32(s) => S::U32(self.f(s, d, l)?), S::U32(s) => S::U32(self.f(s, d, l)?),
S::BF16(s) => S::BF16(self.f(s, d, l)?), S::BF16(s) => S::BF16(self.f(s, d, l)?),
S::F16(s) => S::F16(self.f(s, d, l)?), S::F16(s) => S::F16(self.f(s, d, l)?),
@ -278,6 +296,7 @@ trait Map2 {
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> { fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) { let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?), (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?), (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
@ -596,6 +615,7 @@ impl CudaStorage {
pub fn dtype(&self) -> DType { pub fn dtype(&self) -> DType {
match self.slice { match self.slice {
CudaStorageSlice::U8(_) => DType::U8,
CudaStorageSlice::U32(_) => DType::U32, CudaStorageSlice::U32(_) => DType::U32,
CudaStorageSlice::BF16(_) => DType::BF16, CudaStorageSlice::BF16(_) => DType::BF16,
CudaStorageSlice::F16(_) => DType::F16, CudaStorageSlice::F16(_) => DType::F16,
@ -621,6 +641,7 @@ impl CudaStorage {
// lifetime issue and is safe as long as self.slice does not go out of scope before inp // lifetime issue and is safe as long as self.slice does not go out of scope before inp
// is used. // is used.
let inp = match &self.slice { let inp = match &self.slice {
CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(),
@ -632,6 +653,12 @@ impl CudaStorage {
let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str()); let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str());
let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?; let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?;
let slice = match dtype { let slice = match dtype {
DType::U8 => {
let out = unsafe { dev.alloc::<u8>(el) }?;
let params = (el, dims.len(), &ds, *inp, &out);
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U8(out)
}
DType::U32 => { DType::U32 => {
let out = unsafe { dev.alloc::<u32>(el) }?; let out = unsafe { dev.alloc::<u32>(el) }?;
let params = (el, dims.len(), &ds, *inp, &out); let params = (el, dims.len(), &ds, *inp, &out);
@ -706,6 +733,11 @@ impl CudaStorage {
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> { pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
match &self.slice { match &self.slice {
CudaStorageSlice::U8(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice)?;
Ok(CpuStorage::U8(cpu_storage))
}
CudaStorageSlice::U32(slice) => { CudaStorageSlice::U32(slice) => {
let dev = slice.device(); let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice)?; let cpu_storage = dev.dtoh_sync_copy(slice)?;
@ -857,6 +889,18 @@ impl CudaStorage {
unsafe { func.launch(cfg, params) }? unsafe { func.launch(cfg, params) }?
} }
} }
(CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_u8", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?
}
}
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() { if src_l.is_contiguous() {

View File

@ -43,6 +43,7 @@ impl Tensor {
impl std::fmt::Debug for Tensor { impl std::fmt::Debug for Tensor {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self.dtype() { match self.dtype() {
DType::U8 => self.fmt_dt::<u8>(f),
DType::U32 => self.fmt_dt::<u32>(f), DType::U32 => self.fmt_dt::<u32>(f),
DType::BF16 => self.fmt_dt::<bf16>(f), DType::BF16 => self.fmt_dt::<bf16>(f),
DType::F16 => self.fmt_dt::<f16>(f), DType::F16 => self.fmt_dt::<f16>(f),
@ -415,6 +416,12 @@ impl std::fmt::Display for Tensor {
self.clone() self.clone()
}; };
match self.dtype() { match self.dtype() {
DType::U8 => {
let tf: IntFormatter<u8> = IntFormatter::new();
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::U32 => { DType::U32 => {
let tf: IntFormatter<u32> = IntFormatter::new(); let tf: IntFormatter<u32> = IntFormatter::new();
let max_w = tf.max_width(&to_display); let max_w = tf.max_width(&to_display);

View File

@ -2,6 +2,7 @@ use crate::{CpuStorage, Error, Result};
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType { pub enum DType {
U8,
U32, U32,
BF16, BF16,
F16, F16,
@ -12,6 +13,7 @@ pub enum DType {
impl DType { impl DType {
pub fn as_str(&self) -> &'static str { pub fn as_str(&self) -> &'static str {
match self { match self {
Self::U8 => "u8",
Self::U32 => "u32", Self::U32 => "u32",
Self::BF16 => "bf16", Self::BF16 => "bf16",
Self::F16 => "f16", Self::F16 => "f16",
@ -22,6 +24,7 @@ impl DType {
pub fn size_in_bytes(&self) -> usize { pub fn size_in_bytes(&self) -> usize {
match self { match self {
Self::U8 => 4,
Self::U32 => 4, Self::U32 => 4,
Self::BF16 => 2, Self::BF16 => 2,
Self::F16 => 2, Self::F16 => 2,
@ -89,6 +92,7 @@ macro_rules! with_dtype {
} }
use half::{bf16, f16}; use half::{bf16, f16};
with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64);
with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64); with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64);
with_dtype!(f16, F16, f16::from_f64, f16::to_f64); with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);

View File

@ -86,6 +86,7 @@ impl Header {
DType::F32 => "f4", DType::F32 => "f4",
DType::F64 => "f8", DType::F64 => "f8",
DType::U32 => "u4", DType::U32 => "u4",
DType::U8 => "u1",
}; };
if !shape.is_empty() { if !shape.is_empty() {
shape.push(',') shape.push(',')
@ -162,9 +163,9 @@ impl Header {
// "q" | "i8" => DType::S64, // "q" | "i8" => DType::S64,
// "h" | "i2" => DType::S16, // "h" | "i2" => DType::S16,
// "b" | "i1" => DType::S8, // "b" | "i1" => DType::S8,
// "B" | "u1" => DType::U8, "B" | "u1" => DType::U8,
"I" | "u4" => DType::U32, "I" | "u4" => DType::U32,
// "?" | "b1" => DType::Pred, "?" | "b1" => DType::U8,
// "F" | "F4" => DType::C64, // "F" | "F4" => DType::C64,
// "D" | "F8" => DType::C128, // "D" | "F8" => DType::C128,
descr => return Err(Error::Npy(format!("unrecognized descr {descr}"))), descr => return Err(Error::Npy(format!("unrecognized descr {descr}"))),
@ -218,6 +219,11 @@ impl Tensor {
reader.read_f64_into::<LittleEndian>(&mut data_t)?; reader.read_f64_into::<LittleEndian>(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu) Tensor::from_vec(data_t, shape, &Device::Cpu)
} }
DType::U8 => {
let mut data_t = vec![0u8; elem_count];
reader.read_exact(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::U32 => { DType::U32 => {
let mut data_t = vec![0u32; elem_count]; let mut data_t = vec![0u32; elem_count];
reader.read_u32_into::<LittleEndian>(&mut data_t)?; reader.read_u32_into::<LittleEndian>(&mut data_t)?;
@ -331,6 +337,10 @@ impl Tensor {
f.write_u32::<LittleEndian>(v)? f.write_u32::<LittleEndian>(v)?
} }
} }
DType::U8 => {
let data = self.reshape(elem_count)?.to_vec1::<u8>()?;
f.write_all(&data)?;
}
} }
Ok(()) Ok(())
} }

View File

@ -49,6 +49,7 @@ pub(crate) trait UnaryOp {
fn f16(v1: f16) -> f16; fn f16(v1: f16) -> f16;
fn f32(v1: f32) -> f32; fn f32(v1: f32) -> f32;
fn f64(v1: f64) -> f64; fn f64(v1: f64) -> f64;
fn u8(v1: u8) -> u8;
fn u32(v1: u32) -> u32; fn u32(v1: u32) -> u32;
} }
@ -60,6 +61,7 @@ pub(crate) trait BinaryOp {
fn f16(v1: f16, v2: f16) -> f16; fn f16(v1: f16, v2: f16) -> f16;
fn f32(v1: f32, v2: f32) -> f32; fn f32(v1: f32, v2: f32) -> f32;
fn f64(v1: f64, v2: f64) -> f64; fn f64(v1: f64, v2: f64) -> f64;
fn u8(v1: u8, v2: u8) -> u8;
fn u32(v1: u32, v2: u32) -> u32; fn u32(v1: u32, v2: u32) -> u32;
} }
@ -96,6 +98,9 @@ macro_rules! bin_op {
fn f64(v1: f64, v2: f64) -> f64 { fn f64(v1: f64, v2: f64) -> f64 {
$e(v1, v2) $e(v1, v2)
} }
fn u8(v1: u8, v2: u8) -> u8 {
$e(v1, v2)
}
fn u32(v1: u32, v2: u32) -> u32 { fn u32(v1: u32, v2: u32) -> u32 {
$e(v1, v2) $e(v1, v2)
} }
@ -126,6 +131,9 @@ macro_rules! unary_op {
fn f64($a: f64) -> f64 { fn f64($a: f64) -> f64 {
$e $e
} }
fn u8(_: u8) -> u8 {
todo!("no unary function for u8")
}
fn u32(_: u32) -> u32 { fn u32(_: u32) -> u32 {
todo!("no unary function for u32") todo!("no unary function for u32")
} }
@ -177,6 +185,9 @@ impl UnaryOp for Gelu {
* (1.0 * (1.0
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
} }
fn u8(_: u8) -> u8 {
0
}
fn u32(_: u32) -> u32 { fn u32(_: u32) -> u32 {
0 0
} }
@ -199,6 +210,9 @@ impl UnaryOp for Relu {
fn f64(v: f64) -> f64 { fn f64(v: f64) -> f64 {
v.max(0f64) v.max(0f64)
} }
fn u8(v: u8) -> u8 {
v
}
fn u32(v: u32) -> u32 { fn u32(v: u32) -> u32 {
v v
} }

View File

@ -38,4 +38,5 @@ AFFINE_OP(__half, affine_f16)
AFFINE_OP(float, affine_f32) AFFINE_OP(float, affine_f32)
AFFINE_OP(double, affine_f64) AFFINE_OP(double, affine_f64)
AFFINE_OP(uint8_t, affine_u8)
AFFINE_OP(uint32_t, affine_u32) AFFINE_OP(uint32_t, affine_u32)

View File

@ -17,13 +17,17 @@ BINARY_OP(__half, bsub_f16, x - y)
BINARY_OP(float, badd_f32, x + y) BINARY_OP(float, badd_f32, x + y)
BINARY_OP(double, badd_f64, x + y); BINARY_OP(double, badd_f64, x + y);
BINARY_OP(uint8_t, badd_u8, x + y);
BINARY_OP(uint32_t, badd_u32, x + y); BINARY_OP(uint32_t, badd_u32, x + y);
BINARY_OP(float, bdiv_f32, x / y) BINARY_OP(float, bdiv_f32, x / y)
BINARY_OP(double, bdiv_f64, x / y); BINARY_OP(double, bdiv_f64, x / y);
BINARY_OP(uint8_t, bdiv_u8, x / y);
BINARY_OP(uint32_t, bdiv_u32, x / y); BINARY_OP(uint32_t, bdiv_u32, x / y);
BINARY_OP(float, bmul_f32, x * y) BINARY_OP(float, bmul_f32, x * y)
BINARY_OP(double, bmul_f64, x * y); BINARY_OP(double, bmul_f64, x * y);
BINARY_OP(uint8_t, bmul_u8, x * y);
BINARY_OP(uint32_t, bmul_u32, x * y); BINARY_OP(uint32_t, bmul_u32, x * y);
BINARY_OP(float, bsub_f32, x - y) BINARY_OP(float, bsub_f32, x - y)
BINARY_OP(double, bsub_f64, x - y); BINARY_OP(double, bsub_f64, x - y);
BINARY_OP(uint8_t, bsub_u8, x - y);
BINARY_OP(uint32_t, bsub_u32, x - y); BINARY_OP(uint32_t, bsub_u32, x - y);

View File

@ -27,10 +27,12 @@ extern "C" __global__ void FN_NAME( \
#if __CUDA_ARCH__ >= 800 #if __CUDA_ARCH__ >= 800
CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16) CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16)
CAST_OP(__nv_bfloat16, uint8_t, cast_bf16_u8)
CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32) CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32)
// CAST_OP(__nv_bfloat16, __half, cast_bf16_f16) // CAST_OP(__nv_bfloat16, __half, cast_bf16_f16)
CAST_OP(__nv_bfloat16, float, cast_bf16_f32) CAST_OP(__nv_bfloat16, float, cast_bf16_f32)
CAST_OP(__nv_bfloat16, double, cast_bf16_f64) CAST_OP(__nv_bfloat16, double, cast_bf16_f64)
CAST_OP(uint8_t, __nv_bfloat16, cast_u8_bf16)
CAST_OP(uint32_t, __nv_bfloat16, cast_u32_bf16) CAST_OP(uint32_t, __nv_bfloat16, cast_u32_bf16)
// CAST_OP(__half, __nv_bfloat16, cast_f16_bf16) // CAST_OP(__half, __nv_bfloat16, cast_f16_bf16)
CAST_OP(float, __nv_bfloat16, cast_f32_bf16) CAST_OP(float, __nv_bfloat16, cast_f32_bf16)
@ -40,22 +42,32 @@ CAST_OP(double, __nv_bfloat16, cast_f64_bf16)
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
CAST_OP(__half, __half, cast_f16_f16) CAST_OP(__half, __half, cast_f16_f16)
// CAST_OP(__half, uint8_t, cast_f16_u8 )
CAST_OP(__half, uint32_t, cast_f16_u32) CAST_OP(__half, uint32_t, cast_f16_u32)
CAST_OP(__half, float, cast_f16_f32) CAST_OP(__half, float, cast_f16_f32)
CAST_OP(__half, double, cast_f16_f64) CAST_OP(__half, double, cast_f16_f64)
CAST_OP(uint8_t, __half, cast_u8_f16 )
CAST_OP(uint32_t, __half, cast_u32_f16) CAST_OP(uint32_t, __half, cast_u32_f16)
CAST_OP(float, __half, cast_f32_f16) CAST_OP(float, __half, cast_f32_f16)
CAST_OP(double, __half, cast_f64_f16) CAST_OP(double, __half, cast_f64_f16)
#endif #endif
CAST_OP(uint32_t, uint32_t, cast_u32_u32) CAST_OP(uint32_t, uint32_t, cast_u32_u32)
CAST_OP(uint32_t, uint8_t, cast_u32_u8 )
CAST_OP(uint32_t, float, cast_u32_f32) CAST_OP(uint32_t, float, cast_u32_f32)
CAST_OP(uint32_t, double, cast_u32_f64) CAST_OP(uint32_t, double, cast_u32_f64)
CAST_OP(uint8_t, uint32_t, cast_u8_u32)
CAST_OP(uint8_t, uint8_t, cast_u8_u8 )
CAST_OP(uint8_t, float, cast_u8_f32)
CAST_OP(uint8_t, double, cast_u8_f64)
CAST_OP(float, uint8_t, cast_f32_u8 )
CAST_OP(float, uint32_t, cast_f32_u32) CAST_OP(float, uint32_t, cast_f32_u32)
CAST_OP(float, float, cast_f32_f32) CAST_OP(float, float, cast_f32_f32)
CAST_OP(float, double, cast_f32_f64) CAST_OP(float, double, cast_f32_f64)
CAST_OP(double, uint8_t, cast_f64_u8 )
CAST_OP(double, uint32_t, cast_f64_u32) CAST_OP(double, uint32_t, cast_f64_u32)
CAST_OP(double, float, cast_f64_f32) CAST_OP(double, float, cast_f64_f32)
CAST_OP(double, double, cast_f64_f64) CAST_OP(double, double, cast_f64_f64)

View File

@ -39,4 +39,5 @@ EMB_OP(__half, emb_f16)
EMB_OP(float, emb_f32) EMB_OP(float, emb_f32)
EMB_OP(double, emb_f64) EMB_OP(double, emb_f64)
EMB_OP(uint8_t, emb_u8)
EMB_OP(uint32_t, emb_u32) EMB_OP(uint32_t, emb_u32)

View File

@ -42,4 +42,5 @@ WHERE_OP(__half, where_f16)
WHERE_OP(float, where_f32) WHERE_OP(float, where_f32)
WHERE_OP(double, where_f64) WHERE_OP(double, where_f64)
WHERE_OP(uint8_t, where_u8)
WHERE_OP(uint32_t, where_u32) WHERE_OP(uint32_t, where_u32)