mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add cuda support for unary ops.
This commit is contained in:
@ -8,12 +8,18 @@ pub enum Storage {
|
||||
|
||||
pub(crate) trait UnaryOp {
|
||||
const NAME: &'static str;
|
||||
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
|
||||
// contiguous case separately as it's easy to optimize things out there.
|
||||
const KERNEL_F32: &'static str;
|
||||
const KERNEL_F64: &'static str;
|
||||
fn f32(v1: f32) -> f32;
|
||||
fn f64(v1: f64) -> f64;
|
||||
}
|
||||
|
||||
pub(crate) trait BinaryOp {
|
||||
const NAME: &'static str;
|
||||
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
|
||||
// contiguous case separately as it's easy to optimize things out there.
|
||||
const KERNEL_F32: &'static str;
|
||||
const KERNEL_F64: &'static str;
|
||||
fn f32(v1: f32, v2: f32) -> f32;
|
||||
@ -84,6 +90,8 @@ impl UnaryOp for Neg {
|
||||
fn f64(v1: f64) -> f64 {
|
||||
-v1
|
||||
}
|
||||
const KERNEL_F32: &'static str = "uneg_f32";
|
||||
const KERNEL_F64: &'static str = "uneg_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Sqr {
|
||||
@ -94,6 +102,8 @@ impl UnaryOp for Sqr {
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1 * v1
|
||||
}
|
||||
const KERNEL_F32: &'static str = "usqr_f32";
|
||||
const KERNEL_F64: &'static str = "usqr_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Sqrt {
|
||||
@ -104,6 +114,8 @@ impl UnaryOp for Sqrt {
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.sqrt()
|
||||
}
|
||||
const KERNEL_F32: &'static str = "usqrt_f32";
|
||||
const KERNEL_F64: &'static str = "usqrt_f64";
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
@ -168,7 +180,10 @@ impl Storage {
|
||||
let storage = storage.unary_impl::<B>(shape, stride)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda { .. } => todo!(),
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.unary_impl::<B>(shape, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -269,7 +284,14 @@ impl Storage {
|
||||
let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
_ => todo!(),
|
||||
(Self::Cuda(_), Self::Cuda(_)) => {
|
||||
todo!()
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "matmul",
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user