mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add some cmp tests. (#233)
* Add some cmp tests. * Add the cuda kernels for comparison operations.
This commit is contained in:
@ -406,6 +406,30 @@ trait Map2 {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map2Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
layout1: &Layout,
|
||||
src2: &CudaSlice<T>,
|
||||
layout2: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<S>;
|
||||
|
||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct Clone;
|
||||
impl Map1 for Clone {
|
||||
fn f<T: DeviceRepr>(
|
||||
@ -747,6 +771,43 @@ impl<U: crate::op::BinaryOpT> Map2 for U {
|
||||
}
|
||||
}
|
||||
|
||||
struct Cmp(CmpOp);
|
||||
impl Map2Any for Cmp {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
lhs: &CudaSlice<T>,
|
||||
lhs_l: &Layout,
|
||||
rhs: &CudaSlice<T>,
|
||||
rhs_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<S> {
|
||||
let shape = lhs_l.shape();
|
||||
let dims = shape.dims();
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let dims_and_strides = dev
|
||||
.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||
.w()?;
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let name = match self.0 {
|
||||
CmpOp::Eq => "eq",
|
||||
CmpOp::Ne => "ne",
|
||||
CmpOp::Lt => "lt",
|
||||
CmpOp::Le => "le",
|
||||
CmpOp::Gt => "gt",
|
||||
CmpOp::Ge => "ge",
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::BINARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
|
||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(S::U8(out))
|
||||
}
|
||||
}
|
||||
|
||||
fn slice_src_and_dst<'a, T>(
|
||||
src: &'a CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
@ -1015,8 +1076,10 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
Err(CudaError::InternalError("TODO: implement cmp").into())
|
||||
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Cmp(op).map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||
|
@ -304,6 +304,18 @@ fn embeddings(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cmp(device: &Device) -> Result<()> {
|
||||
let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
|
||||
let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?;
|
||||
assert_eq!(t1.eq(&t2)?.to_vec2::<u8>()?, &[[0, 0], [0, 1], [1, 0]]);
|
||||
assert_eq!(t1.ne(&t2)?.to_vec2::<u8>()?, &[[1, 1], [1, 0], [0, 1]]);
|
||||
assert_eq!(t1.le(&t2)?.to_vec2::<u8>()?, &[[1, 0], [1, 1], [1, 1]]);
|
||||
assert_eq!(t1.lt(&t2)?.to_vec2::<u8>()?, &[[1, 0], [1, 0], [0, 1]]);
|
||||
assert_eq!(t1.gt(&t2)?.to_vec2::<u8>()?, &[[0, 1], [0, 0], [0, 0]]);
|
||||
assert_eq!(t1.ge(&t2)?.to_vec2::<u8>()?, &[[0, 1], [0, 1], [1, 0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select() -> Result<()> {
|
||||
// TODO: Test on cuda once the kernel is available.
|
||||
@ -498,5 +510,6 @@ test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(softmax, softmax_cpu, softmax_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu);
|
||||
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
||||
|
@ -6,6 +6,12 @@ BINARY_OP(__nv_bfloat16, badd_bf16, x + y)
|
||||
BINARY_OP(__nv_bfloat16, bdiv_bf16, x / y)
|
||||
BINARY_OP(__nv_bfloat16, bmul_bf16, x * y)
|
||||
BINARY_OP(__nv_bfloat16, bsub_bf16, x - y)
|
||||
BINARY_OP_OUT(__nv_bfloat16, uint8_t, eq_bf16, x == y)
|
||||
BINARY_OP_OUT(__nv_bfloat16, uint8_t, ne_bf16, x != y)
|
||||
BINARY_OP_OUT(__nv_bfloat16, uint8_t, lt_bf16, x < y)
|
||||
BINARY_OP_OUT(__nv_bfloat16, uint8_t, le_bf16, x <= y)
|
||||
BINARY_OP_OUT(__nv_bfloat16, uint8_t, gt_bf16, x > y)
|
||||
BINARY_OP_OUT(__nv_bfloat16, uint8_t, ge_bf16, x >= y)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -13,6 +19,12 @@ BINARY_OP(__half, badd_f16, x + y)
|
||||
BINARY_OP(__half, bdiv_f16, x / y)
|
||||
BINARY_OP(__half, bmul_f16, x * y)
|
||||
BINARY_OP(__half, bsub_f16, x - y)
|
||||
BINARY_OP_OUT(__half, uint8_t, eq_f16, x == y)
|
||||
BINARY_OP_OUT(__half, uint8_t, ne_f16, x != y)
|
||||
BINARY_OP_OUT(__half, uint8_t, lt_f16, x < y)
|
||||
BINARY_OP_OUT(__half, uint8_t, le_f16, x <= y)
|
||||
BINARY_OP_OUT(__half, uint8_t, gt_f16, x > y)
|
||||
BINARY_OP_OUT(__half, uint8_t, ge_f16, x >= y)
|
||||
#endif
|
||||
|
||||
BINARY_OP(float, badd_f32, x + y)
|
||||
@ -31,3 +43,33 @@ BINARY_OP(float, bsub_f32, x - y)
|
||||
BINARY_OP(double, bsub_f64, x - y);
|
||||
BINARY_OP(uint8_t, bsub_u8, x - y);
|
||||
BINARY_OP(uint32_t, bsub_u32, x - y);
|
||||
|
||||
BINARY_OP_OUT(float, uint8_t, eq_f32, x == y)
|
||||
BINARY_OP_OUT(double, uint8_t, eq_f64, x == y)
|
||||
BINARY_OP_OUT(uint8_t, uint8_t, eq_u8, x == y)
|
||||
BINARY_OP_OUT(uint32_t, uint8_t, eq_u32, x == y)
|
||||
|
||||
BINARY_OP_OUT(float, uint8_t, ne_f32, x != y)
|
||||
BINARY_OP_OUT(double, uint8_t, ne_f64, x != y)
|
||||
BINARY_OP_OUT(uint8_t, uint8_t, ne_u8, x != y)
|
||||
BINARY_OP_OUT(uint32_t, uint8_t, ne_u32, x != y)
|
||||
|
||||
BINARY_OP_OUT(float, uint8_t, lt_f32, x < y)
|
||||
BINARY_OP_OUT(double, uint8_t, lt_f64, x < y)
|
||||
BINARY_OP_OUT(uint8_t, uint8_t, lt_u8, x < y)
|
||||
BINARY_OP_OUT(uint32_t, uint8_t, lt_u32, x < y)
|
||||
|
||||
BINARY_OP_OUT(float, uint8_t, le_f32, x <= y)
|
||||
BINARY_OP_OUT(double, uint8_t, le_f64, x <= y)
|
||||
BINARY_OP_OUT(uint8_t, uint8_t, le_u8, x <= y)
|
||||
BINARY_OP_OUT(uint32_t, uint8_t, le_u32, x <= y)
|
||||
|
||||
BINARY_OP_OUT(float, uint8_t, gt_f32, x > y)
|
||||
BINARY_OP_OUT(double, uint8_t, gt_f64, x > y)
|
||||
BINARY_OP_OUT(uint8_t, uint8_t, gt_u8, x > y)
|
||||
BINARY_OP_OUT(uint32_t, uint8_t, gt_u32, x > y)
|
||||
|
||||
BINARY_OP_OUT(float, uint8_t, ge_f32, x >= y)
|
||||
BINARY_OP_OUT(double, uint8_t, ge_f64, x >= y)
|
||||
BINARY_OP_OUT(uint8_t, uint8_t, ge_u8, x >= y)
|
||||
BINARY_OP_OUT(uint32_t, uint8_t, ge_u32, x >= y)
|
||||
|
@ -1,13 +1,13 @@
|
||||
#include "cuda_utils.cuh"
|
||||
|
||||
#define BINARY_OP(TYPENAME, FN_NAME, FUNC) \
|
||||
#define BINARY_OP_OUT(TYPENAME, OUT_TYPENAME, FN_NAME, FUNC) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *dims_and_strides, \
|
||||
const TYPENAME *lhs, \
|
||||
const TYPENAME *rhs, \
|
||||
TYPENAME *out \
|
||||
OUT_TYPENAME *out \
|
||||
) { \
|
||||
const size_t *dims = dims_and_strides; \
|
||||
const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \
|
||||
@ -16,8 +16,8 @@ extern "C" __global__ void FN_NAME( \
|
||||
bool rhs_cont = is_contiguous(num_dims, dims, rhs_strides); \
|
||||
if (lhs_cont && rhs_cont) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = lhs ? lhs[i] : out[i]; \
|
||||
TYPENAME y = rhs ? rhs[i] : out[i]; \
|
||||
TYPENAME x = lhs[i]; \
|
||||
TYPENAME y = rhs[i]; \
|
||||
out[i] = FUNC; \
|
||||
} \
|
||||
} else if (lhs_cont) { \
|
||||
@ -29,8 +29,8 @@ extern "C" __global__ void FN_NAME( \
|
||||
rhs_i += i_dim * rhs_strides[d]; \
|
||||
tmp_i /= dims[d]; \
|
||||
} \
|
||||
TYPENAME x = lhs ? lhs[i] : out[i]; \
|
||||
TYPENAME y = rhs ? rhs[rhs_i] : out[i]; \
|
||||
TYPENAME x = lhs[i]; \
|
||||
TYPENAME y = rhs[rhs_i]; \
|
||||
out[i] = FUNC; \
|
||||
} \
|
||||
} else if (rhs_cont) { \
|
||||
@ -42,8 +42,8 @@ extern "C" __global__ void FN_NAME( \
|
||||
lhs_i += i_dim * lhs_strides[d]; \
|
||||
tmp_i /= dims[d]; \
|
||||
} \
|
||||
TYPENAME x = lhs ? lhs[lhs_i] : out[i]; \
|
||||
TYPENAME y = rhs ? rhs[i] : out[i]; \
|
||||
TYPENAME x = lhs[lhs_i]; \
|
||||
TYPENAME y = rhs[i]; \
|
||||
out[i] = FUNC; \
|
||||
} \
|
||||
} else { \
|
||||
@ -57,9 +57,13 @@ extern "C" __global__ void FN_NAME( \
|
||||
rhs_i += i_dim * rhs_strides[d]; \
|
||||
tmp_i /= dims[d]; \
|
||||
} \
|
||||
TYPENAME x = lhs ? lhs[lhs_i] : out[i]; \
|
||||
TYPENAME y = rhs ? rhs[rhs_i] : out[i]; \
|
||||
TYPENAME x = lhs[lhs_i]; \
|
||||
TYPENAME y = rhs[rhs_i]; \
|
||||
out[i] = FUNC; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
|
||||
#define BINARY_OP(TYPENAME, FN_NAME, FUNC) \
|
||||
BINARY_OP_OUT(TYPENAME, TYPENAME, FN_NAME, FUNC)
|
||||
|
Reference in New Issue
Block a user