mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Cuda support for the mnist training. (#277)
* Cuda support for the mnist training. * min/max fix + testing. * Add the argmin/argmax tests. * More cuda support for argmin/argmax. * Cuda kernels for argmin and argmax.
This commit is contained in:
@ -244,7 +244,7 @@ impl ReduceIndex {
|
||||
val = s
|
||||
}
|
||||
}
|
||||
dst[unstr_index] = g(val, acc)
|
||||
dst_to_set[unstr_index] = g(val, acc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -438,6 +438,28 @@ trait Map2InPlace {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map1Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
wrap: W,
|
||||
) -> Result<S>;
|
||||
|
||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||
let out = match s {
|
||||
S::U8(s) => self.f(s, d, l, S::U8)?,
|
||||
S::U32(s) => self.f(s, d, l, S::U32)?,
|
||||
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
||||
S::F16(s) => self.f(s, d, l, S::F16)?,
|
||||
S::F32(s) => self.f(s, d, l, S::F32)?,
|
||||
S::F64(s) => self.f(s, d, l, S::F64)?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
trait Map2Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
@ -574,13 +596,14 @@ impl<'a> Map1 for Sum<'a> {
|
||||
}
|
||||
|
||||
struct FastReduce<'a>(&'a [usize], ReduceOp);
|
||||
impl<'a> Map1 for FastReduce<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
impl<'a> Map1Any for FastReduce<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
wrap: W,
|
||||
) -> Result<S> {
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
let src_el: usize = src_dims.iter().product();
|
||||
@ -615,20 +638,32 @@ impl<'a> Map1 for FastReduce<'a> {
|
||||
.htod_copy([dims.as_slice(), stride.as_slice()].concat())
|
||||
.w()?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let name = match self.1 {
|
||||
ReduceOp::Sum => "fast_sum",
|
||||
ReduceOp::Min => "fast_min",
|
||||
ReduceOp::Max => "fast_max",
|
||||
ReduceOp::ArgMin => "fast_argmin",
|
||||
ReduceOp::ArgMax => "fast_argmax",
|
||||
let (name, check_empty, return_index) = match self.1 {
|
||||
ReduceOp::Sum => ("fast_sum", false, false),
|
||||
ReduceOp::Min => ("fast_min", true, false),
|
||||
ReduceOp::Max => ("fast_max", true, false),
|
||||
ReduceOp::ArgMin => ("fast_argmin", true, true),
|
||||
ReduceOp::ArgMax => ("fast_argmax", true, true),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
|
||||
// SAFETY: filled in by the follow up kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
if return_index {
|
||||
// SAFETY: filled in by the follow up kernel.
|
||||
let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
|
||||
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(S::U32(out))
|
||||
} else {
|
||||
// SAFETY: filled in by the follow up kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(wrap(out))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -164,6 +164,278 @@ fn sum(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn min(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(2)?.to_vec3::<u32>()?,
|
||||
&[[[1], [1]], [[1], [2]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[[2, 1, 4], [1, 2, 8]]],
|
||||
);
|
||||
let data: Vec<u32> = (200..4000u32).collect();
|
||||
let tensor = Tensor::new(data.as_slice(), device)?;
|
||||
assert_eq!(tensor.min_keepdim(0)?.to_vec1::<u32>()?, &[200]);
|
||||
let tensor = tensor.reshape((1900, 2))?;
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(0)?.min_keepdim(1)?.to_vec2::<u32>()?,
|
||||
&[[200]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(1)?.min_keepdim(0)?.to_vec2::<u32>()?,
|
||||
&[[200]]
|
||||
);
|
||||
assert_eq!(tensor.min_keepdim(0)?.to_vec2::<u32>()?, &[[200, 201]]);
|
||||
|
||||
// Make the tensor non contiguous.
|
||||
let tensor = tensor.t()?.contiguous()?.t()?;
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(0)?.min_keepdim(1)?.to_vec2::<u32>()?,
|
||||
&[[200]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(1)?.min_keepdim(0)?.to_vec2::<u32>()?,
|
||||
&[[200]]
|
||||
);
|
||||
assert_eq!(tensor.min_keepdim(0)?.to_vec2::<u32>()?, &[[200, 201]]);
|
||||
|
||||
let t1 = tensor.reshape((190, 5, 4))?;
|
||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||
for tensor in [t1, t2] {
|
||||
assert_eq!(
|
||||
tensor
|
||||
.min_keepdim(0)?
|
||||
.min_keepdim(2)?
|
||||
.min_keepdim(1)?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[200]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[
|
||||
[200, 201, 202, 203],
|
||||
[204, 205, 206, 207],
|
||||
[208, 209, 210, 211],
|
||||
[212, 213, 214, 215],
|
||||
[216, 217, 218, 219]
|
||||
]]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn max(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(2)?.to_vec3::<u32>()?,
|
||||
&[[[4], [9]], [[7], [8]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[[3, 1, 7], [8, 5, 9]]],
|
||||
);
|
||||
let data: Vec<u32> = (200..4000u32).collect();
|
||||
let tensor = Tensor::new(data.as_slice(), device)?;
|
||||
assert_eq!(tensor.max_keepdim(0)?.to_vec1::<u32>()?, &[3999]);
|
||||
let tensor = tensor.reshape((1900, 2))?;
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(0)?.max_keepdim(1)?.to_vec2::<u32>()?,
|
||||
&[[3999]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(1)?.max_keepdim(0)?.to_vec2::<u32>()?,
|
||||
&[[3999]]
|
||||
);
|
||||
assert_eq!(tensor.max_keepdim(0)?.to_vec2::<u32>()?, &[[3998, 3999]]);
|
||||
|
||||
// Make the tensor non contiguous.
|
||||
let tensor = tensor.t()?.contiguous()?.t()?;
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(0)?.max_keepdim(1)?.to_vec2::<u32>()?,
|
||||
&[[3999]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(1)?.max_keepdim(0)?.to_vec2::<u32>()?,
|
||||
&[[3999]]
|
||||
);
|
||||
assert_eq!(tensor.max_keepdim(0)?.to_vec2::<u32>()?, &[[3998, 3999]]);
|
||||
|
||||
let t1 = tensor.reshape((190, 5, 4))?;
|
||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||
for tensor in [t1, t2] {
|
||||
assert_eq!(
|
||||
tensor
|
||||
.max_keepdim(0)?
|
||||
.max_keepdim(2)?
|
||||
.max_keepdim(1)?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[3999]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[
|
||||
[3980, 3981, 3982, 3983],
|
||||
[3984, 3985, 3986, 3987],
|
||||
[3988, 3989, 3990, 3991],
|
||||
[3992, 3993, 3994, 3995],
|
||||
[3996, 3997, 3998, 3999]
|
||||
]]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn argmin(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.argmin_keepdim(2)?.to_vec3::<u32>()?,
|
||||
&[[[1], [0]], [[1], [1]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.argmin_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[[1, 0, 0], [0, 1, 1]]],
|
||||
);
|
||||
let data: Vec<u32> = (200..4000u32).collect();
|
||||
let tensor = Tensor::new(data.as_slice(), device)?;
|
||||
assert_eq!(tensor.argmin_keepdim(0)?.to_vec1::<u32>()?, &[0]);
|
||||
let tensor = tensor.reshape((1900, 2))?;
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmin_keepdim(0)?
|
||||
.argmin_keepdim(1)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmin_keepdim(1)?
|
||||
.argmin_keepdim(0)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(tensor.argmin_keepdim(0)?.to_vec2::<u32>()?, &[[0, 0]]);
|
||||
|
||||
// Make the tensor non contiguous.
|
||||
let tensor = tensor.t()?.contiguous()?.t()?;
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmin_keepdim(0)?
|
||||
.argmin_keepdim(1)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmin_keepdim(1)?
|
||||
.argmin_keepdim(0)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(tensor.argmin_keepdim(0)?.to_vec2::<u32>()?, &[[0, 0]]);
|
||||
|
||||
let t1 = tensor.reshape((190, 5, 4))?;
|
||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||
for tensor in [t1, t2] {
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmin_keepdim(0)?
|
||||
.argmin_keepdim(2)?
|
||||
.argmin_keepdim(1)?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[0]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.argmin_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
]]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn argmax(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.argmax_keepdim(2)?.to_vec3::<u32>()?,
|
||||
&[[[2], [2]], [[2], [0]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.argmax_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[[0, 0, 1], [1, 0, 0]]],
|
||||
);
|
||||
let data: Vec<u32> = (200..4000u32).collect();
|
||||
let tensor = Tensor::new(data.as_slice(), device)?;
|
||||
assert_eq!(tensor.argmax_keepdim(0)?.to_vec1::<u32>()?, &[3799]);
|
||||
let tensor = tensor.reshape((1900, 2))?;
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmax_keepdim(0)?
|
||||
.argmax_keepdim(1)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmax_keepdim(1)?
|
||||
.argmax_keepdim(0)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(tensor.argmax_keepdim(0)?.to_vec2::<u32>()?, &[[1899, 1899]]);
|
||||
|
||||
// Make the tensor non contiguous.
|
||||
let tensor = tensor.t()?.contiguous()?.t()?;
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmax_keepdim(0)?
|
||||
.argmax_keepdim(1)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmax_keepdim(1)?
|
||||
.argmax_keepdim(0)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(tensor.argmax_keepdim(0)?.to_vec2::<u32>()?, &[[1899, 1899]]);
|
||||
|
||||
let t1 = tensor.reshape((190, 5, 4))?;
|
||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||
for tensor in [t1, t2] {
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmax_keepdim(0)?
|
||||
.argmax_keepdim(2)?
|
||||
.argmax_keepdim(1)?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[0]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.argmax_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[
|
||||
[189, 189, 189, 189],
|
||||
[189, 189, 189, 189],
|
||||
[189, 189, 189, 189],
|
||||
[189, 189, 189, 189],
|
||||
[189, 189, 189, 189],
|
||||
]]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn narrow(device: &Device) -> Result<()> {
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
@ -581,6 +853,10 @@ test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
|
||||
test_device!(cat, cat_cpu, cat_gpu);
|
||||
test_device!(sum, sum_cpu, sum_gpu);
|
||||
test_device!(min, min_cpu, min_gpu);
|
||||
test_device!(max, max_cpu, max_gpu);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
|
@ -142,17 +142,20 @@ fn training_loop<M: Model>(
|
||||
let dev = candle::Device::cuda_if_available(0)?;
|
||||
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?;
|
||||
let train_images = m.train_images.to_device(&dev)?;
|
||||
let train_labels = train_labels
|
||||
.to_dtype(DType::U32)?
|
||||
.unsqueeze(1)?
|
||||
.to_device(&dev)?;
|
||||
|
||||
let vs = VarStore::new(DType::F32, dev);
|
||||
let vs = VarStore::new(DType::F32, dev.clone());
|
||||
let model = M::new(vs.clone())?;
|
||||
|
||||
let all_vars = vs.all_vars();
|
||||
let all_vars = all_vars.iter().collect::<Vec<_>>();
|
||||
let sgd = candle_nn::SGD::new(&all_vars, learning_rate);
|
||||
let test_images = m.test_images;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?;
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
for epoch in 1..200 {
|
||||
let logits = model.forward(&train_images)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
|
@ -144,7 +144,8 @@ __device__ __forceinline__ double copysigng(double a, double b) { return copysig
|
||||
|
||||
__device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a, b); }
|
||||
__device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a, b); }
|
||||
|
||||
__device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b); }
|
||||
__device__ __forceinline__ uint8_t maxg(uint8_t a, uint8_t b) { return max(a, b); }
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
__device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); }
|
||||
__device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); }
|
||||
|
@ -125,7 +125,116 @@ fast_min(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
dst[dst_id] = shr[0];
|
||||
}
|
||||
|
||||
#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, SUM_NAME) \
|
||||
template <typename T>
|
||||
__device__ void
|
||||
fast_argmin(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
const size_t num_dims, const size_t *info, const T *src, uint32_t *dst) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
|
||||
__shared__ T shr[BLOCK_SIZE];
|
||||
__shared__ uint32_t shr_index[BLOCK_SIZE];
|
||||
size_t tid = threadIdx.x;
|
||||
size_t dst_id = blockIdx.x;
|
||||
|
||||
// Not sure how that works on uint32_t and uint8_t but it seems to do ok.
|
||||
shr[tid] = INFINITY;
|
||||
shr_index[tid] = 0xFFFFFFFF;
|
||||
bool not_set = true;
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||
if (not_set || src[strided_i] < shr[tid]) {
|
||||
shr[tid] = src[strided_i];
|
||||
// Assume that the reduction takes place over the last dimension which is contiguous.
|
||||
shr_index[tid] = idx % dims[num_dims - 1];
|
||||
not_set = false;
|
||||
}
|
||||
idx += blockDim.x;
|
||||
}
|
||||
|
||||
// Parallel reduction, see the slides:
|
||||
// https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf
|
||||
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
__syncthreads();
|
||||
if (tid < s && shr[tid + s] < shr[tid]) {
|
||||
shr[tid] = shr[tid + s];
|
||||
shr_index[tid] = shr_index[tid + s];
|
||||
}
|
||||
}
|
||||
|
||||
if (tid == 0)
|
||||
dst[dst_id] = shr_index[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void
|
||||
fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
const size_t num_dims, const size_t *info, const T *src, uint32_t *dst) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
|
||||
__shared__ T shr[BLOCK_SIZE];
|
||||
__shared__ uint32_t shr_index[BLOCK_SIZE];
|
||||
size_t tid = threadIdx.x;
|
||||
size_t dst_id = blockIdx.x;
|
||||
|
||||
shr[tid] = -INFINITY;
|
||||
shr_index[tid] = 0xFFFFFFFF;
|
||||
bool not_set = true;
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||
if (not_set || src[strided_i] > shr[tid]) {
|
||||
shr[tid] = src[strided_i];
|
||||
// Assume that the reduction takes place over the last dimension which is contiguous.
|
||||
shr_index[tid] = idx % dims[num_dims - 1];
|
||||
not_set = false;
|
||||
}
|
||||
idx += blockDim.x;
|
||||
}
|
||||
|
||||
// Parallel reduction, see the slides:
|
||||
// https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf
|
||||
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
|
||||
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
||||
__syncthreads();
|
||||
if (tid < s && shr[tid + s] > shr[tid]) {
|
||||
shr[tid] = shr[tid + s];
|
||||
shr_index[tid] = shr_index[tid + s];
|
||||
}
|
||||
}
|
||||
|
||||
if (tid == 0)
|
||||
dst[dst_id] = shr_index[0];
|
||||
}
|
||||
|
||||
#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, ARGMIN_NAME, ARGMAX_NAME, SUM_NAME) \
|
||||
extern "C" __global__ void ARGMIN_NAME( \
|
||||
const size_t src_numel, const size_t el_to_sum_per_block, \
|
||||
const size_t num_dims, const size_t *info, const TYPENAME *src, \
|
||||
uint32_t *dst) { \
|
||||
fast_argmin(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
||||
} \
|
||||
extern "C" __global__ void ARGMAX_NAME( \
|
||||
const size_t src_numel, const size_t el_to_sum_per_block, \
|
||||
const size_t num_dims, const size_t *info, const TYPENAME *src, \
|
||||
uint32_t *dst) { \
|
||||
fast_argmax(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
|
||||
} \
|
||||
extern "C" __global__ void MIN_NAME( \
|
||||
const size_t src_numel, const size_t el_to_sum_per_block, \
|
||||
const size_t num_dims, const size_t *info, const TYPENAME *src, \
|
||||
@ -183,18 +292,19 @@ fast_min(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_sum_bf16)
|
||||
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
SUM_OP(__half, sum_f16)
|
||||
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_sum_f16)
|
||||
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
||||
#endif
|
||||
|
||||
SUM_OP(float, sum_f32)
|
||||
SUM_OP(double, sum_f64)
|
||||
SUM_OP(uint32_t, sum_u32)
|
||||
|
||||
FAST_OP(float, fast_min_f32, fast_max_f32, fast_sum_f32)
|
||||
FAST_OP(double, fast_min_f64, fast_max_f64, fast_sum_f64)
|
||||
FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_sum_u32)
|
||||
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
|
||||
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)
|
||||
FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_argmin_u32, fast_argmax_u32, fast_sum_u32)
|
||||
FAST_OP(uint8_t, fast_min_u8, fast_max_u8, fast_argmin_u8, fast_argmax_u8, fast_sum_u8)
|
||||
|
Reference in New Issue
Block a user