Add cuda support for unary ops.

This commit is contained in:
laurent
2023-06-22 15:12:59 +01:00
parent b8f514d9c6
commit 5276755fb3
7 changed files with 101 additions and 23 deletions

View File

@ -8,6 +8,7 @@ fn main() -> Result<()> {
let y = Tensor::new(&[2f32, 7., 1., 8., 2.], &device)?; let y = Tensor::new(&[2f32, 7., 1., 8., 2.], &device)?;
let z = (y + x * 3.)?; let z = (y + x * 3.)?;
println!("{:?}", z.to_vec1::<f32>()?); println!("{:?}", z.to_vec1::<f32>()?);
println!("{:?}", z.sqrt()?.to_vec1::<f32>()?);
let x = Tensor::ones((3, 2), DType::F32, &device)?; let x = Tensor::ones((3, 2), DType::F32, &device)?;
println!("{:?}", x.to_vec2::<f32>()?); println!("{:?}", x.to_vec2::<f32>()?);
Ok(()) Ok(())

View File

@ -2,28 +2,16 @@
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
BINARY_OP(__half, badd_f16, x + y) BINARY_OP(__half, badd_f16, x + y)
BINARY_OP(__half, bdiv_f16, x / y)
BINARY_OP(__half, bmul_f16, x * y)
BINARY_OP(__half, bsub_f16, x - y)
#endif #endif
BINARY_OP(float, badd_f32, x + y) BINARY_OP(float, badd_f32, x + y)
BINARY_OP(double, badd_fwd_f64, x + y); BINARY_OP(double, badd_fwd_f64, x + y);
#if __CUDA_ARCH__ >= 530
BINARY_OP(__half, bdiv_f16, x / y)
#endif
BINARY_OP(float, bdiv_f32, x / y) BINARY_OP(float, bdiv_f32, x / y)
BINARY_OP(double, bdiv_f64, x / y); BINARY_OP(double, bdiv_f64, x / y);
#if __CUDA_ARCH__ >= 530
BINARY_OP(__half, bmul_f16, x * y)
#endif
BINARY_OP(float, bmul_f32, x * y) BINARY_OP(float, bmul_f32, x * y)
BINARY_OP(double, bmul_f64, x * y); BINARY_OP(double, bmul_f64, x * y);
#if __CUDA_ARCH__ >= 530
BINARY_OP(__half, bsub_f16, x - y)
#endif
BINARY_OP(float, bsub_f32, x - y) BINARY_OP(float, bsub_f32, x - y)
BINARY_OP(double, bsub_f64, x - y); BINARY_OP(double, bsub_f64, x - y);

View File

@ -1,3 +1,4 @@
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx")); pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx")); pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));

26
kernels/src/unary.cu Normal file
View File

@ -0,0 +1,26 @@
#include "cuda_utils.cuh"
#define UNARY_OP(TYPENAME, FN_NAME, FUNC) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const TYPENAME *inp, \
TYPENAME *out \
) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
TYPENAME x = inp ? inp[i] : out[i]; \
out[i] = FUNC; \
} \
} \
#if __CUDA_ARCH__ >= 530
UNARY_OP(__half, uneg_f16, -x)
UNARY_OP(__half, usqr_f16, x*x)
UNARY_OP(__half, usqrt_f16, sqrtg(x))
#endif
UNARY_OP(float, uneg_f32, -x)
UNARY_OP(float, uneg_f64, -x)
UNARY_OP(float, usqr_f32, x*x)
UNARY_OP(float, usqr_f64, x*x)
UNARY_OP(float, usqrt_f32, sqrtg(x))
UNARY_OP(float, usqrt_f64, sqrtg(x))

View File

