mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add more binary kernels.
This commit is contained in:
@ -14,15 +14,10 @@ pub(crate) trait UnaryOp {
|
||||
|
||||
pub(crate) trait BinaryOp {
|
||||
const NAME: &'static str;
|
||||
const KERNEL_F32: &'static str;
|
||||
const KERNEL_F64: &'static str;
|
||||
fn f32(v1: f32, v2: f32) -> f32;
|
||||
fn f64(v1: f64, v2: f64) -> f64;
|
||||
fn cuda_impl(
|
||||
lhs: &CudaStorage,
|
||||
rhs: &CudaStorage,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<CudaStorage>;
|
||||
}
|
||||
|
||||
struct Add;
|
||||
@ -35,78 +30,50 @@ struct Sqrt;
|
||||
|
||||
impl BinaryOp for Add {
|
||||
const NAME: &'static str = "add";
|
||||
const KERNEL_F32: &'static str = "badd_f32";
|
||||
const KERNEL_F64: &'static str = "badd_f64";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 + v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 + v2
|
||||
}
|
||||
fn cuda_impl(
|
||||
lhs: &CudaStorage,
|
||||
rhs: &CudaStorage,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<CudaStorage> {
|
||||
Ok(lhs.add_impl(rhs, shape, lhs_stride, rhs_stride)?)
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Sub {
|
||||
const NAME: &'static str = "sub";
|
||||
const KERNEL_F32: &'static str = "bsub_f32";
|
||||
const KERNEL_F64: &'static str = "bsub_f64";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 - v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 - v2
|
||||
}
|
||||
fn cuda_impl(
|
||||
_: &CudaStorage,
|
||||
_: &CudaStorage,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &[usize],
|
||||
) -> Result<CudaStorage> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Mul {
|
||||
const NAME: &'static str = "mul";
|
||||
const KERNEL_F32: &'static str = "bmul_f32";
|
||||
const KERNEL_F64: &'static str = "bmul_f64";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 * v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 * v2
|
||||
}
|
||||
fn cuda_impl(
|
||||
_: &CudaStorage,
|
||||
_: &CudaStorage,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &[usize],
|
||||
) -> Result<CudaStorage> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Div {
|
||||
const NAME: &'static str = "div";
|
||||
const KERNEL_F32: &'static str = "bdiv_f32";
|
||||
const KERNEL_F64: &'static str = "bdiv_f64";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 / v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 / v2
|
||||
}
|
||||
fn cuda_impl(
|
||||
_: &CudaStorage,
|
||||
_: &CudaStorage,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &[usize],
|
||||
) -> Result<CudaStorage> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOp for Neg {
|
||||
@ -221,7 +188,7 @@ impl Storage {
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = B::cuda_impl(lhs, rhs, shape, lhs_stride, rhs_stride)?;
|
||||
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
|
Reference in New Issue
Block a user