Start adding f16/bf16 support.

This commit is contained in:
laurent
2023-06-26 19:37:47 +01:00
parent 36a1a48ba0
commit a31411fd91
6 changed files with 169 additions and 203 deletions

View File

@ -18,11 +18,13 @@ members = [
[dependencies]
safetensors = "0.3.1"
thiserror = "1"
cudarc = { version = "0.9.9", optional = true }
cudarc = { version = "0.9.9", optional = true, features = ["f16"] }
candle-kernels = { path = "kernels", optional = true }
gemm = "0.15.4"
zip = { version = "0.6.6", default-features=false }
byteorder = "1.4.3"
half = "2.3.1"
num-traits = "0.2.15"
[dev-dependencies]
anyhow = "1"

View File

@ -1,6 +1,7 @@
use crate::op::{BinaryOp, UnaryOp};
use crate::{DType, Error, Result, Shape, StridedIndex};
use gemm::{gemm, Parallelism};
use half::{bf16, f16};
// TODO: Think about whether we would be better off with a dtype and
// a buffer as an owned slice of bytes.
@ -9,6 +10,8 @@ use gemm::{gemm, Parallelism};
#[derive(Debug, Clone)]
pub enum CpuStorage {
U32(Vec<u32>),
BF16(Vec<bf16>),
F16(Vec<f16>),
F32(Vec<f32>),
F64(Vec<f64>),
}
@ -132,6 +135,8 @@ impl CpuStorage {
pub fn dtype(&self) -> DType {
match self {
Self::U32(_) => DType::U32,
Self::BF16(_) => DType::BF16,
Self::F16(_) => DType::F16,
Self::F32(_) => DType::F32,
Self::F64(_) => DType::F64,
}
@ -545,6 +550,14 @@ impl CpuStorage {
let data = vec![1u32; elem_count];
Self::U32(data)
}
DType::BF16 => {
let data = vec![bf16::ONE; elem_count];
Self::BF16(data)
}
DType::F16 => {
let data = vec![f16::ONE; elem_count];
Self::F16(data)
}
DType::F32 => {
let data = vec![1f32; elem_count];
Self::F32(data)
@ -563,6 +576,14 @@ impl CpuStorage {
let data = vec![0u32; elem_count];
Self::U32(data)
}
DType::BF16 => {
let data = vec![bf16::ZERO; elem_count];
Self::BF16(data)
}
DType::F16 => {
let data = vec![f16::ZERO; elem_count];
Self::F16(data)
}
DType::F32 => {
let data = vec![0f32; elem_count];
Self::F32(data)

View File

@ -2,6 +2,7 @@ use crate::{CpuStorage, DType, Shape};
use candle_kernels as kernels;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
use half::{bf16, f16};
use std::sync::Arc;
/// cudarc related errors
@ -97,6 +98,14 @@ impl CudaDevice {
let data = self.alloc_zeros::<u32>(elem_count)?;
CudaStorageSlice::U32(data)
}
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count)?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc_zeros::<f16>(elem_count)?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc_zeros::<f32>(elem_count)?;
CudaStorageSlice::F32(data)
@ -190,6 +199,8 @@ impl CudaDevice {
#[derive(Debug)]
enum CudaStorageSlice {
U32(CudaSlice<u32>),
BF16(CudaSlice<bf16>),
F16(CudaSlice<f16>),
F32(CudaSlice<f32>),
F64(CudaSlice<f64>),
}
@ -265,6 +276,8 @@ impl CudaStorage {
pub fn try_clone(&self) -> Result<Self> {
let slice = match &self.slice {
CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?),
CudaStorageSlice::BF16(slice) => CudaStorageSlice::BF16(slice.try_clone()?),
CudaStorageSlice::F16(slice) => CudaStorageSlice::F16(slice.try_clone()?),
CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
};
@ -275,6 +288,8 @@ impl CudaStorage {
pub fn dtype(&self) -> DType {
match self.slice {
CudaStorageSlice::U32(_) => DType::U32,
CudaStorageSlice::BF16(_) => DType::BF16,
CudaStorageSlice::F16(_) => DType::F16,
CudaStorageSlice::F32(_) => DType::F32,
CudaStorageSlice::F64(_) => DType::F64,
}

View File

@ -3,6 +3,8 @@ use crate::{CpuStorage, Error, Result};
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType {
U32,
BF16,
F16,
F32,
F64,
}
@ -11,6 +13,8 @@ impl DType {
pub fn size_in_bytes(&self) -> usize {
match self {
Self::U32 => 4,
Self::BF16 => 2,
Self::F16 => 2,
Self::F32 => 4,
Self::F64 => 8,
}
@ -76,5 +80,7 @@ macro_rules! with_dtype {
};
}
with_dtype!(u32, U32);
with_dtype!(half::f16, F16);
with_dtype!(half::bf16, BF16);
with_dtype!(f32, F32);
with_dtype!(f64, F64);

View File

@ -80,6 +80,8 @@ impl Header {
.collect::<Vec<_>>()
.join(",");
let descr = match self.descr {
DType::BF16 => todo!("bf16"),
DType::F16 => "f2",
DType::F32 => "f4",
DType::F64 => "f8",
DType::U32 => "u4",
@ -152,7 +154,7 @@ impl Header {
// int64, int32, int16, int8,
// uint8, and bool.
match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') {
// "e" | "f2" => DType::F16,
"e" | "f2" => DType::F16,
"f" | "f4" => DType::F32,
"d" | "f8" => DType::F64,
// "i" | "i4" => DType::S32,
@ -194,6 +196,12 @@ impl Tensor {
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
let elem_count = shape.elem_count();
match dtype {
DType::BF16 => {
todo!("bf16")
}
DType::F16 => {
todo!("f16")
}
DType::F32 => {
let mut data_t = vec![0f32; elem_count];
reader.read_f32_into::<LittleEndian>(&mut data_t)?;
@ -289,6 +297,12 @@ impl Tensor {
f.write_all(header.as_bytes())?;
let elem_count = self.elem_count();
match self.dtype() {
DType::BF16 => {
todo!("bf16")
}
DType::F16 => {
todo!("f16")
}
DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU.
for v in self.reshape(elem_count)?.to_vec1::<f32>()? {

310
src/op.rs
View File

@ -1,4 +1,6 @@
use crate::Tensor;
use half::{bf16, f16};
use num_traits::float::Float;
#[derive(Clone)]
pub(crate) enum Op {
@ -40,10 +42,13 @@ pub(crate) enum Op {
pub(crate) trait UnaryOp {
const NAME: &'static str;
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
// contiguous case separately as it's easy to optimize things out there.
const KERNEL_BF16: &'static str;
const KERNEL_F16: &'static str;
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
const KERNEL_U32: &'static str;
fn bf16(v1: bf16) -> bf16;
fn f16(v1: f16) -> f16;
fn f32(v1: f32) -> f32;
fn f64(v1: f64) -> f64;
fn u32(v1: u32) -> u32;
@ -51,11 +56,13 @@ pub(crate) trait UnaryOp {
pub(crate) trait BinaryOp {
const NAME: &'static str;
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
// contiguous case separately as it's easy to optimize things out there.
const KERNEL_BF16: &'static str;
const KERNEL_F16: &'static str;
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
const KERNEL_U32: &'static str;
fn bf16(v1: bf16, v2: bf16) -> bf16;
fn f16(v1: f16, v2: f16) -> f16;
fn f32(v1: f32, v2: f32) -> f32;
fn f64(v1: f64, v2: f64) -> f64;
fn u32(v1: u32, v2: u32) -> u32;
@ -75,215 +82,116 @@ pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
impl BinaryOp for Add {
const NAME: &'static str = "add";
const KERNEL_F32: &'static str = "badd_f32";
const KERNEL_F64: &'static str = "badd_f64";
const KERNEL_U32: &'static str = "badd_u32";
fn f32(v1: f32, v2: f32) -> f32 {
v1 + v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 + v2
}
fn u32(v1: u32, v2: u32) -> u32 {
v1 + v2
}
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr) => {
impl BinaryOp for $op {
const NAME: &'static str = $name;
const KERNEL_BF16: &'static str = concat!("b", $name, "_bf16");
const KERNEL_F16: &'static str = concat!("b", $name, "_f16");
const KERNEL_F32: &'static str = concat!("b", $name, "_f32");
const KERNEL_F64: &'static str = concat!("b", $name, "_f64");
const KERNEL_U32: &'static str = concat!("b", $name, "_u32");
fn bf16(v1: bf16, v2: bf16) -> bf16 {
$e(v1, v2)
}
fn f16(v1: f16, v2: f16) -> f16 {
$e(v1, v2)
}
fn f32(v1: f32, v2: f32) -> f32 {
$e(v1, v2)
}
fn f64(v1: f64, v2: f64) -> f64 {
$e(v1, v2)
}
fn u32(v1: u32, v2: u32) -> u32 {
$e(v1, v2)
}
}
};
}
impl BinaryOp for Sub {
const NAME: &'static str = "sub";
const KERNEL_F32: &'static str = "bsub_f32";
const KERNEL_F64: &'static str = "bsub_f64";
const KERNEL_U32: &'static str = "bsub_u32";
fn f32(v1: f32, v2: f32) -> f32 {
v1 - v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 - v2
}
fn u32(v1: u32, v2: u32) -> u32 {
v1 - v2
}
bin_op!(Add, "add", |v1, v2| v1 + v2);
bin_op!(Sub, "sub", |v1, v2| v1 - v2);
bin_op!(Mul, "mul", |v1, v2| v1 * v2);
bin_op!(Div, "div", |v1, v2| v1 / v2);
macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => {
impl UnaryOp for $op {
const NAME: &'static str = $name;
const KERNEL_BF16: &'static str = concat!("u", $name, "_bf16");
const KERNEL_F16: &'static str = concat!("u", $name, "_f16");
const KERNEL_F32: &'static str = concat!("u", $name, "_f32");
const KERNEL_F64: &'static str = concat!("u", $name, "_f64");
const KERNEL_U32: &'static str = concat!("u", $name, "_u32");
fn bf16($a: bf16) -> bf16 {
$e
}
fn f16($a: f16) -> f16 {
$e
}
fn f32($a: f32) -> f32 {
$e
}
fn f64($a: f64) -> f64 {
$e
}
fn u32(_: u32) -> u32 {
todo!("no unary function for u32")
}
}
};
}
impl BinaryOp for Mul {
const NAME: &'static str = "mul";
const KERNEL_F32: &'static str = "bmul_f32";
const KERNEL_F64: &'static str = "bmul_f64";
const KERNEL_U32: &'static str = "bmul_u32";
fn f32(v1: f32, v2: f32) -> f32 {
v1 * v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 * v2
}
fn u32(v1: u32, v2: u32) -> u32 {
v1 * v2
}
}
impl BinaryOp for Div {
const NAME: &'static str = "div";
const KERNEL_F32: &'static str = "bdiv_f32";
const KERNEL_F64: &'static str = "bdiv_f64";
const KERNEL_U32: &'static str = "bdiv_u32";
fn f32(v1: f32, v2: f32) -> f32 {
v1 / v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 / v2
}
fn u32(v1: u32, v2: u32) -> u32 {
v1 / v2
}
}
impl UnaryOp for Exp {
const NAME: &'static str = "exp";
fn f32(v1: f32) -> f32 {
v1.exp()
}
fn f64(v1: f64) -> f64 {
v1.exp()
}
fn u32(v1: u32) -> u32 {
(v1 as f64).exp() as u32
}
const KERNEL_F32: &'static str = "uexp_f32";
const KERNEL_F64: &'static str = "uexp_f64";
}
impl UnaryOp for Log {
const NAME: &'static str = "log";
fn f32(v1: f32) -> f32 {
v1.ln()
}
fn f64(v1: f64) -> f64 {
v1.ln()
}
fn u32(v1: u32) -> u32 {
(v1 as f64).ln() as u32
}
const KERNEL_F32: &'static str = "ulog_f32";
const KERNEL_F64: &'static str = "ulog_f64";
}
impl UnaryOp for Sin {
const NAME: &'static str = "sin";
fn f32(v1: f32) -> f32 {
v1.sin()
}
fn f64(v1: f64) -> f64 {
v1.sin()
}
fn u32(_: u32) -> u32 {
0
}
const KERNEL_F32: &'static str = "usin_f32";
const KERNEL_F64: &'static str = "usin_f64";
}
impl UnaryOp for Cos {
const NAME: &'static str = "cos";
fn f32(v1: f32) -> f32 {
v1.cos()
}
fn f64(v1: f64) -> f64 {
v1.cos()
}
fn u32(_: u32) -> u32 {
0
}
const KERNEL_F32: &'static str = "ucos_f32";
const KERNEL_F64: &'static str = "ucos_f64";
}
impl UnaryOp for Abs {
const NAME: &'static str = "abs";
fn f32(v1: f32) -> f32 {
v1.abs()
}
fn f64(v1: f64) -> f64 {
v1.abs()
}
fn u32(v1: u32) -> u32 {
v1
}
const KERNEL_F32: &'static str = "uabs_f32";
const KERNEL_F64: &'static str = "uabs_f64";
}
impl UnaryOp for Neg {
const NAME: &'static str = "neg";
fn f32(v1: f32) -> f32 {
-v1
}
fn f64(v1: f64) -> f64 {
-v1
}
fn u32(_: u32) -> u32 {
0
}
const KERNEL_F32: &'static str = "uneg_f32";
const KERNEL_F64: &'static str = "uneg_f64";
}
impl UnaryOp for Sqr {
const NAME: &'static str = "sqr";
fn f32(v1: f32) -> f32 {
v1 * v1
}
fn f64(v1: f64) -> f64 {
v1 * v1
}
fn u32(v: u32) -> u32 {
v * v
}
const KERNEL_F32: &'static str = "usqr_f32";
const KERNEL_F64: &'static str = "usqr_f64";
}
impl UnaryOp for Sqrt {
const NAME: &'static str = "sqrt";
fn f32(v1: f32) -> f32 {
v1.sqrt()
}
fn f64(v1: f64) -> f64 {
v1.sqrt()
}
fn u32(v: u32) -> u32 {
(v as f64).sqrt() as u32
}
const KERNEL_F32: &'static str = "usqrt_f32";
const KERNEL_F64: &'static str = "usqrt_f64";
}
unary_op!(Exp, "exp", v, v.exp());
unary_op!(Log, "log", v, v.ln());
unary_op!(Sin, "sin", v, v.sin());
unary_op!(Cos, "cos", v, v.cos());
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v);
unary_op!(Sqr, "sqr", v, v * v);
unary_op!(Sqrt, "sqrt", v, v.sqrt());
/// `gelu` operation
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
#[inline]
pub fn gelu_f32(v: f32) -> f32 {
0.5 * v
* (1.0 + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
/// `gelu` operation
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
#[inline]
pub fn gelu_f64(v: f64) -> f64 {
0.5 * v
* (1.0 + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
impl UnaryOp for Gelu {
const NAME: &'static str = "gelu";
fn f32(v1: f32) -> f32 {
gelu_f32(v1)
fn bf16(v: bf16) -> bf16 {
bf16::from_f32_const(0.5)
* v
* (bf16::ONE
+ bf16::tanh(
(bf16::from_f32_const(2.0) / bf16::PI).sqrt()
* v
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
))
}
fn f64(v1: f64) -> f64 {
gelu_f64(v1)
fn f16(v: f16) -> f16 {
f16::from_f32_const(0.5)
* v
* (f16::ONE
+ f16::tanh(
(f16::from_f32_const(2.0) / f16::PI).sqrt()
* v
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
))
}
fn u32(v1: u32) -> u32 {
gelu_f64(v1 as f64) as u32
fn f32(v: f32) -> f32 {
0.5 * v
* (1.0
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
fn f64(v: f64) -> f64 {
0.5 * v
* (1.0
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
fn u32(_: u32) -> u32 {
0
}
const KERNEL_BF16: &'static str = "gelu_bf16";
const KERNEL_F16: &'static str = "gelu_f16";
const KERNEL_F32: &'static str = "gelu_f32";
const KERNEL_F64: &'static str = "gelu_f64";
const KERNEL_U32: &'static str = "gelu_u32";
}