Merge pull request #33 from LaurentMazare/cuda-map

Simplify the dtype matchings in the cuda backend
This commit is contained in:
Laurent Mazare
2023-06-29 10:14:12 +01:00
committed by GitHub
6 changed files with 292 additions and 455 deletions

View File

@ -487,6 +487,7 @@ fn main() -> Result<()> {
let mut rng = thread_rng();
let start_gen = std::time::Instant::now();
for index in 0..args.sample_len {
let start_gen = std::time::Instant::now();
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
let input = Tensor::new(ctxt, &device)?;
let logits = llama.forward(&input, &freqs_cis)?;
@ -496,6 +497,7 @@ fn main() -> Result<()> {
let next_token = distr.sample(&mut rng) as u32;
tokens.push(next_token);
new_tokens.push(next_token);
println!("> {:?}", start_gen.elapsed());
println!(
"{} token: {} '{}'",
index + 1,

View File

@ -1,7 +1,9 @@
use crate::{CpuStorage, DType, Layout, Shape};
use crate::{CpuStorage, DType, Layout, Shape, WithDType};
use candle_kernels as kernels;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig};
use cudarc::driver::{
CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
};
use half::{bf16, f16};
use std::sync::Arc;
@ -242,6 +244,260 @@ enum CudaStorageSlice {
F32(CudaSlice<f32>),
F64(CudaSlice<f64>),
}
type S = CudaStorageSlice;
trait Map1 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>>;
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U32(s) => S::U32(self.f(s, d, l)?),
S::BF16(s) => S::BF16(self.f(s, d, l)?),
S::F16(s) => S::F16(self.f(s, d, l)?),
S::F32(s) => S::F32(self.f(s, d, l)?),
S::F64(s) => S::F64(self.f(s, d, l)?),
};
Ok(out)
}
}
trait Map2 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>>;
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) {
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
};
Ok(out)
}
}
struct Clone;
impl Map1 for Clone {
fn f<T: DeviceRepr>(
&self,
s: &CudaSlice<T>,
_: &CudaDevice,
_: &Layout,
) -> Result<CudaSlice<T>> {
Ok(s.try_clone()?)
}
}
fn kernel_name<T: WithDType>(root: &str) -> String {
let dtype = T::DTYPE.as_str();
format!("{root}_{dtype}")
}
struct Affine(f64, f64);
impl Map1 for Affine {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }?;
let params = (
el,
dims.len(),
&ds,
src,
&out,
T::from_f64(self.0),
T::from_f64(self.1),
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
Ok(out)
}
}
struct Sum<'a>(&'a [usize]);
impl<'a> Map1 for Sum<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let src_dims = shape.dims();
let el = shape.elem_count();
let mut dst_el = el;
for &sum_dim in self.0.iter() {
dst_el /= src_dims[sum_dim];
}
let mut sum_dims = self.0.to_vec();
// Sort the sum_dims as they have to be processed from left to right when converting the
// indexes.
sum_dims.sort();
let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
let sum_dims_s: Vec<usize> = sum_dims
.iter()
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
.collect();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("sum"), kernels::REDUCE)?;
let out = dev.alloc_zeros::<T>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
Ok(out)
}
}
impl<U: crate::op::UnaryOp> Map1 for U {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el_count) }?;
let params = (el_count, dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
Ok(out)
}
}
struct Embedding<'a>(&'a CudaStorage, &'a Layout);
impl<'a> Map1 for Embedding<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
rhs: &CudaSlice<T>,
dev: &CudaDevice,
rhs_l: &Layout,
) -> Result<CudaSlice<T>> {
let ids_l = &self.1;
let ids = match &self.0.slice {
CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
_ => Err(CudaError::UnexpectedDType {
msg: "embedding ids should be u32",
expected: DType::U32,
got: self.0.dtype(),
})?,
};
let ids = &ids;
let shape = ids_l.shape();
let (v_size, h_size) = rhs_l
.shape()
.r2()
.map_err(|e| CudaError::WrappedError(Box::new(e)))?;
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, ids_l.stride()].concat())?;
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("emb"), kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
Ok(out)
}
}
struct WhereCond<'a>(&'a CudaStorage, &'a Layout);
impl<'a> Map2 for WhereCond<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
t: &CudaSlice<T>,
layout_t: &Layout,
f: &CudaSlice<T>,
layout_f: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let ids_l = &self.1;
let ids = match &self.0.slice {
CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
_ => Err(CudaError::UnexpectedDType {
msg: "where conditions should be u32",
expected: DType::U32,
got: self.0.dtype(),
})?,
};
let ids = &ids;
let shape = ids_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds =
dev.htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?;
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("where"), kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
Ok(out)
}
}
impl<U: crate::op::BinaryOp> Map2 for U {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
lhs: &CudaSlice<T>,
lhs_l: &Layout,
rhs: &CudaSlice<T>,
rhs_l: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let shape = lhs_l.shape();
let dims = shape.dims();
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?;
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
Ok(out)
}
}
fn slice_src_and_dst<'a, T>(
src: &'a CudaSlice<T>,
@ -332,14 +588,8 @@ fn gemm_config<T>(
}
impl CudaStorage {
pub fn try_clone(&self) -> Result<Self> {
let slice = match &self.slice {
CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?),
CudaStorageSlice::BF16(slice) => CudaStorageSlice::BF16(slice.try_clone()?),
CudaStorageSlice::F16(slice) => CudaStorageSlice::F16(slice.try_clone()?),
CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
};
pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
let slice = Clone.map(&self.slice, self.device(), layout)?;
let device = self.device.clone();
Ok(Self { slice, device })
}
@ -420,152 +670,14 @@ impl CudaStorage {
}
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = self.device();
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let slice = match &self.slice {
CudaStorageSlice::U32(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<u32>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out, mul as u32, add as u32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
CudaStorageSlice::BF16(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
let params = (
el_count,
dims.len(),
&ds,
arg,
&out,
bf16::from_f64(mul),
bf16::from_f64(add),
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el_count) }?;
let params = (
el_count,
dims.len(),
&ds,
arg,
&out,
f16::from_f64(mul),
f16::from_f64(add),
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f64>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out, mul, add);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
};
let device = dev.clone();
let device = self.device().clone();
let slice = Affine(mul, add).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let shape = layout.shape();
let src_dims = shape.dims();
let el = shape.elem_count();
let mut dst_el = el;
for &sum_dim in sum_dims.iter() {
dst_el /= src_dims[sum_dim];
}
let mut sum_dims = sum_dims.to_vec();
// Sort the sum_dims as they have to be processed from left to right when converting the
// indexes.
sum_dims.sort();
let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
let sum_dims_s: Vec<usize> = sum_dims
.iter()
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
.collect();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
let slice = match &self.slice {
CudaStorageSlice::U32(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?;
let out = dev.alloc_zeros::<u32>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
CudaStorageSlice::BF16(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?;
let out = dev.alloc_zeros::<bf16>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?;
let out = dev.alloc_zeros::<f16>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?;
let out = dev.alloc_zeros::<f32>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?;
let out = dev.alloc_zeros::<f64>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
};
let device = dev.clone();
let device = self.device().clone();
let slice = Sum(sum_dims).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
@ -576,58 +688,8 @@ impl CudaStorage {
}
pub(crate) fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device;
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let slice = match &self.slice {
CudaStorageSlice::U32(_arg) => {
todo!("No unary kernels for u32");
}
CudaStorageSlice::BF16(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f64>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
};
let device = dev.clone();
let device = self.device().clone();
let slice = U::V.map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
@ -637,72 +699,8 @@ impl CudaStorage {
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let shape = lhs_l.shape();
let dims = shape.dims();
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dev = self.device();
let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?;
let slice = match (&self.slice, &rhs.slice) {
(CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?;
let out = unsafe { dev.alloc::<f64>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
(CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?;
let out = unsafe { dev.alloc::<u32>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
// The dtypes should have been checked at this point so this is an internal error.
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
};
let device = dev.clone();
let device = self.device().clone();
let slice = B::V.map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?;
Ok(Self { slice, device })
}
@ -740,163 +738,18 @@ impl CudaStorage {
&self,
layout: &Layout,
t: &Self,
layout_t: &Layout,
t_l: &Layout,
f: &Self,
layout_f: &Layout,
f_l: &Layout,
) -> Result<Self> {
let ids = match &self.slice {
CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
_ => Err(CudaError::UnexpectedDType {
msg: "where conditions should be u32",
expected: DType::U32,
got: self.dtype(),
})?,
};
let ids = &ids;
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
let ds =
dev.htod_copy([dims, layout.stride(), layout_t.stride(), layout_f.stride()].concat())?;
let slice = match (&t.slice, &f.slice) {
(CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => {
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
(CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => {
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => {
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?;
let out = unsafe { dev.alloc::<f64>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
(CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => {
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?;
let out = unsafe { dev.alloc::<u32>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
// The dtypes should have been checked at this point so this is an internal error.
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
};
let device = dev.clone();
let device = self.device().clone();
let slice = WhereCond(self, layout).map(&t.slice, t_l, &f.slice, f_l, &device)?;
Ok(Self { slice, device })
}
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
let ids = match &self.slice {
CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
_ => Err(CudaError::UnexpectedDType {
msg: "embedding ids should be u32",
expected: DType::U32,
got: self.dtype(),
})?,
};
let ids = &ids;
let shape = layout.shape();
let (v_size, h_size) = rhs_l
.shape()
.r2()
.map_err(|e| CudaError::WrappedError(Box::new(e)))?;
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let slice = match &rhs.slice {
// The kernels below assume that rhs is contiguous.
CudaStorageSlice::U32(arg) => {
let arg = &arg.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
CudaStorageSlice::BF16(arg) => {
let arg = &arg.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
let arg = &arg.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
let arg = &arg.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
let arg = &arg.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f64>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
};
let device = dev.clone();
let device = self.device().clone();
let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
Ok(Self { slice, device })
}

View File

@ -44,7 +44,7 @@ impl CudaDevice {
pub struct CudaStorage;
impl CudaStorage {
pub fn try_clone(&self) -> Result<Self> {
pub fn try_clone(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -43,11 +43,8 @@ pub(crate) enum Op {
pub(crate) trait UnaryOp {
const NAME: &'static str;
const KERNEL_BF16: &'static str;
const KERNEL_F16: &'static str;
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
const KERNEL_U32: &'static str;
const KERNEL: &'static str;
const V: Self;
fn bf16(v1: bf16) -> bf16;
fn f16(v1: f16) -> f16;
fn f32(v1: f32) -> f32;
@ -57,11 +54,8 @@ pub(crate) trait UnaryOp {
pub(crate) trait BinaryOp {
const NAME: &'static str;
const KERNEL_BF16: &'static str;
const KERNEL_F16: &'static str;
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
const KERNEL_U32: &'static str;
const KERNEL: &'static str;
const V: Self;
fn bf16(v1: bf16, v2: bf16) -> bf16;
fn f16(v1: f16, v2: f16) -> f16;
fn f32(v1: f32, v2: f32) -> f32;
@ -88,11 +82,8 @@ macro_rules! bin_op {
($op:ident, $name: literal, $e: expr) => {
impl BinaryOp for $op {
const NAME: &'static str = $name;
const KERNEL_BF16: &'static str = concat!("b", $name, "_bf16");
const KERNEL_F16: &'static str = concat!("b", $name, "_f16");
const KERNEL_F32: &'static str = concat!("b", $name, "_f32");
const KERNEL_F64: &'static str = concat!("b", $name, "_f64");
const KERNEL_U32: &'static str = concat!("b", $name, "_u32");
const KERNEL: &'static str = concat!("b", $name);
const V: Self = $op;
fn bf16(v1: bf16, v2: bf16) -> bf16 {
$e(v1, v2)
}
@ -121,11 +112,8 @@ macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => {
impl UnaryOp for $op {
const NAME: &'static str = $name;
const KERNEL_BF16: &'static str = concat!("u", $name, "_bf16");
const KERNEL_F16: &'static str = concat!("u", $name, "_f16");
const KERNEL_F32: &'static str = concat!("u", $name, "_f32");
const KERNEL_F64: &'static str = concat!("u", $name, "_f64");
const KERNEL_U32: &'static str = concat!("u", $name, "_u32");
const KERNEL: &'static str = concat!("u", $name);
const V: Self = $op;
fn bf16($a: bf16) -> bf16 {
$e
}
@ -158,6 +146,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt());
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
impl UnaryOp for Gelu {
const NAME: &'static str = "gelu";
const V: Self = Gelu;
fn bf16(v: bf16) -> bf16 {
bf16::from_f32_const(0.5)
* v
@ -191,20 +180,13 @@ impl UnaryOp for Gelu {
fn u32(_: u32) -> u32 {
0
}
const KERNEL_BF16: &'static str = "ugelu_bf16";
const KERNEL_F16: &'static str = "ugelu_f16";
const KERNEL_F32: &'static str = "ugelu_f32";
const KERNEL_F64: &'static str = "ugelu_f64";
const KERNEL_U32: &'static str = "ugelu_u32";
const KERNEL: &'static str = "ugelu";
}
impl UnaryOp for Relu {
const NAME: &'static str = "relu";
const KERNEL_BF16: &'static str = "urelu_bf16";
const KERNEL_F16: &'static str = "urelu_f16";
const KERNEL_F32: &'static str = "urelu_f32";
const KERNEL_F64: &'static str = "urelu_f64";
const KERNEL_U32: &'static str = "urelu_u32";
const KERNEL: &'static str = "urelu";
const V: Self = Relu;
fn bf16(v: bf16) -> bf16 {
v.max(bf16::ZERO)
}

View File

@ -9,11 +9,11 @@ pub enum Storage {
}
impl Storage {
pub fn try_clone(&self) -> Result<Self> {
pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
match self {
Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
Self::Cuda(storage) => {
let storage = storage.try_clone()?;
let storage = storage.try_clone(layout)?;
Ok(Self::Cuda(storage))
}
}

View File

@ -709,7 +709,7 @@ impl Tensor {
pub fn copy(&self) -> Result<Tensor> {
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: Arc::new(self.storage.try_clone()?),
storage: Arc::new(self.storage.try_clone(self.layout())?),
layout: self.layout.clone(),
op: None, // TODO
is_variable: false,