use crate::backend::{BackendDevice, BackendStorage}; use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType}; use candle_kernels as kernels; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, }; use half::{bf16, f16}; use std::sync::{Arc, Mutex}; /// cudarc related errors #[derive(thiserror::Error, Debug)] pub enum CudaError { #[error(transparent)] Cuda(#[from] cudarc::driver::DriverError), #[error(transparent)] Compiler(#[from] cudarc::nvrtc::CompileError), #[error(transparent)] Cublas(#[from] cudarc::cublas::result::CublasError), #[error(transparent)] Curand(#[from] cudarc::curand::result::CurandError), #[error("missing kernel '{module_name}'")] MissingKernel { module_name: String }, #[error("unsupported dtype {dtype:?} for {op}")] UnsupportedDtype { dtype: DType, op: &'static str }, #[error("internal error '{0}'")] InternalError(&'static str), #[error("internal error '{0}'")] WrappedError(Box), #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] MatMulNonContiguous { lhs_stride: Vec, rhs_stride: Vec, mnk: (usize, usize, usize), }, #[error("{msg}, expected: {expected:?}, got: {got:?}")] UnexpectedDType { msg: &'static str, expected: DType, got: DType, }, #[error("{cuda} when loading {module_name}")] Load { cuda: cudarc::driver::DriverError, module_name: String, }, } impl From for crate::Error { fn from(val: CudaError) -> Self { crate::Error::Cuda(Box::new(val)).bt() } } /// Unique identifier for cuda devices. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub(crate) struct DeviceId(usize); impl DeviceId { fn new() -> Self { // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805 use std::sync::atomic; static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1); Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed)) } } struct CudaRng(cudarc::curand::CudaRng); unsafe impl Send for CudaRng {} #[derive(Clone)] pub struct CudaDevice { id: DeviceId, device: Arc, blas: Arc, curand: Arc>, } impl std::fmt::Debug for CudaDevice { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "CudaDevice({:?})", self.id) } } impl std::ops::Deref for CudaDevice { type Target = Arc; fn deref(&self) -> &Self::Target { &self.device } } trait WrapErr { fn w(self) -> std::result::Result; } impl> WrapErr for std::result::Result { fn w(self) -> std::result::Result { self.map_err(|e| crate::Error::Cuda(Box::new(e.into()))) } } impl CudaDevice { fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result { let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); let slice = match dtype { DType::U8 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; let func = self.get_or_load_func("fill_u8", kernels::FILL)?; let params = (&data, v as u8, elem_count); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U8(data) } DType::U32 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; let func = self.get_or_load_func("fill_u32", kernels::FILL)?; let params = (&data, v as u32, elem_count); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(data) } DType::BF16 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; let func = self.get_or_load_func("fill_bf16", kernels::FILL)?; let params = (&data, bf16::from_f64(v), elem_count); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::BF16(data) } DType::F16 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; let func = self.get_or_load_func("fill_f16", kernels::FILL)?; let params = (&data, f16::from_f64(v), elem_count); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F16(data) } DType::F32 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; let func = self.get_or_load_func("fill_f32", kernels::FILL)?; let params = (&data, v as f32, elem_count); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F32(data) } DType::F64 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; let func = self.get_or_load_func("fill_f64", kernels::FILL)?; let params = (&data, v, elem_count); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F64(data) } }; Ok(CudaStorage { slice, device: self.clone(), }) } fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result { if !self.has_func(module_name, module_name) { // Leaking the string here is a bit sad but we need a &'static str and this is only // done once per kernel name. let static_module_name = Box::leak(module_name.to_string().into_boxed_str()); self.load_ptx(ptx.into(), module_name, &[static_module_name]) .map_err(|cuda| CudaError::Load { cuda, module_name: module_name.to_string(), }) .w()?; } self.get_func(module_name, module_name) // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is // able to only build the error value if needed. .ok_or(CudaError::MissingKernel { module_name: module_name.to_string(), }) .w() } } impl BackendDevice for CudaDevice { type Storage = CudaStorage; fn new(ordinal: usize) -> Result { let device = cudarc::driver::CudaDevice::new(ordinal).w()?; let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?; let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?; Ok(Self { id: DeviceId::new(), device, blas: Arc::new(blas), curand: Arc::new(Mutex::new(CudaRng(curand))), }) } fn location(&self) -> crate::DeviceLocation { crate::DeviceLocation::Cuda { gpu_id: self.device.ordinal(), } } fn same_device(&self, rhs: &Self) -> bool { self.id == rhs.id } fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { let elem_count = shape.elem_count(); let slice = match dtype { DType::U8 => { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::U8(data) } DType::U32 => { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::U32(data) } DType::BF16 => { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::BF16(data) } DType::F16 => { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::F16(data) } DType::F32 => { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::F32(data) } DType::F64 => { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::F64(data) } }; Ok(CudaStorage { slice, device: self.clone(), }) } fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result { let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); let slice = match dtype { DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype { dtype, op: "rand_uniform", }) .w()?, DType::F32 => { let mut data = unsafe { self.alloc::(elem_count) }.w()?; curand.0.fill_with_uniform(&mut data).w()?; CudaStorageSlice::F32(data) } DType::F64 => { let mut data = unsafe { self.alloc::(elem_count) }.w()?; curand.0.fill_with_uniform(&mut data).w()?; CudaStorageSlice::F64(data) } }; if lo != 0.0 || up != 1.0 { let layout = Layout::contiguous(shape); Affine(up - lo, lo).map(&slice, self, &layout)?; } Ok(CudaStorage { slice, device: self.clone(), }) } fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result { let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); let slice = match dtype { DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype { dtype, op: "rand_normal", }) .w()?, DType::F32 => { let mut data = unsafe { self.alloc::(elem_count) }.w()?; curand .0 .fill_with_normal(&mut data, mean as f32, std as f32) .w()?; CudaStorageSlice::F32(data) } DType::F64 => { let mut data = unsafe { self.alloc::(elem_count) }.w()?; curand.0.fill_with_normal(&mut data, mean, std).w()?; CudaStorageSlice::F64(data) } }; Ok(CudaStorage { slice, device: self.clone(), }) } fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { self.const_impl(1., shape, dtype) } fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { let slice = match storage { CpuStorage::U8(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::U8(data) } CpuStorage::U32(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::U32(data) } CpuStorage::BF16(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::BF16(data) } CpuStorage::F16(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::F16(data) } CpuStorage::F32(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::F32(data) } CpuStorage::F64(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::F64(data) } }; Ok(CudaStorage { slice, device: self.clone(), }) } } #[derive(Debug)] enum CudaStorageSlice { U8(CudaSlice), U32(CudaSlice), BF16(CudaSlice), F16(CudaSlice), F32(CudaSlice), F64(CudaSlice), } type S = CudaStorageSlice; trait Map1 { fn f( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, ) -> Result>; fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { let out = match s { S::U8(s) => S::U8(self.f(s, d, l)?), S::U32(s) => S::U32(self.f(s, d, l)?), S::BF16(s) => S::BF16(self.f(s, d, l)?), S::F16(s) => S::F16(self.f(s, d, l)?), S::F32(s) => S::F32(self.f(s, d, l)?), S::F64(s) => S::F64(self.f(s, d, l)?), }; Ok(out) } } trait Map2 { fn f( &self, src1: &CudaSlice, layout1: &Layout, src2: &CudaSlice, layout2: &Layout, dev: &CudaDevice, ) -> Result>; fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { let out = match (s1, s2) { (S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?), (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?), (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?), (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), _ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?, }; Ok(out) } } struct Clone; impl Map1 for Clone { fn f( &self, s: &CudaSlice, _: &CudaDevice, _: &Layout, ) -> Result> { s.try_clone().w() } } fn kernel_name(root: &str) -> String { let dtype = T::DTYPE.as_str(); format!("{root}_{dtype}") } struct Affine(f64, f64); impl Map1 for Affine { fn f( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, ) -> Result> { let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("affine"), kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }.w()?; let params = ( el, dims.len(), &ds, src, &out, T::from_f64(self.0), T::from_f64(self.1), ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct Elu(f64); impl Map1 for Elu { fn f( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, ) -> Result> { let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("uelu"), kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct Sum<'a>(&'a [usize]); impl<'a> Map1 for Sum<'a> { fn f( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, ) -> Result> { let shape = layout.shape(); let src_dims = shape.dims(); let el = shape.elem_count(); let mut dst_el = el; for &sum_dim in self.0.iter() { dst_el /= src_dims[sum_dim]; } let mut sum_dims = self.0.to_vec(); // Sort the sum_dims as they have to be processed from left to right when converting the // indexes. sum_dims.sort(); let sum_dims_l: Vec = sum_dims.iter().map(|&d| src_dims[d]).collect(); let sum_dims_s: Vec = sum_dims .iter() .map(|&d| src_dims[d + 1..].iter().product::()) .collect(); let cfg = LaunchConfig::for_num_elems(el as u32); let ds = dev .htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat()) .w()?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("sum"), kernels::REDUCE)?; let out = dev.alloc_zeros::(dst_el).w()?; let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct FastSum<'a>(&'a [usize]); impl<'a> Map1 for FastSum<'a> { fn f( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, ) -> Result> { let src_stride = layout.stride(); let src_dims = layout.shape().dims(); let src_el: usize = src_dims.iter().product(); // Source dims and strides with the sum dims at the end. let mut dims = vec![]; let mut stride = vec![]; let mut dst_el: usize = 1; for (dim_idx, &d) in src_dims.iter().enumerate() { if !self.0.contains(&dim_idx) { dst_el *= d; dims.push(d); stride.push(src_stride[dim_idx]); } } for &dim_idx in self.0.iter() { dims.push(src_dims[dim_idx]); stride.push(src_stride[dim_idx]); } let el_to_sum_per_block = src_el / dst_el; // The reduction loop requires the shared array to be properly initialized and for // this we want the number of threads to be a power of two. let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two(); let cfg = LaunchConfig { // TODO: Maybe use grid_y if the output is too large? // TODO: Specialized implementation when reducing on no or all dimensions or when // reducing only aggregate a small number of elements together. grid_dim: (dst_el as u32, 1, 1), block_dim: (block_dim as u32, 1, 1), shared_mem_bytes: 0, }; let ds = dev .htod_copy([dims.as_slice(), stride.as_slice()].concat()) .w()?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("fast_sum"), kernels::REDUCE)?; let out = dev.alloc_zeros::(dst_el).w()?; let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } impl Map1 for U { fn f( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, ) -> Result> { let shape = layout.shape(); let dims = shape.dims(); let el_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }.w()?; let params = (el_count, dims.len(), &ds, src, &out); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct Embedding<'a>(&'a CudaStorage, &'a Layout); impl<'a> Map1 for Embedding<'a> { fn f( &self, rhs: &CudaSlice, dev: &CudaDevice, rhs_l: &Layout, ) -> Result> { let ids_l = &self.1; let ids = match &self.0.slice { CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..), _ => Err(CudaError::UnexpectedDType { msg: "embedding ids should be u32", expected: DType::U32, got: self.0.dtype(), }) .w()?, }; let ids = &ids; let shape = ids_l.shape(); let (v_size, h_size) = rhs_l .shape() .r2() .map_err(|e| CudaError::WrappedError(Box::new(e))) .w()?; let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?; let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("emb"), kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el * h_size) }.w()?; let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct Conv1D<'a>(&'a crate::conv::ParamsConv1D); impl<'a> Map2 for Conv1D<'a> { fn f( &self, inp: &CudaSlice, inp_l: &Layout, k: &CudaSlice, k_l: &Layout, dev: &CudaDevice, ) -> Result> { // Kernel shape: (c_out, c_in_k, k_size) // Input shape: (b_size, c_in, l_in) or (c_in, l_in) let p = &self.0; let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(k_l.start_offset()..); let shape = inp_l.shape(); let dims = shape.dims(); let el = shape.elem_count(); let l_out = p.l_out(); let dst_el = p.c_out * l_out * p.b_size.unwrap_or(1); let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::("conv1d"), kernels::CONV)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; let ds = if dims.len() == 3 { [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() } else if dims.len() == 2 { [&[1], dims, &[1], inp_l.stride(), k_l.dims(), k_l.stride()].concat() } else { panic!("unexpected input shape for conv1d {dims:?}") }; let ds = dev.htod_copy(ds).w()?; let params = (el, l_out, p.stride, &ds, inp, k, &out); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct WhereCond<'a>(&'a CudaStorage, &'a Layout); impl<'a> Map2 for WhereCond<'a> { fn f( &self, t: &CudaSlice, layout_t: &Layout, f: &CudaSlice, layout_f: &Layout, dev: &CudaDevice, ) -> Result> { let ids_l = &self.1; let ids = match &self.0.slice { CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..), _ => Err(CudaError::UnexpectedDType { msg: "where conditions should be u32", expected: DType::U32, got: self.0.dtype(), }) .w()?, }; let ids = &ids; let shape = ids_l.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let ds = dev .htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat()) .w()?; let t = &t.slice(layout_t.start_offset()..); let f = &f.slice(layout_f.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("where"), kernels::TERNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, ids, t, f, &out); // SAFETY: ffi unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } impl Map2 for U { fn f( &self, lhs: &CudaSlice, lhs_l: &Layout, rhs: &CudaSlice, rhs_l: &Layout, dev: &CudaDevice, ) -> Result> { let shape = lhs_l.shape(); let dims = shape.dims(); let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); let dims_and_strides = dev .htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) .w()?; let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::BINARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(elem_count) }.w()?; let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); // SAFETY: ffi unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, src_l: &Layout, dst: &'a mut CudaSlice, dst_offset: usize, ) -> ( cudarc::driver::CudaView<'a, T>, cudarc::driver::CudaViewMut<'a, T>, ) { let src_offset = src_l.start_offset(); let to_copy = dst .len() .saturating_sub(dst_offset) .min(src.len().saturating_sub(src_offset)); let src = src.slice(src_offset..src_offset + to_copy); let dst = dst.slice_mut(dst_offset..dst_offset + to_copy); (src, dst) } #[derive(Debug)] pub struct CudaStorage { slice: CudaStorageSlice, device: CudaDevice, } fn gemm_config( alpha: T, beta: T, (b, m, n, k): (usize, usize, usize, usize), lhs_l: &Layout, rhs_l: &Layout, ) -> Result> { // https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm use cudarc::cublas::sys::cublasOperation_t; let lhs_stride = lhs_l.stride(); let rhs_stride = rhs_l.stride(); let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; // The a tensor has dims batching, k, n (rhs) let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n { (n as i32, cublasOperation_t::CUBLAS_OP_N) } else if rhs_m1 == k && rhs_m2 == 1 { (k as i32, cublasOperation_t::CUBLAS_OP_T) } else { Err(CudaError::MatMulNonContiguous { lhs_stride: lhs_stride.to_vec(), rhs_stride: rhs_stride.to_vec(), mnk: (m, n, k), }) .w()? }; // The b tensor has dims batching, m, k (lhs) let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k { (k as i32, cublasOperation_t::CUBLAS_OP_N) } else if lhs_m1 == m && lhs_m2 == 1 { (m as i32, cublasOperation_t::CUBLAS_OP_T) } else { Err(CudaError::MatMulNonContiguous { lhs_stride: lhs_stride.to_vec(), rhs_stride: rhs_stride.to_vec(), mnk: (m, n, k), }) .w()? }; // The setup below was copied from: // https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531 let gemm = GemmConfig { alpha, beta, m: n as i32, n: m as i32, k: k as i32, lda, ldb, ldc: n as i32, transa, transb, }; let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] { [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, [stride] => stride, [] => m * k, _ => Err(CudaError::MatMulNonContiguous { lhs_stride: lhs_stride.to_vec(), rhs_stride: rhs_stride.to_vec(), mnk: (m, n, k), }) .w()?, }; let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] { [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, [stride] => stride, [] => n * k, _ => Err(CudaError::MatMulNonContiguous { lhs_stride: lhs_stride.to_vec(), rhs_stride: rhs_stride.to_vec(), mnk: (m, n, k), }) .w()?, }; Ok(StridedBatchedConfig { batch_size: b as i32, gemm, stride_a: stride_a as i64, stride_b: stride_b as i64, stride_c: (m * n) as i64, }) } impl BackendStorage for CudaStorage { type Device = CudaDevice; fn try_clone(&self, layout: &Layout) -> Result { let slice = Clone.map(&self.slice, self.device(), layout)?; let device = self.device.clone(); Ok(Self { slice, device }) } fn dtype(&self) -> DType { match self.slice { CudaStorageSlice::U8(_) => DType::U8, CudaStorageSlice::U32(_) => DType::U32, CudaStorageSlice::BF16(_) => DType::BF16, CudaStorageSlice::F16(_) => DType::F16, CudaStorageSlice::F32(_) => DType::F32, CudaStorageSlice::F64(_) => DType::F64, } } fn device(&self) -> &CudaDevice { &self.device } fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { use cudarc::driver::DevicePtr; let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; let start_o = layout.start_offset(); // This returns an i64 rather than a &i64, this is useful to get around some temporary // lifetime issue and is safe as long as self.slice does not go out of scope before inp // is used. let inp = match &self.slice { CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F64(inp) => *inp.slice(start_o..).device_ptr(), }; let inp = &inp; let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str()); let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?; let slice = match dtype { DType::U8 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U8(out) } DType::U32 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(out) } DType::BF16 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::BF16(out) } DType::F16 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F16(out) } DType::F32 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F32(out) } DType::F64 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F64(out) } }; Ok(Self { slice, device: dev.clone(), }) } fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { let device = self.device().clone(); let slice = Affine(mul, add).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } fn elu(&self, layout: &Layout, alpha: f64) -> Result { let device = self.device().clone(); let slice = Elu(alpha).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result { let device = self.device().clone(); let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> { Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into()) } fn unary_impl(&self, layout: &Layout) -> Result { let device = self.device().clone(); let slice = U::V.map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } fn binary_impl( &self, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout, ) -> Result { let device = self.device().clone(); let slice = B::V.map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?; Ok(Self { slice, device }) } fn to_cpu_storage(&self) -> Result { match &self.slice { CudaStorageSlice::U8(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::U8(cpu_storage)) } CudaStorageSlice::U32(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::U32(cpu_storage)) } CudaStorageSlice::BF16(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::BF16(cpu_storage)) } CudaStorageSlice::F16(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::F16(cpu_storage)) } CudaStorageSlice::F32(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::F32(cpu_storage)) } CudaStorageSlice::F64(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::F64(cpu_storage)) } } } fn where_cond( &self, layout: &Layout, t: &Self, t_l: &Layout, f: &Self, f_l: &Layout, ) -> Result { let device = self.device().clone(); let slice = WhereCond(self, layout).map(&t.slice, t_l, &f.slice, f_l, &device)?; Ok(Self { slice, device }) } fn conv1d( &self, l: &Layout, kernel: &Self, kernel_l: &Layout, params: &crate::conv::ParamsConv1D, ) -> Result { let device = self.device().clone(); let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; Ok(Self { slice, device }) } fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { let device = self.device().clone(); let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?; Ok(Self { slice, device }) } fn matmul( &self, rhs: &Self, (b, m, n, k): (usize, usize, usize, usize), lhs_l: &Layout, rhs_l: &Layout, ) -> Result { let elem_count = b * m * n; let dev = &self.device; let slice = match (&self.slice, &rhs.slice) { (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }.w()?; unsafe { self.device .blas .gemm_strided_batched(cfg, rhs, lhs, &mut out) } .w()?; CudaStorageSlice::BF16(out) } (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }.w()?; unsafe { self.device .blas .gemm_strided_batched(cfg, rhs, lhs, &mut out) } .w()?; CudaStorageSlice::F16(out) } (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }.w()?; unsafe { self.device .blas .gemm_strided_batched(cfg, rhs, lhs, &mut out) } .w()?; CudaStorageSlice::F32(out) } (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }.w()?; unsafe { self.device .blas .gemm_strided_batched(cfg, rhs, lhs, &mut out) } .w()?; CudaStorageSlice::F64(out) } _ => Err(CudaError::InternalError("dtype mismatch in matmul op")).w()?, }; let device = dev.clone(); Ok(Self { slice, device }) } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let src_shape = src_l.shape(); let dims = src_shape.dims(); let el_count = src_shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = &self.device; let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?; match (&self.slice, &mut dst.slice) { (CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()? } } (CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()? } } (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()? } } (CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_u8", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()? } } (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()? } } (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; } } _ => Err(CudaError::InternalError( "dtype mismatch in copy_strided op", )) .w()?, } Ok(()) } }