Merge pull request #33 from LaurentMazare/cuda-map

Simplify the dtype matchings in the cuda backend
This commit is contained in:
Laurent Mazare
2023-06-29 10:14:12 +01:00
committed by GitHub
6 changed files with 292 additions and 455 deletions

View File

@ -487,6 +487,7 @@ fn main() -> Result<()> {
let mut rng = thread_rng(); let mut rng = thread_rng();
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
for index in 0..args.sample_len { for index in 0..args.sample_len {
let start_gen = std::time::Instant::now();
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
let input = Tensor::new(ctxt, &device)?; let input = Tensor::new(ctxt, &device)?;
let logits = llama.forward(&input, &freqs_cis)?; let logits = llama.forward(&input, &freqs_cis)?;
@ -496,6 +497,7 @@ fn main() -> Result<()> {
let next_token = distr.sample(&mut rng) as u32; let next_token = distr.sample(&mut rng) as u32;
tokens.push(next_token); tokens.push(next_token);
new_tokens.push(next_token); new_tokens.push(next_token);
println!("> {:?}", start_gen.elapsed());
println!( println!(
"{} token: {} '{}'", "{} token: {} '{}'",
index + 1, index + 1,

View File

@ -1,7 +1,9 @@
use crate::{CpuStorage, DType, Layout, Shape}; use crate::{CpuStorage, DType, Layout, Shape, WithDType};
use candle_kernels as kernels; use candle_kernels as kernels;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig}; use cudarc::driver::{
CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
};
use half::{bf16, f16}; use half::{bf16, f16};
use std::sync::Arc; use std::sync::Arc;
@ -242,6 +244,260 @@ enum CudaStorageSlice {
F32(CudaSlice<f32>), F32(CudaSlice<f32>),
F64(CudaSlice<f64>), F64(CudaSlice<f64>),
} }
type S = CudaStorageSlice;
trait Map1 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>>;
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U32(s) => S::U32(self.f(s, d, l)?),
S::BF16(s) => S::BF16(self.f(s, d, l)?),
S::F16(s) => S::F16(self.f(s, d, l)?),
S::F32(s) => S::F32(self.f(s, d, l)?),
S::F64(s) => S::F64(self.f(s, d, l)?),
};
Ok(out)
}
}
trait Map2 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>>;
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) {
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
};
Ok(out)
}
}
struct Clone;
impl Map1 for Clone {
fn f<T: DeviceRepr>(
&self,
s: &CudaSlice<T>,
_: &CudaDevice,
_: &Layout,
) -> Result<CudaSlice<T>> {
Ok(s.try_clone()?)
}
}
fn kernel_name<T: WithDType>(root: &str) -> String {
let dtype = T::DTYPE.as_str();
format!("{root}_{dtype}")
}
struct Affine(f64, f64);
impl Map1 for Affine {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }?;
let params = (
el,
dims.len(),
&ds,
src,
&out,
T::from_f64(self.0),
T::from_f64(self.1),
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
Ok(out)
}
}
struct Sum<'a>(&'a [usize]);
impl<'a> Map1 for Sum<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let src_dims = shape.dims();
let el = shape.elem_count();
let mut dst_el = el;
for &sum_dim in self.0.iter() {
dst_el /= src_dims[sum_dim];
}
let mut sum_dims = self.0.to_vec();
// Sort the sum_dims as they have to be processed from left to right when converting the
// indexes.
sum_dims.sort();
let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
let sum_dims_s: Vec<usize> = sum_dims
.iter()
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
.collect();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("sum"), kernels::REDUCE)?;
let out = dev.alloc_zeros::<T>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
Ok(out)
}
}
impl<U: crate::op::UnaryOp> Map1 for U {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el_count) }?;
let params = (el_count, dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
Ok(out)
}
}
struct Embedding<'a>(&'a CudaStorage, &'a Layout);
impl<'a> Map1 for Embedding<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
rhs: &CudaSlice<T>,
dev: &CudaDevice,
rhs_l: &Layout,
) -> Result<CudaSlice<T>> {
let ids_l = &self.1;
let ids = match &self.0.slice {
CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
_ => Err(CudaError::UnexpectedDType {
msg: "embedding ids should be u32",
expected: DType::U32,
got: self.0.dtype(),
})?,
};
let ids = &ids;
let shape = ids_l.shape();
let (v_size, h_size) = rhs_l
.shape()
.r2()
.map_err(|e| CudaError::WrappedError(Box::new(e)))?;
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, ids_l.stride()].concat())?;
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("emb"), kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
Ok(out)
}
}
struct WhereCond<'a>(&'a CudaStorage, &'a Layout);
impl<'a> Map2 for WhereCond<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
t: &CudaSlice<T>,
layout_t: &Layout,
f: &CudaSlice<T>,
layout_f: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let ids_l = &self.1;
let ids = match &self.0.slice {
CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
_ => Err(CudaError::UnexpectedDType {
msg: "where conditions should be u32",
expected: DType::U32,
got: self.0.dtype(),
})?,
};
let ids = &ids;
let shape = ids_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds =
dev.htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?;
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("where"), kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
Ok(out)
}
}
impl<U: crate::op::BinaryOp> Map2 for U {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
lhs: &CudaSlice<T>,
lhs_l: &Layout,
rhs: &CudaSlice<T>,
rhs_l: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let shape = lhs_l.shape();
let dims = shape.dims();
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?;
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
Ok(out)
}
}
fn slice_src_and_dst<'a, T>( fn slice_src_and_dst<'a, T>(
src: &'a CudaSlice<T>, src: &'a CudaSlice<T>,
@ -332,14 +588,8 @@ fn gemm_config<T>(
} }
impl CudaStorage { impl CudaStorage {
pub fn try_clone(&self) -> Result<Self> { pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
let slice = match &self.slice { let slice = Clone.map(&self.slice, self.device(), layout)?;
CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?),
CudaStorageSlice::BF16(slice) => CudaStorageSlice::BF16(slice.try_clone()?),
CudaStorageSlice::F16(slice) => CudaStorageSlice::F16(slice.try_clone()?),
CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
};
let device = self.device.clone(); let device = self.device.clone();
Ok(Self { slice, device }) Ok(Self { slice, device })
} }
@ -420,152 +670,14 @@ impl CudaStorage {
} }
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> { pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
let shape = layout.shape(); let device = self.device().clone();
let dims = shape.dims(); let slice = Affine(mul, add).map(&self.slice, &device, layout)?;
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = self.device();
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let slice = match &self.slice {
CudaStorageSlice::U32(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<u32>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out, mul as u32, add as u32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
CudaStorageSlice::BF16(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
let params = (
el_count,
dims.len(),
&ds,
arg,
&out,
bf16::from_f64(mul),
bf16::from_f64(add),
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el_count) }?;
let params = (
el_count,
dims.len(),
&ds,
arg,
&out,
f16::from_f64(mul),
f16::from_f64(add),
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f64>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out, mul, add);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
};
let device = dev.clone();
Ok(Self { slice, device }) Ok(Self { slice, device })
} }
pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> { pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let shape = layout.shape(); let device = self.device().clone();
let src_dims = shape.dims(); let slice = Sum(sum_dims).map(&self.slice, &device, layout)?;
let el = shape.elem_count();
let mut dst_el = el;
for &sum_dim in sum_dims.iter() {
dst_el /= src_dims[sum_dim];
}
let mut sum_dims = sum_dims.to_vec();
// Sort the sum_dims as they have to be processed from left to right when converting the
// indexes.
sum_dims.sort();
let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
let sum_dims_s: Vec<usize> = sum_dims
.iter()
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
.collect();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
let slice = match &self.slice {
CudaStorageSlice::U32(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?;
let out = dev.alloc_zeros::<u32>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
CudaStorageSlice::BF16(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?;
let out = dev.alloc_zeros::<bf16>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?;
let out = dev.alloc_zeros::<f16>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?;
let out = dev.alloc_zeros::<f32>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?;
let out = dev.alloc_zeros::<f64>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
};
let device = dev.clone();
Ok(Self { slice, device }) Ok(Self { slice, device })
} }
@ -576,58 +688,8 @@ impl CudaStorage {
} }
pub(crate) fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> { pub(crate) fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
let shape = layout.shape(); let device = self.device().clone();
let dims = shape.dims(); let slice = U::V.map(&self.slice, &device, layout)?;
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device;
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let slice = match &self.slice {
CudaStorageSlice::U32(_arg) => {
todo!("No unary kernels for u32");
}
CudaStorageSlice::BF16(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f64>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
};
let device = dev.clone();
Ok(Self { slice, device }) Ok(Self { slice, device })
} }
@ -637,72 +699,8 @@ impl CudaStorage {
lhs_l: &Layout, lhs_l: &Layout,
rhs_l: &Layout, rhs_l: &Layout,
) -> Result<Self> { ) -> Result<Self> {
let shape = lhs_l.shape(); let device = self.device().clone();
let dims = shape.dims(); let slice = B::V.map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?;
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dev = self.device();
let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?;
let slice = match (&self.slice, &rhs.slice) {
(CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?;
let out = unsafe { dev.alloc::<f64>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
(CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?;
let out = unsafe { dev.alloc::<u32>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
// The dtypes should have been checked at this point so this is an internal error.
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
};
let device = dev.clone();
Ok(Self { slice, device }) Ok(Self { slice, device })
} }
@ -740,163 +738,18 @@ impl CudaStorage {
&self, &self,
layout: &Layout, layout: &Layout,
t: &Self, t: &Self,
layout_t: &Layout, t_l: &Layout,
f: &Self, f: &Self,
layout_f: &Layout, f_l: &Layout,
) -> Result<Self> { ) -> Result<Self> {
let ids = match &self.slice { let device = self.device().clone();
CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..), let slice = WhereCond(self, layout).map(&t.slice, t_l, &f.slice, f_l, &device)?;
_ => Err(CudaError::UnexpectedDType {
msg: "where conditions should be u32",
expected: DType::U32,
got: self.dtype(),
})?,
};
let ids = &ids;
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
let ds =
dev.htod_copy([dims, layout.stride(), layout_t.stride(), layout_f.stride()].concat())?;
let slice = match (&t.slice, &f.slice) {
(CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => {
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
(CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => {
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => {
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?;
let out = unsafe { dev.alloc::<f64>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
(CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => {
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?;
let out = unsafe { dev.alloc::<u32>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
// The dtypes should have been checked at this point so this is an internal error.
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
};
let device = dev.clone();
Ok(Self { slice, device }) Ok(Self { slice, device })
} }
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
let ids = match &self.slice { let device = self.device().clone();
CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..), let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
_ => Err(CudaError::UnexpectedDType {
msg: "embedding ids should be u32",
expected: DType::U32,
got: self.dtype(),
})?,
};
let ids = &ids;
let shape = layout.shape();
let (v_size, h_size) = rhs_l
.shape()
.r2()
.map_err(|e| CudaError::WrappedError(Box::new(e)))?;
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let slice = match &rhs.slice {
// The kernels below assume that rhs is contiguous.
CudaStorageSlice::U32(arg) => {
let arg = &arg.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
CudaStorageSlice::BF16(arg) => {
let arg = &arg.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
let arg = &arg.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
let arg = &arg.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
let arg = &arg.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f64>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
};
let device = dev.clone();
Ok(Self { slice, device }) Ok(Self { slice, device })
} }

View File

@ -44,7 +44,7 @@ impl CudaDevice {
pub struct CudaStorage; pub struct CudaStorage;
impl CudaStorage { impl CudaStorage {
pub fn try_clone(&self) -> Result<Self> { pub fn try_clone(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }

View File

@ -43,11 +43,8 @@ pub(crate) enum Op {
pub(crate) trait UnaryOp { pub(crate) trait UnaryOp {
const NAME: &'static str; const NAME: &'static str;
const KERNEL_BF16: &'static str; const KERNEL: &'static str;
const KERNEL_F16: &'static str; const V: Self;
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
const KERNEL_U32: &'static str;
fn bf16(v1: bf16) -> bf16; fn bf16(v1: bf16) -> bf16;
fn f16(v1: f16) -> f16; fn f16(v1: f16) -> f16;
fn f32(v1: f32) -> f32; fn f32(v1: f32) -> f32;
@ -57,11 +54,8 @@ pub(crate) trait UnaryOp {
pub(crate) trait BinaryOp { pub(crate) trait BinaryOp {
const NAME: &'static str; const NAME: &'static str;
const KERNEL_BF16: &'static str; const KERNEL: &'static str;
const KERNEL_F16: &'static str; const V: Self;
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
const KERNEL_U32: &'static str;
fn bf16(v1: bf16, v2: bf16) -> bf16; fn bf16(v1: bf16, v2: bf16) -> bf16;
fn f16(v1: f16, v2: f16) -> f16; fn f16(v1: f16, v2: f16) -> f16;
fn f32(v1: f32, v2: f32) -> f32; fn f32(v1: f32, v2: f32) -> f32;
@ -88,11 +82,8 @@ macro_rules! bin_op {
($op:ident, $name: literal, $e: expr) => { ($op:ident, $name: literal, $e: expr) => {
impl BinaryOp for $op { impl BinaryOp for $op {
const NAME: &'static str = $name; const NAME: &'static str = $name;
const KERNEL_BF16: &'static str = concat!("b", $name, "_bf16"); const KERNEL: &'static str = concat!("b", $name);
const KERNEL_F16: &'static str = concat!("b", $name, "_f16"); const V: Self = $op;
const KERNEL_F32: &'static str = concat!("b", $name, "_f32");
const KERNEL_F64: &'static str = concat!("b", $name, "_f64");
const KERNEL_U32: &'static str = concat!("b", $name, "_u32");
fn bf16(v1: bf16, v2: bf16) -> bf16 { fn bf16(v1: bf16, v2: bf16) -> bf16 {
$e(v1, v2) $e(v1, v2)
} }
@ -121,11 +112,8 @@ macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => { ($op: ident, $name: literal, $a: ident, $e: expr) => {
impl UnaryOp for $op { impl UnaryOp for $op {
const NAME: &'static str = $name; const NAME: &'static str = $name;
const KERNEL_BF16: &'static str = concat!("u", $name, "_bf16"); const KERNEL: &'static str = concat!("u", $name);
const KERNEL_F16: &'static str = concat!("u", $name, "_f16"); const V: Self = $op;
const KERNEL_F32: &'static str = concat!("u", $name, "_f32");
const KERNEL_F64: &'static str = concat!("u", $name, "_f64");
const KERNEL_U32: &'static str = concat!("u", $name, "_u32");
fn bf16($a: bf16) -> bf16 { fn bf16($a: bf16) -> bf16 {
$e $e
} }
@ -158,6 +146,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt());
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions> /// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
impl UnaryOp for Gelu { impl UnaryOp for Gelu {
const NAME: &'static str = "gelu"; const NAME: &'static str = "gelu";
const V: Self = Gelu;
fn bf16(v: bf16) -> bf16 { fn bf16(v: bf16) -> bf16 {
bf16::from_f32_const(0.5) bf16::from_f32_const(0.5)
* v * v
@ -191,20 +180,13 @@ impl UnaryOp for Gelu {
fn u32(_: u32) -> u32 { fn u32(_: u32) -> u32 {
0 0
} }
const KERNEL_BF16: &'static str = "ugelu_bf16"; const KERNEL: &'static str = "ugelu";
const KERNEL_F16: &'static str = "ugelu_f16";
const KERNEL_F32: &'static str = "ugelu_f32";
const KERNEL_F64: &'static str = "ugelu_f64";
const KERNEL_U32: &'static str = "ugelu_u32";
} }
impl UnaryOp for Relu { impl UnaryOp for Relu {
const NAME: &'static str = "relu"; const NAME: &'static str = "relu";
const KERNEL_BF16: &'static str = "urelu_bf16"; const KERNEL: &'static str = "urelu";
const KERNEL_F16: &'static str = "urelu_f16"; const V: Self = Relu;
const KERNEL_F32: &'static str = "urelu_f32";
const KERNEL_F64: &'static str = "urelu_f64";
const KERNEL_U32: &'static str = "urelu_u32";
fn bf16(v: bf16) -> bf16 { fn bf16(v: bf16) -> bf16 {
v.max(bf16::ZERO) v.max(bf16::ZERO)
} }

View File

@ -9,11 +9,11 @@ pub enum Storage {
} }
impl Storage { impl Storage {
pub fn try_clone(&self) -> Result<Self> { pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
match self { match self {
Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())), Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
Self::Cuda(storage) => { Self::Cuda(storage) => {
let storage = storage.try_clone()?; let storage = storage.try_clone(layout)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
} }

View File

@ -709,7 +709,7 @@ impl Tensor {
pub fn copy(&self) -> Result<Tensor> { pub fn copy(&self) -> Result<Tensor> {
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
storage: Arc::new(self.storage.try_clone()?), storage: Arc::new(self.storage.try_clone(self.layout())?),
layout: self.layout.clone(), layout: self.layout.clone(),
op: None, // TODO op: None, // TODO
is_variable: false, is_variable: false,