mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Merge pull request #10 from LaurentMazare/f16
Add support for f16 and bf16
This commit is contained in:
4
.github/workflows/rust-ci.yml
vendored
4
.github/workflows/rust-ci.yml
vendored
@ -20,6 +20,7 @@ jobs:
|
|||||||
- uses: actions-rs/cargo@v1
|
- uses: actions-rs/cargo@v1
|
||||||
with:
|
with:
|
||||||
command: check
|
command: check
|
||||||
|
args: --no-default-features
|
||||||
|
|
||||||
test:
|
test:
|
||||||
name: Test Suite
|
name: Test Suite
|
||||||
@ -38,6 +39,7 @@ jobs:
|
|||||||
- uses: actions-rs/cargo@v1
|
- uses: actions-rs/cargo@v1
|
||||||
with:
|
with:
|
||||||
command: test
|
command: test
|
||||||
|
args: --no-default-features
|
||||||
|
|
||||||
fmt:
|
fmt:
|
||||||
name: Rustfmt
|
name: Rustfmt
|
||||||
@ -69,4 +71,4 @@ jobs:
|
|||||||
- uses: actions-rs/cargo@v1
|
- uses: actions-rs/cargo@v1
|
||||||
with:
|
with:
|
||||||
command: clippy
|
command: clippy
|
||||||
args: --tests --examples -- -D warnings
|
args: --no-default-features --tests --examples -- -D warnings
|
||||||
|
@ -18,11 +18,13 @@ members = [
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
safetensors = "0.3.1"
|
safetensors = "0.3.1"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
cudarc = { version = "0.9.9", optional = true }
|
cudarc = { version = "0.9.9", optional = true, features = ["f16"] }
|
||||||
candle-kernels = { path = "kernels", optional = true }
|
candle-kernels = { path = "kernels", optional = true }
|
||||||
gemm = "0.15.4"
|
gemm = "0.15.4"
|
||||||
zip = { version = "0.6.6", default-features=false }
|
zip = { version = "0.6.6", default-features=false }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
num-traits = "0.2.15"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
@ -31,5 +33,5 @@ rand = "0.8.5"
|
|||||||
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
|
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = ["cuda"]
|
||||||
cuda = ["dep:cudarc", "dep:candle-kernels"]
|
cuda = ["dep:cudarc", "dep:candle-kernels"]
|
||||||
|
@ -28,6 +28,10 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
|
AFFINE_OP(__half, affine_f16)
|
||||||
|
#endif
|
||||||
|
|
||||||
AFFINE_OP(float, affine_f32)
|
AFFINE_OP(float, affine_f32)
|
||||||
AFFINE_OP(double, affine_f64)
|
AFFINE_OP(double, affine_f64)
|
||||||
AFFINE_OP(uint32_t, affine_u32)
|
AFFINE_OP(uint32_t, affine_u32)
|
||||||
|
@ -29,6 +29,10 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
|
EMB_OP(__half, emb_f16)
|
||||||
|
#endif
|
||||||
|
|
||||||
EMB_OP(float, emb_f32)
|
EMB_OP(float, emb_f32)
|
||||||
EMB_OP(double, emb_f64)
|
EMB_OP(double, emb_f64)
|
||||||
EMB_OP(uint32_t, emb_u32)
|
EMB_OP(uint32_t, emb_u32)
|
||||||
|
@ -43,6 +43,10 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
|
SUM_OP(__half, sum_f16)
|
||||||
|
#endif
|
||||||
|
|
||||||
SUM_OP(float, sum_f32)
|
SUM_OP(float, sum_f32)
|
||||||
SUM_OP(double, sum_f64)
|
SUM_OP(double, sum_f64)
|
||||||
SUM_OP(uint32_t, sum_u32)
|
SUM_OP(uint32_t, sum_u32)
|
||||||
|
@ -32,6 +32,10 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
|
WHERE_OP(__half, where_f16)
|
||||||
|
#endif
|
||||||
|
|
||||||
WHERE_OP(float, where_f32)
|
WHERE_OP(float, where_f32)
|
||||||
WHERE_OP(double, where_f64)
|
WHERE_OP(double, where_f64)
|
||||||
WHERE_OP(uint32_t, where_u32)
|
WHERE_OP(uint32_t, where_u32)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
use crate::op::{BinaryOp, UnaryOp};
|
use crate::op::{BinaryOp, UnaryOp};
|
||||||
use crate::{DType, Error, Result, Shape, StridedIndex};
|
use crate::{DType, Error, Result, Shape, StridedIndex};
|
||||||
use gemm::{gemm, Parallelism};
|
use gemm::{gemm, Parallelism};
|
||||||
|
use half::{bf16, f16};
|
||||||
|
|
||||||
// TODO: Think about whether we would be better off with a dtype and
|
// TODO: Think about whether we would be better off with a dtype and
|
||||||
// a buffer as an owned slice of bytes.
|
// a buffer as an owned slice of bytes.
|
||||||
@ -9,6 +10,8 @@ use gemm::{gemm, Parallelism};
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum CpuStorage {
|
pub enum CpuStorage {
|
||||||
U32(Vec<u32>),
|
U32(Vec<u32>),
|
||||||
|
BF16(Vec<bf16>),
|
||||||
|
F16(Vec<f16>),
|
||||||
F32(Vec<f32>),
|
F32(Vec<f32>),
|
||||||
F64(Vec<f64>),
|
F64(Vec<f64>),
|
||||||
}
|
}
|
||||||
@ -43,10 +46,37 @@ fn wcond<T: Copy>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
macro_rules! map1 {
|
||||||
|
($v: expr, $fn: ident, $( $args:expr ),*) => {{
|
||||||
|
let v = match $v {
|
||||||
|
CpuStorage::BF16(__s) => CpuStorage::BF16($fn::<bf16>(__s, $($args),*)?),
|
||||||
|
CpuStorage::F16(__s) => CpuStorage::F16($fn::<f16>(__s, $($args),*)?),
|
||||||
|
CpuStorage::F32(__s) => CpuStorage::F32($fn::<f32>(__s, $($args),*)?),
|
||||||
|
CpuStorage::F64(__s) => CpuStorage::F64($fn::<f64>(__s, $($args),*)?),
|
||||||
|
CpuStorage::U32(__s) => CpuStorage::U32($fn::<u32>(__s, $($args),*)?),
|
||||||
|
};
|
||||||
|
Ok(v)
|
||||||
|
}};
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sum_impl1<T: Copy + num_traits::NumAssign>(
|
||||||
|
src: &[T],
|
||||||
|
dst_shape: &Shape,
|
||||||
|
src_dims: &[usize],
|
||||||
|
stride: &[usize],
|
||||||
|
to_dst_index: impl Fn(usize) -> usize,
|
||||||
|
) -> Result<Vec<T>> {
|
||||||
|
let mut dst = vec![T::zero(); dst_shape.elem_count()];
|
||||||
|
for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() {
|
||||||
|
dst[to_dst_index(unstr_index)] += src[src_index];
|
||||||
|
}
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
|
||||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||||
|
vs: &[T],
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
stride: &[usize],
|
stride: &[usize],
|
||||||
vs: &[T],
|
|
||||||
mut f: F,
|
mut f: F,
|
||||||
) -> Vec<U> {
|
) -> Vec<U> {
|
||||||
if shape.is_contiguous(stride) {
|
if shape.is_contiguous(stride) {
|
||||||
@ -80,11 +110,11 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn take<T: Copy>(
|
fn take_impl1<T: Copy>(
|
||||||
|
vs: &[T],
|
||||||
ids: &[u32],
|
ids: &[u32],
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
stride: &[usize],
|
stride: &[usize],
|
||||||
vs: &[T],
|
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
hidden_size: usize,
|
hidden_size: usize,
|
||||||
) -> Result<Vec<T>> {
|
) -> Result<Vec<T>> {
|
||||||
@ -132,6 +162,8 @@ impl CpuStorage {
|
|||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
match self {
|
match self {
|
||||||
Self::U32(_) => DType::U32,
|
Self::U32(_) => DType::U32,
|
||||||
|
Self::BF16(_) => DType::BF16,
|
||||||
|
Self::F16(_) => DType::F16,
|
||||||
Self::F32(_) => DType::F32,
|
Self::F32(_) => DType::F32,
|
||||||
Self::F64(_) => DType::F64,
|
Self::F64(_) => DType::F64,
|
||||||
}
|
}
|
||||||
@ -148,40 +180,104 @@ impl CpuStorage {
|
|||||||
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
||||||
// TODO: find a way around the quadratic number of cases below.
|
// TODO: find a way around the quadratic number of cases below.
|
||||||
match (self, dtype) {
|
match (self, dtype) {
|
||||||
|
(Self::U32(storage), DType::BF16) => {
|
||||||
|
let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v as f32));
|
||||||
|
Ok(Self::BF16(data))
|
||||||
|
}
|
||||||
|
(Self::BF16(storage), DType::BF16) => {
|
||||||
|
let data = unary_map(storage, shape, stride, |v| v);
|
||||||
|
Ok(Self::BF16(data))
|
||||||
|
}
|
||||||
|
(Self::F16(storage), DType::BF16) => {
|
||||||
|
let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v.to_f32()));
|
||||||
|
Ok(Self::BF16(data))
|
||||||
|
}
|
||||||
|
(Self::F32(storage), DType::BF16) => {
|
||||||
|
let data = unary_map(storage, shape, stride, bf16::from_f32);
|
||||||
|
Ok(Self::BF16(data))
|
||||||
|
}
|
||||||
|
(Self::F64(storage), DType::BF16) => {
|
||||||
|
let data = unary_map(storage, shape, stride, bf16::from_f64);
|
||||||
|
Ok(Self::BF16(data))
|
||||||
|
}
|
||||||
|
(Self::U32(storage), DType::F16) => {
|
||||||
|
let data = unary_map(storage, shape, stride, |v| f16::from_f32(v as f32));
|
||||||
|
Ok(Self::F16(data))
|
||||||
|
}
|
||||||
|
(Self::BF16(storage), DType::F16) => {
|
||||||
|
let data = unary_map(storage, shape, stride, |v| f16::from_f32(v.to_f32()));
|
||||||
|
Ok(Self::F16(data))
|
||||||
|
}
|
||||||
|
(Self::F16(storage), DType::F16) => {
|
||||||
|
let data = unary_map(storage, shape, stride, |v| v);
|
||||||
|
Ok(Self::F16(data))
|
||||||
|
}
|
||||||
|
(Self::F32(storage), DType::F16) => {
|
||||||
|
let data = unary_map(storage, shape, stride, f16::from_f32);
|
||||||
|
Ok(Self::F16(data))
|
||||||
|
}
|
||||||
|
(Self::F64(storage), DType::F16) => {
|
||||||
|
let data = unary_map(storage, shape, stride, f16::from_f64);
|
||||||
|
Ok(Self::F16(data))
|
||||||
|
}
|
||||||
(Self::U32(storage), DType::F32) => {
|
(Self::U32(storage), DType::F32) => {
|
||||||
let data = unary_map(shape, stride, storage, |v| v as f32);
|
let data = unary_map(storage, shape, stride, |v| v as f32);
|
||||||
|
Ok(Self::F32(data))
|
||||||
|
}
|
||||||
|
(Self::BF16(storage), DType::F32) => {
|
||||||
|
let data = unary_map(storage, shape, stride, |v| v.to_f32());
|
||||||
|
Ok(Self::F32(data))
|
||||||
|
}
|
||||||
|
(Self::F16(storage), DType::F32) => {
|
||||||
|
let data = unary_map(storage, shape, stride, |v| v.to_f32());
|
||||||
Ok(Self::F32(data))
|
Ok(Self::F32(data))
|
||||||
}
|
}
|
||||||
(Self::F32(storage), DType::F32) => {
|
(Self::F32(storage), DType::F32) => {
|
||||||
let data = unary_map(shape, stride, storage, |v| v);
|
let data = unary_map(storage, shape, stride, |v| v);
|
||||||
Ok(Self::F32(data))
|
Ok(Self::F32(data))
|
||||||
}
|
}
|
||||||
(Self::F64(storage), DType::F32) => {
|
(Self::F64(storage), DType::F32) => {
|
||||||
let data = unary_map(shape, stride, storage, |v| v as f32);
|
let data = unary_map(storage, shape, stride, |v| v as f32);
|
||||||
Ok(Self::F32(data))
|
Ok(Self::F32(data))
|
||||||
}
|
}
|
||||||
(Self::U32(storage), DType::U32) => {
|
(Self::U32(storage), DType::U32) => {
|
||||||
let data = unary_map(shape, stride, storage, |v| v);
|
let data = unary_map(storage, shape, stride, |v| v);
|
||||||
|
Ok(Self::U32(data))
|
||||||
|
}
|
||||||
|
(Self::BF16(storage), DType::U32) => {
|
||||||
|
let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32);
|
||||||
|
Ok(Self::U32(data))
|
||||||
|
}
|
||||||
|
(Self::F16(storage), DType::U32) => {
|
||||||
|
let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32);
|
||||||
Ok(Self::U32(data))
|
Ok(Self::U32(data))
|
||||||
}
|
}
|
||||||
(Self::F32(storage), DType::U32) => {
|
(Self::F32(storage), DType::U32) => {
|
||||||
let data = unary_map(shape, stride, storage, |v| v as u32);
|
let data = unary_map(storage, shape, stride, |v| v as u32);
|
||||||
Ok(Self::U32(data))
|
Ok(Self::U32(data))
|
||||||
}
|
}
|
||||||
(Self::F64(storage), DType::U32) => {
|
(Self::F64(storage), DType::U32) => {
|
||||||
let data = unary_map(shape, stride, storage, |v| v as u32);
|
let data = unary_map(storage, shape, stride, |v| v as u32);
|
||||||
Ok(Self::U32(data))
|
Ok(Self::U32(data))
|
||||||
}
|
}
|
||||||
(Self::U32(storage), DType::F64) => {
|
(Self::U32(storage), DType::F64) => {
|
||||||
let data = unary_map(shape, stride, storage, |v| v as f64);
|
let data = unary_map(storage, shape, stride, |v| v as f64);
|
||||||
|
Ok(Self::F64(data))
|
||||||
|
}
|
||||||
|
(Self::BF16(storage), DType::F64) => {
|
||||||
|
let data = unary_map(storage, shape, stride, |v| v.to_f64());
|
||||||
|
Ok(Self::F64(data))
|
||||||
|
}
|
||||||
|
(Self::F16(storage), DType::F64) => {
|
||||||
|
let data = unary_map(storage, shape, stride, |v| v.to_f64());
|
||||||
Ok(Self::F64(data))
|
Ok(Self::F64(data))
|
||||||
}
|
}
|
||||||
(Self::F32(storage), DType::F64) => {
|
(Self::F32(storage), DType::F64) => {
|
||||||
let data = unary_map(shape, stride, storage, |v| v as f64);
|
let data = unary_map(storage, shape, stride, |v| v as f64);
|
||||||
Ok(Self::F64(data))
|
Ok(Self::F64(data))
|
||||||
}
|
}
|
||||||
(Self::F64(storage), DType::F64) => {
|
(Self::F64(storage), DType::F64) => {
|
||||||
let data = unary_map(shape, stride, storage, |v| v);
|
let data = unary_map(storage, shape, stride, |v| v);
|
||||||
Ok(Self::F64(data))
|
Ok(Self::F64(data))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -214,29 +310,7 @@ impl CpuStorage {
|
|||||||
dst_index
|
dst_index
|
||||||
};
|
};
|
||||||
// TODO: Maybe provide an implementation with higher precision accumulators?
|
// TODO: Maybe provide an implementation with higher precision accumulators?
|
||||||
match self {
|
map1!(self, sum_impl1, &dst_shape, src_dims, stride, to_dst_index)
|
||||||
Self::F32(src) => {
|
|
||||||
let mut dst = vec![0f32; dst_shape.elem_count()];
|
|
||||||
for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() {
|
|
||||||
dst[to_dst_index(unstr_index)] += src[src_index];
|
|
||||||
}
|
|
||||||
Ok(Self::F32(dst))
|
|
||||||
}
|
|
||||||
Self::F64(src) => {
|
|
||||||
let mut dst = vec![0f64; dst_shape.elem_count()];
|
|
||||||
for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() {
|
|
||||||
dst[to_dst_index(unstr_index)] += src[src_index];
|
|
||||||
}
|
|
||||||
Ok(Self::F64(dst))
|
|
||||||
}
|
|
||||||
Self::U32(src) => {
|
|
||||||
let mut dst = vec![0u32; dst_shape.elem_count()];
|
|
||||||
for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() {
|
|
||||||
dst[to_dst_index(unstr_index)] += src[src_index];
|
|
||||||
}
|
|
||||||
Ok(Self::U32(dst))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||||
@ -246,6 +320,42 @@ impl CpuStorage {
|
|||||||
let prod_pre_dim = dims[..dim].iter().product();
|
let prod_pre_dim = dims[..dim].iter().product();
|
||||||
let prod_post_dim = dims[dim + 1..].iter().product();
|
let prod_post_dim = dims[dim + 1..].iter().product();
|
||||||
match self {
|
match self {
|
||||||
|
Self::BF16(storage) => {
|
||||||
|
for pre_idx in 0..prod_pre_dim {
|
||||||
|
for post_idx in 0..prod_post_dim {
|
||||||
|
let mut sum = 0f64;
|
||||||
|
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||||
|
for _ in 0..elem_per_slice {
|
||||||
|
sum += storage[idx].to_f64();
|
||||||
|
idx += prod_post_dim
|
||||||
|
}
|
||||||
|
let sum = bf16::from_f64(sum);
|
||||||
|
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||||
|
for _ in 0..elem_per_slice {
|
||||||
|
storage[idx] /= sum;
|
||||||
|
idx += prod_post_dim
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Self::F16(storage) => {
|
||||||
|
for pre_idx in 0..prod_pre_dim {
|
||||||
|
for post_idx in 0..prod_post_dim {
|
||||||
|
let mut sum = 0f64;
|
||||||
|
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||||
|
for _ in 0..elem_per_slice {
|
||||||
|
sum += storage[idx].to_f64();
|
||||||
|
idx += prod_post_dim
|
||||||
|
}
|
||||||
|
let sum = f16::from_f64(sum);
|
||||||
|
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||||
|
for _ in 0..elem_per_slice {
|
||||||
|
storage[idx] /= sum;
|
||||||
|
idx += prod_post_dim
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Self::F32(storage) => {
|
Self::F32(storage) => {
|
||||||
for pre_idx in 0..prod_pre_dim {
|
for pre_idx in 0..prod_pre_dim {
|
||||||
for post_idx in 0..prod_post_dim {
|
for post_idx in 0..prod_post_dim {
|
||||||
@ -297,17 +407,29 @@ impl CpuStorage {
|
|||||||
Self::U32(storage) => {
|
Self::U32(storage) => {
|
||||||
let mul = mul as u32;
|
let mul = mul as u32;
|
||||||
let add = add as u32;
|
let add = add as u32;
|
||||||
let data = unary_map(shape, stride, storage, |v| v * mul + add);
|
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||||
Ok(Self::U32(data))
|
Ok(Self::U32(data))
|
||||||
}
|
}
|
||||||
|
Self::BF16(storage) => {
|
||||||
|
let mul = bf16::from_f64(mul);
|
||||||
|
let add = bf16::from_f64(add);
|
||||||
|
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||||
|
Ok(Self::BF16(data))
|
||||||
|
}
|
||||||
|
Self::F16(storage) => {
|
||||||
|
let mul = f16::from_f64(mul);
|
||||||
|
let add = f16::from_f64(add);
|
||||||
|
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||||
|
Ok(Self::F16(data))
|
||||||
|
}
|
||||||
Self::F32(storage) => {
|
Self::F32(storage) => {
|
||||||
let mul = mul as f32;
|
let mul = mul as f32;
|
||||||
let add = add as f32;
|
let add = add as f32;
|
||||||
let data = unary_map(shape, stride, storage, |v| v * mul + add);
|
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||||
Ok(Self::F32(data))
|
Ok(Self::F32(data))
|
||||||
}
|
}
|
||||||
Self::F64(storage) => {
|
Self::F64(storage) => {
|
||||||
let data = unary_map(shape, stride, storage, |v| v * mul + add);
|
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||||
Ok(Self::F64(data))
|
Ok(Self::F64(data))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -315,16 +437,24 @@ impl CpuStorage {
|
|||||||
|
|
||||||
pub(crate) fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
pub(crate) fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||||
match self {
|
match self {
|
||||||
|
Self::BF16(storage) => {
|
||||||
|
let data = unary_map(storage, shape, stride, B::bf16);
|
||||||
|
Ok(Self::BF16(data))
|
||||||
|
}
|
||||||
|
Self::F16(storage) => {
|
||||||
|
let data = unary_map(storage, shape, stride, B::f16);
|
||||||
|
Ok(Self::F16(data))
|
||||||
|
}
|
||||||
Self::F32(storage) => {
|
Self::F32(storage) => {
|
||||||
let data = unary_map(shape, stride, storage, B::f32);
|
let data = unary_map(storage, shape, stride, B::f32);
|
||||||
Ok(Self::F32(data))
|
Ok(Self::F32(data))
|
||||||
}
|
}
|
||||||
Self::F64(storage) => {
|
Self::F64(storage) => {
|
||||||
let data = unary_map(shape, stride, storage, B::f64);
|
let data = unary_map(storage, shape, stride, B::f64);
|
||||||
Ok(Self::F64(data))
|
Ok(Self::F64(data))
|
||||||
}
|
}
|
||||||
Self::U32(storage) => {
|
Self::U32(storage) => {
|
||||||
let data = unary_map(shape, stride, storage, B::u32);
|
let data = unary_map(storage, shape, stride, B::u32);
|
||||||
Ok(Self::U32(data))
|
Ok(Self::U32(data))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -338,6 +468,14 @@ impl CpuStorage {
|
|||||||
rhs_stride: &[usize],
|
rhs_stride: &[usize],
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
|
(Self::BF16(lhs), Self::BF16(rhs)) => {
|
||||||
|
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::bf16);
|
||||||
|
Ok(Self::BF16(data))
|
||||||
|
}
|
||||||
|
(Self::F16(lhs), Self::F16(rhs)) => {
|
||||||
|
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f16);
|
||||||
|
Ok(Self::F16(data))
|
||||||
|
}
|
||||||
(Self::F32(lhs), Self::F32(rhs)) => {
|
(Self::F32(lhs), Self::F32(rhs)) => {
|
||||||
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32);
|
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32);
|
||||||
Ok(Self::F32(data))
|
Ok(Self::F32(data))
|
||||||
@ -376,6 +514,12 @@ impl CpuStorage {
|
|||||||
(Self::U32(src), Self::U32(dst)) => {
|
(Self::U32(src), Self::U32(dst)) => {
|
||||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||||
}
|
}
|
||||||
|
(Self::BF16(src), Self::BF16(dst)) => {
|
||||||
|
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||||
|
}
|
||||||
|
(Self::F16(src), Self::F16(dst)) => {
|
||||||
|
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||||
|
}
|
||||||
(Self::F32(src), Self::F32(dst)) => {
|
(Self::F32(src), Self::F32(dst)) => {
|
||||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||||
}
|
}
|
||||||
@ -406,6 +550,14 @@ impl CpuStorage {
|
|||||||
// TODO: Support types that could be casted to a boolean.
|
// TODO: Support types that could be casted to a boolean.
|
||||||
let pred = self.as_slice::<u32>()?;
|
let pred = self.as_slice::<u32>()?;
|
||||||
match (t, f) {
|
match (t, f) {
|
||||||
|
(Self::BF16(t), Self::BF16(f)) => {
|
||||||
|
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||||
|
Ok(Self::BF16(data))
|
||||||
|
}
|
||||||
|
(Self::F16(t), Self::F16(f)) => {
|
||||||
|
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||||
|
Ok(Self::F16(data))
|
||||||
|
}
|
||||||
(Self::F32(t), Self::F32(f)) => {
|
(Self::F32(t), Self::F32(f)) => {
|
||||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||||
Ok(Self::F32(data))
|
Ok(Self::F32(data))
|
||||||
@ -435,20 +587,7 @@ impl CpuStorage {
|
|||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let ids = self.as_slice::<u32>()?;
|
let ids = self.as_slice::<u32>()?;
|
||||||
match vs {
|
map1!(vs, take_impl1, ids, shape, stride, vocab_size, hidden_size)
|
||||||
CpuStorage::F32(vs) => {
|
|
||||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
|
||||||
Ok(CpuStorage::F32(storage))
|
|
||||||
}
|
|
||||||
CpuStorage::F64(vs) => {
|
|
||||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
|
||||||
Ok(CpuStorage::F64(storage))
|
|
||||||
}
|
|
||||||
CpuStorage::U32(vs) => {
|
|
||||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
|
||||||
Ok(CpuStorage::U32(storage))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_impl(
|
pub(crate) fn matmul_impl(
|
||||||
@ -479,16 +618,17 @@ impl CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut dst = vec![0.0; b * m * n];
|
|
||||||
|
|
||||||
let dst_shape: Shape = (m, n).into();
|
let dst_shape: Shape = (m, n).into();
|
||||||
let dst_strides = dst_shape.stride_contiguous();
|
let dst_strides = dst_shape.stride_contiguous();
|
||||||
let dst_rs = dst_strides[0];
|
let dst_rs = dst_strides[0];
|
||||||
let dst_cs = dst_strides[1];
|
let dst_cs = dst_strides[1];
|
||||||
|
|
||||||
|
match (self, rhs) {
|
||||||
|
(CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
|
||||||
|
let mut dst = vec![f16::ZERO; b * m * n];
|
||||||
for step in 0..b {
|
for step in 0..b {
|
||||||
let lhs_p = &self.as_slice::<f32>()?[step * a_skip..];
|
let lhs_p = &lhs[step * a_skip..];
|
||||||
let rhs_p = &rhs.as_slice::<f32>()?[step * b_skip..];
|
let rhs_p = &rhs[step * b_skip..];
|
||||||
let dst_p = &mut dst[step * c_skip..];
|
let dst_p = &mut dst[step * c_skip..];
|
||||||
unsafe {
|
unsafe {
|
||||||
gemm(
|
gemm(
|
||||||
@ -519,9 +659,9 @@ impl CpuStorage {
|
|||||||
// rhs_rs: isize,
|
// rhs_rs: isize,
|
||||||
rhs_rs as isize,
|
rhs_rs as isize,
|
||||||
// alpha: T,
|
// alpha: T,
|
||||||
1.0,
|
f16::ONE,
|
||||||
// beta: T,
|
// beta: T,
|
||||||
1.0,
|
f16::ONE,
|
||||||
// conj_dst: bool,
|
// conj_dst: bool,
|
||||||
false,
|
false,
|
||||||
// conj_lhs: bool,
|
// conj_lhs: bool,
|
||||||
@ -534,8 +674,120 @@ impl CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let c = Self::F32(dst);
|
Ok(Self::F16(dst))
|
||||||
Ok(c)
|
}
|
||||||
|
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
||||||
|
let mut dst = vec![0f32; b * m * n];
|
||||||
|
for step in 0..b {
|
||||||
|
let lhs_p = &lhs[step * a_skip..];
|
||||||
|
let rhs_p = &rhs[step * b_skip..];
|
||||||
|
let dst_p = &mut dst[step * c_skip..];
|
||||||
|
unsafe {
|
||||||
|
gemm(
|
||||||
|
// m: usize,
|
||||||
|
m,
|
||||||
|
// n: usize,
|
||||||
|
n,
|
||||||
|
// k: usize,
|
||||||
|
k,
|
||||||
|
// dst: *mut T,
|
||||||
|
dst_p.as_mut_ptr(),
|
||||||
|
// dst_cs: isize,
|
||||||
|
dst_cs as isize,
|
||||||
|
// dst_rs: isize,
|
||||||
|
dst_rs as isize,
|
||||||
|
// read_dst: bool,
|
||||||
|
false,
|
||||||
|
// lhs: *const T,
|
||||||
|
lhs_p.as_ptr(),
|
||||||
|
// lhs_cs: isize,
|
||||||
|
lhs_cs as isize,
|
||||||
|
// lhs_rs: isize,
|
||||||
|
lhs_rs as isize,
|
||||||
|
// rhs: *const T,
|
||||||
|
rhs_p.as_ptr(),
|
||||||
|
// rhs_cs: isize,
|
||||||
|
rhs_cs as isize,
|
||||||
|
// rhs_rs: isize,
|
||||||
|
rhs_rs as isize,
|
||||||
|
// alpha: T,
|
||||||
|
1f32,
|
||||||
|
// beta: T,
|
||||||
|
1f32,
|
||||||
|
// conj_dst: bool,
|
||||||
|
false,
|
||||||
|
// conj_lhs: bool,
|
||||||
|
false,
|
||||||
|
// conj_rhs: bool,
|
||||||
|
true,
|
||||||
|
// parallelism: Parallelism
|
||||||
|
Parallelism::None,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self::F32(dst))
|
||||||
|
}
|
||||||
|
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
||||||
|
let mut dst = vec![0f64; b * m * n];
|
||||||
|
for step in 0..b {
|
||||||
|
let lhs_p = &lhs[step * a_skip..];
|
||||||
|
let rhs_p = &rhs[step * b_skip..];
|
||||||
|
let dst_p = &mut dst[step * c_skip..];
|
||||||
|
unsafe {
|
||||||
|
gemm(
|
||||||
|
// m: usize,
|
||||||
|
m,
|
||||||
|
// n: usize,
|
||||||
|
n,
|
||||||
|
// k: usize,
|
||||||
|
k,
|
||||||
|
// dst: *mut T,
|
||||||
|
dst_p.as_mut_ptr(),
|
||||||
|
// dst_cs: isize,
|
||||||
|
dst_cs as isize,
|
||||||
|
// dst_rs: isize,
|
||||||
|
dst_rs as isize,
|
||||||
|
// read_dst: bool,
|
||||||
|
false,
|
||||||
|
// lhs: *const T,
|
||||||
|
lhs_p.as_ptr(),
|
||||||
|
// lhs_cs: isize,
|
||||||
|
lhs_cs as isize,
|
||||||
|
// lhs_rs: isize,
|
||||||
|
lhs_rs as isize,
|
||||||
|
// rhs: *const T,
|
||||||
|
rhs_p.as_ptr(),
|
||||||
|
// rhs_cs: isize,
|
||||||
|
rhs_cs as isize,
|
||||||
|
// rhs_rs: isize,
|
||||||
|
rhs_rs as isize,
|
||||||
|
// alpha: T,
|
||||||
|
1f64,
|
||||||
|
// beta: T,
|
||||||
|
1f64,
|
||||||
|
// conj_dst: bool,
|
||||||
|
false,
|
||||||
|
// conj_lhs: bool,
|
||||||
|
false,
|
||||||
|
// conj_rhs: bool,
|
||||||
|
true,
|
||||||
|
// parallelism: Parallelism
|
||||||
|
Parallelism::None,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Self::F64(dst))
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// This should be covered by the dtype check above.
|
||||||
|
Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: self.dtype(),
|
||||||
|
rhs: rhs.dtype(),
|
||||||
|
op: "matmul",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
||||||
@ -545,6 +797,14 @@ impl CpuStorage {
|
|||||||
let data = vec![1u32; elem_count];
|
let data = vec![1u32; elem_count];
|
||||||
Self::U32(data)
|
Self::U32(data)
|
||||||
}
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
let data = vec![bf16::ONE; elem_count];
|
||||||
|
Self::BF16(data)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let data = vec![f16::ONE; elem_count];
|
||||||
|
Self::F16(data)
|
||||||
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let data = vec![1f32; elem_count];
|
let data = vec![1f32; elem_count];
|
||||||
Self::F32(data)
|
Self::F32(data)
|
||||||
@ -563,6 +823,14 @@ impl CpuStorage {
|
|||||||
let data = vec![0u32; elem_count];
|
let data = vec![0u32; elem_count];
|
||||||
Self::U32(data)
|
Self::U32(data)
|
||||||
}
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
let data = vec![bf16::ZERO; elem_count];
|
||||||
|
Self::BF16(data)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let data = vec![f16::ZERO; elem_count];
|
||||||
|
Self::F16(data)
|
||||||
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let data = vec![0f32; elem_count];
|
let data = vec![0f32; elem_count];
|
||||||
Self::F32(data)
|
Self::F32(data)
|
||||||
|
@ -2,6 +2,7 @@ use crate::{CpuStorage, DType, Shape};
|
|||||||
use candle_kernels as kernels;
|
use candle_kernels as kernels;
|
||||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||||
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
|
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
|
||||||
|
use half::{bf16, f16};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// cudarc related errors
|
/// cudarc related errors
|
||||||
@ -38,6 +39,12 @@ pub enum CudaError {
|
|||||||
expected: DType,
|
expected: DType,
|
||||||
got: DType,
|
got: DType,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
#[error("{cuda} when loading {module_name}")]
|
||||||
|
Load {
|
||||||
|
cuda: cudarc::driver::DriverError,
|
||||||
|
module_name: &'static str,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
type Result<T> = std::result::Result<T, CudaError>;
|
type Result<T> = std::result::Result<T, CudaError>;
|
||||||
@ -97,6 +104,14 @@ impl CudaDevice {
|
|||||||
let data = self.alloc_zeros::<u32>(elem_count)?;
|
let data = self.alloc_zeros::<u32>(elem_count)?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
let data = self.alloc_zeros::<bf16>(elem_count)?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let data = self.alloc_zeros::<f16>(elem_count)?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let data = self.alloc_zeros::<f32>(elem_count)?;
|
let data = self.alloc_zeros::<f32>(elem_count)?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
@ -124,6 +139,22 @@ impl CudaDevice {
|
|||||||
unsafe { func.launch(cfg, params) }?;
|
unsafe { func.launch(cfg, params) }?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<bf16>(elem_count) }?;
|
||||||
|
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
||||||
|
let params = (&data, bf16::from_f64(v), elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<f16>(elem_count) }?;
|
||||||
|
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
||||||
|
let params = (&data, f16::from_f64(v), elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<f32>(elem_count) }?;
|
let data = unsafe { self.alloc::<f32>(elem_count) }?;
|
||||||
@ -157,6 +188,14 @@ impl CudaDevice {
|
|||||||
let data = self.htod_sync_copy(storage)?;
|
let data = self.htod_sync_copy(storage)?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
|
CpuStorage::BF16(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage)?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F16(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage)?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
CpuStorage::F32(storage) => {
|
CpuStorage::F32(storage) => {
|
||||||
let data = self.htod_sync_copy(storage)?;
|
let data = self.htod_sync_copy(storage)?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
@ -178,7 +217,8 @@ impl CudaDevice {
|
|||||||
ptx: &'static str,
|
ptx: &'static str,
|
||||||
) -> Result<CudaFunction> {
|
) -> Result<CudaFunction> {
|
||||||
if !self.has_func(module_name, module_name) {
|
if !self.has_func(module_name, module_name) {
|
||||||
self.load_ptx(ptx.into(), module_name, &[module_name])?;
|
self.load_ptx(ptx.into(), module_name, &[module_name])
|
||||||
|
.map_err(|cuda| CudaError::Load { cuda, module_name })?;
|
||||||
}
|
}
|
||||||
self.get_func(module_name, module_name)
|
self.get_func(module_name, module_name)
|
||||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
||||||
@ -190,6 +230,8 @@ impl CudaDevice {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum CudaStorageSlice {
|
enum CudaStorageSlice {
|
||||||
U32(CudaSlice<u32>),
|
U32(CudaSlice<u32>),
|
||||||
|
BF16(CudaSlice<bf16>),
|
||||||
|
F16(CudaSlice<f16>),
|
||||||
F32(CudaSlice<f32>),
|
F32(CudaSlice<f32>),
|
||||||
F64(CudaSlice<f64>),
|
F64(CudaSlice<f64>),
|
||||||
}
|
}
|
||||||
@ -265,6 +307,8 @@ impl CudaStorage {
|
|||||||
pub fn try_clone(&self) -> Result<Self> {
|
pub fn try_clone(&self) -> Result<Self> {
|
||||||
let slice = match &self.slice {
|
let slice = match &self.slice {
|
||||||
CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?),
|
CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?),
|
||||||
|
CudaStorageSlice::BF16(slice) => CudaStorageSlice::BF16(slice.try_clone()?),
|
||||||
|
CudaStorageSlice::F16(slice) => CudaStorageSlice::F16(slice.try_clone()?),
|
||||||
CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
|
CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
|
||||||
CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
|
CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
|
||||||
};
|
};
|
||||||
@ -275,6 +319,8 @@ impl CudaStorage {
|
|||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
match self.slice {
|
match self.slice {
|
||||||
CudaStorageSlice::U32(_) => DType::U32,
|
CudaStorageSlice::U32(_) => DType::U32,
|
||||||
|
CudaStorageSlice::BF16(_) => DType::BF16,
|
||||||
|
CudaStorageSlice::F16(_) => DType::F16,
|
||||||
CudaStorageSlice::F32(_) => DType::F32,
|
CudaStorageSlice::F32(_) => DType::F32,
|
||||||
CudaStorageSlice::F64(_) => DType::F64,
|
CudaStorageSlice::F64(_) => DType::F64,
|
||||||
}
|
}
|
||||||
@ -310,6 +356,40 @@ impl CudaStorage {
|
|||||||
unsafe { func.launch(cfg, params) }?;
|
unsafe { func.launch(cfg, params) }?;
|
||||||
CudaStorageSlice::U32(out)
|
CudaStorageSlice::U32(out)
|
||||||
}
|
}
|
||||||
|
CudaStorageSlice::BF16(arg) => {
|
||||||
|
let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
|
||||||
|
let params = (
|
||||||
|
el_count,
|
||||||
|
dims.len(),
|
||||||
|
&ds,
|
||||||
|
arg,
|
||||||
|
&out,
|
||||||
|
bf16::from_f64(mul),
|
||||||
|
bf16::from_f64(add),
|
||||||
|
);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::BF16(out)
|
||||||
|
}
|
||||||
|
CudaStorageSlice::F16(arg) => {
|
||||||
|
let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<f16>(el_count) }?;
|
||||||
|
let params = (
|
||||||
|
el_count,
|
||||||
|
dims.len(),
|
||||||
|
&ds,
|
||||||
|
arg,
|
||||||
|
&out,
|
||||||
|
f16::from_f64(mul),
|
||||||
|
f16::from_f64(add),
|
||||||
|
);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::F16(out)
|
||||||
|
}
|
||||||
CudaStorageSlice::F32(arg) => {
|
CudaStorageSlice::F32(arg) => {
|
||||||
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
|
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
@ -361,6 +441,22 @@ impl CudaStorage {
|
|||||||
unsafe { func.launch(cfg, params) }?;
|
unsafe { func.launch(cfg, params) }?;
|
||||||
CudaStorageSlice::U32(out)
|
CudaStorageSlice::U32(out)
|
||||||
}
|
}
|
||||||
|
CudaStorageSlice::BF16(arg) => {
|
||||||
|
let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?;
|
||||||
|
let out = dev.alloc_zeros::<bf16>(dst_el)?;
|
||||||
|
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::BF16(out)
|
||||||
|
}
|
||||||
|
CudaStorageSlice::F16(arg) => {
|
||||||
|
let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?;
|
||||||
|
let out = dev.alloc_zeros::<f16>(dst_el)?;
|
||||||
|
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::F16(out)
|
||||||
|
}
|
||||||
CudaStorageSlice::F32(arg) => {
|
CudaStorageSlice::F32(arg) => {
|
||||||
let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?;
|
let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?;
|
||||||
let out = dev.alloc_zeros::<f32>(dst_el)?;
|
let out = dev.alloc_zeros::<f32>(dst_el)?;
|
||||||
@ -402,6 +498,24 @@ impl CudaStorage {
|
|||||||
CudaStorageSlice::U32(_arg) => {
|
CudaStorageSlice::U32(_arg) => {
|
||||||
todo!("No unary kernels for u32");
|
todo!("No unary kernels for u32");
|
||||||
}
|
}
|
||||||
|
CudaStorageSlice::BF16(arg) => {
|
||||||
|
let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
|
||||||
|
let params = (el_count, dims.len(), &ds, arg, &out);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::BF16(out)
|
||||||
|
}
|
||||||
|
CudaStorageSlice::F16(arg) => {
|
||||||
|
let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<f16>(el_count) }?;
|
||||||
|
let params = (el_count, dims.len(), &ds, arg, &out);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::F16(out)
|
||||||
|
}
|
||||||
CudaStorageSlice::F32(arg) => {
|
CudaStorageSlice::F32(arg) => {
|
||||||
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
|
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
@ -438,6 +552,24 @@ impl CudaStorage {
|
|||||||
let dev = self.device();
|
let dev = self.device();
|
||||||
let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?;
|
let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?;
|
||||||
let slice = match (&self.slice, &rhs.slice) {
|
let slice = match (&self.slice, &rhs.slice) {
|
||||||
|
(CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => {
|
||||||
|
let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<bf16>(elem_count) }?;
|
||||||
|
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||||
|
// SAFETY: ffi
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::BF16(out)
|
||||||
|
}
|
||||||
|
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
||||||
|
let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<f16>(elem_count) }?;
|
||||||
|
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||||
|
// SAFETY: ffi
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::F16(out)
|
||||||
|
}
|
||||||
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||||
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
|
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
@ -479,6 +611,16 @@ impl CudaStorage {
|
|||||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||||
Ok(CpuStorage::U32(cpu_storage))
|
Ok(CpuStorage::U32(cpu_storage))
|
||||||
}
|
}
|
||||||
|
CudaStorageSlice::BF16(slice) => {
|
||||||
|
let dev = slice.device();
|
||||||
|
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||||
|
Ok(CpuStorage::BF16(cpu_storage))
|
||||||
|
}
|
||||||
|
CudaStorageSlice::F16(slice) => {
|
||||||
|
let dev = slice.device();
|
||||||
|
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||||
|
Ok(CpuStorage::F16(cpu_storage))
|
||||||
|
}
|
||||||
CudaStorageSlice::F32(slice) => {
|
CudaStorageSlice::F32(slice) => {
|
||||||
let dev = slice.device();
|
let dev = slice.device();
|
||||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||||
@ -515,6 +657,24 @@ impl CudaStorage {
|
|||||||
let dev = self.device();
|
let dev = self.device();
|
||||||
let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?;
|
let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?;
|
||||||
let slice = match (&t.slice, &f.slice) {
|
let slice = match (&t.slice, &f.slice) {
|
||||||
|
(CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => {
|
||||||
|
let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<bf16>(el) }?;
|
||||||
|
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||||
|
// SAFETY: ffi
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::BF16(out)
|
||||||
|
}
|
||||||
|
(CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => {
|
||||||
|
let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<f16>(el) }?;
|
||||||
|
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||||
|
// SAFETY: ffi
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::F16(out)
|
||||||
|
}
|
||||||
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
|
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
|
||||||
let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
|
let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
@ -573,7 +733,7 @@ impl CudaStorage {
|
|||||||
let slice = match &rhs.slice {
|
let slice = match &rhs.slice {
|
||||||
// The kernels below assume that rhs is contiguous.
|
// The kernels below assume that rhs is contiguous.
|
||||||
CudaStorageSlice::U32(arg) => {
|
CudaStorageSlice::U32(arg) => {
|
||||||
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
|
let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
|
let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
|
||||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
||||||
@ -581,6 +741,24 @@ impl CudaStorage {
|
|||||||
unsafe { func.launch(cfg, params) }?;
|
unsafe { func.launch(cfg, params) }?;
|
||||||
CudaStorageSlice::U32(out)
|
CudaStorageSlice::U32(out)
|
||||||
}
|
}
|
||||||
|
CudaStorageSlice::BF16(arg) => {
|
||||||
|
let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<bf16>(el * h_size) }?;
|
||||||
|
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::BF16(out)
|
||||||
|
}
|
||||||
|
CudaStorageSlice::F16(arg) => {
|
||||||
|
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<f16>(el * h_size) }?;
|
||||||
|
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::F16(out)
|
||||||
|
}
|
||||||
CudaStorageSlice::F32(arg) => {
|
CudaStorageSlice::F32(arg) => {
|
||||||
let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
|
let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
@ -614,6 +792,19 @@ impl CudaStorage {
|
|||||||
let elem_count = b * m * n;
|
let elem_count = b * m * n;
|
||||||
let dev = &self.device;
|
let dev = &self.device;
|
||||||
let slice = match (&self.slice, &rhs.slice) {
|
let slice = match (&self.slice, &rhs.slice) {
|
||||||
|
(CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => {
|
||||||
|
todo!("bf16")
|
||||||
|
}
|
||||||
|
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
||||||
|
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||||
|
let mut out = unsafe { dev.alloc::<f16>(elem_count) }?;
|
||||||
|
unsafe {
|
||||||
|
self.device
|
||||||
|
.blas
|
||||||
|
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||||
|
}?;
|
||||||
|
CudaStorageSlice::F16(out)
|
||||||
|
}
|
||||||
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
|
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||||
let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||||
@ -657,6 +848,32 @@ impl CudaStorage {
|
|||||||
let dev = &self.device;
|
let dev = &self.device;
|
||||||
let ds = dev.htod_copy([dims, src_stride].concat())?;
|
let ds = dev.htod_copy([dims, src_stride].concat())?;
|
||||||
match (&self.slice, &mut dst.slice) {
|
match (&self.slice, &mut dst.slice) {
|
||||||
|
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
||||||
|
let src = src.slice(src_offset..);
|
||||||
|
let mut dst = dst.slice_mut(dst_offset..);
|
||||||
|
if src_shape.is_contiguous(src_stride) {
|
||||||
|
dev.dtod_copy(&src, &mut dst)?
|
||||||
|
} else {
|
||||||
|
let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
|
||||||
|
let src = src.slice(src_offset..);
|
||||||
|
let mut dst = dst.slice_mut(dst_offset..);
|
||||||
|
if src_shape.is_contiguous(src_stride) {
|
||||||
|
dev.dtod_copy(&src, &mut dst)?
|
||||||
|
} else {
|
||||||
|
let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?
|
||||||
|
}
|
||||||
|
}
|
||||||
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
||||||
let src = src.slice(src_offset..);
|
let src = src.slice(src_offset..);
|
||||||
let mut dst = dst.slice_mut(dst_offset..);
|
let mut dst = dst.slice_mut(dst_offset..);
|
||||||
@ -670,6 +887,19 @@ impl CudaStorage {
|
|||||||
unsafe { func.launch(cfg, params) }?
|
unsafe { func.launch(cfg, params) }?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
|
||||||
|
let src = src.slice(src_offset..);
|
||||||
|
let mut dst = dst.slice_mut(dst_offset..);
|
||||||
|
if src_shape.is_contiguous(src_stride) {
|
||||||
|
dev.dtod_copy(&src, &mut dst)?
|
||||||
|
} else {
|
||||||
|
let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?
|
||||||
|
}
|
||||||
|
}
|
||||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||||
let src = src.slice(src_offset..);
|
let src = src.slice(src_offset..);
|
||||||
let mut dst = dst.slice_mut(dst_offset..);
|
let mut dst = dst.slice_mut(dst_offset..);
|
||||||
|
@ -3,6 +3,8 @@ use crate::{CpuStorage, Error, Result};
|
|||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum DType {
|
pub enum DType {
|
||||||
U32,
|
U32,
|
||||||
|
BF16,
|
||||||
|
F16,
|
||||||
F32,
|
F32,
|
||||||
F64,
|
F64,
|
||||||
}
|
}
|
||||||
@ -11,6 +13,8 @@ impl DType {
|
|||||||
pub fn size_in_bytes(&self) -> usize {
|
pub fn size_in_bytes(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
Self::U32 => 4,
|
Self::U32 => 4,
|
||||||
|
Self::BF16 => 2,
|
||||||
|
Self::F16 => 2,
|
||||||
Self::F32 => 4,
|
Self::F32 => 4,
|
||||||
Self::F64 => 8,
|
Self::F64 => 8,
|
||||||
}
|
}
|
||||||
@ -76,5 +80,7 @@ macro_rules! with_dtype {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
with_dtype!(u32, U32);
|
with_dtype!(u32, U32);
|
||||||
|
with_dtype!(half::f16, F16);
|
||||||
|
with_dtype!(half::bf16, BF16);
|
||||||
with_dtype!(f32, F32);
|
with_dtype!(f32, F32);
|
||||||
with_dtype!(f64, F64);
|
with_dtype!(f64, F64);
|
||||||
|
28
src/npy.rs
28
src/npy.rs
@ -27,6 +27,7 @@
|
|||||||
//! ```
|
//! ```
|
||||||
use crate::{DType, Device, Error, Result, Shape, Tensor};
|
use crate::{DType, Device, Error, Result, Shape, Tensor};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||||
|
use half::{bf16, f16, slice::HalfFloatSliceExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{BufReader, Read, Write};
|
use std::io::{BufReader, Read, Write};
|
||||||
@ -80,6 +81,8 @@ impl Header {
|
|||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(",");
|
.join(",");
|
||||||
let descr = match self.descr {
|
let descr = match self.descr {
|
||||||
|
DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?,
|
||||||
|
DType::F16 => "f2",
|
||||||
DType::F32 => "f4",
|
DType::F32 => "f4",
|
||||||
DType::F64 => "f8",
|
DType::F64 => "f8",
|
||||||
DType::U32 => "u4",
|
DType::U32 => "u4",
|
||||||
@ -152,7 +155,7 @@ impl Header {
|
|||||||
// int64, int32, int16, int8,
|
// int64, int32, int16, int8,
|
||||||
// uint8, and bool.
|
// uint8, and bool.
|
||||||
match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') {
|
match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') {
|
||||||
// "e" | "f2" => DType::F16,
|
"e" | "f2" => DType::F16,
|
||||||
"f" | "f4" => DType::F32,
|
"f" | "f4" => DType::F32,
|
||||||
"d" | "f8" => DType::F64,
|
"d" | "f8" => DType::F64,
|
||||||
// "i" | "i4" => DType::S32,
|
// "i" | "i4" => DType::S32,
|
||||||
@ -191,9 +194,20 @@ impl Header {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
|
// TODO: Add the possibility to read directly to a device?
|
||||||
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
|
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
match dtype {
|
match dtype {
|
||||||
|
DType::BF16 => {
|
||||||
|
let mut data_t = vec![bf16::ZERO; elem_count];
|
||||||
|
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
|
||||||
|
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let mut data_t = vec![f16::ZERO; elem_count];
|
||||||
|
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
|
||||||
|
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||||
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let mut data_t = vec![0f32; elem_count];
|
let mut data_t = vec![0f32; elem_count];
|
||||||
reader.read_f32_into::<LittleEndian>(&mut data_t)?;
|
reader.read_f32_into::<LittleEndian>(&mut data_t)?;
|
||||||
@ -289,6 +303,18 @@ impl Tensor {
|
|||||||
f.write_all(header.as_bytes())?;
|
f.write_all(header.as_bytes())?;
|
||||||
let elem_count = self.elem_count();
|
let elem_count = self.elem_count();
|
||||||
match self.dtype() {
|
match self.dtype() {
|
||||||
|
DType::BF16 => {
|
||||||
|
let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
|
||||||
|
for &v in vs.reinterpret_cast() {
|
||||||
|
f.write_u16::<LittleEndian>(v)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
|
||||||
|
for &v in vs.reinterpret_cast() {
|
||||||
|
f.write_u16::<LittleEndian>(v)?
|
||||||
|
}
|
||||||
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
// TODO: Avoid using a buffer when data is already on the CPU.
|
// TODO: Avoid using a buffer when data is already on the CPU.
|
||||||
for v in self.reshape(elem_count)?.to_vec1::<f32>()? {
|
for v in self.reshape(elem_count)?.to_vec1::<f32>()? {
|
||||||
|
284
src/op.rs
284
src/op.rs
@ -1,4 +1,6 @@
|
|||||||
use crate::Tensor;
|
use crate::Tensor;
|
||||||
|
use half::{bf16, f16};
|
||||||
|
use num_traits::float::Float;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) enum Op {
|
pub(crate) enum Op {
|
||||||
@ -40,10 +42,13 @@ pub(crate) enum Op {
|
|||||||
|
|
||||||
pub(crate) trait UnaryOp {
|
pub(crate) trait UnaryOp {
|
||||||
const NAME: &'static str;
|
const NAME: &'static str;
|
||||||
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
|
const KERNEL_BF16: &'static str;
|
||||||
// contiguous case separately as it's easy to optimize things out there.
|
const KERNEL_F16: &'static str;
|
||||||
const KERNEL_F32: &'static str;
|
const KERNEL_F32: &'static str;
|
||||||
const KERNEL_F64: &'static str;
|
const KERNEL_F64: &'static str;
|
||||||
|
const KERNEL_U32: &'static str;
|
||||||
|
fn bf16(v1: bf16) -> bf16;
|
||||||
|
fn f16(v1: f16) -> f16;
|
||||||
fn f32(v1: f32) -> f32;
|
fn f32(v1: f32) -> f32;
|
||||||
fn f64(v1: f64) -> f64;
|
fn f64(v1: f64) -> f64;
|
||||||
fn u32(v1: u32) -> u32;
|
fn u32(v1: u32) -> u32;
|
||||||
@ -51,11 +56,13 @@ pub(crate) trait UnaryOp {
|
|||||||
|
|
||||||
pub(crate) trait BinaryOp {
|
pub(crate) trait BinaryOp {
|
||||||
const NAME: &'static str;
|
const NAME: &'static str;
|
||||||
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
|
const KERNEL_BF16: &'static str;
|
||||||
// contiguous case separately as it's easy to optimize things out there.
|
const KERNEL_F16: &'static str;
|
||||||
const KERNEL_F32: &'static str;
|
const KERNEL_F32: &'static str;
|
||||||
const KERNEL_F64: &'static str;
|
const KERNEL_F64: &'static str;
|
||||||
const KERNEL_U32: &'static str;
|
const KERNEL_U32: &'static str;
|
||||||
|
fn bf16(v1: bf16, v2: bf16) -> bf16;
|
||||||
|
fn f16(v1: f16, v2: f16) -> f16;
|
||||||
fn f32(v1: f32, v2: f32) -> f32;
|
fn f32(v1: f32, v2: f32) -> f32;
|
||||||
fn f64(v1: f64, v2: f64) -> f64;
|
fn f64(v1: f64, v2: f64) -> f64;
|
||||||
fn u32(v1: u32, v2: u32) -> u32;
|
fn u32(v1: u32, v2: u32) -> u32;
|
||||||
@ -75,215 +82,116 @@ pub(crate) struct Sqr;
|
|||||||
pub(crate) struct Sqrt;
|
pub(crate) struct Sqrt;
|
||||||
pub(crate) struct Gelu;
|
pub(crate) struct Gelu;
|
||||||
|
|
||||||
impl BinaryOp for Add {
|
macro_rules! bin_op {
|
||||||
const NAME: &'static str = "add";
|
($op:ident, $name: literal, $e: expr) => {
|
||||||
const KERNEL_F32: &'static str = "badd_f32";
|
impl BinaryOp for $op {
|
||||||
const KERNEL_F64: &'static str = "badd_f64";
|
const NAME: &'static str = $name;
|
||||||
const KERNEL_U32: &'static str = "badd_u32";
|
const KERNEL_BF16: &'static str = concat!("b", $name, "_bf16");
|
||||||
|
const KERNEL_F16: &'static str = concat!("b", $name, "_f16");
|
||||||
|
const KERNEL_F32: &'static str = concat!("b", $name, "_f32");
|
||||||
|
const KERNEL_F64: &'static str = concat!("b", $name, "_f64");
|
||||||
|
const KERNEL_U32: &'static str = concat!("b", $name, "_u32");
|
||||||
|
fn bf16(v1: bf16, v2: bf16) -> bf16 {
|
||||||
|
$e(v1, v2)
|
||||||
|
}
|
||||||
|
fn f16(v1: f16, v2: f16) -> f16 {
|
||||||
|
$e(v1, v2)
|
||||||
|
}
|
||||||
fn f32(v1: f32, v2: f32) -> f32 {
|
fn f32(v1: f32, v2: f32) -> f32 {
|
||||||
v1 + v2
|
$e(v1, v2)
|
||||||
}
|
}
|
||||||
fn f64(v1: f64, v2: f64) -> f64 {
|
fn f64(v1: f64, v2: f64) -> f64 {
|
||||||
v1 + v2
|
$e(v1, v2)
|
||||||
}
|
}
|
||||||
fn u32(v1: u32, v2: u32) -> u32 {
|
fn u32(v1: u32, v2: u32) -> u32 {
|
||||||
v1 + v2
|
$e(v1, v2)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BinaryOp for Sub {
|
bin_op!(Add, "add", |v1, v2| v1 + v2);
|
||||||
const NAME: &'static str = "sub";
|
bin_op!(Sub, "sub", |v1, v2| v1 - v2);
|
||||||
const KERNEL_F32: &'static str = "bsub_f32";
|
bin_op!(Mul, "mul", |v1, v2| v1 * v2);
|
||||||
const KERNEL_F64: &'static str = "bsub_f64";
|
bin_op!(Div, "div", |v1, v2| v1 / v2);
|
||||||
const KERNEL_U32: &'static str = "bsub_u32";
|
|
||||||
fn f32(v1: f32, v2: f32) -> f32 {
|
|
||||||
v1 - v2
|
|
||||||
}
|
|
||||||
fn f64(v1: f64, v2: f64) -> f64 {
|
|
||||||
v1 - v2
|
|
||||||
}
|
|
||||||
fn u32(v1: u32, v2: u32) -> u32 {
|
|
||||||
v1 - v2
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BinaryOp for Mul {
|
macro_rules! unary_op {
|
||||||
const NAME: &'static str = "mul";
|
($op: ident, $name: literal, $a: ident, $e: expr) => {
|
||||||
const KERNEL_F32: &'static str = "bmul_f32";
|
impl UnaryOp for $op {
|
||||||
const KERNEL_F64: &'static str = "bmul_f64";
|
const NAME: &'static str = $name;
|
||||||
const KERNEL_U32: &'static str = "bmul_u32";
|
const KERNEL_BF16: &'static str = concat!("u", $name, "_bf16");
|
||||||
fn f32(v1: f32, v2: f32) -> f32 {
|
const KERNEL_F16: &'static str = concat!("u", $name, "_f16");
|
||||||
v1 * v2
|
const KERNEL_F32: &'static str = concat!("u", $name, "_f32");
|
||||||
|
const KERNEL_F64: &'static str = concat!("u", $name, "_f64");
|
||||||
|
const KERNEL_U32: &'static str = concat!("u", $name, "_u32");
|
||||||
|
fn bf16($a: bf16) -> bf16 {
|
||||||
|
$e
|
||||||
}
|
}
|
||||||
fn f64(v1: f64, v2: f64) -> f64 {
|
fn f16($a: f16) -> f16 {
|
||||||
v1 * v2
|
$e
|
||||||
}
|
}
|
||||||
fn u32(v1: u32, v2: u32) -> u32 {
|
fn f32($a: f32) -> f32 {
|
||||||
v1 * v2
|
$e
|
||||||
}
|
}
|
||||||
}
|
fn f64($a: f64) -> f64 {
|
||||||
|
$e
|
||||||
impl BinaryOp for Div {
|
|
||||||
const NAME: &'static str = "div";
|
|
||||||
const KERNEL_F32: &'static str = "bdiv_f32";
|
|
||||||
const KERNEL_F64: &'static str = "bdiv_f64";
|
|
||||||
const KERNEL_U32: &'static str = "bdiv_u32";
|
|
||||||
fn f32(v1: f32, v2: f32) -> f32 {
|
|
||||||
v1 / v2
|
|
||||||
}
|
|
||||||
fn f64(v1: f64, v2: f64) -> f64 {
|
|
||||||
v1 / v2
|
|
||||||
}
|
|
||||||
fn u32(v1: u32, v2: u32) -> u32 {
|
|
||||||
v1 / v2
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UnaryOp for Exp {
|
|
||||||
const NAME: &'static str = "exp";
|
|
||||||
fn f32(v1: f32) -> f32 {
|
|
||||||
v1.exp()
|
|
||||||
}
|
|
||||||
fn f64(v1: f64) -> f64 {
|
|
||||||
v1.exp()
|
|
||||||
}
|
|
||||||
fn u32(v1: u32) -> u32 {
|
|
||||||
(v1 as f64).exp() as u32
|
|
||||||
}
|
|
||||||
const KERNEL_F32: &'static str = "uexp_f32";
|
|
||||||
const KERNEL_F64: &'static str = "uexp_f64";
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UnaryOp for Log {
|
|
||||||
const NAME: &'static str = "log";
|
|
||||||
fn f32(v1: f32) -> f32 {
|
|
||||||
v1.ln()
|
|
||||||
}
|
|
||||||
fn f64(v1: f64) -> f64 {
|
|
||||||
v1.ln()
|
|
||||||
}
|
|
||||||
fn u32(v1: u32) -> u32 {
|
|
||||||
(v1 as f64).ln() as u32
|
|
||||||
}
|
|
||||||
const KERNEL_F32: &'static str = "ulog_f32";
|
|
||||||
const KERNEL_F64: &'static str = "ulog_f64";
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UnaryOp for Sin {
|
|
||||||
const NAME: &'static str = "sin";
|
|
||||||
fn f32(v1: f32) -> f32 {
|
|
||||||
v1.sin()
|
|
||||||
}
|
|
||||||
fn f64(v1: f64) -> f64 {
|
|
||||||
v1.sin()
|
|
||||||
}
|
}
|
||||||
fn u32(_: u32) -> u32 {
|
fn u32(_: u32) -> u32 {
|
||||||
0
|
todo!("no unary function for u32")
|
||||||
}
|
}
|
||||||
const KERNEL_F32: &'static str = "usin_f32";
|
}
|
||||||
const KERNEL_F64: &'static str = "usin_f64";
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UnaryOp for Cos {
|
unary_op!(Exp, "exp", v, v.exp());
|
||||||
const NAME: &'static str = "cos";
|
unary_op!(Log, "log", v, v.ln());
|
||||||
fn f32(v1: f32) -> f32 {
|
unary_op!(Sin, "sin", v, v.sin());
|
||||||
v1.cos()
|
unary_op!(Cos, "cos", v, v.cos());
|
||||||
}
|
unary_op!(Abs, "abs", v, v.abs());
|
||||||
fn f64(v1: f64) -> f64 {
|
unary_op!(Neg, "neg", v, -v);
|
||||||
v1.cos()
|
unary_op!(Sqr, "sqr", v, v * v);
|
||||||
}
|
unary_op!(Sqrt, "sqrt", v, v.sqrt());
|
||||||
fn u32(_: u32) -> u32 {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
const KERNEL_F32: &'static str = "ucos_f32";
|
|
||||||
const KERNEL_F64: &'static str = "ucos_f64";
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UnaryOp for Abs {
|
|
||||||
const NAME: &'static str = "abs";
|
|
||||||
fn f32(v1: f32) -> f32 {
|
|
||||||
v1.abs()
|
|
||||||
}
|
|
||||||
fn f64(v1: f64) -> f64 {
|
|
||||||
v1.abs()
|
|
||||||
}
|
|
||||||
fn u32(v1: u32) -> u32 {
|
|
||||||
v1
|
|
||||||
}
|
|
||||||
const KERNEL_F32: &'static str = "uabs_f32";
|
|
||||||
const KERNEL_F64: &'static str = "uabs_f64";
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UnaryOp for Neg {
|
|
||||||
const NAME: &'static str = "neg";
|
|
||||||
fn f32(v1: f32) -> f32 {
|
|
||||||
-v1
|
|
||||||
}
|
|
||||||
fn f64(v1: f64) -> f64 {
|
|
||||||
-v1
|
|
||||||
}
|
|
||||||
fn u32(_: u32) -> u32 {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
const KERNEL_F32: &'static str = "uneg_f32";
|
|
||||||
const KERNEL_F64: &'static str = "uneg_f64";
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UnaryOp for Sqr {
|
|
||||||
const NAME: &'static str = "sqr";
|
|
||||||
fn f32(v1: f32) -> f32 {
|
|
||||||
v1 * v1
|
|
||||||
}
|
|
||||||
fn f64(v1: f64) -> f64 {
|
|
||||||
v1 * v1
|
|
||||||
}
|
|
||||||
fn u32(v: u32) -> u32 {
|
|
||||||
v * v
|
|
||||||
}
|
|
||||||
const KERNEL_F32: &'static str = "usqr_f32";
|
|
||||||
const KERNEL_F64: &'static str = "usqr_f64";
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UnaryOp for Sqrt {
|
|
||||||
const NAME: &'static str = "sqrt";
|
|
||||||
fn f32(v1: f32) -> f32 {
|
|
||||||
v1.sqrt()
|
|
||||||
}
|
|
||||||
fn f64(v1: f64) -> f64 {
|
|
||||||
v1.sqrt()
|
|
||||||
}
|
|
||||||
fn u32(v: u32) -> u32 {
|
|
||||||
(v as f64).sqrt() as u32
|
|
||||||
}
|
|
||||||
const KERNEL_F32: &'static str = "usqrt_f32";
|
|
||||||
const KERNEL_F64: &'static str = "usqrt_f64";
|
|
||||||
}
|
|
||||||
|
|
||||||
/// `gelu` operation
|
/// `gelu` operation
|
||||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||||
#[inline]
|
|
||||||
pub fn gelu_f32(v: f32) -> f32 {
|
|
||||||
0.5 * v
|
|
||||||
* (1.0 + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
|
||||||
}
|
|
||||||
/// `gelu` operation
|
|
||||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
|
||||||
#[inline]
|
|
||||||
pub fn gelu_f64(v: f64) -> f64 {
|
|
||||||
0.5 * v
|
|
||||||
* (1.0 + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
|
||||||
}
|
|
||||||
impl UnaryOp for Gelu {
|
impl UnaryOp for Gelu {
|
||||||
const NAME: &'static str = "gelu";
|
const NAME: &'static str = "gelu";
|
||||||
fn f32(v1: f32) -> f32 {
|
fn bf16(v: bf16) -> bf16 {
|
||||||
gelu_f32(v1)
|
bf16::from_f32_const(0.5)
|
||||||
|
* v
|
||||||
|
* (bf16::ONE
|
||||||
|
+ bf16::tanh(
|
||||||
|
(bf16::from_f32_const(2.0) / bf16::PI).sqrt()
|
||||||
|
* v
|
||||||
|
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
fn f64(v1: f64) -> f64 {
|
fn f16(v: f16) -> f16 {
|
||||||
gelu_f64(v1)
|
f16::from_f32_const(0.5)
|
||||||
|
* v
|
||||||
|
* (f16::ONE
|
||||||
|
+ f16::tanh(
|
||||||
|
(f16::from_f32_const(2.0) / f16::PI).sqrt()
|
||||||
|
* v
|
||||||
|
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
fn u32(v1: u32) -> u32 {
|
fn f32(v: f32) -> f32 {
|
||||||
gelu_f64(v1 as f64) as u32
|
0.5 * v
|
||||||
|
* (1.0
|
||||||
|
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||||
}
|
}
|
||||||
|
fn f64(v: f64) -> f64 {
|
||||||
|
0.5 * v
|
||||||
|
* (1.0
|
||||||
|
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||||
|
}
|
||||||
|
fn u32(_: u32) -> u32 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
const KERNEL_BF16: &'static str = "gelu_bf16";
|
||||||
|
const KERNEL_F16: &'static str = "gelu_f16";
|
||||||
const KERNEL_F32: &'static str = "gelu_f32";
|
const KERNEL_F32: &'static str = "gelu_f32";
|
||||||
const KERNEL_F64: &'static str = "gelu_f64";
|
const KERNEL_F64: &'static str = "gelu_f64";
|
||||||
|
const KERNEL_U32: &'static str = "gelu_u32";
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user