mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Indexing cuda (#235)
* Allow using uint8_t for indexing. * Revert the default cuda feature. * Add a cuda-kernel for index-select. * Add a test for gather.
This commit is contained in:
@ -5,7 +5,8 @@ use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{
|
||||
CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
CudaFunction, CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig,
|
||||
ValidAsZeroBits,
|
||||
};
|
||||
use half::{bf16, f16};
|
||||
use std::sync::{Arc, Mutex};
|
||||
@ -34,9 +35,6 @@ pub enum CudaError {
|
||||
#[error("internal error '{0}'")]
|
||||
InternalError(&'static str),
|
||||
|
||||
#[error("internal error '{0}'")]
|
||||
WrappedError(Box<dyn std::error::Error + 'static + std::marker::Send + std::marker::Sync>),
|
||||
|
||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||
MatMulNonContiguous {
|
||||
lhs_stride: Vec<usize>,
|
||||
@ -632,28 +630,28 @@ impl<'a> Map1 for Embedding<'a> {
|
||||
rhs_l: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let ids_l = &self.1;
|
||||
let ids = match &self.0.slice {
|
||||
CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
|
||||
let (name, ids) = match &self.0.slice {
|
||||
CudaStorageSlice::U32(slice) => {
|
||||
("emb_u32", *slice.slice(ids_l.start_offset()..).device_ptr())
|
||||
}
|
||||
CudaStorageSlice::U8(slice) => {
|
||||
("emb_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "embedding ids should be u32",
|
||||
msg: "embedding ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: self.0.dtype(),
|
||||
})
|
||||
.w()?,
|
||||
};
|
||||
let ids = &ids;
|
||||
let shape = ids_l.shape();
|
||||
let (v_size, h_size) = rhs_l
|
||||
.shape()
|
||||
.dims2()
|
||||
.map_err(|e| CudaError::WrappedError(Box::new(e)))
|
||||
.w()?;
|
||||
let (v_size, h_size) = rhs_l.shape().dims2()?;
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?;
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("emb"), kernels::EMBEDDINGS)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el * h_size) }.w()?;
|
||||
let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
|
||||
@ -663,6 +661,109 @@ impl<'a> Map1 for Embedding<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct IndexSelect<'a>(&'a CudaStorage, &'a Layout, usize);
|
||||
impl<'a> Map1 for IndexSelect<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
src_l: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let ids_l = &self.1;
|
||||
let (name, ids) = match &self.0.slice {
|
||||
CudaStorageSlice::U32(slice) => {
|
||||
("is_u32", *slice.slice(ids_l.start_offset()..).device_ptr())
|
||||
}
|
||||
CudaStorageSlice::U8(slice) => {
|
||||
("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "index_select ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: self.0.dtype(),
|
||||
})
|
||||
.w()?,
|
||||
};
|
||||
let ids_shape = ids_l.shape();
|
||||
let ids_dims = ids_shape.dims();
|
||||
let ids_el = ids_shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(ids_el as u32);
|
||||
let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
|
||||
};
|
||||
let left_size: usize = src_l.dims()[..self.2].iter().product();
|
||||
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
|
||||
let dim_size = src_l.dims()[self.2];
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(ids_el * left_size * right_size) }.w()?;
|
||||
let params = (
|
||||
ids_el,
|
||||
ids_dims.len(),
|
||||
&ds,
|
||||
ids,
|
||||
&src,
|
||||
&out,
|
||||
left_size,
|
||||
dim_size,
|
||||
right_size,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct Gather<'a>(&'a CudaStorage, &'a Layout, usize);
|
||||
impl<'a> Map1 for Gather<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
src_l: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let ids = &self.0;
|
||||
let ids_l = &self.1;
|
||||
let dim = self.2;
|
||||
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
};
|
||||
let (name, ids) = match &ids.slice {
|
||||
CudaStorageSlice::U32(slice) => {
|
||||
("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr())
|
||||
}
|
||||
CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "gather ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let el = ids_l.shape().elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
};
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let ids_dim_sz = ids_l.dims()[dim];
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
el, ids, &src, &out, left_sz, src_dim_sz, ids_dim_sz, right_sz,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
|
||||
impl<'a> Map2 for Conv1D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
@ -991,7 +1092,6 @@ impl BackendStorage for CudaStorage {
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
use cudarc::driver::DevicePtr;
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
@ -1169,11 +1269,15 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(CudaError::InternalError("TODO: implement index-select").into())
|
||||
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(CudaError::InternalError("TODO: implement gather").into())
|
||||
fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
fn scatter_add(
|
||||
&self,
|
||||
|
@ -316,10 +316,7 @@ fn cmp(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select() -> Result<()> {
|
||||
// TODO: Test on cuda once the kernel is available.
|
||||
let device = &Device::Cpu;
|
||||
fn index_select(device: &Device) -> Result<()> {
|
||||
let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?;
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
@ -349,6 +346,38 @@ fn index_select() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn gather(device: &Device) -> Result<()> {
|
||||
let ids = Tensor::new(&[[0u32], [2u32], [1u32], [0u32]], device)?;
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
let hs = t.gather(&ids, 1)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0], [5.0], [7.0], [9.0]]);
|
||||
let ids = Tensor::new(
|
||||
&[[0u32, 0u32], [2u32, 0u32], [1u32, 1u32], [0u32, 2u32]],
|
||||
device,
|
||||
)?;
|
||||
let hs = t.gather(&ids, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 0.0], [5.0, 3.0], [7.0, 7.0], [9.0, 11.0]]
|
||||
);
|
||||
let ids = Tensor::new(&[[0u32, 2u32, 0u32]], device)?;
|
||||
let hs = t.gather(&ids, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0]]);
|
||||
let ids = Tensor::new(&[[0u32, 2u32, 0u32], [0u32, 1u32, 1u32]], device)?;
|
||||
let hs = t.gather(&ids, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn matmul(device: &Device) -> Result<()> {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
@ -513,3 +542,5 @@ test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu);
|
||||
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
||||
test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||
test_device!(gather, gather_cpu, gather_gpu);
|
||||
|
@ -3,12 +3,12 @@
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
#define EMB_OP(TYPENAME, FN_NAME) \
|
||||
#define EMB_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const uint32_t *ids, \
|
||||
const INDEX_TYPENAME *ids, \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out, \
|
||||
const size_t h_size, \
|
||||
@ -29,15 +29,126 @@ extern "C" __global__ void FN_NAME( \
|
||||
} \
|
||||
} \
|
||||
|
||||
template<typename T, typename I>
|
||||
__device__ void index_select(
|
||||
const size_t numel,
|
||||
const size_t num_dims,
|
||||
const size_t *info,
|
||||
const I *ids,
|
||||
const T *inp,
|
||||
T *out,
|
||||
const size_t left_size,
|
||||
const size_t dim_size,
|
||||
const size_t right_size
|
||||
) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
if (is_contiguous(num_dims, dims, strides)) {
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
for (unsigned int j = 0; j < left_size; ++j) {
|
||||
memcpy(&out[(i + j * numel) * right_size], &inp[(j * dim_size + ids[i]) * right_size], right_size * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||
for (unsigned int j = 0; j < left_size; ++j) {
|
||||
memcpy(&out[(i + j * numel) * right_size], &inp[(j * dim_size + ids[strided_i]) * right_size], right_size * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define IS_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const INDEX_TYPENAME *ids, \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out, \
|
||||
const size_t left_size, \
|
||||
const size_t dim_size, \
|
||||
const size_t right_size \
|
||||
) { index_select(numel, num_dims, info, ids, inp, out, left_size, dim_size, right_size); } \
|
||||
|
||||
template<typename T, typename I>
|
||||
__device__ void gather(
|
||||
const size_t numel,
|
||||
const I *ids,
|
||||
const T *inp,
|
||||
T *out,
|
||||
const size_t left_size,
|
||||
const size_t src_dim_size,
|
||||
const size_t ids_dim_size,
|
||||
const size_t right_size
|
||||
) {
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
size_t post = i % right_size;
|
||||
size_t idx = ids[i];
|
||||
size_t pre = i / (right_size * ids_dim_size);
|
||||
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
|
||||
out[i] = inp[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
#define GATHER_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const INDEX_TYPENAME *ids, \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out, \
|
||||
const size_t left_size, \
|
||||
const size_t src_dim_size, \
|
||||
const size_t ids_dim_size, \
|
||||
const size_t right_size \
|
||||
) { gather(numel, ids, inp, out, left_size, src_dim_size, ids_dim_size, right_size); } \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
EMB_OP(__nv_bfloat16, emb_bf16)
|
||||
EMB_OP(__nv_bfloat16, uint32_t, emb_u32_bf16)
|
||||
EMB_OP(__nv_bfloat16, uint8_t, emb_u8_bf16)
|
||||
IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16)
|
||||
IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16)
|
||||
GATHER_OP(__nv_bfloat16, uint32_t, gather_u32_bf16)
|
||||
GATHER_OP(__nv_bfloat16, uint8_t, gather_u8_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
EMB_OP(__half, emb_f16)
|
||||
EMB_OP(__half, uint32_t, emb_u32_f16)
|
||||
EMB_OP(__half, uint8_t, emb_u8_f16)
|
||||
IS_OP(__half, uint32_t, is_u32_f16)
|
||||
IS_OP(__half, uint8_t, is_u8_f16)
|
||||
GATHER_OP(__half, uint32_t, gather_u32_f16)
|
||||
GATHER_OP(__half, uint8_t, gather_u8_f16)
|
||||
#endif
|
||||
|
||||
EMB_OP(float, emb_f32)
|
||||
EMB_OP(double, emb_f64)
|
||||
EMB_OP(uint8_t, emb_u8)
|
||||
EMB_OP(uint32_t, emb_u32)
|
||||
EMB_OP(float, uint32_t, emb_u32_f32)
|
||||
EMB_OP(double, uint32_t, emb_u32_f64)
|
||||
EMB_OP(uint8_t, uint32_t, emb_u32_u8)
|
||||
EMB_OP(uint32_t, uint32_t, emb_u32_u32)
|
||||
|
||||
EMB_OP(float, uint8_t, emb_u8_f32)
|
||||
EMB_OP(double, uint8_t, emb_u8_f64)
|
||||
EMB_OP(uint8_t, uint8_t, emb_u8_u8)
|
||||
EMB_OP(uint32_t, uint8_t, emb_u8_u32)
|
||||
|
||||
IS_OP(float, uint32_t, is_u32_f32)
|
||||
IS_OP(double, uint32_t, is_u32_f64)
|
||||
IS_OP(uint8_t, uint32_t, is_u32_u8)
|
||||
IS_OP(uint32_t, uint32_t, is_u32_u32)
|
||||
|
||||
IS_OP(float, uint8_t, is_u8_f32)
|
||||
IS_OP(double, uint8_t, is_u8_f64)
|
||||
IS_OP(uint8_t, uint8_t, is_u8_u8)
|
||||
IS_OP(uint32_t, uint8_t, is_u8_u32)
|
||||
|
||||
GATHER_OP(float, uint32_t, gather_u32_f32)
|
||||
GATHER_OP(double, uint32_t, gather_u32_f64)
|
||||
GATHER_OP(uint8_t, uint32_t, gather_u32_u8)
|
||||
GATHER_OP(uint32_t, uint32_t, gather_u32_u32)
|
||||
|
||||
GATHER_OP(float, uint8_t, gather_u8_f32)
|
||||
GATHER_OP(double, uint8_t, gather_u8_f64)
|
||||
GATHER_OP(uint8_t, uint8_t, gather_u8_u8)
|
||||
GATHER_OP(uint32_t, uint8_t, gather_u8_u32)
|
||||
|
Reference in New Issue
Block a user