Add more binary kernels.

This commit is contained in:
laurent
2023-06-22 14:07:02 +01:00
parent 97fe1fac85
commit b8f514d9c6
9 changed files with 54 additions and 87 deletions

29
kernels/src/binary.cu Normal file
View File

@ -0,0 +1,29 @@
#include "binary_op_macros.cuh"
#if __CUDA_ARCH__ >= 530
BINARY_OP(__half, badd_f16, x + y)
#endif
BINARY_OP(float, badd_f32, x + y)
BINARY_OP(double, badd_fwd_f64, x + y);
#if __CUDA_ARCH__ >= 530
BINARY_OP(__half, bdiv_f16, x / y)
#endif
BINARY_OP(float, bdiv_f32, x / y)
BINARY_OP(double, bdiv_f64, x / y);
#if __CUDA_ARCH__ >= 530
BINARY_OP(__half, bmul_f16, x * y)
#endif
BINARY_OP(float, bmul_f32, x * y)
BINARY_OP(double, bmul_f64, x * y);
#if __CUDA_ARCH__ >= 530
BINARY_OP(__half, bsub_f16, x - y)
#endif
BINARY_OP(float, bsub_f32, x - y)
BINARY_OP(double, bsub_f64, x - y);

View File

@ -1,8 +0,0 @@
#include "binary_op_macros.cuh"
#if __CUDA_ARCH__ >= 530
BINARY_OP(__half, badd_f16, x + y)
#endif
BINARY_OP(float, badd_f32, x + y)
BINARY_OP(double, badd_fwd_f64, x + y);

View File

@ -1,8 +0,0 @@
#include "binary_op_macros.cuh"
#if __CUDA_ARCH__ >= 530
BINARY_OP(__half, bdiv_f16, x / y)
#endif
BINARY_OP(float, bdiv_f32, x / y)
BINARY_OP(double, bdiv_f64, x / y);

View File

@ -1,8 +0,0 @@
#include "binary_op_macros.cuh"
#if __CUDA_ARCH__ >= 530
BINARY_OP(__half, bmul_f16, x * y)
#endif
BINARY_OP(float, bmul_f32, x * y)
BINARY_OP(double, bmul_f64, x * y);

View File

@ -1,8 +0,0 @@
#include "binary_op_macros.cuh"
#if __CUDA_ARCH__ >= 530
BINARY_OP(__half, bsub_f16, x - y)
#endif
BINARY_OP(float, bsub_f32, x - y)
BINARY_OP(double, bsub_f64, x - y);

View File

@ -1,6 +1,3 @@
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
pub const BINARY_ADD: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_add.ptx"));
pub const BINARY_DIV: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_div.ptx"));
pub const BINARY_MUL: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_mul.ptx"));
pub const BINARY_SUB: &str = include_str!(concat!(env!("OUT_DIR"), "/binary_sub.ptx"));
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));

View File

@ -166,7 +166,7 @@ impl CudaStorage {
}
}
pub(crate) fn add_impl(
pub(crate) fn binary_impl<B: crate::storage::BinaryOp>(
&self,
rhs: &Self,
shape: &Shape,
@ -180,8 +180,8 @@ impl CudaStorage {
let dims_and_strides = [dims, lhs_stride, rhs_stride].concat();
match (self, rhs) {
(Self::F32(lhs), Self::F32(rhs)) => {
let func = dev.get_or_load_func("badd_f32", kernels::BINARY_ADD)?;
// SAFETY: Set later by running the add kernel.
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.0.alloc::<f32>(elem_count) }?;
let dims_and_strides = dev.0.htod_copy(dims_and_strides)?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
@ -190,8 +190,8 @@ impl CudaStorage {
Ok(Self::F32(out))
}
(Self::F64(lhs), Self::F64(rhs)) => {
// SAFETY: Set later by running the add kernel.
let func = dev.get_or_load_func("badd_f64", kernels::BINARY_ADD)?;
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?;
let out = unsafe { dev.0.alloc::<f64>(elem_count) }?;
let dims_and_strides = dev.0.htod_copy(dims_and_strides)?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
@ -200,7 +200,7 @@ impl CudaStorage {
Ok(Self::F64(out))
}
// The dtypes should have been checked at this point so this is an internal error.
_ => Err(CudaError::InternalError("dtype mismatch in add")),
_ => Err(CudaError::InternalError("dtype mismatch in binary op")),
}
}

View File

@ -54,7 +54,13 @@ impl CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn add_impl(&self, _: &Self, _: &Shape, _: &[usize], _: &[usize]) -> Result<Self> {
pub(crate) fn binary_impl<B: crate::storage::BinaryOp>(
&self,
_: &Self,
_: &Shape,
_: &[usize],
_: &[usize],
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
}

View File

@ -14,15 +14,10 @@ pub(crate) trait UnaryOp {
pub(crate) trait BinaryOp {
const NAME: &'static str;
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
fn f32(v1: f32, v2: f32) -> f32;
fn f64(v1: f64, v2: f64) -> f64;
fn cuda_impl(
lhs: &CudaStorage,
rhs: &CudaStorage,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<CudaStorage>;
}
struct Add;
@ -35,78 +30,50 @@ struct Sqrt;
impl BinaryOp for Add {
const NAME: &'static str = "add";
const KERNEL_F32: &'static str = "badd_f32";
const KERNEL_F64: &'static str = "badd_f64";
fn f32(v1: f32, v2: f32) -> f32 {
v1 + v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 + v2
}
fn cuda_impl(
lhs: &CudaStorage,
rhs: &CudaStorage,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<CudaStorage> {
Ok(lhs.add_impl(rhs, shape, lhs_stride, rhs_stride)?)
}
}
impl BinaryOp for Sub {
const NAME: &'static str = "sub";
const KERNEL_F32: &'static str = "bsub_f32";
const KERNEL_F64: &'static str = "bsub_f64";
fn f32(v1: f32, v2: f32) -> f32 {
v1 - v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 - v2
}
fn cuda_impl(
_: &CudaStorage,
_: &CudaStorage,
_: &Shape,
_: &[usize],
_: &[usize],
) -> Result<CudaStorage> {
todo!()
}
}
impl BinaryOp for Mul {
const NAME: &'static str = "mul";
const KERNEL_F32: &'static str = "bmul_f32";
const KERNEL_F64: &'static str = "bmul_f64";
fn f32(v1: f32, v2: f32) -> f32 {
v1 * v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 * v2
}
fn cuda_impl(
_: &CudaStorage,
_: &CudaStorage,
_: &Shape,
_: &[usize],
_: &[usize],
) -> Result<CudaStorage> {
todo!()
}
}
impl BinaryOp for Div {
const NAME: &'static str = "div";
const KERNEL_F32: &'static str = "bdiv_f32";
const KERNEL_F64: &'static str = "bdiv_f64";
fn f32(v1: f32, v2: f32) -> f32 {
v1 / v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 / v2
}
fn cuda_impl(
_: &CudaStorage,
_: &CudaStorage,
_: &Shape,
_: &[usize],
_: &[usize],
) -> Result<CudaStorage> {
todo!()
}
}
impl UnaryOp for Neg {
@ -221,7 +188,7 @@ impl Storage {
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = B::cuda_impl(lhs, rhs, shape, lhs_stride, rhs_stride)?;
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => {