Indexing cuda (#235)

* Allow using uint8_t for indexing.

* Revert the default cuda feature.

* Add a cuda-kernel for index-select.

* Add a test for gather.
This commit is contained in:
Laurent Mazare
2023-07-24 20:22:47 +01:00
committed by GitHub
parent b50f932e7c
commit 581b104f97
3 changed files with 277 additions and 31 deletions

View File

@ -5,7 +5,8 @@ use candle_kernels as kernels;
pub use cudarc;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{
CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
CudaFunction, CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig,
ValidAsZeroBits,
};
use half::{bf16, f16};
use std::sync::{Arc, Mutex};
@ -34,9 +35,6 @@ pub enum CudaError {
#[error("internal error '{0}'")]
InternalError(&'static str),
#[error("internal error '{0}'")]
WrappedError(Box<dyn std::error::Error + 'static + std::marker::Send + std::marker::Sync>),
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous {
lhs_stride: Vec<usize>,
@ -632,28 +630,28 @@ impl<'a> Map1 for Embedding<'a> {
rhs_l: &Layout,
) -> Result<CudaSlice<T>> {
let ids_l = &self.1;
let ids = match &self.0.slice {
CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
let (name, ids) = match &self.0.slice {
CudaStorageSlice::U32(slice) => {
("emb_u32", *slice.slice(ids_l.start_offset()..).device_ptr())
}
CudaStorageSlice::U8(slice) => {
("emb_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
}
_ => Err(CudaError::UnexpectedDType {
msg: "embedding ids should be u32",
msg: "embedding ids should be u8 or u32",
expected: DType::U32,
got: self.0.dtype(),
})
.w()?,
};
let ids = &ids;
let shape = ids_l.shape();
let (v_size, h_size) = rhs_l
.shape()
.dims2()
.map_err(|e| CudaError::WrappedError(Box::new(e)))
.w()?;
let (v_size, h_size) = rhs_l.shape().dims2()?;
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?;
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("emb"), kernels::EMBEDDINGS)?;
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el * h_size) }.w()?;
let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
@ -663,6 +661,109 @@ impl<'a> Map1 for Embedding<'a> {
}
}
struct IndexSelect<'a>(&'a CudaStorage, &'a Layout, usize);
impl<'a> Map1 for IndexSelect<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
src_l: &Layout,
) -> Result<CudaSlice<T>> {
let ids_l = &self.1;
let (name, ids) = match &self.0.slice {
CudaStorageSlice::U32(slice) => {
("is_u32", *slice.slice(ids_l.start_offset()..).device_ptr())
}
CudaStorageSlice::U8(slice) => {
("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
}
_ => Err(CudaError::UnexpectedDType {
msg: "index_select ids should be u8 or u32",
expected: DType::U32,
got: self.0.dtype(),
})
.w()?,
};
let ids_shape = ids_l.shape();
let ids_dims = ids_shape.dims();
let ids_el = ids_shape.elem_count();
let cfg = LaunchConfig::for_num_elems(ids_el as u32);
let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
};
let left_size: usize = src_l.dims()[..self.2].iter().product();
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
let dim_size = src_l.dims()[self.2];
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(ids_el * left_size * right_size) }.w()?;
let params = (
ids_el,
ids_dims.len(),
&ds,
ids,
&src,
&out,
left_size,
dim_size,
right_size,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct Gather<'a>(&'a CudaStorage, &'a Layout, usize);
impl<'a> Map1 for Gather<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
src_l: &Layout,
) -> Result<CudaSlice<T>> {
let ids = &self.0;
let ids_l = &self.1;
let dim = self.2;
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => {
("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr())
}
CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
_ => Err(CudaError::UnexpectedDType {
msg: "gather ids should be u8 or u32",
expected: DType::U32,
got: ids.dtype(),
})?,
};
let el = ids_l.shape().elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
};
let left_sz: usize = src_l.dims()[..dim].iter().product();
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_sz = src_l.dims()[dim];
let ids_dim_sz = ids_l.dims()[dim];
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
el, ids, &src, &out, left_sz, src_dim_sz, ids_dim_sz, right_sz,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
impl<'a> Map2 for Conv1D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@ -991,7 +1092,6 @@ impl BackendStorage for CudaStorage {
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
use cudarc::driver::DevicePtr;
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
@ -1169,11 +1269,15 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(CudaError::InternalError("TODO: implement index-select").into())
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
let device = self.device().clone();
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;
Ok(Self { slice, device })
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
Err(CudaError::InternalError("TODO: implement gather").into())
fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
let device = self.device().clone();
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
Ok(Self { slice, device })
}
fn scatter_add(
&self,