mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
More f16 and bf16 support.
This commit is contained in:
@ -23,7 +23,7 @@ candle-kernels = { path = "kernels", optional = true }
|
||||
gemm = "0.15.4"
|
||||
zip = { version = "0.6.6", default-features=false }
|
||||
byteorder = "1.4.3"
|
||||
half = "2.3.1"
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
num-traits = "0.2.15"
|
||||
|
||||
[dev-dependencies]
|
||||
@ -33,5 +33,5 @@ rand = "0.8.5"
|
||||
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
default = ["cuda"]
|
||||
cuda = ["dep:cudarc", "dep:candle-kernels"]
|
||||
|
@ -46,10 +46,37 @@ fn wcond<T: Copy>(
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! map1 {
|
||||
($v: expr, $fn: ident, $( $args:expr ),*) => {{
|
||||
let v = match $v {
|
||||
CpuStorage::BF16(__s) => CpuStorage::BF16($fn::<bf16>(__s, $($args),*)?),
|
||||
CpuStorage::F16(__s) => CpuStorage::F16($fn::<f16>(__s, $($args),*)?),
|
||||
CpuStorage::F32(__s) => CpuStorage::F32($fn::<f32>(__s, $($args),*)?),
|
||||
CpuStorage::F64(__s) => CpuStorage::F64($fn::<f64>(__s, $($args),*)?),
|
||||
CpuStorage::U32(__s) => CpuStorage::U32($fn::<u32>(__s, $($args),*)?),
|
||||
};
|
||||
Ok(v)
|
||||
}};
|
||||
}
|
||||
|
||||
fn sum_impl1<T: Copy + num_traits::NumAssign>(
|
||||
src: &[T],
|
||||
dst_shape: &Shape,
|
||||
src_dims: &[usize],
|
||||
stride: &[usize],
|
||||
to_dst_index: impl Fn(usize) -> usize,
|
||||
) -> Result<Vec<T>> {
|
||||
let mut dst = vec![T::zero(); dst_shape.elem_count()];
|
||||
for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() {
|
||||
dst[to_dst_index(unstr_index)] += src[src_index];
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||
vs: &[T],
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
vs: &[T],
|
||||
mut f: F,
|
||||
) -> Vec<U> {
|
||||
if shape.is_contiguous(stride) {
|
||||
@ -83,11 +110,11 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
}
|
||||
}
|
||||
|
||||
fn take<T: Copy>(
|
||||
fn take_impl1<T: Copy>(
|
||||
vs: &[T],
|
||||
ids: &[u32],
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
vs: &[T],
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
) -> Result<Vec<T>> {
|
||||
@ -153,40 +180,104 @@ impl CpuStorage {
|
||||
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
||||
// TODO: find a way around the quadratic number of cases below.
|
||||
match (self, dtype) {
|
||||
(Self::U32(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v as f32));
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F16(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v.to_f32()));
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F32(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, shape, stride, bf16::from_f32);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F64(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, shape, stride, bf16::from_f64);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::U32(storage), DType::F16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| f16::from_f32(v as f32));
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::F16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| f16::from_f32(v.to_f32()));
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F16(storage), DType::F16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F32(storage), DType::F16) => {
|
||||
let data = unary_map(storage, shape, stride, f16::from_f32);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F64(storage), DType::F16) => {
|
||||
let data = unary_map(storage, shape, stride, f16::from_f64);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::U32(storage), DType::F32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as f32);
|
||||
let data = unary_map(storage, shape, stride, |v| v as f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::F32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f32());
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F16(storage), DType::F32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f32());
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F32(storage), DType::F32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v);
|
||||
let data = unary_map(storage, shape, stride, |v| v);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F64(storage), DType::F32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as f32);
|
||||
let data = unary_map(storage, shape, stride, |v| v as f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::U32(storage), DType::U32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v);
|
||||
let data = unary_map(storage, shape, stride, |v| v);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::U32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::F16(storage), DType::U32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::F32(storage), DType::U32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as u32);
|
||||
let data = unary_map(storage, shape, stride, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::F64(storage), DType::U32) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as u32);
|
||||
let data = unary_map(storage, shape, stride, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::U32(storage), DType::F64) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as f64);
|
||||
let data = unary_map(storage, shape, stride, |v| v as f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::F64) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f64());
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::F16(storage), DType::F64) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f64());
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::F32(storage), DType::F64) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v as f64);
|
||||
let data = unary_map(storage, shape, stride, |v| v as f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::F64(storage), DType::F64) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v);
|
||||
let data = unary_map(storage, shape, stride, |v| v);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
}
|
||||
@ -219,29 +310,7 @@ impl CpuStorage {
|
||||
dst_index
|
||||
};
|
||||
// TODO: Maybe provide an implementation with higher precision accumulators?
|
||||
match self {
|
||||
Self::F32(src) => {
|
||||
let mut dst = vec![0f32; dst_shape.elem_count()];
|
||||
for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() {
|
||||
dst[to_dst_index(unstr_index)] += src[src_index];
|
||||
}
|
||||
Ok(Self::F32(dst))
|
||||
}
|
||||
Self::F64(src) => {
|
||||
let mut dst = vec![0f64; dst_shape.elem_count()];
|
||||
for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() {
|
||||
dst[to_dst_index(unstr_index)] += src[src_index];
|
||||
}
|
||||
Ok(Self::F64(dst))
|
||||
}
|
||||
Self::U32(src) => {
|
||||
let mut dst = vec![0u32; dst_shape.elem_count()];
|
||||
for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() {
|
||||
dst[to_dst_index(unstr_index)] += src[src_index];
|
||||
}
|
||||
Ok(Self::U32(dst))
|
||||
}
|
||||
}
|
||||
map1!(self, sum_impl1, &dst_shape, src_dims, stride, to_dst_index)
|
||||
}
|
||||
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
@ -251,6 +320,42 @@ impl CpuStorage {
|
||||
let prod_pre_dim = dims[..dim].iter().product();
|
||||
let prod_post_dim = dims[dim + 1..].iter().product();
|
||||
match self {
|
||||
Self::BF16(storage) => {
|
||||
for pre_idx in 0..prod_pre_dim {
|
||||
for post_idx in 0..prod_post_dim {
|
||||
let mut sum = 0f64;
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
sum += storage[idx].to_f64();
|
||||
idx += prod_post_dim
|
||||
}
|
||||
let sum = bf16::from_f64(sum);
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
storage[idx] /= sum;
|
||||
idx += prod_post_dim
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Self::F16(storage) => {
|
||||
for pre_idx in 0..prod_pre_dim {
|
||||
for post_idx in 0..prod_post_dim {
|
||||
let mut sum = 0f64;
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
sum += storage[idx].to_f64();
|
||||
idx += prod_post_dim
|
||||
}
|
||||
let sum = f16::from_f64(sum);
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
storage[idx] /= sum;
|
||||
idx += prod_post_dim
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Self::F32(storage) => {
|
||||
for pre_idx in 0..prod_pre_dim {
|
||||
for post_idx in 0..prod_post_dim {
|
||||
@ -302,17 +407,29 @@ impl CpuStorage {
|
||||
Self::U32(storage) => {
|
||||
let mul = mul as u32;
|
||||
let add = add as u32;
|
||||
let data = unary_map(shape, stride, storage, |v| v * mul + add);
|
||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
Self::BF16(storage) => {
|
||||
let mul = bf16::from_f64(mul);
|
||||
let add = bf16::from_f64(add);
|
||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
Self::F16(storage) => {
|
||||
let mul = f16::from_f64(mul);
|
||||
let add = f16::from_f64(add);
|
||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
Self::F32(storage) => {
|
||||
let mul = mul as f32;
|
||||
let add = add as f32;
|
||||
let data = unary_map(shape, stride, storage, |v| v * mul + add);
|
||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
Self::F64(storage) => {
|
||||
let data = unary_map(shape, stride, storage, |v| v * mul + add);
|
||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
}
|
||||
@ -320,16 +437,24 @@ impl CpuStorage {
|
||||
|
||||
pub(crate) fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||
match self {
|
||||
Self::BF16(storage) => {
|
||||
let data = unary_map(storage, shape, stride, B::bf16);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
Self::F16(storage) => {
|
||||
let data = unary_map(storage, shape, stride, B::f16);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
Self::F32(storage) => {
|
||||
let data = unary_map(shape, stride, storage, B::f32);
|
||||
let data = unary_map(storage, shape, stride, B::f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
Self::F64(storage) => {
|
||||
let data = unary_map(shape, stride, storage, B::f64);
|
||||
let data = unary_map(storage, shape, stride, B::f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
Self::U32(storage) => {
|
||||
let data = unary_map(shape, stride, storage, B::u32);
|
||||
let data = unary_map(storage, shape, stride, B::u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
}
|
||||
@ -343,6 +468,14 @@ impl CpuStorage {
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
match (self, rhs) {
|
||||
(Self::BF16(lhs), Self::BF16(rhs)) => {
|
||||
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::bf16);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F16(lhs), Self::F16(rhs)) => {
|
||||
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f16);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F32(lhs), Self::F32(rhs)) => {
|
||||
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32);
|
||||
Ok(Self::F32(data))
|
||||
@ -381,6 +514,12 @@ impl CpuStorage {
|
||||
(Self::U32(src), Self::U32(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::BF16(src), Self::BF16(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::F16(src), Self::F16(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::F32(src), Self::F32(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
@ -411,6 +550,14 @@ impl CpuStorage {
|
||||
// TODO: Support types that could be casted to a boolean.
|
||||
let pred = self.as_slice::<u32>()?;
|
||||
match (t, f) {
|
||||
(Self::BF16(t), Self::BF16(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F16(t), Self::F16(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F32(t), Self::F32(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
Ok(Self::F32(data))
|
||||
@ -440,20 +587,7 @@ impl CpuStorage {
|
||||
vocab_size: usize,
|
||||
) -> Result<Self> {
|
||||
let ids = self.as_slice::<u32>()?;
|
||||
match vs {
|
||||
CpuStorage::F32(vs) => {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::F32(storage))
|
||||
}
|
||||
CpuStorage::F64(vs) => {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::F64(storage))
|
||||
}
|
||||
CpuStorage::U32(vs) => {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::U32(storage))
|
||||
}
|
||||
}
|
||||
map1!(vs, take_impl1, ids, shape, stride, vocab_size, hidden_size)
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
|
@ -133,6 +133,22 @@ impl CudaDevice {
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count) }?;
|
||||
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
||||
let params = (&data, bf16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f16>(elem_count) }?;
|
||||
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
||||
let params = (&data, f16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count) }?;
|
||||
@ -166,6 +182,14 @@ impl CudaDevice {
|
||||
let data = self.htod_sync_copy(storage)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage)?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.htod_sync_copy(storage)?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.htod_sync_copy(storage)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
@ -325,6 +349,40 @@ impl CudaStorage {
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
CudaStorageSlice::BF16(arg) => {
|
||||
let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
|
||||
let params = (
|
||||
el_count,
|
||||
dims.len(),
|
||||
&ds,
|
||||
arg,
|
||||
&out,
|
||||
bf16::from_f64(mul),
|
||||
bf16::from_f64(add),
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
CudaStorageSlice::F16(arg) => {
|
||||
let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f16>(el_count) }?;
|
||||
let params = (
|
||||
el_count,
|
||||
dims.len(),
|
||||
&ds,
|
||||
arg,
|
||||
&out,
|
||||
f16::from_f64(mul),
|
||||
f16::from_f64(add),
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
CudaStorageSlice::F32(arg) => {
|
||||
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -376,6 +434,22 @@ impl CudaStorage {
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
CudaStorageSlice::BF16(arg) => {
|
||||
let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<bf16>(dst_el)?;
|
||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
CudaStorageSlice::F16(arg) => {
|
||||
let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<f16>(dst_el)?;
|
||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
CudaStorageSlice::F32(arg) => {
|
||||
let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<f32>(dst_el)?;
|
||||
@ -417,6 +491,24 @@ impl CudaStorage {
|
||||
CudaStorageSlice::U32(_arg) => {
|
||||
todo!("No unary kernels for u32");
|
||||
}
|
||||
CudaStorageSlice::BF16(arg) => {
|
||||
let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
|
||||
let params = (el_count, dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
CudaStorageSlice::F16(arg) => {
|
||||
let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f16>(el_count) }?;
|
||||
let params = (el_count, dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
CudaStorageSlice::F32(arg) => {
|
||||
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -453,6 +545,24 @@ impl CudaStorage {
|
||||
let dev = self.device();
|
||||
let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?;
|
||||
let slice = match (&self.slice, &rhs.slice) {
|
||||
(CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => {
|
||||
let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<bf16>(elem_count) }?;
|
||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
||||
let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f16>(elem_count) }?;
|
||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -494,6 +604,16 @@ impl CudaStorage {
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||
Ok(CpuStorage::U32(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::BF16(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||
Ok(CpuStorage::BF16(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::F16(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||
Ok(CpuStorage::F16(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::F32(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||
@ -530,6 +650,24 @@ impl CudaStorage {
|
||||
let dev = self.device();
|
||||
let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?;
|
||||
let slice = match (&t.slice, &f.slice) {
|
||||
(CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => {
|
||||
let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<bf16>(el) }?;
|
||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
(CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => {
|
||||
let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f16>(el) }?;
|
||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
|
||||
let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -596,6 +734,24 @@ impl CudaStorage {
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
CudaStorageSlice::BF16(arg) => {
|
||||
let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<bf16>(el * h_size) }?;
|
||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
CudaStorageSlice::F16(arg) => {
|
||||
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f16>(el * h_size) }?;
|
||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
CudaStorageSlice::F32(arg) => {
|
||||
let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -629,6 +785,12 @@ impl CudaStorage {
|
||||
let elem_count = b * m * n;
|
||||
let dev = &self.device;
|
||||
let slice = match (&self.slice, &rhs.slice) {
|
||||
(CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => {
|
||||
todo!("bf16")
|
||||
}
|
||||
(CudaStorageSlice::F16(_lhs), CudaStorageSlice::F16(_rhs)) => {
|
||||
todo!("f16")
|
||||
}
|
||||
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||
let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||
@ -672,6 +834,32 @@ impl CudaStorage {
|
||||
let dev = &self.device;
|
||||
let ds = dev.htod_copy([dims, src_stride].concat())?;
|
||||
match (&self.slice, &mut dst.slice) {
|
||||
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
@ -685,6 +873,19 @@ impl CudaStorage {
|
||||
unsafe { func.launch(cfg, params) }?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
|
Reference in New Issue
Block a user