mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add some first binary op (add).
This commit is contained in:
@ -16,6 +16,13 @@ pub(crate) trait BinaryOp {
|
||||
const NAME: &'static str;
|
||||
fn f32(v1: f32, v2: f32) -> f32;
|
||||
fn f64(v1: f64, v2: f64) -> f64;
|
||||
fn cuda_impl(
|
||||
lhs: &CudaStorage,
|
||||
rhs: &CudaStorage,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<CudaStorage>;
|
||||
}
|
||||
|
||||
struct Add;
|
||||
@ -34,6 +41,15 @@ impl BinaryOp for Add {
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 + v2
|
||||
}
|
||||
fn cuda_impl(
|
||||
lhs: &CudaStorage,
|
||||
rhs: &CudaStorage,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<CudaStorage> {
|
||||
Ok(lhs.add_impl(rhs, shape, lhs_stride, rhs_stride)?)
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Sub {
|
||||
@ -44,6 +60,15 @@ impl BinaryOp for Sub {
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 - v2
|
||||
}
|
||||
fn cuda_impl(
|
||||
_: &CudaStorage,
|
||||
_: &CudaStorage,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &[usize],
|
||||
) -> Result<CudaStorage> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Mul {
|
||||
@ -54,6 +79,15 @@ impl BinaryOp for Mul {
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 * v2
|
||||
}
|
||||
fn cuda_impl(
|
||||
_: &CudaStorage,
|
||||
_: &CudaStorage,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &[usize],
|
||||
) -> Result<CudaStorage> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Div {
|
||||
@ -64,6 +98,15 @@ impl BinaryOp for Div {
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 / v2
|
||||
}
|
||||
fn cuda_impl(
|
||||
_: &CudaStorage,
|
||||
_: &CudaStorage,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &[usize],
|
||||
) -> Result<CudaStorage> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOp for Neg {
|
||||
@ -177,7 +220,10 @@ impl Storage {
|
||||
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda { .. }, Self::Cuda { .. }) => todo!(),
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = B::cuda_impl(lhs, rhs, shape, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
|
Reference in New Issue
Block a user