From a31411fd91256d50e9f3384a752641865f6a614d Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 19:37:47 +0100 Subject: [PATCH 01/12] Start adding f16/bf16 support. --- Cargo.toml | 4 +- src/cpu_backend.rs | 21 +++ src/cuda_backend.rs | 15 +++ src/dtype.rs | 6 + src/npy.rs | 16 ++- src/op.rs | 310 ++++++++++++++++---------------------------- 6 files changed, 169 insertions(+), 203 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f95a5044..9ccd9381 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,11 +18,13 @@ members = [ [dependencies] safetensors = "0.3.1" thiserror = "1" -cudarc = { version = "0.9.9", optional = true } +cudarc = { version = "0.9.9", optional = true, features = ["f16"] } candle-kernels = { path = "kernels", optional = true } gemm = "0.15.4" zip = { version = "0.6.6", default-features=false } byteorder = "1.4.3" +half = "2.3.1" +num-traits = "0.2.15" [dev-dependencies] anyhow = "1" diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 61ffcb28..0ed665a8 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -1,6 +1,7 @@ use crate::op::{BinaryOp, UnaryOp}; use crate::{DType, Error, Result, Shape, StridedIndex}; use gemm::{gemm, Parallelism}; +use half::{bf16, f16}; // TODO: Think about whether we would be better off with a dtype and // a buffer as an owned slice of bytes. @@ -9,6 +10,8 @@ use gemm::{gemm, Parallelism}; #[derive(Debug, Clone)] pub enum CpuStorage { U32(Vec), + BF16(Vec), + F16(Vec), F32(Vec), F64(Vec), } @@ -132,6 +135,8 @@ impl CpuStorage { pub fn dtype(&self) -> DType { match self { Self::U32(_) => DType::U32, + Self::BF16(_) => DType::BF16, + Self::F16(_) => DType::F16, Self::F32(_) => DType::F32, Self::F64(_) => DType::F64, } @@ -545,6 +550,14 @@ impl CpuStorage { let data = vec![1u32; elem_count]; Self::U32(data) } + DType::BF16 => { + let data = vec![bf16::ONE; elem_count]; + Self::BF16(data) + } + DType::F16 => { + let data = vec![f16::ONE; elem_count]; + Self::F16(data) + } DType::F32 => { let data = vec![1f32; elem_count]; Self::F32(data) @@ -563,6 +576,14 @@ impl CpuStorage { let data = vec![0u32; elem_count]; Self::U32(data) } + DType::BF16 => { + let data = vec![bf16::ZERO; elem_count]; + Self::BF16(data) + } + DType::F16 => { + let data = vec![f16::ZERO; elem_count]; + Self::F16(data) + } DType::F32 => { let data = vec![0f32; elem_count]; Self::F32(data) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 56fa1684..8fadccdd 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -2,6 +2,7 @@ use crate::{CpuStorage, DType, Shape}; use candle_kernels as kernels; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig}; +use half::{bf16, f16}; use std::sync::Arc; /// cudarc related errors @@ -97,6 +98,14 @@ impl CudaDevice { let data = self.alloc_zeros::(elem_count)?; CudaStorageSlice::U32(data) } + DType::BF16 => { + let data = self.alloc_zeros::(elem_count)?; + CudaStorageSlice::BF16(data) + } + DType::F16 => { + let data = self.alloc_zeros::(elem_count)?; + CudaStorageSlice::F16(data) + } DType::F32 => { let data = self.alloc_zeros::(elem_count)?; CudaStorageSlice::F32(data) @@ -190,6 +199,8 @@ impl CudaDevice { #[derive(Debug)] enum CudaStorageSlice { U32(CudaSlice), + BF16(CudaSlice), + F16(CudaSlice), F32(CudaSlice), F64(CudaSlice), } @@ -265,6 +276,8 @@ impl CudaStorage { pub fn try_clone(&self) -> Result { let slice = match &self.slice { CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?), + CudaStorageSlice::BF16(slice) => CudaStorageSlice::BF16(slice.try_clone()?), + CudaStorageSlice::F16(slice) => CudaStorageSlice::F16(slice.try_clone()?), CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?), CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?), }; @@ -275,6 +288,8 @@ impl CudaStorage { pub fn dtype(&self) -> DType { match self.slice { CudaStorageSlice::U32(_) => DType::U32, + CudaStorageSlice::BF16(_) => DType::BF16, + CudaStorageSlice::F16(_) => DType::F16, CudaStorageSlice::F32(_) => DType::F32, CudaStorageSlice::F64(_) => DType::F64, } diff --git a/src/dtype.rs b/src/dtype.rs index 471f415c..53b61ce0 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -3,6 +3,8 @@ use crate::{CpuStorage, Error, Result}; #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum DType { U32, + BF16, + F16, F32, F64, } @@ -11,6 +13,8 @@ impl DType { pub fn size_in_bytes(&self) -> usize { match self { Self::U32 => 4, + Self::BF16 => 2, + Self::F16 => 2, Self::F32 => 4, Self::F64 => 8, } @@ -76,5 +80,7 @@ macro_rules! with_dtype { }; } with_dtype!(u32, U32); +with_dtype!(half::f16, F16); +with_dtype!(half::bf16, BF16); with_dtype!(f32, F32); with_dtype!(f64, F64); diff --git a/src/npy.rs b/src/npy.rs index 43a6cb1c..3eb4e7c1 100644 --- a/src/npy.rs +++ b/src/npy.rs @@ -80,6 +80,8 @@ impl Header { .collect::>() .join(","); let descr = match self.descr { + DType::BF16 => todo!("bf16"), + DType::F16 => "f2", DType::F32 => "f4", DType::F64 => "f8", DType::U32 => "u4", @@ -152,7 +154,7 @@ impl Header { // int64, int32, int16, int8, // uint8, and bool. match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') { - // "e" | "f2" => DType::F16, + "e" | "f2" => DType::F16, "f" | "f4" => DType::F32, "d" | "f8" => DType::F64, // "i" | "i4" => DType::S32, @@ -194,6 +196,12 @@ impl Tensor { fn from_reader(shape: Shape, dtype: DType, reader: &mut R) -> Result { let elem_count = shape.elem_count(); match dtype { + DType::BF16 => { + todo!("bf16") + } + DType::F16 => { + todo!("f16") + } DType::F32 => { let mut data_t = vec![0f32; elem_count]; reader.read_f32_into::(&mut data_t)?; @@ -289,6 +297,12 @@ impl Tensor { f.write_all(header.as_bytes())?; let elem_count = self.elem_count(); match self.dtype() { + DType::BF16 => { + todo!("bf16") + } + DType::F16 => { + todo!("f16") + } DType::F32 => { // TODO: Avoid using a buffer when data is already on the CPU. for v in self.reshape(elem_count)?.to_vec1::()? { diff --git a/src/op.rs b/src/op.rs index 82068dfd..5eab6b1f 100644 --- a/src/op.rs +++ b/src/op.rs @@ -1,4 +1,6 @@ use crate::Tensor; +use half::{bf16, f16}; +use num_traits::float::Float; #[derive(Clone)] pub(crate) enum Op { @@ -40,10 +42,13 @@ pub(crate) enum Op { pub(crate) trait UnaryOp { const NAME: &'static str; - // TODO: These kernels are compatible with arbitrary strides. We should also consider the - // contiguous case separately as it's easy to optimize things out there. + const KERNEL_BF16: &'static str; + const KERNEL_F16: &'static str; const KERNEL_F32: &'static str; const KERNEL_F64: &'static str; + const KERNEL_U32: &'static str; + fn bf16(v1: bf16) -> bf16; + fn f16(v1: f16) -> f16; fn f32(v1: f32) -> f32; fn f64(v1: f64) -> f64; fn u32(v1: u32) -> u32; @@ -51,11 +56,13 @@ pub(crate) trait UnaryOp { pub(crate) trait BinaryOp { const NAME: &'static str; - // TODO: These kernels are compatible with arbitrary strides. We should also consider the - // contiguous case separately as it's easy to optimize things out there. + const KERNEL_BF16: &'static str; + const KERNEL_F16: &'static str; const KERNEL_F32: &'static str; const KERNEL_F64: &'static str; const KERNEL_U32: &'static str; + fn bf16(v1: bf16, v2: bf16) -> bf16; + fn f16(v1: f16, v2: f16) -> f16; fn f32(v1: f32, v2: f32) -> f32; fn f64(v1: f64, v2: f64) -> f64; fn u32(v1: u32, v2: u32) -> u32; @@ -75,215 +82,116 @@ pub(crate) struct Sqr; pub(crate) struct Sqrt; pub(crate) struct Gelu; -impl BinaryOp for Add { - const NAME: &'static str = "add"; - const KERNEL_F32: &'static str = "badd_f32"; - const KERNEL_F64: &'static str = "badd_f64"; - const KERNEL_U32: &'static str = "badd_u32"; - fn f32(v1: f32, v2: f32) -> f32 { - v1 + v2 - } - fn f64(v1: f64, v2: f64) -> f64 { - v1 + v2 - } - fn u32(v1: u32, v2: u32) -> u32 { - v1 + v2 - } +macro_rules! bin_op { + ($op:ident, $name: literal, $e: expr) => { + impl BinaryOp for $op { + const NAME: &'static str = $name; + const KERNEL_BF16: &'static str = concat!("b", $name, "_bf16"); + const KERNEL_F16: &'static str = concat!("b", $name, "_f16"); + const KERNEL_F32: &'static str = concat!("b", $name, "_f32"); + const KERNEL_F64: &'static str = concat!("b", $name, "_f64"); + const KERNEL_U32: &'static str = concat!("b", $name, "_u32"); + fn bf16(v1: bf16, v2: bf16) -> bf16 { + $e(v1, v2) + } + fn f16(v1: f16, v2: f16) -> f16 { + $e(v1, v2) + } + fn f32(v1: f32, v2: f32) -> f32 { + $e(v1, v2) + } + fn f64(v1: f64, v2: f64) -> f64 { + $e(v1, v2) + } + fn u32(v1: u32, v2: u32) -> u32 { + $e(v1, v2) + } + } + }; } -impl BinaryOp for Sub { - const NAME: &'static str = "sub"; - const KERNEL_F32: &'static str = "bsub_f32"; - const KERNEL_F64: &'static str = "bsub_f64"; - const KERNEL_U32: &'static str = "bsub_u32"; - fn f32(v1: f32, v2: f32) -> f32 { - v1 - v2 - } - fn f64(v1: f64, v2: f64) -> f64 { - v1 - v2 - } - fn u32(v1: u32, v2: u32) -> u32 { - v1 - v2 - } +bin_op!(Add, "add", |v1, v2| v1 + v2); +bin_op!(Sub, "sub", |v1, v2| v1 - v2); +bin_op!(Mul, "mul", |v1, v2| v1 * v2); +bin_op!(Div, "div", |v1, v2| v1 / v2); + +macro_rules! unary_op { + ($op: ident, $name: literal, $a: ident, $e: expr) => { + impl UnaryOp for $op { + const NAME: &'static str = $name; + const KERNEL_BF16: &'static str = concat!("u", $name, "_bf16"); + const KERNEL_F16: &'static str = concat!("u", $name, "_f16"); + const KERNEL_F32: &'static str = concat!("u", $name, "_f32"); + const KERNEL_F64: &'static str = concat!("u", $name, "_f64"); + const KERNEL_U32: &'static str = concat!("u", $name, "_u32"); + fn bf16($a: bf16) -> bf16 { + $e + } + fn f16($a: f16) -> f16 { + $e + } + fn f32($a: f32) -> f32 { + $e + } + fn f64($a: f64) -> f64 { + $e + } + fn u32(_: u32) -> u32 { + todo!("no unary function for u32") + } + } + }; } -impl BinaryOp for Mul { - const NAME: &'static str = "mul"; - const KERNEL_F32: &'static str = "bmul_f32"; - const KERNEL_F64: &'static str = "bmul_f64"; - const KERNEL_U32: &'static str = "bmul_u32"; - fn f32(v1: f32, v2: f32) -> f32 { - v1 * v2 - } - fn f64(v1: f64, v2: f64) -> f64 { - v1 * v2 - } - fn u32(v1: u32, v2: u32) -> u32 { - v1 * v2 - } -} - -impl BinaryOp for Div { - const NAME: &'static str = "div"; - const KERNEL_F32: &'static str = "bdiv_f32"; - const KERNEL_F64: &'static str = "bdiv_f64"; - const KERNEL_U32: &'static str = "bdiv_u32"; - fn f32(v1: f32, v2: f32) -> f32 { - v1 / v2 - } - fn f64(v1: f64, v2: f64) -> f64 { - v1 / v2 - } - fn u32(v1: u32, v2: u32) -> u32 { - v1 / v2 - } -} - -impl UnaryOp for Exp { - const NAME: &'static str = "exp"; - fn f32(v1: f32) -> f32 { - v1.exp() - } - fn f64(v1: f64) -> f64 { - v1.exp() - } - fn u32(v1: u32) -> u32 { - (v1 as f64).exp() as u32 - } - const KERNEL_F32: &'static str = "uexp_f32"; - const KERNEL_F64: &'static str = "uexp_f64"; -} - -impl UnaryOp for Log { - const NAME: &'static str = "log"; - fn f32(v1: f32) -> f32 { - v1.ln() - } - fn f64(v1: f64) -> f64 { - v1.ln() - } - fn u32(v1: u32) -> u32 { - (v1 as f64).ln() as u32 - } - const KERNEL_F32: &'static str = "ulog_f32"; - const KERNEL_F64: &'static str = "ulog_f64"; -} - -impl UnaryOp for Sin { - const NAME: &'static str = "sin"; - fn f32(v1: f32) -> f32 { - v1.sin() - } - fn f64(v1: f64) -> f64 { - v1.sin() - } - fn u32(_: u32) -> u32 { - 0 - } - const KERNEL_F32: &'static str = "usin_f32"; - const KERNEL_F64: &'static str = "usin_f64"; -} - -impl UnaryOp for Cos { - const NAME: &'static str = "cos"; - fn f32(v1: f32) -> f32 { - v1.cos() - } - fn f64(v1: f64) -> f64 { - v1.cos() - } - fn u32(_: u32) -> u32 { - 0 - } - const KERNEL_F32: &'static str = "ucos_f32"; - const KERNEL_F64: &'static str = "ucos_f64"; -} - -impl UnaryOp for Abs { - const NAME: &'static str = "abs"; - fn f32(v1: f32) -> f32 { - v1.abs() - } - fn f64(v1: f64) -> f64 { - v1.abs() - } - fn u32(v1: u32) -> u32 { - v1 - } - const KERNEL_F32: &'static str = "uabs_f32"; - const KERNEL_F64: &'static str = "uabs_f64"; -} - -impl UnaryOp for Neg { - const NAME: &'static str = "neg"; - fn f32(v1: f32) -> f32 { - -v1 - } - fn f64(v1: f64) -> f64 { - -v1 - } - fn u32(_: u32) -> u32 { - 0 - } - const KERNEL_F32: &'static str = "uneg_f32"; - const KERNEL_F64: &'static str = "uneg_f64"; -} - -impl UnaryOp for Sqr { - const NAME: &'static str = "sqr"; - fn f32(v1: f32) -> f32 { - v1 * v1 - } - fn f64(v1: f64) -> f64 { - v1 * v1 - } - fn u32(v: u32) -> u32 { - v * v - } - const KERNEL_F32: &'static str = "usqr_f32"; - const KERNEL_F64: &'static str = "usqr_f64"; -} - -impl UnaryOp for Sqrt { - const NAME: &'static str = "sqrt"; - fn f32(v1: f32) -> f32 { - v1.sqrt() - } - fn f64(v1: f64) -> f64 { - v1.sqrt() - } - fn u32(v: u32) -> u32 { - (v as f64).sqrt() as u32 - } - const KERNEL_F32: &'static str = "usqrt_f32"; - const KERNEL_F64: &'static str = "usqrt_f64"; -} +unary_op!(Exp, "exp", v, v.exp()); +unary_op!(Log, "log", v, v.ln()); +unary_op!(Sin, "sin", v, v.sin()); +unary_op!(Cos, "cos", v, v.cos()); +unary_op!(Abs, "abs", v, v.abs()); +unary_op!(Neg, "neg", v, -v); +unary_op!(Sqr, "sqr", v, v * v); +unary_op!(Sqrt, "sqrt", v, v.sqrt()); /// `gelu` operation /// -#[inline] -pub fn gelu_f32(v: f32) -> f32 { - 0.5 * v - * (1.0 + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) -} -/// `gelu` operation -/// -#[inline] -pub fn gelu_f64(v: f64) -> f64 { - 0.5 * v - * (1.0 + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) -} impl UnaryOp for Gelu { const NAME: &'static str = "gelu"; - fn f32(v1: f32) -> f32 { - gelu_f32(v1) + fn bf16(v: bf16) -> bf16 { + bf16::from_f32_const(0.5) + * v + * (bf16::ONE + + bf16::tanh( + (bf16::from_f32_const(2.0) / bf16::PI).sqrt() + * v + * (bf16::ONE + bf16::from_f32_const(0.044715) * v * v), + )) } - fn f64(v1: f64) -> f64 { - gelu_f64(v1) + fn f16(v: f16) -> f16 { + f16::from_f32_const(0.5) + * v + * (f16::ONE + + f16::tanh( + (f16::from_f32_const(2.0) / f16::PI).sqrt() + * v + * (f16::ONE + f16::from_f32_const(0.044715) * v * v), + )) } - fn u32(v1: u32) -> u32 { - gelu_f64(v1 as f64) as u32 + fn f32(v: f32) -> f32 { + 0.5 * v + * (1.0 + + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) } + fn f64(v: f64) -> f64 { + 0.5 * v + * (1.0 + + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) + } + fn u32(_: u32) -> u32 { + 0 + } + const KERNEL_BF16: &'static str = "gelu_bf16"; + const KERNEL_F16: &'static str = "gelu_f16"; const KERNEL_F32: &'static str = "gelu_f32"; const KERNEL_F64: &'static str = "gelu_f64"; + const KERNEL_U32: &'static str = "gelu_u32"; } From 22da2c7e02d8b0364b3b142dfdf3114781c20590 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 20:52:01 +0100 Subject: [PATCH 02/12] More f16 and bf16 support. --- Cargo.toml | 4 +- src/cpu_backend.rs | 244 ++++++++++++++++++++++++++++++++++---------- src/cuda_backend.rs | 201 ++++++++++++++++++++++++++++++++++++ 3 files changed, 392 insertions(+), 57 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9ccd9381..de6b80f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ candle-kernels = { path = "kernels", optional = true } gemm = "0.15.4" zip = { version = "0.6.6", default-features=false } byteorder = "1.4.3" -half = "2.3.1" +half = { version = "2.3.1", features = ["num-traits"] } num-traits = "0.2.15" [dev-dependencies] @@ -33,5 +33,5 @@ rand = "0.8.5" tokenizers = { version = "0.13.3", default-features=false, features=["onig"] } [features] -default = [] +default = ["cuda"] cuda = ["dep:cudarc", "dep:candle-kernels"] diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 0ed665a8..03e1e785 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -46,10 +46,37 @@ fn wcond( } } +macro_rules! map1 { + ($v: expr, $fn: ident, $( $args:expr ),*) => {{ + let v = match $v { + CpuStorage::BF16(__s) => CpuStorage::BF16($fn::(__s, $($args),*)?), + CpuStorage::F16(__s) => CpuStorage::F16($fn::(__s, $($args),*)?), + CpuStorage::F32(__s) => CpuStorage::F32($fn::(__s, $($args),*)?), + CpuStorage::F64(__s) => CpuStorage::F64($fn::(__s, $($args),*)?), + CpuStorage::U32(__s) => CpuStorage::U32($fn::(__s, $($args),*)?), + }; + Ok(v) + }}; +} + +fn sum_impl1( + src: &[T], + dst_shape: &Shape, + src_dims: &[usize], + stride: &[usize], + to_dst_index: impl Fn(usize) -> usize, +) -> Result> { + let mut dst = vec![T::zero(); dst_shape.elem_count()]; + for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() { + dst[to_dst_index(unstr_index)] += src[src_index]; + } + Ok(dst) +} + fn unary_map U>( + vs: &[T], shape: &Shape, stride: &[usize], - vs: &[T], mut f: F, ) -> Vec { if shape.is_contiguous(stride) { @@ -83,11 +110,11 @@ fn binary_map T>( } } -fn take( +fn take_impl1( + vs: &[T], ids: &[u32], shape: &Shape, stride: &[usize], - vs: &[T], vocab_size: usize, hidden_size: usize, ) -> Result> { @@ -153,40 +180,104 @@ impl CpuStorage { pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result { // TODO: find a way around the quadratic number of cases below. match (self, dtype) { + (Self::U32(storage), DType::BF16) => { + let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } + (Self::BF16(storage), DType::BF16) => { + let data = unary_map(storage, shape, stride, |v| v); + Ok(Self::BF16(data)) + } + (Self::F16(storage), DType::BF16) => { + let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v.to_f32())); + Ok(Self::BF16(data)) + } + (Self::F32(storage), DType::BF16) => { + let data = unary_map(storage, shape, stride, bf16::from_f32); + Ok(Self::BF16(data)) + } + (Self::F64(storage), DType::BF16) => { + let data = unary_map(storage, shape, stride, bf16::from_f64); + Ok(Self::BF16(data)) + } + (Self::U32(storage), DType::F16) => { + let data = unary_map(storage, shape, stride, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } + (Self::BF16(storage), DType::F16) => { + let data = unary_map(storage, shape, stride, |v| f16::from_f32(v.to_f32())); + Ok(Self::F16(data)) + } + (Self::F16(storage), DType::F16) => { + let data = unary_map(storage, shape, stride, |v| v); + Ok(Self::F16(data)) + } + (Self::F32(storage), DType::F16) => { + let data = unary_map(storage, shape, stride, f16::from_f32); + Ok(Self::F16(data)) + } + (Self::F64(storage), DType::F16) => { + let data = unary_map(storage, shape, stride, f16::from_f64); + Ok(Self::F16(data)) + } (Self::U32(storage), DType::F32) => { - let data = unary_map(shape, stride, storage, |v| v as f32); + let data = unary_map(storage, shape, stride, |v| v as f32); + Ok(Self::F32(data)) + } + (Self::BF16(storage), DType::F32) => { + let data = unary_map(storage, shape, stride, |v| v.to_f32()); + Ok(Self::F32(data)) + } + (Self::F16(storage), DType::F32) => { + let data = unary_map(storage, shape, stride, |v| v.to_f32()); Ok(Self::F32(data)) } (Self::F32(storage), DType::F32) => { - let data = unary_map(shape, stride, storage, |v| v); + let data = unary_map(storage, shape, stride, |v| v); Ok(Self::F32(data)) } (Self::F64(storage), DType::F32) => { - let data = unary_map(shape, stride, storage, |v| v as f32); + let data = unary_map(storage, shape, stride, |v| v as f32); Ok(Self::F32(data)) } (Self::U32(storage), DType::U32) => { - let data = unary_map(shape, stride, storage, |v| v); + let data = unary_map(storage, shape, stride, |v| v); + Ok(Self::U32(data)) + } + (Self::BF16(storage), DType::U32) => { + let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32); + Ok(Self::U32(data)) + } + (Self::F16(storage), DType::U32) => { + let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32); Ok(Self::U32(data)) } (Self::F32(storage), DType::U32) => { - let data = unary_map(shape, stride, storage, |v| v as u32); + let data = unary_map(storage, shape, stride, |v| v as u32); Ok(Self::U32(data)) } (Self::F64(storage), DType::U32) => { - let data = unary_map(shape, stride, storage, |v| v as u32); + let data = unary_map(storage, shape, stride, |v| v as u32); Ok(Self::U32(data)) } (Self::U32(storage), DType::F64) => { - let data = unary_map(shape, stride, storage, |v| v as f64); + let data = unary_map(storage, shape, stride, |v| v as f64); + Ok(Self::F64(data)) + } + (Self::BF16(storage), DType::F64) => { + let data = unary_map(storage, shape, stride, |v| v.to_f64()); + Ok(Self::F64(data)) + } + (Self::F16(storage), DType::F64) => { + let data = unary_map(storage, shape, stride, |v| v.to_f64()); Ok(Self::F64(data)) } (Self::F32(storage), DType::F64) => { - let data = unary_map(shape, stride, storage, |v| v as f64); + let data = unary_map(storage, shape, stride, |v| v as f64); Ok(Self::F64(data)) } (Self::F64(storage), DType::F64) => { - let data = unary_map(shape, stride, storage, |v| v); + let data = unary_map(storage, shape, stride, |v| v); Ok(Self::F64(data)) } } @@ -219,29 +310,7 @@ impl CpuStorage { dst_index }; // TODO: Maybe provide an implementation with higher precision accumulators? - match self { - Self::F32(src) => { - let mut dst = vec![0f32; dst_shape.elem_count()]; - for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() { - dst[to_dst_index(unstr_index)] += src[src_index]; - } - Ok(Self::F32(dst)) - } - Self::F64(src) => { - let mut dst = vec![0f64; dst_shape.elem_count()]; - for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() { - dst[to_dst_index(unstr_index)] += src[src_index]; - } - Ok(Self::F64(dst)) - } - Self::U32(src) => { - let mut dst = vec![0u32; dst_shape.elem_count()]; - for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() { - dst[to_dst_index(unstr_index)] += src[src_index]; - } - Ok(Self::U32(dst)) - } - } + map1!(self, sum_impl1, &dst_shape, src_dims, stride, to_dst_index) } pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { @@ -251,6 +320,42 @@ impl CpuStorage { let prod_pre_dim = dims[..dim].iter().product(); let prod_post_dim = dims[dim + 1..].iter().product(); match self { + Self::BF16(storage) => { + for pre_idx in 0..prod_pre_dim { + for post_idx in 0..prod_post_dim { + let mut sum = 0f64; + let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; + for _ in 0..elem_per_slice { + sum += storage[idx].to_f64(); + idx += prod_post_dim + } + let sum = bf16::from_f64(sum); + let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; + for _ in 0..elem_per_slice { + storage[idx] /= sum; + idx += prod_post_dim + } + } + } + } + Self::F16(storage) => { + for pre_idx in 0..prod_pre_dim { + for post_idx in 0..prod_post_dim { + let mut sum = 0f64; + let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; + for _ in 0..elem_per_slice { + sum += storage[idx].to_f64(); + idx += prod_post_dim + } + let sum = f16::from_f64(sum); + let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; + for _ in 0..elem_per_slice { + storage[idx] /= sum; + idx += prod_post_dim + } + } + } + } Self::F32(storage) => { for pre_idx in 0..prod_pre_dim { for post_idx in 0..prod_post_dim { @@ -302,17 +407,29 @@ impl CpuStorage { Self::U32(storage) => { let mul = mul as u32; let add = add as u32; - let data = unary_map(shape, stride, storage, |v| v * mul + add); + let data = unary_map(storage, shape, stride, |v| v * mul + add); Ok(Self::U32(data)) } + Self::BF16(storage) => { + let mul = bf16::from_f64(mul); + let add = bf16::from_f64(add); + let data = unary_map(storage, shape, stride, |v| v * mul + add); + Ok(Self::BF16(data)) + } + Self::F16(storage) => { + let mul = f16::from_f64(mul); + let add = f16::from_f64(add); + let data = unary_map(storage, shape, stride, |v| v * mul + add); + Ok(Self::F16(data)) + } Self::F32(storage) => { let mul = mul as f32; let add = add as f32; - let data = unary_map(shape, stride, storage, |v| v * mul + add); + let data = unary_map(storage, shape, stride, |v| v * mul + add); Ok(Self::F32(data)) } Self::F64(storage) => { - let data = unary_map(shape, stride, storage, |v| v * mul + add); + let data = unary_map(storage, shape, stride, |v| v * mul + add); Ok(Self::F64(data)) } } @@ -320,16 +437,24 @@ impl CpuStorage { pub(crate) fn unary_impl(&self, shape: &Shape, stride: &[usize]) -> Result { match self { + Self::BF16(storage) => { + let data = unary_map(storage, shape, stride, B::bf16); + Ok(Self::BF16(data)) + } + Self::F16(storage) => { + let data = unary_map(storage, shape, stride, B::f16); + Ok(Self::F16(data)) + } Self::F32(storage) => { - let data = unary_map(shape, stride, storage, B::f32); + let data = unary_map(storage, shape, stride, B::f32); Ok(Self::F32(data)) } Self::F64(storage) => { - let data = unary_map(shape, stride, storage, B::f64); + let data = unary_map(storage, shape, stride, B::f64); Ok(Self::F64(data)) } Self::U32(storage) => { - let data = unary_map(shape, stride, storage, B::u32); + let data = unary_map(storage, shape, stride, B::u32); Ok(Self::U32(data)) } } @@ -343,6 +468,14 @@ impl CpuStorage { rhs_stride: &[usize], ) -> Result { match (self, rhs) { + (Self::BF16(lhs), Self::BF16(rhs)) => { + let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::bf16); + Ok(Self::BF16(data)) + } + (Self::F16(lhs), Self::F16(rhs)) => { + let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f16); + Ok(Self::F16(data)) + } (Self::F32(lhs), Self::F32(rhs)) => { let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32); Ok(Self::F32(data)) @@ -381,6 +514,12 @@ impl CpuStorage { (Self::U32(src), Self::U32(dst)) => { copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) } + (Self::BF16(src), Self::BF16(dst)) => { + copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) + } + (Self::F16(src), Self::F16(dst)) => { + copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) + } (Self::F32(src), Self::F32(dst)) => { copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) } @@ -411,6 +550,14 @@ impl CpuStorage { // TODO: Support types that could be casted to a boolean. let pred = self.as_slice::()?; match (t, f) { + (Self::BF16(t), Self::BF16(f)) => { + let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + Ok(Self::BF16(data)) + } + (Self::F16(t), Self::F16(f)) => { + let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + Ok(Self::F16(data)) + } (Self::F32(t), Self::F32(f)) => { let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); Ok(Self::F32(data)) @@ -440,20 +587,7 @@ impl CpuStorage { vocab_size: usize, ) -> Result { let ids = self.as_slice::()?; - match vs { - CpuStorage::F32(vs) => { - let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?; - Ok(CpuStorage::F32(storage)) - } - CpuStorage::F64(vs) => { - let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?; - Ok(CpuStorage::F64(storage)) - } - CpuStorage::U32(vs) => { - let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?; - Ok(CpuStorage::U32(storage)) - } - } + map1!(vs, take_impl1, ids, shape, stride, vocab_size, hidden_size) } pub(crate) fn matmul_impl( diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 8fadccdd..90cb0f72 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -133,6 +133,22 @@ impl CudaDevice { unsafe { func.launch(cfg, params) }?; CudaStorageSlice::U32(data) } + DType::BF16 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }?; + let func = self.get_or_load_func("fill_bf16", kernels::FILL)?; + let params = (&data, bf16::from_f64(v), elem_count); + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::BF16(data) + } + DType::F16 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }?; + let func = self.get_or_load_func("fill_f16", kernels::FILL)?; + let params = (&data, f16::from_f64(v), elem_count); + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F16(data) + } DType::F32 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }?; @@ -166,6 +182,14 @@ impl CudaDevice { let data = self.htod_sync_copy(storage)?; CudaStorageSlice::U32(data) } + CpuStorage::BF16(storage) => { + let data = self.htod_sync_copy(storage)?; + CudaStorageSlice::BF16(data) + } + CpuStorage::F16(storage) => { + let data = self.htod_sync_copy(storage)?; + CudaStorageSlice::F16(data) + } CpuStorage::F32(storage) => { let data = self.htod_sync_copy(storage)?; CudaStorageSlice::F32(data) @@ -325,6 +349,40 @@ impl CudaStorage { unsafe { func.launch(cfg, params) }?; CudaStorageSlice::U32(out) } + CudaStorageSlice::BF16(arg) => { + let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el_count) }?; + let params = ( + el_count, + dims.len(), + &ds, + arg, + &out, + bf16::from_f64(mul), + bf16::from_f64(add), + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::BF16(out) + } + CudaStorageSlice::F16(arg) => { + let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el_count) }?; + let params = ( + el_count, + dims.len(), + &ds, + arg, + &out, + f16::from_f64(mul), + f16::from_f64(add), + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F16(out) + } CudaStorageSlice::F32(arg) => { let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. @@ -376,6 +434,22 @@ impl CudaStorage { unsafe { func.launch(cfg, params) }?; CudaStorageSlice::U32(out) } + CudaStorageSlice::BF16(arg) => { + let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?; + let out = dev.alloc_zeros::(dst_el)?; + let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::BF16(out) + } + CudaStorageSlice::F16(arg) => { + let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?; + let out = dev.alloc_zeros::(dst_el)?; + let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F16(out) + } CudaStorageSlice::F32(arg) => { let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?; let out = dev.alloc_zeros::(dst_el)?; @@ -417,6 +491,24 @@ impl CudaStorage { CudaStorageSlice::U32(_arg) => { todo!("No unary kernels for u32"); } + CudaStorageSlice::BF16(arg) => { + let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el_count) }?; + let params = (el_count, dims.len(), &ds, arg, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::BF16(out) + } + CudaStorageSlice::F16(arg) => { + let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el_count) }?; + let params = (el_count, dims.len(), &ds, arg, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F16(out) + } CudaStorageSlice::F32(arg) => { let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?; // SAFETY: Set later by running the kernel. @@ -453,6 +545,24 @@ impl CudaStorage { let dev = self.device(); let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?; let slice = match (&self.slice, &rhs.slice) { + (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => { + let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(elem_count) }?; + let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::BF16(out) + } + (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { + let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(elem_count) }?; + let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F16(out) + } (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?; // SAFETY: Set later by running the kernel. @@ -494,6 +604,16 @@ impl CudaStorage { let cpu_storage = dev.dtoh_sync_copy(slice)?; Ok(CpuStorage::U32(cpu_storage)) } + CudaStorageSlice::BF16(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice)?; + Ok(CpuStorage::BF16(cpu_storage)) + } + CudaStorageSlice::F16(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice)?; + Ok(CpuStorage::F16(cpu_storage)) + } CudaStorageSlice::F32(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice)?; @@ -530,6 +650,24 @@ impl CudaStorage { let dev = self.device(); let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?; let slice = match (&t.slice, &f.slice) { + (CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => { + let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, ids, t, f, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::BF16(out) + } + (CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => { + let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, ids, t, f, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F16(out) + } (CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => { let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?; // SAFETY: Set later by running the kernel. @@ -596,6 +734,24 @@ impl CudaStorage { unsafe { func.launch(cfg, params) }?; CudaStorageSlice::U32(out) } + CudaStorageSlice::BF16(arg) => { + let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el * h_size) }?; + let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::BF16(out) + } + CudaStorageSlice::F16(arg) => { + let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el * h_size) }?; + let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F16(out) + } CudaStorageSlice::F32(arg) => { let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. @@ -629,6 +785,12 @@ impl CudaStorage { let elem_count = b * m * n; let dev = &self.device; let slice = match (&self.slice, &rhs.slice) { + (CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => { + todo!("bf16") + } + (CudaStorageSlice::F16(_lhs), CudaStorageSlice::F16(_rhs)) => { + todo!("f16") + } (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?; let mut out = unsafe { dev.alloc::(elem_count) }?; @@ -672,6 +834,32 @@ impl CudaStorage { let dev = &self.device; let ds = dev.htod_copy([dims, src_stride].concat())?; match (&self.slice, &mut dst.slice) { + (CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => { + let src = src.slice(src_offset..); + let mut dst = dst.slice_mut(dst_offset..); + if src_shape.is_contiguous(src_stride) { + dev.dtod_copy(&src, &mut dst)? + } else { + let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }? + } + } + (CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => { + let src = src.slice(src_offset..); + let mut dst = dst.slice_mut(dst_offset..); + if src_shape.is_contiguous(src_stride) { + dev.dtod_copy(&src, &mut dst)? + } else { + let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }? + } + } (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { let src = src.slice(src_offset..); let mut dst = dst.slice_mut(dst_offset..); @@ -685,6 +873,19 @@ impl CudaStorage { unsafe { func.launch(cfg, params) }? } } + (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { + let src = src.slice(src_offset..); + let mut dst = dst.slice_mut(dst_offset..); + if src_shape.is_contiguous(src_stride) { + dev.dtod_copy(&src, &mut dst)? + } else { + let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }? + } + } (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { let src = src.slice(src_offset..); let mut dst = dst.slice_mut(dst_offset..); From de1f612645aa2c7a411fc6659bbfe0c6482766c1 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 20:56:13 +0100 Subject: [PATCH 03/12] Remove the default features from the CI as cuda is not available. --- .github/workflows/rust-ci.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index bfc68412..9e0ec89b 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -20,6 +20,7 @@ jobs: - uses: actions-rs/cargo@v1 with: command: check + args: --no-default-features test: name: Test Suite @@ -38,6 +39,7 @@ jobs: - uses: actions-rs/cargo@v1 with: command: test + args: --no-default-features fmt: name: Rustfmt @@ -69,4 +71,4 @@ jobs: - uses: actions-rs/cargo@v1 with: command: clippy - args: --tests --examples -- -D warnings + args: --no-default-features --tests --examples -- -D warnings From 7cfa4c307c00917a24b2e94796c2cb0196ab96d5 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 21:10:03 +0100 Subject: [PATCH 04/12] Handle f16/bf16 in npy. --- src/npy.rs | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/npy.rs b/src/npy.rs index 3eb4e7c1..7e157c8f 100644 --- a/src/npy.rs +++ b/src/npy.rs @@ -27,6 +27,7 @@ //! ``` use crate::{DType, Device, Error, Result, Shape, Tensor}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::collections::HashMap; use std::fs::File; use std::io::{BufReader, Read, Write}; @@ -80,7 +81,7 @@ impl Header { .collect::>() .join(","); let descr = match self.descr { - DType::BF16 => todo!("bf16"), + DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?, DType::F16 => "f2", DType::F32 => "f4", DType::F64 => "f8", @@ -193,14 +194,19 @@ impl Header { } impl Tensor { + // TODO: Add the possibility to read directly to a device? fn from_reader(shape: Shape, dtype: DType, reader: &mut R) -> Result { let elem_count = shape.elem_count(); match dtype { DType::BF16 => { - todo!("bf16") + let mut data_t = vec![bf16::ZERO; elem_count]; + reader.read_u16_into::(data_t.reinterpret_cast_mut())?; + Tensor::from_vec(data_t, shape, &Device::Cpu) } DType::F16 => { - todo!("f16") + let mut data_t = vec![f16::ZERO; elem_count]; + reader.read_u16_into::(data_t.reinterpret_cast_mut())?; + Tensor::from_vec(data_t, shape, &Device::Cpu) } DType::F32 => { let mut data_t = vec![0f32; elem_count]; @@ -298,10 +304,16 @@ impl Tensor { let elem_count = self.elem_count(); match self.dtype() { DType::BF16 => { - todo!("bf16") + let vs = self.reshape(elem_count)?.to_vec1::()?; + for &v in vs.reinterpret_cast() { + f.write_u16::(v)? + } } DType::F16 => { - todo!("f16") + let vs = self.reshape(elem_count)?.to_vec1::()?; + for &v in vs.reinterpret_cast() { + f.write_u16::(v)? + } } DType::F32 => { // TODO: Avoid using a buffer when data is already on the CPU. From becb822ce098848d257b8fe4a0602a01505a2378 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 21:37:41 +0100 Subject: [PATCH 05/12] Support more types in the cpu matmul. --- src/cpu_backend.rs | 213 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 163 insertions(+), 50 deletions(-) diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 03e1e785..72599afc 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -618,63 +618,176 @@ impl CpuStorage { } } - let mut dst = vec![0.0; b * m * n]; - let dst_shape: Shape = (m, n).into(); let dst_strides = dst_shape.stride_contiguous(); let dst_rs = dst_strides[0]; let dst_cs = dst_strides[1]; - for step in 0..b { - let lhs_p = &self.as_slice::()?[step * a_skip..]; - let rhs_p = &rhs.as_slice::()?[step * b_skip..]; - let dst_p = &mut dst[step * c_skip..]; - unsafe { - gemm( - // m: usize, - m, - // n: usize, - n, - // k: usize, - k, - // dst: *mut T, - dst_p.as_mut_ptr(), - // dst_cs: isize, - dst_cs as isize, - // dst_rs: isize, - dst_rs as isize, - // read_dst: bool, - false, - // lhs: *const T, - lhs_p.as_ptr(), - // lhs_cs: isize, - lhs_cs as isize, - // lhs_rs: isize, - lhs_rs as isize, - // rhs: *const T, - rhs_p.as_ptr(), - // rhs_cs: isize, - rhs_cs as isize, - // rhs_rs: isize, - rhs_rs as isize, - // alpha: T, - 1.0, - // beta: T, - 1.0, - // conj_dst: bool, - false, - // conj_lhs: bool, - false, - // conj_rhs: bool, - true, - // parallelism: Parallelism - Parallelism::None, - ) + match (self, rhs) { + (CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => { + let mut dst = vec![f16::ZERO; b * m * n]; + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + gemm( + // m: usize, + m, + // n: usize, + n, + // k: usize, + k, + // dst: *mut T, + dst_p.as_mut_ptr(), + // dst_cs: isize, + dst_cs as isize, + // dst_rs: isize, + dst_rs as isize, + // read_dst: bool, + false, + // lhs: *const T, + lhs_p.as_ptr(), + // lhs_cs: isize, + lhs_cs as isize, + // lhs_rs: isize, + lhs_rs as isize, + // rhs: *const T, + rhs_p.as_ptr(), + // rhs_cs: isize, + rhs_cs as isize, + // rhs_rs: isize, + rhs_rs as isize, + // alpha: T, + f16::ONE, + // beta: T, + f16::ONE, + // conj_dst: bool, + false, + // conj_lhs: bool, + false, + // conj_rhs: bool, + true, + // parallelism: Parallelism + Parallelism::None, + ) + } + } + + Ok(Self::F16(dst)) + } + (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => { + let mut dst = vec![0f32; b * m * n]; + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + gemm( + // m: usize, + m, + // n: usize, + n, + // k: usize, + k, + // dst: *mut T, + dst_p.as_mut_ptr(), + // dst_cs: isize, + dst_cs as isize, + // dst_rs: isize, + dst_rs as isize, + // read_dst: bool, + false, + // lhs: *const T, + lhs_p.as_ptr(), + // lhs_cs: isize, + lhs_cs as isize, + // lhs_rs: isize, + lhs_rs as isize, + // rhs: *const T, + rhs_p.as_ptr(), + // rhs_cs: isize, + rhs_cs as isize, + // rhs_rs: isize, + rhs_rs as isize, + // alpha: T, + 1f32, + // beta: T, + 1f32, + // conj_dst: bool, + false, + // conj_lhs: bool, + false, + // conj_rhs: bool, + true, + // parallelism: Parallelism + Parallelism::None, + ) + } + } + + Ok(Self::F32(dst)) + } + (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => { + let mut dst = vec![0f64; b * m * n]; + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + gemm( + // m: usize, + m, + // n: usize, + n, + // k: usize, + k, + // dst: *mut T, + dst_p.as_mut_ptr(), + // dst_cs: isize, + dst_cs as isize, + // dst_rs: isize, + dst_rs as isize, + // read_dst: bool, + false, + // lhs: *const T, + lhs_p.as_ptr(), + // lhs_cs: isize, + lhs_cs as isize, + // lhs_rs: isize, + lhs_rs as isize, + // rhs: *const T, + rhs_p.as_ptr(), + // rhs_cs: isize, + rhs_cs as isize, + // rhs_rs: isize, + rhs_rs as isize, + // alpha: T, + 1f64, + // beta: T, + 1f64, + // conj_dst: bool, + false, + // conj_lhs: bool, + false, + // conj_rhs: bool, + true, + // parallelism: Parallelism + Parallelism::None, + ) + } + } + Ok(Self::F64(dst)) + } + _ => { + // This should be covered by the dtype check above. + Err(Error::DTypeMismatchBinaryOp { + lhs: self.dtype(), + rhs: rhs.dtype(), + op: "matmul", + }) } } - - let c = Self::F32(dst); - Ok(c) } pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self { From d204f1c7c0ee328e511ca85bf15324c9a35171f9 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 21:58:15 +0100 Subject: [PATCH 06/12] Cuda support for embedding f16. --- kernels/src/embeddings.cu | 4 ++++ src/cuda_backend.rs | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/kernels/src/embeddings.cu b/kernels/src/embeddings.cu index 79bd85a4..1dd12cf1 100644 --- a/kernels/src/embeddings.cu +++ b/kernels/src/embeddings.cu @@ -29,6 +29,10 @@ extern "C" __global__ void FN_NAME( \ } \ } \ +#if __CUDA_ARCH__ >= 530 +EMB_OP(__half, emb_f16) +#endif + EMB_OP(float, emb_f32) EMB_OP(double, emb_f64) EMB_OP(uint32_t, emb_u32) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 90cb0f72..d5be8bf6 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -726,7 +726,7 @@ impl CudaStorage { let slice = match &rhs.slice { // The kernels below assume that rhs is contiguous. CudaStorageSlice::U32(arg) => { - let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?; + let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el * h_size) }?; let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); From 93e24f29f4984ad98a8cedb2a2ffa633a5a48ec6 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 22:01:29 +0100 Subject: [PATCH 07/12] Add the f16 sum kernel. --- kernels/src/reduce.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/kernels/src/reduce.cu b/kernels/src/reduce.cu index d12d6b22..76dbbdc3 100644 --- a/kernels/src/reduce.cu +++ b/kernels/src/reduce.cu @@ -43,6 +43,10 @@ extern "C" __global__ void FN_NAME( \ } \ } \ +#if __CUDA_ARCH__ >= 530 +SUM_OP(float, sum_f32) +#endif + SUM_OP(float, sum_f32) SUM_OP(double, sum_f64) SUM_OP(uint32_t, sum_u32) From 53fdbda683ea42947a51addffebbe369474159e8 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 22:02:22 +0100 Subject: [PATCH 08/12] Add the f16 sum kernel (fix). --- kernels/src/reduce.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernels/src/reduce.cu b/kernels/src/reduce.cu index 76dbbdc3..e1ed57ab 100644 --- a/kernels/src/reduce.cu +++ b/kernels/src/reduce.cu @@ -44,7 +44,7 @@ extern "C" __global__ void FN_NAME( \ } \ #if __CUDA_ARCH__ >= 530 -SUM_OP(float, sum_f32) +SUM_OP(__half, sum_f16) #endif SUM_OP(float, sum_f32) From 36a4749e952ca6511cc5066e637b86b1422ba63d Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 22:05:31 +0100 Subject: [PATCH 09/12] Add the f16 affine kernel. --- kernels/src/affine.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/kernels/src/affine.cu b/kernels/src/affine.cu index 986c94a5..3ce5b8a7 100644 --- a/kernels/src/affine.cu +++ b/kernels/src/affine.cu @@ -28,6 +28,10 @@ extern "C" __global__ void FN_NAME( \ } \ } \ +#if __CUDA_ARCH__ >= 530 +AFFINE_OP(__half, affine_f16) +#endif + AFFINE_OP(float, affine_f32) AFFINE_OP(double, affine_f64) AFFINE_OP(uint32_t, affine_u32) From a6a7477bea8721332b160d571894a480e7298376 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 22:08:22 +0100 Subject: [PATCH 10/12] Matmul cublas support for f16. --- src/cuda_backend.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index d5be8bf6..12790125 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -788,8 +788,15 @@ impl CudaStorage { (CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => { todo!("bf16") } - (CudaStorageSlice::F16(_lhs), CudaStorageSlice::F16(_rhs)) => { - todo!("f16") + (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { + let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_stride, rhs_stride)?; + let mut out = unsafe { dev.alloc::(elem_count) }?; + unsafe { + self.device + .blas + .gemm_strided_batched(cfg, rhs, lhs, &mut out) + }?; + CudaStorageSlice::F16(out) } (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?; From 4d19889acc0bc72d9680aff19602e0ea812c9a57 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 22:14:32 +0100 Subject: [PATCH 11/12] where_cond for f16. --- kernels/src/ternary.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/kernels/src/ternary.cu b/kernels/src/ternary.cu index 8f8b3ac5..8f51526b 100644 --- a/kernels/src/ternary.cu +++ b/kernels/src/ternary.cu @@ -32,6 +32,10 @@ extern "C" __global__ void FN_NAME( \ } \ } \ +#if __CUDA_ARCH__ >= 530 +WHERE_OP(__half, where_f16) +#endif + WHERE_OP(float, where_f32) WHERE_OP(double, where_f64) WHERE_OP(uint32_t, where_u32) From e152c1273d1569f8fb748da37388d49dc68d6be0 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 27 Jun 2023 05:56:19 +0100 Subject: [PATCH 12/12] Add more context for missing cuda kernels. --- src/cuda_backend.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 12790125..0739b6b3 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -39,6 +39,12 @@ pub enum CudaError { expected: DType, got: DType, }, + + #[error("{cuda} when loading {module_name}")] + Load { + cuda: cudarc::driver::DriverError, + module_name: &'static str, + }, } type Result = std::result::Result; @@ -211,7 +217,8 @@ impl CudaDevice { ptx: &'static str, ) -> Result { if !self.has_func(module_name, module_name) { - self.load_ptx(ptx.into(), module_name, &[module_name])?; + self.load_ptx(ptx.into(), module_name, &[module_name]) + .map_err(|cuda| CudaError::Load { cuda, module_name })?; } self.get_func(module_name, module_name) // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is