diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index eb681f4b..3fc893e3 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -487,6 +487,7 @@ fn main() -> Result<()> { let mut rng = thread_rng(); let start_gen = std::time::Instant::now(); for index in 0..args.sample_len { + let start_gen = std::time::Instant::now(); let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; let input = Tensor::new(ctxt, &device)?; let logits = llama.forward(&input, &freqs_cis)?; @@ -496,6 +497,7 @@ fn main() -> Result<()> { let next_token = distr.sample(&mut rng) as u32; tokens.push(next_token); new_tokens.push(next_token); + println!("> {:?}", start_gen.elapsed()); println!( "{} token: {} '{}'", index + 1, diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 9d9a5f99..40b7e67f 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,7 +1,9 @@ -use crate::{CpuStorage, DType, Layout, Shape}; +use crate::{CpuStorage, DType, Layout, Shape, WithDType}; use candle_kernels as kernels; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; -use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig}; +use cudarc::driver::{ + CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, +}; use half::{bf16, f16}; use std::sync::Arc; @@ -242,6 +244,260 @@ enum CudaStorageSlice { F32(CudaSlice), F64(CudaSlice), } +type S = CudaStorageSlice; + +trait Map1 { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result>; + + fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { + let out = match s { + S::U32(s) => S::U32(self.f(s, d, l)?), + S::BF16(s) => S::BF16(self.f(s, d, l)?), + S::F16(s) => S::F16(self.f(s, d, l)?), + S::F32(s) => S::F32(self.f(s, d, l)?), + S::F64(s) => S::F64(self.f(s, d, l)?), + }; + Ok(out) + } +} + +trait Map2 { + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + dev: &CudaDevice, + ) -> Result>; + + fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { + let out = match (s1, s2) { + (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?), + (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?), + (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), + (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), + (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), + _ => return Err(CudaError::InternalError("dtype mismatch in binary op")), + }; + Ok(out) + } +} + +struct Clone; +impl Map1 for Clone { + fn f( + &self, + s: &CudaSlice, + _: &CudaDevice, + _: &Layout, + ) -> Result> { + Ok(s.try_clone()?) + } +} + +fn kernel_name(root: &str) -> String { + let dtype = T::DTYPE.as_str(); + format!("{root}_{dtype}") +} + +struct Affine(f64, f64); +impl Map1 for Affine { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat())?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("affine"), kernels::AFFINE)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }?; + let params = ( + el, + dims.len(), + &ds, + src, + &out, + T::from_f64(self.0), + T::from_f64(self.1), + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + +struct Sum<'a>(&'a [usize]); +impl<'a> Map1 for Sum<'a> { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let src_dims = shape.dims(); + let el = shape.elem_count(); + let mut dst_el = el; + for &sum_dim in self.0.iter() { + dst_el /= src_dims[sum_dim]; + } + let mut sum_dims = self.0.to_vec(); + // Sort the sum_dims as they have to be processed from left to right when converting the + // indexes. + sum_dims.sort(); + let sum_dims_l: Vec = sum_dims.iter().map(|&d| src_dims[d]).collect(); + let sum_dims_s: Vec = sum_dims + .iter() + .map(|&d| src_dims[d + 1..].iter().product::()) + .collect(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("sum"), kernels::REDUCE)?; + let out = dev.alloc_zeros::(dst_el)?; + let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + +impl Map1 for U { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el_count as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat())?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el_count) }?; + let params = (el_count, dims.len(), &ds, src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + +struct Embedding<'a>(&'a CudaStorage, &'a Layout); +impl<'a> Map1 for Embedding<'a> { + fn f( + &self, + rhs: &CudaSlice, + dev: &CudaDevice, + rhs_l: &Layout, + ) -> Result> { + let ids_l = &self.1; + let ids = match &self.0.slice { + CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..), + _ => Err(CudaError::UnexpectedDType { + msg: "embedding ids should be u32", + expected: DType::U32, + got: self.0.dtype(), + })?, + }; + let ids = &ids; + let shape = ids_l.shape(); + let (v_size, h_size) = rhs_l + .shape() + .r2() + .map_err(|e| CudaError::WrappedError(Box::new(e)))?; + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = dev.htod_copy([dims, ids_l.stride()].concat())?; + let rhs = &rhs.slice(rhs_l.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("emb"), kernels::EMBEDDINGS)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el * h_size) }?; + let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + +struct WhereCond<'a>(&'a CudaStorage, &'a Layout); +impl<'a> Map2 for WhereCond<'a> { + fn f( + &self, + t: &CudaSlice, + layout_t: &Layout, + f: &CudaSlice, + layout_f: &Layout, + dev: &CudaDevice, + ) -> Result> { + let ids_l = &self.1; + let ids = match &self.0.slice { + CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..), + _ => Err(CudaError::UnexpectedDType { + msg: "where conditions should be u32", + expected: DType::U32, + got: self.0.dtype(), + })?, + }; + let ids = &ids; + let shape = ids_l.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = + dev.htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?; + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("where"), kernels::TERNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, ids, t, f, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + +impl Map2 for U { + fn f( + &self, + lhs: &CudaSlice, + lhs_l: &Layout, + rhs: &CudaSlice, + rhs_l: &Layout, + dev: &CudaDevice, + ) -> Result> { + let shape = lhs_l.shape(); + let dims = shape.dims(); + let elem_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?; + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::BINARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(elem_count) }?; + let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, @@ -332,14 +588,8 @@ fn gemm_config( } impl CudaStorage { - pub fn try_clone(&self) -> Result { - let slice = match &self.slice { - CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?), - CudaStorageSlice::BF16(slice) => CudaStorageSlice::BF16(slice.try_clone()?), - CudaStorageSlice::F16(slice) => CudaStorageSlice::F16(slice.try_clone()?), - CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?), - CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?), - }; + pub fn try_clone(&self, layout: &Layout) -> Result { + let slice = Clone.map(&self.slice, self.device(), layout)?; let device = self.device.clone(); Ok(Self { slice, device }) } @@ -420,152 +670,14 @@ impl CudaStorage { } pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { - let shape = layout.shape(); - let dims = shape.dims(); - let el_count = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(el_count as u32); - let dev = self.device(); - let ds = dev.htod_copy([dims, layout.stride()].concat())?; - let slice = match &self.slice { - CudaStorageSlice::U32(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out, mul as u32, add as u32); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::U32(out) - } - CudaStorageSlice::BF16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = ( - el_count, - dims.len(), - &ds, - arg, - &out, - bf16::from_f64(mul), - bf16::from_f64(add), - ); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - CudaStorageSlice::F16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = ( - el_count, - dims.len(), - &ds, - arg, - &out, - f16::from_f64(mul), - f16::from_f64(add), - ); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - CudaStorageSlice::F32(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - CudaStorageSlice::F64(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out, mul, add); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = Affine(mul, add).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result { - let shape = layout.shape(); - let src_dims = shape.dims(); - let el = shape.elem_count(); - let mut dst_el = el; - for &sum_dim in sum_dims.iter() { - dst_el /= src_dims[sum_dim]; - } - let mut sum_dims = sum_dims.to_vec(); - // Sort the sum_dims as they have to be processed from left to right when converting the - // indexes. - sum_dims.sort(); - let sum_dims_l: Vec = sum_dims.iter().map(|&d| src_dims[d]).collect(); - let sum_dims_s: Vec = sum_dims - .iter() - .map(|&d| src_dims[d + 1..].iter().product::()) - .collect(); - let cfg = LaunchConfig::for_num_elems(el as u32); - let dev = self.device(); - let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?; - let slice = match &self.slice { - CudaStorageSlice::U32(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; - let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::U32(out) - } - CudaStorageSlice::BF16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; - let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - CudaStorageSlice::F16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; - let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - CudaStorageSlice::F32(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; - let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - CudaStorageSlice::F64(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; - let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = Sum(sum_dims).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } @@ -576,58 +688,8 @@ impl CudaStorage { } pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { - let shape = layout.shape(); - let dims = shape.dims(); - let el_count = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(el_count as u32); - let dev = &self.device; - let ds = dev.htod_copy([dims, layout.stride()].concat())?; - let slice = match &self.slice { - CudaStorageSlice::U32(_arg) => { - todo!("No unary kernels for u32"); - } - CudaStorageSlice::BF16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - CudaStorageSlice::F16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - CudaStorageSlice::F32(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - CudaStorageSlice::F64(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = U::V.map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } @@ -637,72 +699,8 @@ impl CudaStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result { - let shape = lhs_l.shape(); - let dims = shape.dims(); - let elem_count = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(elem_count as u32); - let dev = self.device(); - let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?; - let slice = match (&self.slice, &rhs.slice) { - (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(elem_count) }?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(elem_count) }?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(elem_count) }?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - // SAFETY: Set later by running the kernel. - let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?; - let out = unsafe { dev.alloc::(elem_count) }?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - (CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - // SAFETY: Set later by running the kernel. - let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?; - let out = unsafe { dev.alloc::(elem_count) }?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::U32(out) - } - // The dtypes should have been checked at this point so this is an internal error. - _ => return Err(CudaError::InternalError("dtype mismatch in binary op")), - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = B::V.map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?; Ok(Self { slice, device }) } @@ -740,163 +738,18 @@ impl CudaStorage { &self, layout: &Layout, t: &Self, - layout_t: &Layout, + t_l: &Layout, f: &Self, - layout_f: &Layout, + f_l: &Layout, ) -> Result { - let ids = match &self.slice { - CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..), - _ => Err(CudaError::UnexpectedDType { - msg: "where conditions should be u32", - expected: DType::U32, - got: self.dtype(), - })?, - }; - let ids = &ids; - let shape = layout.shape(); - let dims = shape.dims(); - let el = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(el as u32); - let dev = self.device(); - let ds = - dev.htod_copy([dims, layout.stride(), layout_t.stride(), layout_f.stride()].concat())?; - let slice = match (&t.slice, &f.slice) { - (CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => { - let t = &t.slice(layout_t.start_offset()..); - let f = &f.slice(layout_f.start_offset()..); - let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }?; - let params = (el, dims.len(), &ds, ids, t, f, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - (CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => { - let t = &t.slice(layout_t.start_offset()..); - let f = &f.slice(layout_f.start_offset()..); - let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }?; - let params = (el, dims.len(), &ds, ids, t, f, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - (CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => { - let t = &t.slice(layout_t.start_offset()..); - let f = &f.slice(layout_f.start_offset()..); - let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }?; - let params = (el, dims.len(), &ds, ids, t, f, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - (CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => { - let t = &t.slice(layout_t.start_offset()..); - let f = &f.slice(layout_f.start_offset()..); - // SAFETY: Set later by running the kernel. - let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?; - let out = unsafe { dev.alloc::(el) }?; - let params = (el, dims.len(), &ds, ids, t, f, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - (CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => { - let t = &t.slice(layout_t.start_offset()..); - let f = &f.slice(layout_f.start_offset()..); - // SAFETY: Set later by running the kernel. - let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?; - let out = unsafe { dev.alloc::(el) }?; - let params = (el, dims.len(), &ds, ids, t, f, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::U32(out) - } - // The dtypes should have been checked at this point so this is an internal error. - _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = WhereCond(self, layout).map(&t.slice, t_l, &f.slice, f_l, &device)?; Ok(Self { slice, device }) } pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { - let ids = match &self.slice { - CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..), - _ => Err(CudaError::UnexpectedDType { - msg: "embedding ids should be u32", - expected: DType::U32, - got: self.dtype(), - })?, - }; - let ids = &ids; - let shape = layout.shape(); - let (v_size, h_size) = rhs_l - .shape() - .r2() - .map_err(|e| CudaError::WrappedError(Box::new(e)))?; - let dims = shape.dims(); - let el = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(el as u32); - let dev = self.device(); - let ds = dev.htod_copy([dims, layout.stride()].concat())?; - let slice = match &rhs.slice { - // The kernels below assume that rhs is contiguous. - CudaStorageSlice::U32(arg) => { - let arg = &arg.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el * h_size) }?; - let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::U32(out) - } - CudaStorageSlice::BF16(arg) => { - let arg = &arg.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el * h_size) }?; - let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - CudaStorageSlice::F16(arg) => { - let arg = &arg.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el * h_size) }?; - let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - CudaStorageSlice::F32(arg) => { - let arg = &arg.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el * h_size) }?; - let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - CudaStorageSlice::F64(arg) => { - let arg = &arg.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el * h_size) }?; - let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?; Ok(Self { slice, device }) } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 8193b1af..b025eeab 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -44,7 +44,7 @@ impl CudaDevice { pub struct CudaStorage; impl CudaStorage { - pub fn try_clone(&self) -> Result { + pub fn try_clone(&self, _: &Layout) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 7b0e18fe..db6ef87f 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -43,11 +43,8 @@ pub(crate) enum Op { pub(crate) trait UnaryOp { const NAME: &'static str; - const KERNEL_BF16: &'static str; - const KERNEL_F16: &'static str; - const KERNEL_F32: &'static str; - const KERNEL_F64: &'static str; - const KERNEL_U32: &'static str; + const KERNEL: &'static str; + const V: Self; fn bf16(v1: bf16) -> bf16; fn f16(v1: f16) -> f16; fn f32(v1: f32) -> f32; @@ -57,11 +54,8 @@ pub(crate) trait UnaryOp { pub(crate) trait BinaryOp { const NAME: &'static str; - const KERNEL_BF16: &'static str; - const KERNEL_F16: &'static str; - const KERNEL_F32: &'static str; - const KERNEL_F64: &'static str; - const KERNEL_U32: &'static str; + const KERNEL: &'static str; + const V: Self; fn bf16(v1: bf16, v2: bf16) -> bf16; fn f16(v1: f16, v2: f16) -> f16; fn f32(v1: f32, v2: f32) -> f32; @@ -88,11 +82,8 @@ macro_rules! bin_op { ($op:ident, $name: literal, $e: expr) => { impl BinaryOp for $op { const NAME: &'static str = $name; - const KERNEL_BF16: &'static str = concat!("b", $name, "_bf16"); - const KERNEL_F16: &'static str = concat!("b", $name, "_f16"); - const KERNEL_F32: &'static str = concat!("b", $name, "_f32"); - const KERNEL_F64: &'static str = concat!("b", $name, "_f64"); - const KERNEL_U32: &'static str = concat!("b", $name, "_u32"); + const KERNEL: &'static str = concat!("b", $name); + const V: Self = $op; fn bf16(v1: bf16, v2: bf16) -> bf16 { $e(v1, v2) } @@ -121,11 +112,8 @@ macro_rules! unary_op { ($op: ident, $name: literal, $a: ident, $e: expr) => { impl UnaryOp for $op { const NAME: &'static str = $name; - const KERNEL_BF16: &'static str = concat!("u", $name, "_bf16"); - const KERNEL_F16: &'static str = concat!("u", $name, "_f16"); - const KERNEL_F32: &'static str = concat!("u", $name, "_f32"); - const KERNEL_F64: &'static str = concat!("u", $name, "_f64"); - const KERNEL_U32: &'static str = concat!("u", $name, "_u32"); + const KERNEL: &'static str = concat!("u", $name); + const V: Self = $op; fn bf16($a: bf16) -> bf16 { $e } @@ -158,6 +146,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt()); /// impl UnaryOp for Gelu { const NAME: &'static str = "gelu"; + const V: Self = Gelu; fn bf16(v: bf16) -> bf16 { bf16::from_f32_const(0.5) * v @@ -191,20 +180,13 @@ impl UnaryOp for Gelu { fn u32(_: u32) -> u32 { 0 } - const KERNEL_BF16: &'static str = "ugelu_bf16"; - const KERNEL_F16: &'static str = "ugelu_f16"; - const KERNEL_F32: &'static str = "ugelu_f32"; - const KERNEL_F64: &'static str = "ugelu_f64"; - const KERNEL_U32: &'static str = "ugelu_u32"; + const KERNEL: &'static str = "ugelu"; } impl UnaryOp for Relu { const NAME: &'static str = "relu"; - const KERNEL_BF16: &'static str = "urelu_bf16"; - const KERNEL_F16: &'static str = "urelu_f16"; - const KERNEL_F32: &'static str = "urelu_f32"; - const KERNEL_F64: &'static str = "urelu_f64"; - const KERNEL_U32: &'static str = "urelu_u32"; + const KERNEL: &'static str = "urelu"; + const V: Self = Relu; fn bf16(v: bf16) -> bf16 { v.max(bf16::ZERO) } diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 7acf6dd0..4e630a58 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -9,11 +9,11 @@ pub enum Storage { } impl Storage { - pub fn try_clone(&self) -> Result { + pub fn try_clone(&self, layout: &Layout) -> Result { match self { Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())), Self::Cuda(storage) => { - let storage = storage.try_clone()?; + let storage = storage.try_clone(layout)?; Ok(Self::Cuda(storage)) } } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f64bd6f2..4b9b3306 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -709,7 +709,7 @@ impl Tensor { pub fn copy(&self) -> Result { let tensor_ = Tensor_ { id: TensorId::new(), - storage: Arc::new(self.storage.try_clone()?), + storage: Arc::new(self.storage.try_clone(self.layout())?), layout: self.layout.clone(), op: None, // TODO is_variable: false,