@ -145,21 +145,53 @@ impl CudaStorage {
match self { match self {
Self::F32(arg) => { Self::F32(arg) => {
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?; let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
// SAFETY: if this function returns Ok(..), the kernel has been applied // SAFETY: Set later by running the kernel.
// and has set the initially unset memory.
let out = unsafe { dev.0.alloc::<f32>(elem_count) }?; let out = unsafe { dev.0.alloc::<f32>(elem_count) }?;
let params = (elem_count, arg, &out, mul as f32, add as f32); let params = (elem_count, arg, &out, mul as f32, add as f32);
// SAFETY: well, well, well... // SAFETY: ffi.
unsafe { func.launch(cfg, params) }?; unsafe { func.launch(cfg, params) }?;
Ok(Self::F32(out)) Ok(Self::F32(out))
} }
Self::F64(arg) => { Self::F64(arg) => {
let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?; let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
// SAFETY: if this function returns Ok(..), the kernel has been applied // SAFETY: Set later by running the kernel.
// and has set the initially unset memory.
let out = unsafe { dev.0.alloc::<f64>(elem_count) }?; let out = unsafe { dev.0.alloc::<f64>(elem_count) }?;
let params = (elem_count, arg, &out, mul, add); let params = (elem_count, arg, &out, mul, add);
// SAFETY: well, well, well... // SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
Ok(Self::F64(out))
}
}
}
pub(crate) fn unary_impl<U: crate::storage::UnaryOp>(
&self,
shape: &Shape,
stride: &[usize],
) -> Result<Self> {
if !shape.is_contiguous(stride) {
return Err(CudaError::RequiresContiguous { op: "affine" });
}
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dev = self.device();
match self {
Self::F32(arg) => {
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.0.alloc::<f32>(elem_count) }?;
let params = (elem_count, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
Ok(Self::F32(out))
}
Self::F64(arg) => {
let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.0.alloc::<f64>(elem_count) }?;
let params = (elem_count, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?; unsafe { func.launch(cfg, params) }?;
Ok(Self::F64(out)) Ok(Self::F64(out))
} }

View File

@ -54,6 +54,14 @@ impl CudaStorage {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
pub(crate) fn unary_impl<B: crate::storage::UnaryOp>(
&self,
_: &Shape,
_: &[usize],
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn binary_impl<B: crate::storage::BinaryOp>( pub(crate) fn binary_impl<B: crate::storage::BinaryOp>(
&self, &self,
_: &Self, _: &Self,

View File

@ -8,12 +8,18 @@ pub enum Storage {
pub(crate) trait UnaryOp { pub(crate) trait UnaryOp {
const NAME: &'static str; const NAME: &'static str;
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
// contiguous case separately as it's easy to optimize things out there.
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
fn f32(v1: f32) -> f32; fn f32(v1: f32) -> f32;
fn f64(v1: f64) -> f64; fn f64(v1: f64) -> f64;
} }
pub(crate) trait BinaryOp { pub(crate) trait BinaryOp {
const NAME: &'static str; const NAME: &'static str;
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
// contiguous case separately as it's easy to optimize things out there.
const KERNEL_F32: &'static str; const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str; const KERNEL_F64: &'static str;
fn f32(v1: f32, v2: f32) -> f32; fn f32(v1: f32, v2: f32) -> f32;
@ -84,6 +90,8 @@ impl UnaryOp for Neg {
fn f64(v1: f64) -> f64 { fn f64(v1: f64) -> f64 {
-v1 -v1
} }
const KERNEL_F32: &'static str = "uneg_f32";
const KERNEL_F64: &'static str = "uneg_f64";
} }
impl UnaryOp for Sqr { impl UnaryOp for Sqr {
@ -94,6 +102,8 @@ impl UnaryOp for Sqr {
fn f64(v1: f64) -> f64 { fn f64(v1: f64) -> f64 {
v1 * v1 v1 * v1
} }
const KERNEL_F32: &'static str = "usqr_f32";
const KERNEL_F64: &'static str = "usqr_f64";
} }
impl UnaryOp for Sqrt { impl UnaryOp for Sqrt {
@ -104,6 +114,8 @@ impl UnaryOp for Sqrt {
fn f64(v1: f64) -> f64 { fn f64(v1: f64) -> f64 {
v1.sqrt() v1.sqrt()
} }
const KERNEL_F32: &'static str = "usqrt_f32";
const KERNEL_F64: &'static str = "usqrt_f64";
} }
impl Storage { impl Storage {
@ -168,7 +180,10 @@ impl Storage {
let storage = storage.unary_impl::<B>(shape, stride)?; let storage = storage.unary_impl::<B>(shape, stride)?;
Ok(Self::Cpu(storage)) Ok(Self::Cpu(storage))
} }
Self::Cuda { .. } => todo!(), Self::Cuda(storage) => {
let storage = storage.unary_impl::<B>(shape, stride)?;
Ok(Self::Cuda(storage))
}
} }
} }
@ -269,7 +284,14 @@ impl Storage {
let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?; let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?;
Ok(Self::Cpu(storage)) Ok(Self::Cpu(storage))
} }
_ => todo!(), (Self::Cuda(_), Self::Cuda(_)) => {
todo!()
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "matmul",
}),
} }
} }
} }