From a0d65585db0323747f71b4f33831a165b56a759b Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 5 Sep 2023 19:38:03 +0200 Subject: [PATCH] Softmax implementation for cuda. (#747) --- candle-core/src/cuda_backend.rs | 20 +++++------ candle-nn/src/ops.rs | 59 ++++++++++++++++++++++++++++----- 2 files changed, 61 insertions(+), 18 deletions(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 663f2319..2180be5e 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,7 +1,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType}; -use candle_kernels as kernels; +pub use candle_kernels as kernels; pub use cudarc; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ @@ -383,7 +383,7 @@ impl BackendDevice for CudaDevice { } #[derive(Debug)] -enum CudaStorageSlice { +pub enum CudaStorageSlice { U8(CudaSlice), U32(CudaSlice), I64(CudaSlice), @@ -394,7 +394,7 @@ enum CudaStorageSlice { } type S = CudaStorageSlice; -trait Map1 { +pub trait Map1 { fn f( &self, src: &CudaSlice, @@ -416,7 +416,7 @@ trait Map1 { } } -trait Map2 { +pub trait Map2 { fn f( &self, src1: &CudaSlice, @@ -441,7 +441,7 @@ trait Map2 { } } -trait Map2InPlace { +pub trait Map2InPlace { fn f( &self, dst: &mut CudaSlice, @@ -472,7 +472,7 @@ trait Map2InPlace { } } -trait Map1Any { +pub trait Map1Any { fn f) -> S>( &self, src: &CudaSlice, @@ -495,7 +495,7 @@ trait Map1Any { } } -trait Map2Any { +pub trait Map2Any { fn f( &self, src1: &CudaSlice, @@ -532,7 +532,7 @@ impl Map1 for Clone { } } -fn kernel_name(root: &str) -> String { +pub fn kernel_name(root: &str) -> String { let dtype = T::DTYPE.as_str(); format!("{root}_{dtype}") } @@ -1310,8 +1310,8 @@ fn slice_src_and_dst<'a, T>( #[derive(Debug)] pub struct CudaStorage { - slice: CudaStorageSlice, - device: CudaDevice, + pub slice: CudaStorageSlice, + pub device: CudaDevice, } pub trait CudaDType: Sized { diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 73214077..adf1451c 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -126,19 +126,62 @@ impl candle::CustomOp1 for SoftmaxLastDim { } } + #[cfg(feature = "cuda")] fn cuda_fwd( &self, - _storage: &candle::CudaStorage, - _layout: &Layout, + storage: &candle::CudaStorage, + layout: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { - candle::bail!("TODO: implement a cuda kernel") + use candle::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr}; + use candle::{CudaDevice, WithDType}; + + struct S; + impl Map1 for S { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let src = match layout.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let el = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let (n_rows, n_cols) = (el / dim_m1, dim_m1); + + let cfg = LaunchConfig { + grid_dim: (n_rows as u32, 1, 1), + block_dim: (1, 32, 1), + shared_mem_bytes: 0, + }; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("softmax"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(el) }.w()?; + let params = (src, &dst, n_cols as i32); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + } + + use candle::backend::BackendStorage; + let dev = storage.device(); + let slice = S.map(&storage.slice, dev, layout)?; + let dst = candle::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, layout.shape().clone())) } } pub fn softmax_last_dim(xs: &Tensor) -> Result { - if xs.device().is_cpu() { - xs.apply_op1_no_bwd(&SoftmaxLastDim) - } else { - softmax(xs, candle::D::Minus1) - } + xs.apply_op1_no_bwd(&SoftmaxLastDim) }