diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index 42d04e80..ea611eba 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -27,7 +27,7 @@ extern "C" __global__ void FN_NAME( \ #if __CUDA_ARCH__ >= 800 CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16) -CAST_OP(__nv_bfloat16, uint8_t, cast_bf16_u8) +// CAST_OP(__nv_bfloat16, uint8_t, cast_bf16_u8) CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32) // CAST_OP(__nv_bfloat16, __half, cast_bf16_f16) CAST_OP(__nv_bfloat16, float, cast_bf16_f32)