From 790037390ca81a9cf32f35c03f514452d1366e4f Mon Sep 17 00:00:00 2001 From: yinqiwen Date: Sat, 23 Mar 2024 20:44:10 +0800 Subject: [PATCH] Add cast_bf16_x/cast_x_bf16 when CUDA_ARCH<800 but CUDA_VERSION >= 11000 (#1919) - it make possible to load bf16 models on T4(sm75) --- candle-kernels/src/cast.cu | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index 2fe85e1c..90f5e7ba 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -83,6 +83,18 @@ CAST_OP(double, __nv_bfloat16, cast_f64_bf16) CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8) CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16) CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) +#else +#include +#if CUDA_VERSION >= 11000 +CAST_OP(__nv_bfloat16, float, cast_bf16_f32) +CAST_OP(float, __nv_bfloat16, cast_f32_bf16) +CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8) +CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16) +CAST_THROUGH_OP(__nv_bfloat16, double, float, cast_bf16_f64) +CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) +CAST_THROUGH_OP(double, __nv_bfloat16, float, cast_f64_bf16) +CAST_THROUGH_OP(uint8_t, __nv_bfloat16, float, cast_u8_bf16) +#endif #endif #if __CUDA_ARCH__ >= 530