Softmax implementation for cuda. (#747)

This commit is contained in:
Laurent Mazare
2023-09-05 19:38:03 +02:00
committed by GitHub
parent 94c6a8d3d3
commit a0d65585db
2 changed files with 61 additions and 18 deletions

View File

@ -126,19 +126,62 @@ impl candle::CustomOp1 for SoftmaxLastDim {
}
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
_storage: &candle::CudaStorage,
_layout: &Layout,
storage: &candle::CudaStorage,
layout: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
candle::bail!("TODO: implement a cuda kernel")
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
use candle::{CudaDevice, WithDType};
struct S;
impl Map1 for S {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let src = match layout.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let el = layout.shape().elem_count();
let dims = layout.shape().dims();
let dim_m1 = dims[dims.len() - 1];
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
let cfg = LaunchConfig {
grid_dim: (n_rows as u32, 1, 1),
block_dim: (1, 32, 1),
shared_mem_bytes: 0,
};
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (src, &dst, n_cols as i32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
use candle::backend::BackendStorage;
let dev = storage.device();
let slice = S.map(&storage.slice, dev, layout)?;
let dst = candle::cuda_backend::CudaStorage {
slice,
device: dev.clone(),
};
Ok((dst, layout.shape().clone()))
}
}
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
if xs.device().is_cpu() {
xs.apply_op1_no_bwd(&SoftmaxLastDim)
} else {
softmax(xs, candle::D::Minus1)
}
xs.apply_op1_no_bwd(&SoftmaxLastDim)
}