fix: add missing gpu fill_* (#996)

This commit is contained in:
Gonzalo
2023-09-29 11:49:30 -03:00
committed by GitHub
parent 03348e2e6f
commit fc59bc31bf
2 changed files with 35 additions and 0 deletions

View File

@ -1,3 +1,4 @@
#include<stdint.h>
#include "cuda_fp16.h"
template<typename T>
@ -6,6 +7,14 @@ __device__ void fill_with(T *buf, T value, const size_t numel) {
buf[i] = value;
}
}
extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); }
extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); }
extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); }
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
#if __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
#endif