Add cuda support for unary ops.

This commit is contained in:
laurent
2023-06-22 15:12:59 +01:00
parent b8f514d9c6
commit 5276755fb3
7 changed files with 101 additions and 23 deletions

View File

@ -8,12 +8,18 @@ pub enum Storage {
pub(crate) trait UnaryOp {
const NAME: &'static str;
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
// contiguous case separately as it's easy to optimize things out there.
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
fn f32(v1: f32) -> f32;
fn f64(v1: f64) -> f64;
}
pub(crate) trait BinaryOp {
const NAME: &'static str;
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
// contiguous case separately as it's easy to optimize things out there.
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
fn f32(v1: f32, v2: f32) -> f32;
@ -84,6 +90,8 @@ impl UnaryOp for Neg {
fn f64(v1: f64) -> f64 {
-v1
}
const KERNEL_F32: &'static str = "uneg_f32";
const KERNEL_F64: &'static str = "uneg_f64";
}
impl UnaryOp for Sqr {
@ -94,6 +102,8 @@ impl UnaryOp for Sqr {
fn f64(v1: f64) -> f64 {
v1 * v1
}
const KERNEL_F32: &'static str = "usqr_f32";
const KERNEL_F64: &'static str = "usqr_f64";
}
impl UnaryOp for Sqrt {
@ -104,6 +114,8 @@ impl UnaryOp for Sqrt {
fn f64(v1: f64) -> f64 {
v1.sqrt()
}
const KERNEL_F32: &'static str = "usqrt_f32";
const KERNEL_F64: &'static str = "usqrt_f64";
}
impl Storage {
@ -168,7 +180,10 @@ impl Storage {
let storage = storage.unary_impl::<B>(shape, stride)?;
Ok(Self::Cpu(storage))
}
Self::Cuda { .. } => todo!(),
Self::Cuda(storage) => {
let storage = storage.unary_impl::<B>(shape, stride)?;
Ok(Self::Cuda(storage))
}
}
}
@ -269,7 +284,14 @@ impl Storage {
let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?;
Ok(Self::Cpu(storage))
}
_ => todo!(),
(Self::Cuda(_), Self::Cuda(_)) => {
todo!()
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "matmul",
}),
}
}
}