mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Make it easier to use whisper samples from the repo. (#112)
* Make it easier to use samples from the repo. * Use f32 for accumulation in the f16/bf16 kernels.
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename A>
|
||||
__device__ void conv1d(
|
||||
const size_t src_numel,
|
||||
const size_t l_out,
|
||||
@ -30,7 +30,7 @@ __device__ void conv1d(
|
||||
const size_t dst_l = dst_i % l_out;
|
||||
|
||||
const size_t src_idx0 = b_idx * src_s[0];
|
||||
T d = 0;
|
||||
A d = 0;
|
||||
for (size_t offset = 0; offset < k_size; ++offset) {
|
||||
const size_t src_l_plus = stride * dst_l + offset;
|
||||
if (k_over_2 <= src_l_plus && src_l_plus < l_in + k_over_2) {
|
||||
@ -38,15 +38,15 @@ __device__ void conv1d(
|
||||
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
||||
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_l * src_s[2];
|
||||
const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + offset * k_s[2];
|
||||
d += src[src_idx] * kernel[k_idx];
|
||||
d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[dst_i] = d;
|
||||
dst[dst_i] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
|
||||
#define CONV1D_OP(TYPENAME, FN_NAME) \
|
||||
#define CONV1D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t src_numel, \
|
||||
const size_t num_dims, \
|
||||
@ -56,19 +56,19 @@ extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *kernel, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
conv1d(src_numel, num_dims, stride, info, src, kernel, dst); \
|
||||
conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
CONV1D_OP(__nv_bfloat16, conv1d_bf16)
|
||||
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
CONV1D_OP(__half, conv1d_f16)
|
||||
CONV1D_OP(__half, float, conv1d_f16)
|
||||
#endif
|
||||
|
||||
CONV1D_OP(float, conv1d_f32)
|
||||
CONV1D_OP(double, conv1d_f64)
|
||||
CONV1D_OP(uint8_t, conv1d_u8)
|
||||
CONV1D_OP(uint32_t, conv1d_u32)
|
||||
CONV1D_OP(float, float, conv1d_f32)
|
||||
CONV1D_OP(double, double, conv1d_f64)
|
||||
CONV1D_OP(uint8_t, uint8_t, conv1d_u8)
|
||||
CONV1D_OP(uint32_t, uint32_t, conv1d_u32)
|
||||
|
||||
|
Reference in New Issue
Block a user