mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
More u32 support.
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
#include "binary_op_macros.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
BINARY_OP(__half, badd_f16, x + y)
|
||||
@ -8,10 +9,14 @@ BINARY_OP(__half, bsub_f16, x - y)
|
||||
#endif
|
||||
|
||||
BINARY_OP(float, badd_f32, x + y)
|
||||
BINARY_OP(double, badd_fwd_f64, x + y);
|
||||
BINARY_OP(double, badd_f64, x + y);
|
||||
BINARY_OP(uint32_t, badd_u32, x + y);
|
||||
BINARY_OP(float, bdiv_f32, x / y)
|
||||
BINARY_OP(double, bdiv_f64, x / y);
|
||||
BINARY_OP(uint32_t, bdiv_u32, x / y);
|
||||
BINARY_OP(float, bmul_f32, x * y)
|
||||
BINARY_OP(double, bmul_f64, x * y);
|
||||
BINARY_OP(uint32_t, bmul_u32, x * y);
|
||||
BINARY_OP(float, bsub_f32, x - y)
|
||||
BINARY_OP(double, bsub_f64, x - y);
|
||||
BINARY_OP(uint32_t, bsub_u32, x - y);
|
||||
|
@ -358,6 +358,15 @@ impl CudaStorage {
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F64(out)
|
||||
}
|
||||
(CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => {
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?;
|
||||
let out = unsafe { dev.alloc::<u32>(elem_count) }?;
|
||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
// The dtypes should have been checked at this point so this is an internal error.
|
||||
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
|
||||
};
|
||||
|
@ -47,6 +47,7 @@ pub(crate) trait BinaryOp {
|
||||
// contiguous case separately as it's easy to optimize things out there.
|
||||
const KERNEL_F32: &'static str;
|
||||
const KERNEL_F64: &'static str;
|
||||
const KERNEL_U32: &'static str;
|
||||
fn f32(v1: f32, v2: f32) -> f32;
|
||||
fn f64(v1: f64, v2: f64) -> f64;
|
||||
fn u32(v1: u32, v2: u32) -> u32;
|
||||
@ -65,6 +66,7 @@ impl BinaryOp for Add {
|
||||
const NAME: &'static str = "add";
|
||||
const KERNEL_F32: &'static str = "badd_f32";
|
||||
const KERNEL_F64: &'static str = "badd_f64";
|
||||
const KERNEL_U32: &'static str = "badd_u32";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 + v2
|
||||
}
|
||||
@ -80,6 +82,7 @@ impl BinaryOp for Sub {
|
||||
const NAME: &'static str = "sub";
|
||||
const KERNEL_F32: &'static str = "bsub_f32";
|
||||
const KERNEL_F64: &'static str = "bsub_f64";
|
||||
const KERNEL_U32: &'static str = "bsub_u32";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 - v2
|
||||
}
|
||||
@ -95,6 +98,7 @@ impl BinaryOp for Mul {
|
||||
const NAME: &'static str = "mul";
|
||||
const KERNEL_F32: &'static str = "bmul_f32";
|
||||
const KERNEL_F64: &'static str = "bmul_f64";
|
||||
const KERNEL_U32: &'static str = "bmul_u32";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 * v2
|
||||
}
|
||||
@ -110,6 +114,7 @@ impl BinaryOp for Div {
|
||||
const NAME: &'static str = "div";
|
||||
const KERNEL_F32: &'static str = "bdiv_f32";
|
||||
const KERNEL_F64: &'static str = "bdiv_f64";
|
||||
const KERNEL_U32: &'static str = "bdiv_u32";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 / v2
|
||||
}
|
||||
|
Reference in New Issue
Block a user