/// Helper functions to plug cuda kernels in candle. use crate::{Layout, Result, WithDType}; pub use cudarc; use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits}; use super::{CudaDevice, CudaError, WrapErr}; pub type S = super::CudaStorageSlice; pub trait Map1 { fn f( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, ) -> Result>; fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { let out = match s { S::U8(s) => S::U8(self.f(s, d, l)?), S::U32(s) => S::U32(self.f(s, d, l)?), S::I64(s) => S::I64(self.f(s, d, l)?), S::BF16(s) => S::BF16(self.f(s, d, l)?), S::F16(s) => S::F16(self.f(s, d, l)?), S::F32(s) => S::F32(self.f(s, d, l)?), S::F64(s) => S::F64(self.f(s, d, l)?), }; Ok(out) } } pub trait Map2 { fn f( &self, src1: &CudaSlice, layout1: &Layout, src2: &CudaSlice, layout2: &Layout, dev: &CudaDevice, ) -> Result>; fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { let out = match (s1, s2) { (S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?), (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?), (S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?), (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?), (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, }; Ok(out) } } pub trait Map3 { #[allow(clippy::too_many_arguments)] fn f( &self, src1: &CudaSlice, layout1: &Layout, src2: &CudaSlice, layout2: &Layout, src3: &CudaSlice, layout3: &Layout, dev: &CudaDevice, ) -> Result>; #[allow(clippy::too_many_arguments)] fn map( &self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, s3: &S, l3: &Layout, d: &CudaDevice, ) -> Result { let out = match (s1, s2, s3) { (S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?), (S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?), (S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?), (S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?), (S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?), (S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?), (S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?), _ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?, }; Ok(out) } } pub trait Map2InPlace { fn f( &self, dst: &mut CudaSlice, dst_l: &Layout, src: &CudaSlice, src_l: &Layout, dev: &CudaDevice, ) -> Result<()>; fn map( &self, dst: &mut S, dst_l: &Layout, src: &S, src_l: &Layout, d: &CudaDevice, ) -> Result<()> { match (dst, src) { (S::U8(dst), S::U8(src)) => self.f(dst, dst_l, src, src_l, d), (S::U32(dst), S::U32(src)) => self.f(dst, dst_l, src, src_l, d), (S::I64(dst), S::I64(src)) => self.f(dst, dst_l, src, src_l, d), (S::BF16(dst), S::BF16(src)) => self.f(dst, dst_l, src, src_l, d), (S::F16(dst), S::F16(src)) => self.f(dst, dst_l, src, src_l, d), (S::F32(dst), S::F32(src)) => self.f(dst, dst_l, src, src_l, d), (S::F64(dst), S::F64(src)) => self.f(dst, dst_l, src, src_l, d), _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, } } } pub trait Map1Any { fn f) -> S>( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, wrap: W, ) -> Result; fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { let out = match s { S::U8(s) => self.f(s, d, l, S::U8)?, S::U32(s) => self.f(s, d, l, S::U32)?, S::I64(s) => self.f(s, d, l, S::I64)?, S::BF16(s) => self.f(s, d, l, S::BF16)?, S::F16(s) => self.f(s, d, l, S::F16)?, S::F32(s) => self.f(s, d, l, S::F32)?, S::F64(s) => self.f(s, d, l, S::F64)?, }; Ok(out) } } pub trait Map2Any { fn f( &self, src1: &CudaSlice, layout1: &Layout, src2: &CudaSlice, layout2: &Layout, dev: &CudaDevice, ) -> Result; fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { let out = match (s1, s2) { (S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?, (S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?, (S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?, (S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?, (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, _ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?, }; Ok(out) } }