mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Cleanup some todos. (#226)
* Cleanup some todos. * Fix more todo. * Optimize for the contiguous case. * Add the IntDType trait. * Handle the intdtype trait for more ops. * Remove a todo. * Remove a todo.
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{DType, Error, Layout, Result, Shape, WithDType};
|
||||
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
|
||||
use half::{bf16, f16};
|
||||
|
||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||
@ -133,9 +133,9 @@ impl Map2U8 for Cmp {
|
||||
}
|
||||
}
|
||||
|
||||
struct WCond<'a>(&'a [u32], &'a Layout);
|
||||
struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
|
||||
|
||||
impl<'a> Map2 for WCond<'a> {
|
||||
impl<'a, I: IntDType> Map2 for WCond<'a, I> {
|
||||
const OP: &'static str = "where";
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
|
||||
@ -150,14 +150,20 @@ impl<'a> Map2 for WCond<'a> {
|
||||
let f = &f[o_f1..o_f2];
|
||||
pred.iter()
|
||||
.zip(t.iter().zip(f.iter()))
|
||||
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
|
||||
.map(|(p, (&t, &f))| if p.is_true() { t } else { f })
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
_ => self
|
||||
.1
|
||||
.strided_index()
|
||||
.zip(t_l.strided_index().zip(f_l.strided_index()))
|
||||
.map(|(i_p, (i_t, i_f))| if self.0[i_p] > 0 { t[i_t] } else { f[i_f] })
|
||||
.map(|(i_p, (i_t, i_f))| {
|
||||
if self.0[i_p].is_true() {
|
||||
t[i_t]
|
||||
} else {
|
||||
f[i_f]
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
};
|
||||
Ok(vs)
|
||||
@ -628,13 +634,13 @@ impl Map1 for Affine {
|
||||
}
|
||||
}
|
||||
|
||||
struct Gather<'a> {
|
||||
ids: &'a [u32],
|
||||
struct Gather<'a, I: IntDType> {
|
||||
ids: &'a [I],
|
||||
ids_l: &'a Layout,
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a> Map1 for Gather<'a> {
|
||||
impl<'a, I: IntDType> Map1 for Gather<'a, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
@ -663,7 +669,7 @@ impl<'a> Map1 for Gather<'a> {
|
||||
let start_dst_idx = start_dst_idx + i * dst_right_len;
|
||||
for right_i in 0..dst_right_len {
|
||||
let dst_idx = start_dst_idx + right_i;
|
||||
let index = ids[dst_idx] as usize;
|
||||
let index = ids[dst_idx].as_usize();
|
||||
if index >= src_dim_len {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
@ -681,13 +687,13 @@ impl<'a> Map1 for Gather<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct IndexSelect<'a> {
|
||||
ids: &'a [u32],
|
||||
struct IndexSelect<'a, T: IntDType> {
|
||||
ids: &'a [T],
|
||||
ids_l: &'a Layout,
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a> Map1 for IndexSelect<'a> {
|
||||
impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
@ -714,7 +720,7 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
let start_src_idx = left_i * right_len * src_dim;
|
||||
let start_dst_idx = left_i * right_len * n_ids;
|
||||
for i in 0..n_ids {
|
||||
let index = self.ids[self.ids_l.start_offset() + stride_ids * i] as usize;
|
||||
let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
|
||||
if index >= src_dim {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
@ -733,13 +739,13 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct ScatterAdd<'a> {
|
||||
ids: &'a [u32],
|
||||
struct ScatterAdd<'a, I: IntDType> {
|
||||
ids: &'a [I],
|
||||
ids_l: &'a Layout,
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a> Map2 for ScatterAdd<'a> {
|
||||
impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
||||
const OP: &'static str = "scatter-add";
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let dst_len = l1.shape().elem_count();
|
||||
@ -771,7 +777,7 @@ impl<'a> Map2 for ScatterAdd<'a> {
|
||||
let start_ids_idx = start_ids_idx + i * ids_right_len;
|
||||
for right_i in 0..dst_right_len {
|
||||
let ids_idx = start_ids_idx + right_i;
|
||||
let index = ids[ids_idx] as usize;
|
||||
let index = ids[ids_idx].as_usize();
|
||||
if index >= dst_dim_len {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
@ -790,12 +796,12 @@ impl<'a> Map2 for ScatterAdd<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct IndexAdd<'a> {
|
||||
ids: &'a [u32],
|
||||
struct IndexAdd<'a, I: IntDType> {
|
||||
ids: &'a [I],
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a> Map2 for IndexAdd<'a> {
|
||||
impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
||||
const OP: &'static str = "index-add";
|
||||
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
|
||||
// v1, l1 -> self
|
||||
@ -811,8 +817,8 @@ impl<'a> Map2 for IndexAdd<'a> {
|
||||
let max_idx = l1.dims()[dim];
|
||||
let stride = src_l.stride()[dim];
|
||||
if dim == 0 {
|
||||
for (src_idx, &dst_idx) in self.ids.iter().enumerate() {
|
||||
let dst_idx = dst_idx as usize;
|
||||
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
||||
let dst_idx = dst_idx.as_usize();
|
||||
if dst_idx >= max_idx {
|
||||
Err(Error::InvalidIndex {
|
||||
index: dst_idx,
|
||||
@ -831,8 +837,8 @@ impl<'a> Map2 for IndexAdd<'a> {
|
||||
} else {
|
||||
let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
|
||||
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
|
||||
for (src_idx, &dst_idx) in self.ids.iter().enumerate() {
|
||||
let dst_idx = dst_idx as usize;
|
||||
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
||||
let dst_idx = dst_idx.as_usize();
|
||||
if dst_idx >= max_idx {
|
||||
Err(Error::InvalidIndex {
|
||||
index: dst_idx,
|
||||
@ -856,31 +862,52 @@ impl<'a> Map2 for IndexAdd<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct Embedding<'a> {
|
||||
struct Embedding<'a, I: IntDType> {
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
ids: &'a [u32],
|
||||
ids: &'a [I],
|
||||
ids_l: &'a Layout,
|
||||
}
|
||||
|
||||
impl<'a> Map1 for Embedding<'a> {
|
||||
impl<'a, I: IntDType> Map1 for Embedding<'a, I> {
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
// TODO: We assume that vs is contiguous here.
|
||||
if !layout.is_contiguous() {
|
||||
Err(Error::RequiresContiguous { op: "embedding" })?
|
||||
}
|
||||
let vs = &vs[layout.start_offset()..];
|
||||
let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size);
|
||||
// TODO: Optimize for the case where ids are contiguous.
|
||||
for index in self.ids_l.strided_index() {
|
||||
let index = self.ids[index].try_into()?;
|
||||
if index >= self.vocab_size {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: self.vocab_size,
|
||||
op: "take",
|
||||
match self.ids_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => {
|
||||
for index in self.ids[o1..o2].iter() {
|
||||
let index = index.as_usize();
|
||||
if index >= self.vocab_size {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: self.vocab_size,
|
||||
op: "take",
|
||||
}
|
||||
.bt())?
|
||||
} else {
|
||||
let hidden_size = self.hidden_size;
|
||||
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
for index in self.ids_l.strided_index() {
|
||||
let index = self.ids[index].as_usize();
|
||||
if index >= self.vocab_size {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: self.vocab_size,
|
||||
op: "take",
|
||||
}
|
||||
.bt())?
|
||||
} else {
|
||||
let hidden_size = self.hidden_size;
|
||||
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
.bt())?
|
||||
} else {
|
||||
let hidden_size = self.hidden_size;
|
||||
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
Ok(values)
|
||||
@ -1671,9 +1698,11 @@ impl BackendStorage for CpuStorage {
|
||||
f: &Self,
|
||||
f_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
// TODO: Support types that could be casted to a boolean.
|
||||
let pred = self.as_slice::<u32>()?;
|
||||
WCond(pred, layout).map(t, t_l, f, f_l)
|
||||
match self {
|
||||
Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
|
||||
Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
|
||||
}
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
@ -1687,25 +1716,40 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
|
||||
fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
let ids = self.as_slice::<u32>()?;
|
||||
let (vocab_size, hidden_size) = rhs_l.shape().dims2()?;
|
||||
Embedding {
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
ids,
|
||||
ids_l,
|
||||
match self {
|
||||
Self::U8(ids) => Embedding {
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
ids,
|
||||
ids_l,
|
||||
}
|
||||
.map(rhs, rhs_l),
|
||||
Self::U32(ids) => Embedding {
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
ids,
|
||||
ids_l,
|
||||
}
|
||||
.map(rhs, rhs_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "embedding")),
|
||||
}
|
||||
.map(rhs, rhs_l)
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let ids = ids.as_slice::<u32>()?;
|
||||
IndexSelect { ids, ids_l, dim }.map(self, l)
|
||||
match ids {
|
||||
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
|
||||
}
|
||||
}
|
||||
|
||||
fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let ids = ids.as_slice::<u32>()?;
|
||||
Gather { ids, ids_l, dim }.map(self, l)
|
||||
match ids {
|
||||
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
|
||||
}
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
@ -1717,8 +1761,11 @@ impl BackendStorage for CpuStorage {
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let ids = ids.as_slice::<u32>()?;
|
||||
ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l)
|
||||
match ids {
|
||||
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
|
||||
}
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
@ -1730,12 +1777,23 @@ impl BackendStorage for CpuStorage {
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let ids = ids.as_slice::<u32>()?;
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
match ids {
|
||||
Self::U8(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
Self::U32(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
|
||||
}
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
|
@ -119,3 +119,26 @@ with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
|
||||
with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
|
||||
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
|
||||
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
|
||||
|
||||
pub trait IntDType {
|
||||
fn is_true(&self) -> bool;
|
||||
fn as_usize(&self) -> usize;
|
||||
}
|
||||
|
||||
impl IntDType for u32 {
|
||||
fn is_true(&self) -> bool {
|
||||
*self != 0
|
||||
}
|
||||
fn as_usize(&self) -> usize {
|
||||
*self as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl IntDType for u8 {
|
||||
fn is_true(&self) -> bool {
|
||||
*self != 0
|
||||
}
|
||||
fn as_usize(&self) -> usize {
|
||||
*self as usize
|
||||
}
|
||||
}
|
||||
|
@ -61,7 +61,7 @@ mod variable;
|
||||
|
||||
pub use cpu_backend::CpuStorage;
|
||||
pub use device::{Device, DeviceLocation};
|
||||
pub use dtype::{DType, WithDType};
|
||||
pub use dtype::{DType, IntDType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use indexer::IndexOp;
|
||||
pub use layout::Layout;
|
||||
|
@ -206,7 +206,6 @@ impl Storage {
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
// TODO: Different code path for the contiguous case?
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.unary_impl::<B>(layout)?;
|
||||
|
@ -270,7 +270,11 @@ fn cat(device: &Device) -> Result<()> {
|
||||
[2.0, 7.0, 1.0, 8.0, 2.0]
|
||||
]
|
||||
);
|
||||
// TODO: This is not the expected answer, to be fixed!
|
||||
// PyTorch equivalent:
|
||||
// import torch
|
||||
// t1 = torch.tensor([[3, 1, 4, 1, 5], [2, 7, 1, 8, 2]])
|
||||
// t2 = torch.tensor([[5]*5, [2, 7, 1, 8, 2]])
|
||||
// torch.cat([t1.t(), t2.t()], dim=1).t()
|
||||
assert_eq!(
|
||||
Tensor::cat(&[&t1.t()?, &t2.t()?], 1)?
|
||||
.t()?
|
||||
@ -282,7 +286,6 @@ fn cat(device: &Device) -> Result<()> {
|
||||
[2.0, 7.0, 1.0, 8.0, 2.0]
|
||||
]
|
||||
);
|
||||
// TODO: This is not the expected answer, to be fixed!
|
||||
assert_eq!(
|
||||
Tensor::cat(&[&t1, &t2], 1)?.to_vec2::<f32>()?,
|
||||
[
|
||||
|
@ -1,26 +1,20 @@
|
||||
// TODO: Use a proper distributed reduction rather than atomicAdd.
|
||||
// https://people.maths.ox.ac.uk/gilesm/cuda/prac4/reduction.pdf
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
#include<cmath>
|
||||
#include <cmath>
|
||||
#include <stdint.h>
|
||||
|
||||
const int BLOCK_SIZE = 1024;
|
||||
|
||||
// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32 but
|
||||
// also expect a f32 output so that this can be used for normalization e.g. in softmax.
|
||||
// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32
|
||||
// but also expect a f32 output so that this can be used for normalization e.g.
|
||||
// in softmax.
|
||||
|
||||
// Fast reduce sum kernel, this assumes that the dimensions to loop over are at
|
||||
// the end, each block is responsible for populating one value in the output array.
|
||||
// There are at most 1024 threads per block.
|
||||
// the end, each block is responsible for populating one value in the output
|
||||
// array. There are at most 1024 threads per block.
|
||||
template <typename T>
|
||||
__device__ void fast_sum(
|
||||
const size_t src_numel,
|
||||
const size_t el_to_sum_per_block,
|
||||
const size_t num_dims,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
__device__ void
|
||||
fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
const size_t num_dims, const size_t *info, const T *src, T *dst) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
|
||||
@ -47,21 +41,18 @@ __device__ void fast_sum(
|
||||
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
__syncthreads();
|
||||
if (tid < s) shr[tid] += shr[tid + s];
|
||||
if (tid < s)
|
||||
shr[tid] += shr[tid + s];
|
||||
}
|
||||
|
||||
if (tid == 0) dst[dst_id] = shr[0];
|
||||
if (tid == 0)
|
||||
dst[dst_id] = shr[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void fast_max(
|
||||
const size_t src_numel,
|
||||
const size_t el_to_sum_per_block,
|
||||
const size_t num_dims,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
__device__ void
|
||||
fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
const size_t num_dims, const size_t *info, const T *src, T *dst) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
|
||||
@ -88,21 +79,18 @@ __device__ void fast_max(
|
||||
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
__syncthreads();
|
||||
if (tid < s) shr[tid] = maxg(shr[tid], shr[tid + s]);
|
||||
if (tid < s)
|
||||
shr[tid] = maxg(shr[tid], shr[tid + s]);
|
||||
}
|
||||
|
||||
if (tid == 0) dst[dst_id] = shr[0];
|
||||
if (tid == 0)
|
||||
dst[dst_id] = shr[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void fast_min(
|
||||
const size_t src_numel,
|
||||
const size_t el_to_sum_per_block,
|
||||
const size_t num_dims,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
__device__ void
|
||||
fast_min(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
const size_t num_dims, const size_t *info, const T *src, T *dst) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
|
||||
@ -129,83 +117,69 @@ __device__ void fast_min(
|
||||
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
__syncthreads();
|
||||
if (tid < s) shr[tid] = ming(shr[tid], shr[tid + s]);
|
||||
if (tid < s)
|
||||
shr[tid] = ming(shr[tid], shr[tid + s]);
|
||||
}
|
||||
|
||||
if (tid == 0) dst[dst_id] = shr[0];
|
||||
if (tid == 0)
|
||||
dst[dst_id] = shr[0];
|
||||
}
|
||||
|
||||
#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, SUM_NAME) \
|
||||
extern "C" __global__ void MIN_NAME( \
|
||||
const size_t src_numel, \
|
||||
const size_t el_to_sum_per_block, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
fast_min(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
||||
} \
|
||||
extern "C" __global__ void MAX_NAME( \
|
||||
const size_t src_numel, \
|
||||
const size_t el_to_sum_per_block, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
fast_max(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
||||
} \
|
||||
extern "C" __global__ void SUM_NAME( \
|
||||
const size_t src_numel, \
|
||||
const size_t el_to_sum_per_block, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
fast_sum(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
||||
} \
|
||||
#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, SUM_NAME) \
|
||||
extern "C" __global__ void MIN_NAME( \
|
||||
const size_t src_numel, const size_t el_to_sum_per_block, \
|
||||
const size_t num_dims, const size_t *info, const TYPENAME *src, \
|
||||
TYPENAME *dst) { \
|
||||
fast_min(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
||||
} \
|
||||
extern "C" __global__ void MAX_NAME( \
|
||||
const size_t src_numel, const size_t el_to_sum_per_block, \
|
||||
const size_t num_dims, const size_t *info, const TYPENAME *src, \
|
||||
TYPENAME *dst) { \
|
||||
fast_max(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
||||
} \
|
||||
extern "C" __global__ void SUM_NAME( \
|
||||
const size_t src_numel, const size_t el_to_sum_per_block, \
|
||||
const size_t num_dims, const size_t *info, const TYPENAME *src, \
|
||||
TYPENAME *dst) { \
|
||||
fast_sum(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
||||
}
|
||||
|
||||
#define SUM_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t num_sum_dims, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
const size_t *sum_dims_l = info + 2*num_dims; \
|
||||
const size_t *sum_dims_s = info + 2*num_dims + num_sum_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
size_t dst_index = i; \
|
||||
for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \
|
||||
size_t stride = sum_dims_s[nd]; \
|
||||
size_t pre = dst_index / stride; \
|
||||
size_t post = dst_index % stride; \
|
||||
dst_index = (pre / sum_dims_l[nd]) * stride + post; \
|
||||
} \
|
||||
atomicAdd(out + dst_index, inp[i]); \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
size_t dst_index = i; \
|
||||
for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \
|
||||
size_t stride = sum_dims_s[nd]; \
|
||||
size_t pre = dst_index / stride; \
|
||||
size_t post = dst_index % stride; \
|
||||
dst_index = (pre / sum_dims_l[nd]) * stride + post; \
|
||||
} \
|
||||
atomicAdd(out + dst_index, inp[strided_i]); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
#define SUM_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, const size_t num_dims, const size_t num_sum_dims, \
|
||||
const size_t *info, const TYPENAME *inp, TYPENAME *out) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
const size_t *sum_dims_l = info + 2 * num_dims; \
|
||||
const size_t *sum_dims_s = info + 2 * num_dims + num_sum_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; \
|
||||
i += blockDim.x * gridDim.x) { \
|
||||
size_t dst_index = i; \
|
||||
for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \
|
||||
size_t stride = sum_dims_s[nd]; \
|
||||
size_t pre = dst_index / stride; \
|
||||
size_t post = dst_index % stride; \
|
||||
dst_index = (pre / sum_dims_l[nd]) * stride + post; \
|
||||
} \
|
||||
atomicAdd(out + dst_index, inp[i]); \
|
||||
} \
|
||||
} else { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; \
|
||||
i += blockDim.x * gridDim.x) { \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
size_t dst_index = i; \
|
||||
for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \
|
||||
size_t stride = sum_dims_s[nd]; \
|
||||
size_t pre = dst_index / stride; \
|
||||
size_t post = dst_index % stride; \
|
||||
dst_index = (pre / sum_dims_l[nd]) * stride + post; \
|
||||
} \
|
||||
atomicAdd(out + dst_index, inp[strided_i]); \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||
|
Reference in New Issue
Block a user