mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Cuda kernels for IndexAdd/ScatterAdd. (#236)
* Skeleton methods for IndexAdd/ScatterAdd. * Add a Map2InPlace trait. * Add the glue code for the index-add/scatter-add kernels. * Tweak the file name: embeddings -> indexing. * Add the cuda kernel for indexadd. * And add the scatter-add kernels.
This commit is contained in:
@ -398,12 +398,42 @@ trait Map2 {
|
|||||||
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||||
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||||
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
|
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||||
};
|
};
|
||||||
Ok(out)
|
Ok(out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trait Map2InPlace {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
dst: &mut CudaSlice<T>,
|
||||||
|
dst_shape: &Shape,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
src_l: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<()>;
|
||||||
|
|
||||||
|
fn map(
|
||||||
|
&self,
|
||||||
|
dst: &mut S,
|
||||||
|
dst_s: &Shape,
|
||||||
|
src: &S,
|
||||||
|
src_l: &Layout,
|
||||||
|
d: &CudaDevice,
|
||||||
|
) -> Result<()> {
|
||||||
|
match (dst, src) {
|
||||||
|
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
trait Map2Any {
|
trait Map2Any {
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
&self,
|
&self,
|
||||||
@ -651,7 +681,7 @@ impl<'a> Map1 for Embedding<'a> {
|
|||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?;
|
let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?;
|
||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::EMBEDDINGS)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(el * h_size) }.w()?;
|
let out = unsafe { dev.alloc::<T>(el * h_size) }.w()?;
|
||||||
let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
|
let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
|
||||||
@ -696,7 +726,7 @@ impl<'a> Map1 for IndexSelect<'a> {
|
|||||||
let left_size: usize = src_l.dims()[..self.2].iter().product();
|
let left_size: usize = src_l.dims()[..self.2].iter().product();
|
||||||
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
|
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
|
||||||
let dim_size = src_l.dims()[self.2];
|
let dim_size = src_l.dims()[self.2];
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::EMBEDDINGS)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(ids_el * left_size * right_size) }.w()?;
|
let out = unsafe { dev.alloc::<T>(ids_el * left_size * right_size) }.w()?;
|
||||||
let params = (
|
let params = (
|
||||||
@ -752,7 +782,7 @@ impl<'a> Map1 for Gather<'a> {
|
|||||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||||
let src_dim_sz = src_l.dims()[dim];
|
let src_dim_sz = src_l.dims()[dim];
|
||||||
let ids_dim_sz = ids_l.dims()[dim];
|
let ids_dim_sz = ids_l.dims()[dim];
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::EMBEDDINGS)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let params = (
|
let params = (
|
||||||
@ -764,6 +794,97 @@ impl<'a> Map1 for Gather<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct IndexAdd<'a>(&'a CudaStorage, &'a Layout, usize);
|
||||||
|
impl<'a> Map2InPlace for IndexAdd<'a> {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
dst: &mut CudaSlice<T>,
|
||||||
|
dst_shape: &Shape,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
src_l: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<()> {
|
||||||
|
let ids = &self.0;
|
||||||
|
let ids_l = &self.1;
|
||||||
|
let dim = self.2;
|
||||||
|
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
|
||||||
|
Some(o12) => o12,
|
||||||
|
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||||
|
};
|
||||||
|
let (name, ids) = match &ids.slice {
|
||||||
|
CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||||
|
CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||||
|
_ => Err(CudaError::UnexpectedDType {
|
||||||
|
msg: "index-add ids should be u8 or u32",
|
||||||
|
expected: DType::U32,
|
||||||
|
got: ids.dtype(),
|
||||||
|
})?,
|
||||||
|
};
|
||||||
|
let src = match src_l.contiguous_offsets() {
|
||||||
|
Some((o1, o2)) => src.slice(o1..o2),
|
||||||
|
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||||
|
};
|
||||||
|
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||||
|
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||||
|
let src_dim_sz = src_l.dims()[dim];
|
||||||
|
let dst_dim_sz = dst_shape.dims()[dim];
|
||||||
|
let ids_dim_sz = ids_l.dims()[0];
|
||||||
|
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let params = (
|
||||||
|
ids, ids_dim_sz, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz,
|
||||||
|
);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ScatterAdd<'a>(&'a CudaStorage, &'a Layout, usize);
|
||||||
|
impl<'a> Map2InPlace for ScatterAdd<'a> {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
dst: &mut CudaSlice<T>,
|
||||||
|
_dst_shape: &Shape,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
src_l: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<()> {
|
||||||
|
let ids = &self.0;
|
||||||
|
let ids_l = &self.1;
|
||||||
|
let dim = self.2;
|
||||||
|
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
|
||||||
|
Some(o12) => o12,
|
||||||
|
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||||
|
};
|
||||||
|
let (name, ids) = match &ids.slice {
|
||||||
|
CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||||
|
CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||||
|
_ => Err(CudaError::UnexpectedDType {
|
||||||
|
msg: "scatter-add ids should be u8 or u32",
|
||||||
|
expected: DType::U32,
|
||||||
|
got: ids.dtype(),
|
||||||
|
})?,
|
||||||
|
};
|
||||||
|
let src = match src_l.contiguous_offsets() {
|
||||||
|
Some((o1, o2)) => src.slice(o1..o2),
|
||||||
|
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||||
|
};
|
||||||
|
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||||
|
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||||
|
let src_dim_sz = src_l.dims()[dim];
|
||||||
|
let ids_dim_sz = ids_l.dims()[dim];
|
||||||
|
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let params = (ids, &src, dst, left_sz, src_dim_sz, ids_dim_sz, right_sz);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
|
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
|
||||||
impl<'a> Map2 for Conv1D<'a> {
|
impl<'a> Map2 for Conv1D<'a> {
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
@ -1004,8 +1125,7 @@ fn gemm_config<T>(
|
|||||||
lhs_stride: lhs_stride.to_vec(),
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})
|
})?
|
||||||
.w()?
|
|
||||||
};
|
};
|
||||||
// The b tensor has dims batching, m, k (lhs)
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||||
@ -1017,8 +1137,7 @@ fn gemm_config<T>(
|
|||||||
lhs_stride: lhs_stride.to_vec(),
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})
|
})?
|
||||||
.w()?
|
|
||||||
};
|
};
|
||||||
// The setup below was copied from:
|
// The setup below was copied from:
|
||||||
// https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531
|
// https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531
|
||||||
@ -1043,8 +1162,7 @@ fn gemm_config<T>(
|
|||||||
lhs_stride: lhs_stride.to_vec(),
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})
|
})?,
|
||||||
.w()?,
|
|
||||||
};
|
};
|
||||||
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
|
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
|
||||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
@ -1054,8 +1172,7 @@ fn gemm_config<T>(
|
|||||||
lhs_stride: lhs_stride.to_vec(),
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})
|
})?,
|
||||||
.w()?,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(StridedBatchedConfig {
|
Ok(StridedBatchedConfig {
|
||||||
@ -1281,25 +1398,33 @@ impl BackendStorage for CudaStorage {
|
|||||||
}
|
}
|
||||||
fn scatter_add(
|
fn scatter_add(
|
||||||
&self,
|
&self,
|
||||||
_: &Layout,
|
l: &Layout,
|
||||||
_: &Self,
|
ids: &Self,
|
||||||
_: &Layout,
|
ids_l: &Layout,
|
||||||
_: &Self,
|
src: &Self,
|
||||||
_: &Layout,
|
src_l: &Layout,
|
||||||
_: usize,
|
dim: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Err(CudaError::InternalError("TODO: implement scatter-add").into())
|
let device = self.device().clone();
|
||||||
|
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
|
||||||
|
self.copy_strided_src(&mut acc, 0, l)?;
|
||||||
|
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||||
|
Ok(acc)
|
||||||
}
|
}
|
||||||
fn index_add(
|
fn index_add(
|
||||||
&self,
|
&self,
|
||||||
_: &Layout,
|
l: &Layout,
|
||||||
_: &Self,
|
ids: &Self,
|
||||||
_: &Layout,
|
ids_l: &Layout,
|
||||||
_: &Self,
|
src: &Self,
|
||||||
_: &Layout,
|
src_l: &Layout,
|
||||||
_: usize,
|
dim: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Err(CudaError::InternalError("TODO: implement index-add").into())
|
let device = self.device().clone();
|
||||||
|
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
|
||||||
|
self.copy_strided_src(&mut acc, 0, l)?;
|
||||||
|
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||||
|
Ok(acc)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn matmul(
|
fn matmul(
|
||||||
@ -1364,7 +1489,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
.w()?;
|
.w()?;
|
||||||
CudaStorageSlice::F64(out)
|
CudaStorageSlice::F64(out)
|
||||||
}
|
}
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in matmul op")).w()?,
|
_ => Err(CudaError::InternalError("dtype mismatch in matmul op"))?,
|
||||||
};
|
};
|
||||||
let device = dev.clone();
|
let device = dev.clone();
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
@ -1452,8 +1577,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
}
|
}
|
||||||
_ => Err(CudaError::InternalError(
|
_ => Err(CudaError::InternalError(
|
||||||
"dtype mismatch in copy_strided op",
|
"dtype mismatch in copy_strided op",
|
||||||
))
|
))?,
|
||||||
.w()?,
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -105,6 +105,79 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
const size_t right_size \
|
const size_t right_size \
|
||||||
) { gather(numel, ids, inp, out, left_size, src_dim_size, ids_dim_size, right_size); } \
|
) { gather(numel, ids, inp, out, left_size, src_dim_size, ids_dim_size, right_size); } \
|
||||||
|
|
||||||
|
template<typename T, typename I>
|
||||||
|
__device__ void index_add(
|
||||||
|
const I *ids,
|
||||||
|
const size_t ids_dim_size,
|
||||||
|
const T *inp,
|
||||||
|
T *out,
|
||||||
|
const size_t left_size,
|
||||||
|
const size_t src_dim_size,
|
||||||
|
const size_t dst_dim_size,
|
||||||
|
const size_t right_size
|
||||||
|
) {
|
||||||
|
const size_t numel = left_size * right_size;
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||||
|
const size_t pre = i / right_size;
|
||||||
|
const size_t post = i % right_size;
|
||||||
|
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
||||||
|
const size_t idx = ids[j];
|
||||||
|
const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||||
|
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||||
|
out[dst_i] += inp[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||||
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
const INDEX_TYPENAME *ids, \
|
||||||
|
const size_t ids_dim_size, \
|
||||||
|
const TYPENAME *inp, \
|
||||||
|
TYPENAME *out, \
|
||||||
|
const size_t left_size, \
|
||||||
|
const size_t src_dim_size, \
|
||||||
|
const size_t dst_dim_size, \
|
||||||
|
const size_t right_size \
|
||||||
|
) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
|
||||||
|
|
||||||
|
template<typename T, typename I>
|
||||||
|
__device__ void scatter_add(
|
||||||
|
const I *ids,
|
||||||
|
const size_t ids_dim_size,
|
||||||
|
const T *inp,
|
||||||
|
T *out,
|
||||||
|
const size_t left_size,
|
||||||
|
const size_t src_dim_size,
|
||||||
|
const size_t dst_dim_size,
|
||||||
|
const size_t right_size
|
||||||
|
) {
|
||||||
|
const size_t numel = left_size * right_size;
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||||
|
const size_t pre = i / right_size;
|
||||||
|
const size_t post = i % right_size;
|
||||||
|
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
||||||
|
const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||||
|
const size_t idx = ids[src_i];
|
||||||
|
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||||
|
out[dst_i] += inp[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define SA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||||
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
const INDEX_TYPENAME *ids, \
|
||||||
|
const size_t ids_dim_size, \
|
||||||
|
const TYPENAME *inp, \
|
||||||
|
TYPENAME *out, \
|
||||||
|
const size_t left_size, \
|
||||||
|
const size_t src_dim_size, \
|
||||||
|
const size_t dst_dim_size, \
|
||||||
|
const size_t right_size \
|
||||||
|
) { scatter_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
|
||||||
|
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 800
|
#if __CUDA_ARCH__ >= 800
|
||||||
EMB_OP(__nv_bfloat16, uint32_t, emb_u32_bf16)
|
EMB_OP(__nv_bfloat16, uint32_t, emb_u32_bf16)
|
||||||
EMB_OP(__nv_bfloat16, uint8_t, emb_u8_bf16)
|
EMB_OP(__nv_bfloat16, uint8_t, emb_u8_bf16)
|
||||||
@ -112,6 +185,10 @@ IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16)
|
|||||||
IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16)
|
IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16)
|
||||||
GATHER_OP(__nv_bfloat16, uint32_t, gather_u32_bf16)
|
GATHER_OP(__nv_bfloat16, uint32_t, gather_u32_bf16)
|
||||||
GATHER_OP(__nv_bfloat16, uint8_t, gather_u8_bf16)
|
GATHER_OP(__nv_bfloat16, uint8_t, gather_u8_bf16)
|
||||||
|
IA_OP(__nv_bfloat16, uint32_t, ia_u32_bf16)
|
||||||
|
IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16)
|
||||||
|
SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16)
|
||||||
|
SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
@ -121,6 +198,10 @@ IS_OP(__half, uint32_t, is_u32_f16)
|
|||||||
IS_OP(__half, uint8_t, is_u8_f16)
|
IS_OP(__half, uint8_t, is_u8_f16)
|
||||||
GATHER_OP(__half, uint32_t, gather_u32_f16)
|
GATHER_OP(__half, uint32_t, gather_u32_f16)
|
||||||
GATHER_OP(__half, uint8_t, gather_u8_f16)
|
GATHER_OP(__half, uint8_t, gather_u8_f16)
|
||||||
|
IA_OP(__half, uint32_t, ia_u32_f16)
|
||||||
|
IA_OP(__half, uint8_t, ia_u8_f16)
|
||||||
|
SA_OP(__half, uint32_t, sa_u32_f16)
|
||||||
|
SA_OP(__half, uint8_t, sa_u8_f16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
EMB_OP(float, uint32_t, emb_u32_f32)
|
EMB_OP(float, uint32_t, emb_u32_f32)
|
||||||
@ -152,3 +233,23 @@ GATHER_OP(float, uint8_t, gather_u8_f32)
|
|||||||
GATHER_OP(double, uint8_t, gather_u8_f64)
|
GATHER_OP(double, uint8_t, gather_u8_f64)
|
||||||
GATHER_OP(uint8_t, uint8_t, gather_u8_u8)
|
GATHER_OP(uint8_t, uint8_t, gather_u8_u8)
|
||||||
GATHER_OP(uint32_t, uint8_t, gather_u8_u32)
|
GATHER_OP(uint32_t, uint8_t, gather_u8_u32)
|
||||||
|
|
||||||
|
IA_OP(float, uint32_t, ia_u32_f32)
|
||||||
|
IA_OP(double, uint32_t, ia_u32_f64)
|
||||||
|
IA_OP(uint8_t, uint32_t, ia_u32_u8)
|
||||||
|
IA_OP(uint32_t, uint32_t, ia_u32_u32)
|
||||||
|
|
||||||
|
IA_OP(float, uint8_t, ia_u8_f32)
|
||||||
|
IA_OP(double, uint8_t, ia_u8_f64)
|
||||||
|
IA_OP(uint8_t, uint8_t, ia_u8_u8)
|
||||||
|
IA_OP(uint32_t, uint8_t, ia_u8_u32)
|
||||||
|
|
||||||
|
SA_OP(float, uint32_t, sa_u32_f32)
|
||||||
|
SA_OP(double, uint32_t, sa_u32_f64)
|
||||||
|
SA_OP(uint8_t, uint32_t, sa_u32_u8)
|
||||||
|
SA_OP(uint32_t, uint32_t, sa_u32_u32)
|
||||||
|
|
||||||
|
SA_OP(float, uint8_t, sa_u8_f32)
|
||||||
|
SA_OP(double, uint8_t, sa_u8_f64)
|
||||||
|
SA_OP(uint8_t, uint8_t, sa_u8_u8)
|
||||||
|
SA_OP(uint32_t, uint8_t, sa_u8_u32)
|
@ -2,8 +2,8 @@ pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
|||||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||||
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||||
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
||||||
pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx"));
|
|
||||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||||
|
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
||||||
|
Reference in New Issue
Block a user