From 22da2c7e02d8b0364b3b142dfdf3114781c20590 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 20:52:01 +0100 Subject: [PATCH] More f16 and bf16 support. --- Cargo.toml | 4 +- src/cpu_backend.rs | 244 ++++++++++++++++++++++++++++++++++---------- src/cuda_backend.rs | 201 ++++++++++++++++++++++++++++++++++++ 3 files changed, 392 insertions(+), 57 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9ccd9381..de6b80f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ candle-kernels = { path = "kernels", optional = true } gemm = "0.15.4" zip = { version = "0.6.6", default-features=false } byteorder = "1.4.3" -half = "2.3.1" +half = { version = "2.3.1", features = ["num-traits"] } num-traits = "0.2.15" [dev-dependencies] @@ -33,5 +33,5 @@ rand = "0.8.5" tokenizers = { version = "0.13.3", default-features=false, features=["onig"] } [features] -default = [] +default = ["cuda"] cuda = ["dep:cudarc", "dep:candle-kernels"] diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 0ed665a8..03e1e785 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -46,10 +46,37 @@ fn wcond( } } +macro_rules! map1 { + ($v: expr, $fn: ident, $( $args:expr ),*) => {{ + let v = match $v { + CpuStorage::BF16(__s) => CpuStorage::BF16($fn::(__s, $($args),*)?), + CpuStorage::F16(__s) => CpuStorage::F16($fn::(__s, $($args),*)?), + CpuStorage::F32(__s) => CpuStorage::F32($fn::(__s, $($args),*)?), + CpuStorage::F64(__s) => CpuStorage::F64($fn::(__s, $($args),*)?), + CpuStorage::U32(__s) => CpuStorage::U32($fn::(__s, $($args),*)?), + }; + Ok(v) + }}; +} + +fn sum_impl1( + src: &[T], + dst_shape: &Shape, + src_dims: &[usize], + stride: &[usize], + to_dst_index: impl Fn(usize) -> usize, +) -> Result> { + let mut dst = vec![T::zero(); dst_shape.elem_count()]; + for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() { + dst[to_dst_index(unstr_index)] += src[src_index]; + } + Ok(dst) +} + fn unary_map U>( + vs: &[T], shape: &Shape, stride: &[usize], - vs: &[T], mut f: F, ) -> Vec { if shape.is_contiguous(stride) { @@ -83,11 +110,11 @@ fn binary_map T>( } } -fn take( +fn take_impl1( + vs: &[T], ids: &[u32], shape: &Shape, stride: &[usize], - vs: &[T], vocab_size: usize, hidden_size: usize, ) -> Result> { @@ -153,40 +180,104 @@ impl CpuStorage { pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result { // TODO: find a way around the quadratic number of cases below. match (self, dtype) { + (Self::U32(storage), DType::BF16) => { + let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } + (Self::BF16(storage), DType::BF16) => { + let data = unary_map(storage, shape, stride, |v| v); + Ok(Self::BF16(data)) + } + (Self::F16(storage), DType::BF16) => { + let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v.to_f32())); + Ok(Self::BF16(data)) + } + (Self::F32(storage), DType::BF16) => { + let data = unary_map(storage, shape, stride, bf16::from_f32); + Ok(Self::BF16(data)) + } + (Self::F64(storage), DType::BF16) => { + let data = unary_map(storage, shape, stride, bf16::from_f64); + Ok(Self::BF16(data)) + } + (Self::U32(storage), DType::F16) => { + let data = unary_map(storage, shape, stride, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } + (Self::BF16(storage), DType::F16) => { + let data = unary_map(storage, shape, stride, |v| f16::from_f32(v.to_f32())); + Ok(Self::F16(data)) + } + (Self::F16(storage), DType::F16) => { + let data = unary_map(storage, shape, stride, |v| v); + Ok(Self::F16(data)) + } + (Self::F32(storage), DType::F16) => { + let data = unary_map(storage, shape, stride, f16::from_f32); + Ok(Self::F16(data)) + } + (Self::F64(storage), DType::F16) => { + let data = unary_map(storage, shape, stride, f16::from_f64); + Ok(Self::F16(data)) + } (Self::U32(storage), DType::F32) => { - let data = unary_map(shape, stride, storage, |v| v as f32); + let data = unary_map(storage, shape, stride, |v| v as f32); + Ok(Self::F32(data)) + } + (Self::BF16(storage), DType::F32) => { + let data = unary_map(storage, shape, stride, |v| v.to_f32()); + Ok(Self::F32(data)) + } + (Self::F16(storage), DType::F32) => { + let data = unary_map(storage, shape, stride, |v| v.to_f32()); Ok(Self::F32(data)) } (Self::F32(storage), DType::F32) => { - let data = unary_map(shape, stride, storage, |v| v); + let data = unary_map(storage, shape, stride, |v| v); Ok(Self::F32(data)) } (Self::F64(storage), DType::F32) => { - let data = unary_map(shape, stride, storage, |v| v as f32); + let data = unary_map(storage, shape, stride, |v| v as f32); Ok(Self::F32(data)) } (Self::U32(storage), DType::U32) => { - let data = unary_map(shape, stride, storage, |v| v); + let data = unary_map(storage, shape, stride, |v| v); + Ok(Self::U32(data)) + } + (Self::BF16(storage), DType::U32) => { + let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32); + Ok(Self::U32(data)) + } + (Self::F16(storage), DType::U32) => { + let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32); Ok(Self::U32(data)) } (Self::F32(storage), DType::U32) => { - let data = unary_map(shape, stride, storage, |v| v as u32); + let data = unary_map(storage, shape, stride, |v| v as u32); Ok(Self::U32(data)) } (Self::F64(storage), DType::U32) => { - let data = unary_map(shape, stride, storage, |v| v as u32); + let data = unary_map(storage, shape, stride, |v| v as u32); Ok(Self::U32(data)) } (Self::U32(storage), DType::F64) => { - let data = unary_map(shape, stride, storage, |v| v as f64); + let data = unary_map(storage, shape, stride, |v| v as f64); + Ok(Self::F64(data)) + } + (Self::BF16(storage), DType::F64) => { + let data = unary_map(storage, shape, stride, |v| v.to_f64()); + Ok(Self::F64(data)) + } + (Self::F16(storage), DType::F64) => { + let data = unary_map(storage, shape, stride, |v| v.to_f64()); Ok(Self::F64(data)) } (Self::F32(storage), DType::F64) => { - let data = unary_map(shape, stride, storage, |v| v as f64); + let data = unary_map(storage, shape, stride, |v| v as f64); Ok(Self::F64(data)) } (Self::F64(storage), DType::F64) => { - let data = unary_map(shape, stride, storage, |v| v); + let data = unary_map(storage, shape, stride, |v| v); Ok(Self::F64(data)) } } @@ -219,29 +310,7 @@ impl CpuStorage { dst_index }; // TODO: Maybe provide an implementation with higher precision accumulators? - match self { - Self::F32(src) => { - let mut dst = vec![0f32; dst_shape.elem_count()]; - for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() { - dst[to_dst_index(unstr_index)] += src[src_index]; - } - Ok(Self::F32(dst)) - } - Self::F64(src) => { - let mut dst = vec![0f64; dst_shape.elem_count()]; - for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() { - dst[to_dst_index(unstr_index)] += src[src_index]; - } - Ok(Self::F64(dst)) - } - Self::U32(src) => { - let mut dst = vec![0u32; dst_shape.elem_count()]; - for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() { - dst[to_dst_index(unstr_index)] += src[src_index]; - } - Ok(Self::U32(dst)) - } - } + map1!(self, sum_impl1, &dst_shape, src_dims, stride, to_dst_index) } pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { @@ -251,6 +320,42 @@ impl CpuStorage { let prod_pre_dim = dims[..dim].iter().product(); let prod_post_dim = dims[dim + 1..].iter().product(); match self { + Self::BF16(storage) => { + for pre_idx in 0..prod_pre_dim { + for post_idx in 0..prod_post_dim { + let mut sum = 0f64; + let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; + for _ in 0..elem_per_slice { + sum += storage[idx].to_f64(); + idx += prod_post_dim + } + let sum = bf16::from_f64(sum); + let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; + for _ in 0..elem_per_slice { + storage[idx] /= sum; + idx += prod_post_dim + } + } + } + } + Self::F16(storage) => { + for pre_idx in 0..prod_pre_dim { + for post_idx in 0..prod_post_dim { + let mut sum = 0f64; + let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; + for _ in 0..elem_per_slice { + sum += storage[idx].to_f64(); + idx += prod_post_dim + } + let sum = f16::from_f64(sum); + let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; + for _ in 0..elem_per_slice { + storage[idx] /= sum; + idx += prod_post_dim + } + } + } + } Self::F32(storage) => { for pre_idx in 0..prod_pre_dim { for post_idx in 0..prod_post_dim { @@ -302,17 +407,29 @@ impl CpuStorage { Self::U32(storage) => { let mul = mul as u32; let add = add as u32; - let data = unary_map(shape, stride, storage, |v| v * mul + add); + let data = unary_map(storage, shape, stride, |v| v * mul + add); Ok(Self::U32(data)) } + Self::BF16(storage) => { + let mul = bf16::from_f64(mul); + let add = bf16::from_f64(add); + let data = unary_map(storage, shape, stride, |v| v * mul + add); + Ok(Self::BF16(data)) + } + Self::F16(storage) => { + let mul = f16::from_f64(mul); + let add = f16::from_f64(add); + let data = unary_map(storage, shape, stride, |v| v * mul + add); + Ok(Self::F16(data)) + } Self::F32(storage) => { let mul = mul as f32; let add = add as f32; - let data = unary_map(shape, stride, storage, |v| v * mul + add); + let data = unary_map(storage, shape, stride, |v| v * mul + add); Ok(Self::F32(data)) } Self::F64(storage) => { - let data = unary_map(shape, stride, storage, |v| v * mul + add); + let data = unary_map(storage, shape, stride, |v| v * mul + add); Ok(Self::F64(data)) } } @@ -320,16 +437,24 @@ impl CpuStorage { pub(crate) fn unary_impl(&self, shape: &Shape, stride: &[usize]) -> Result { match self { + Self::BF16(storage) => { + let data = unary_map(storage, shape, stride, B::bf16); + Ok(Self::BF16(data)) + } + Self::F16(storage) => { + let data = unary_map(storage, shape, stride, B::f16); + Ok(Self::F16(data)) + } Self::F32(storage) => { - let data = unary_map(shape, stride, storage, B::f32); + let data = unary_map(storage, shape, stride, B::f32); Ok(Self::F32(data)) } Self::F64(storage) => { - let data = unary_map(shape, stride, storage, B::f64); + let data = unary_map(storage, shape, stride, B::f64); Ok(Self::F64(data)) } Self::U32(storage) => { - let data = unary_map(shape, stride, storage, B::u32); + let data = unary_map(storage, shape, stride, B::u32); Ok(Self::U32(data)) } } @@ -343,6 +468,14 @@ impl CpuStorage { rhs_stride: &[usize], ) -> Result { match (self, rhs) { + (Self::BF16(lhs), Self::BF16(rhs)) => { + let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::bf16); + Ok(Self::BF16(data)) + } + (Self::F16(lhs), Self::F16(rhs)) => { + let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f16); + Ok(Self::F16(data)) + } (Self::F32(lhs), Self::F32(rhs)) => { let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32); Ok(Self::F32(data)) @@ -381,6 +514,12 @@ impl CpuStorage { (Self::U32(src), Self::U32(dst)) => { copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) } + (Self::BF16(src), Self::BF16(dst)) => { + copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) + } + (Self::F16(src), Self::F16(dst)) => { + copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) + } (Self::F32(src), Self::F32(dst)) => { copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) } @@ -411,6 +550,14 @@ impl CpuStorage { // TODO: Support types that could be casted to a boolean. let pred = self.as_slice::()?; match (t, f) { + (Self::BF16(t), Self::BF16(f)) => { + let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + Ok(Self::BF16(data)) + } + (Self::F16(t), Self::F16(f)) => { + let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + Ok(Self::F16(data)) + } (Self::F32(t), Self::F32(f)) => { let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); Ok(Self::F32(data)) @@ -440,20 +587,7 @@ impl CpuStorage { vocab_size: usize, ) -> Result { let ids = self.as_slice::()?; - match vs { - CpuStorage::F32(vs) => { - let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?; - Ok(CpuStorage::F32(storage)) - } - CpuStorage::F64(vs) => { - let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?; - Ok(CpuStorage::F64(storage)) - } - CpuStorage::U32(vs) => { - let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?; - Ok(CpuStorage::U32(storage)) - } - } + map1!(vs, take_impl1, ids, shape, stride, vocab_size, hidden_size) } pub(crate) fn matmul_impl( diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 8fadccdd..90cb0f72 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -133,6 +133,22 @@ impl CudaDevice { unsafe { func.launch(cfg, params) }?; CudaStorageSlice::U32(data) } + DType::BF16 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }?; + let func = self.get_or_load_func("fill_bf16", kernels::FILL)?; + let params = (&data, bf16::from_f64(v), elem_count); + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::BF16(data) + } + DType::F16 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }?; + let func = self.get_or_load_func("fill_f16", kernels::FILL)?; + let params = (&data, f16::from_f64(v), elem_count); + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F16(data) + } DType::F32 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }?; @@ -166,6 +182,14 @@ impl CudaDevice { let data = self.htod_sync_copy(storage)?; CudaStorageSlice::U32(data) } + CpuStorage::BF16(storage) => { + let data = self.htod_sync_copy(storage)?; + CudaStorageSlice::BF16(data) + } + CpuStorage::F16(storage) => { + let data = self.htod_sync_copy(storage)?; + CudaStorageSlice::F16(data) + } CpuStorage::F32(storage) => { let data = self.htod_sync_copy(storage)?; CudaStorageSlice::F32(data) @@ -325,6 +349,40 @@ impl CudaStorage { unsafe { func.launch(cfg, params) }?; CudaStorageSlice::U32(out) } + CudaStorageSlice::BF16(arg) => { + let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el_count) }?; + let params = ( + el_count, + dims.len(), + &ds, + arg, + &out, + bf16::from_f64(mul), + bf16::from_f64(add), + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::BF16(out) + } + CudaStorageSlice::F16(arg) => { + let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el_count) }?; + let params = ( + el_count, + dims.len(), + &ds, + arg, + &out, + f16::from_f64(mul), + f16::from_f64(add), + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F16(out) + } CudaStorageSlice::F32(arg) => { let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. @@ -376,6 +434,22 @@ impl CudaStorage { unsafe { func.launch(cfg, params) }?; CudaStorageSlice::U32(out) } + CudaStorageSlice::BF16(arg) => { + let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?; + let out = dev.alloc_zeros::(dst_el)?; + let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::BF16(out) + } + CudaStorageSlice::F16(arg) => { + let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?; + let out = dev.alloc_zeros::(dst_el)?; + let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F16(out) + } CudaStorageSlice::F32(arg) => { let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?; let out = dev.alloc_zeros::(dst_el)?; @@ -417,6 +491,24 @@ impl CudaStorage { CudaStorageSlice::U32(_arg) => { todo!("No unary kernels for u32"); } + CudaStorageSlice::BF16(arg) => { + let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el_count) }?; + let params = (el_count, dims.len(), &ds, arg, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::BF16(out) + } + CudaStorageSlice::F16(arg) => { + let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el_count) }?; + let params = (el_count, dims.len(), &ds, arg, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F16(out) + } CudaStorageSlice::F32(arg) => { let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?; // SAFETY: Set later by running the kernel. @@ -453,6 +545,24 @@ impl CudaStorage { let dev = self.device(); let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?; let slice = match (&self.slice, &rhs.slice) { + (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => { + let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(elem_count) }?; + let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::BF16(out) + } + (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { + let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(elem_count) }?; + let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F16(out) + } (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?; // SAFETY: Set later by running the kernel. @@ -494,6 +604,16 @@ impl CudaStorage { let cpu_storage = dev.dtoh_sync_copy(slice)?; Ok(CpuStorage::U32(cpu_storage)) } + CudaStorageSlice::BF16(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice)?; + Ok(CpuStorage::BF16(cpu_storage)) + } + CudaStorageSlice::F16(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice)?; + Ok(CpuStorage::F16(cpu_storage)) + } CudaStorageSlice::F32(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice)?; @@ -530,6 +650,24 @@ impl CudaStorage { let dev = self.device(); let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?; let slice = match (&t.slice, &f.slice) { + (CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => { + let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, ids, t, f, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::BF16(out) + } + (CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => { + let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, ids, t, f, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F16(out) + } (CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => { let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?; // SAFETY: Set later by running the kernel. @@ -596,6 +734,24 @@ impl CudaStorage { unsafe { func.launch(cfg, params) }?; CudaStorageSlice::U32(out) } + CudaStorageSlice::BF16(arg) => { + let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el * h_size) }?; + let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::BF16(out) + } + CudaStorageSlice::F16(arg) => { + let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el * h_size) }?; + let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F16(out) + } CudaStorageSlice::F32(arg) => { let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. @@ -629,6 +785,12 @@ impl CudaStorage { let elem_count = b * m * n; let dev = &self.device; let slice = match (&self.slice, &rhs.slice) { + (CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => { + todo!("bf16") + } + (CudaStorageSlice::F16(_lhs), CudaStorageSlice::F16(_rhs)) => { + todo!("f16") + } (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?; let mut out = unsafe { dev.alloc::(elem_count) }?; @@ -672,6 +834,32 @@ impl CudaStorage { let dev = &self.device; let ds = dev.htod_copy([dims, src_stride].concat())?; match (&self.slice, &mut dst.slice) { + (CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => { + let src = src.slice(src_offset..); + let mut dst = dst.slice_mut(dst_offset..); + if src_shape.is_contiguous(src_stride) { + dev.dtod_copy(&src, &mut dst)? + } else { + let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }? + } + } + (CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => { + let src = src.slice(src_offset..); + let mut dst = dst.slice_mut(dst_offset..); + if src_shape.is_contiguous(src_stride) { + dev.dtod_copy(&src, &mut dst)? + } else { + let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }? + } + } (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { let src = src.slice(src_offset..); let mut dst = dst.slice_mut(dst_offset..); @@ -685,6 +873,19 @@ impl CudaStorage { unsafe { func.launch(cfg, params) }? } } + (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { + let src = src.slice(src_offset..); + let mut dst = dst.slice_mut(dst_offset..); + if src_shape.is_contiguous(src_stride) { + dev.dtod_copy(&src, &mut dst)? + } else { + let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }? + } + } (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { let src = src.slice(src_offset..); let mut dst = dst.slice_mut(dst_offset..);