From 122e334d0cf9c6b56adc2f6f287617141841f636 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 29 Jun 2023 09:21:11 +0100 Subject: [PATCH] Simplify the pattern matching logic in the cuda backend. --- candle-core/examples/llama/main.rs | 2 + candle-core/src/cuda_backend.rs | 157 ++++++++++++-------------- candle-core/src/dummy_cuda_backend.rs | 2 +- candle-core/src/storage.rs | 4 +- candle-core/src/tensor.rs | 2 +- 5 files changed, 78 insertions(+), 89 deletions(-) diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index eb681f4b..3fc893e3 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -487,6 +487,7 @@ fn main() -> Result<()> { let mut rng = thread_rng(); let start_gen = std::time::Instant::now(); for index in 0..args.sample_len { + let start_gen = std::time::Instant::now(); let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; let input = Tensor::new(ctxt, &device)?; let logits = llama.forward(&input, &freqs_cis)?; @@ -496,6 +497,7 @@ fn main() -> Result<()> { let next_token = distr.sample(&mut rng) as u32; tokens.push(next_token); new_tokens.push(next_token); + println!("> {:?}", start_gen.elapsed()); println!( "{} token: {} '{}'", index + 1, diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 9d9a5f99..7dfbb468 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,7 +1,7 @@ -use crate::{CpuStorage, DType, Layout, Shape}; +use crate::{CpuStorage, DType, Layout, Shape, WithDType}; use candle_kernels as kernels; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; -use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig}; +use cudarc::driver::{CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig}; use half::{bf16, f16}; use std::sync::Arc; @@ -243,6 +243,72 @@ enum CudaStorageSlice { F64(CudaSlice), } +trait Map1 { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result>; + + fn map(&self, s: &CudaStorageSlice, d: &CudaDevice, l: &Layout) -> Result { + let out = match s { + CudaStorageSlice::U32(s) => CudaStorageSlice::U32(self.f(s, d, l)?), + CudaStorageSlice::BF16(s) => CudaStorageSlice::BF16(self.f(s, d, l)?), + CudaStorageSlice::F16(s) => CudaStorageSlice::F16(self.f(s, d, l)?), + CudaStorageSlice::F32(s) => CudaStorageSlice::F32(self.f(s, d, l)?), + CudaStorageSlice::F64(s) => CudaStorageSlice::F64(self.f(s, d, l)?), + }; + Ok(out) + } +} + +struct Clone; +impl Map1 for Clone { + fn f( + &self, + s: &CudaSlice, + _: &CudaDevice, + _: &Layout, + ) -> Result> { + Ok(s.try_clone()?) + } +} + +struct Affine(f64, f64); + +impl Map1 for Affine { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat())?; + let src = &src.slice(layout.start_offset()..); + let kernel_name = format!("affine_{}", T::DTYPE.as_str()); + let func = dev.get_or_load_func(&kernel_name, kernels::AFFINE)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }?; + let params = ( + el, + dims.len(), + &ds, + src, + &out, + T::from_f64(self.0), + T::from_f64(self.1), + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, src_l: &Layout, @@ -332,14 +398,8 @@ fn gemm_config( } impl CudaStorage { - pub fn try_clone(&self) -> Result { - let slice = match &self.slice { - CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?), - CudaStorageSlice::BF16(slice) => CudaStorageSlice::BF16(slice.try_clone()?), - CudaStorageSlice::F16(slice) => CudaStorageSlice::F16(slice.try_clone()?), - CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?), - CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?), - }; + pub fn try_clone(&self, layout: &Layout) -> Result { + let slice = Clone.map(&self.slice, self.device(), layout)?; let device = self.device.clone(); Ok(Self { slice, device }) } @@ -420,81 +480,8 @@ impl CudaStorage { } pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { - let shape = layout.shape(); - let dims = shape.dims(); - let el_count = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(el_count as u32); - let dev = self.device(); - let ds = dev.htod_copy([dims, layout.stride()].concat())?; - let slice = match &self.slice { - CudaStorageSlice::U32(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out, mul as u32, add as u32); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::U32(out) - } - CudaStorageSlice::BF16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = ( - el_count, - dims.len(), - &ds, - arg, - &out, - bf16::from_f64(mul), - bf16::from_f64(add), - ); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - CudaStorageSlice::F16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = ( - el_count, - dims.len(), - &ds, - arg, - &out, - f16::from_f64(mul), - f16::from_f64(add), - ); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - CudaStorageSlice::F32(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - CudaStorageSlice::F64(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out, mul, add); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = Affine(mul, add).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 8193b1af..b025eeab 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -44,7 +44,7 @@ impl CudaDevice { pub struct CudaStorage; impl CudaStorage { - pub fn try_clone(&self) -> Result { + pub fn try_clone(&self, _: &Layout) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 7acf6dd0..4e630a58 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -9,11 +9,11 @@ pub enum Storage { } impl Storage { - pub fn try_clone(&self) -> Result { + pub fn try_clone(&self, layout: &Layout) -> Result { match self { Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())), Self::Cuda(storage) => { - let storage = storage.try_clone()?; + let storage = storage.try_clone(layout)?; Ok(Self::Cuda(storage)) } } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f64bd6f2..4b9b3306 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -709,7 +709,7 @@ impl Tensor { pub fn copy(&self) -> Result { let tensor_ = Tensor_ { id: TensorId::new(), - storage: Arc::new(self.storage.try_clone()?), + storage: Arc::new(self.storage.try_clone(self.layout())?), layout: self.layout.clone(), op: None, // TODO is_variable: false,