mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
21 lines
1.1 KiB
Plaintext
21 lines
1.1 KiB
Plaintext
#include<stdint.h>
|
|
#include "cuda_fp16.h"
|
|
|
|
template<typename T>
|
|
__device__ void fill_with(T *buf, T value, const size_t numel) {
|
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
|
buf[i] = value;
|
|
}
|
|
}
|
|
extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); }
|
|
extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); }
|
|
extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); }
|
|
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
|
extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
|
|
extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
|
|
|
|
#if __CUDA_ARCH__ >= 800
|
|
#include <cuda_bf16.h>
|
|
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
|
|
#endif
|