mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Remove one level of indirection for the binary and unary ops.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use crate::storage::{BinaryOp, UnaryOp};
|
||||
use crate::op::{BinaryOp, UnaryOp};
|
||||
use crate::{DType, Error, Result, Shape, StridedIndex};
|
||||
use gemm::{gemm, Parallelism};
|
||||
|
||||
|
@ -164,7 +164,7 @@ impl CudaStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<U: crate::storage::UnaryOp>(
|
||||
pub(crate) fn unary_impl<U: crate::op::UnaryOp>(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
@ -198,7 +198,7 @@ impl CudaStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn binary_impl<B: crate::storage::BinaryOp>(
|
||||
pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
|
@ -54,15 +54,11 @@ impl CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: crate::storage::UnaryOp>(
|
||||
&self,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
) -> Result<Self> {
|
||||
pub(crate) fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Shape, _: &[usize]) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn binary_impl<B: crate::storage::BinaryOp>(
|
||||
pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: &Shape,
|
||||
|
112
src/op.rs
112
src/op.rs
@ -18,3 +18,115 @@ pub(crate) enum Op {
|
||||
Sqrt(Tensor),
|
||||
// TODO: Support for custom ops.
|
||||
}
|
||||
|
||||
pub(crate) trait UnaryOp {
|
||||
const NAME: &'static str;
|
||||
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
|
||||
// contiguous case separately as it's easy to optimize things out there.
|
||||
const KERNEL_F32: &'static str;
|
||||
const KERNEL_F64: &'static str;
|
||||
fn f32(v1: f32) -> f32;
|
||||
fn f64(v1: f64) -> f64;
|
||||
}
|
||||
|
||||
pub(crate) trait BinaryOp {
|
||||
const NAME: &'static str;
|
||||
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
|
||||
// contiguous case separately as it's easy to optimize things out there.
|
||||
const KERNEL_F32: &'static str;
|
||||
const KERNEL_F64: &'static str;
|
||||
fn f32(v1: f32, v2: f32) -> f32;
|
||||
fn f64(v1: f64, v2: f64) -> f64;
|
||||
}
|
||||
|
||||
pub(crate) struct Add;
|
||||
pub(crate) struct Div;
|
||||
pub(crate) struct Mul;
|
||||
pub(crate) struct Sub;
|
||||
pub(crate) struct Neg;
|
||||
pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
|
||||
impl BinaryOp for Add {
|
||||
const NAME: &'static str = "add";
|
||||
const KERNEL_F32: &'static str = "badd_f32";
|
||||
const KERNEL_F64: &'static str = "badd_f64";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 + v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 + v2
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Sub {
|
||||
const NAME: &'static str = "sub";
|
||||
const KERNEL_F32: &'static str = "bsub_f32";
|
||||
const KERNEL_F64: &'static str = "bsub_f64";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 - v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 - v2
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Mul {
|
||||
const NAME: &'static str = "mul";
|
||||
const KERNEL_F32: &'static str = "bmul_f32";
|
||||
const KERNEL_F64: &'static str = "bmul_f64";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 * v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 * v2
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Div {
|
||||
const NAME: &'static str = "div";
|
||||
const KERNEL_F32: &'static str = "bdiv_f32";
|
||||
const KERNEL_F64: &'static str = "bdiv_f64";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 / v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 / v2
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOp for Neg {
|
||||
const NAME: &'static str = "neg";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
-v1
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
-v1
|
||||
}
|
||||
const KERNEL_F32: &'static str = "uneg_f32";
|
||||
const KERNEL_F64: &'static str = "uneg_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Sqr {
|
||||
const NAME: &'static str = "sqr";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1 * v1
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1 * v1
|
||||
}
|
||||
const KERNEL_F32: &'static str = "usqr_f32";
|
||||
const KERNEL_F64: &'static str = "usqr_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Sqrt {
|
||||
const NAME: &'static str = "sqrt";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.sqrt()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.sqrt()
|
||||
}
|
||||
const KERNEL_F32: &'static str = "usqrt_f32";
|
||||
const KERNEL_F64: &'static str = "usqrt_f64";
|
||||
}
|
||||
|
174
src/storage.rs
174
src/storage.rs
@ -1,4 +1,4 @@
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
|
||||
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Storage {
|
||||
@ -6,118 +6,6 @@ pub enum Storage {
|
||||
Cuda(CudaStorage),
|
||||
}
|
||||
|
||||
pub(crate) trait UnaryOp {
|
||||
const NAME: &'static str;
|
||||
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
|
||||
// contiguous case separately as it's easy to optimize things out there.
|
||||
const KERNEL_F32: &'static str;
|
||||
const KERNEL_F64: &'static str;
|
||||
fn f32(v1: f32) -> f32;
|
||||
fn f64(v1: f64) -> f64;
|
||||
}
|
||||
|
||||
pub(crate) trait BinaryOp {
|
||||
const NAME: &'static str;
|
||||
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
|
||||
// contiguous case separately as it's easy to optimize things out there.
|
||||
const KERNEL_F32: &'static str;
|
||||
const KERNEL_F64: &'static str;
|
||||
fn f32(v1: f32, v2: f32) -> f32;
|
||||
fn f64(v1: f64, v2: f64) -> f64;
|
||||
}
|
||||
|
||||
struct Add;
|
||||
struct Div;
|
||||
struct Mul;
|
||||
struct Sub;
|
||||
struct Neg;
|
||||
struct Sqr;
|
||||
struct Sqrt;
|
||||
|
||||
impl BinaryOp for Add {
|
||||
const NAME: &'static str = "add";
|
||||
const KERNEL_F32: &'static str = "badd_f32";
|
||||
const KERNEL_F64: &'static str = "badd_f64";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 + v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 + v2
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Sub {
|
||||
const NAME: &'static str = "sub";
|
||||
const KERNEL_F32: &'static str = "bsub_f32";
|
||||
const KERNEL_F64: &'static str = "bsub_f64";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 - v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 - v2
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Mul {
|
||||
const NAME: &'static str = "mul";
|
||||
const KERNEL_F32: &'static str = "bmul_f32";
|
||||
const KERNEL_F64: &'static str = "bmul_f64";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 * v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 * v2
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Div {
|
||||
const NAME: &'static str = "div";
|
||||
const KERNEL_F32: &'static str = "bdiv_f32";
|
||||
const KERNEL_F64: &'static str = "bdiv_f64";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 / v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 / v2
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOp for Neg {
|
||||
const NAME: &'static str = "neg";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
-v1
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
-v1
|
||||
}
|
||||
const KERNEL_F32: &'static str = "uneg_f32";
|
||||
const KERNEL_F64: &'static str = "uneg_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Sqr {
|
||||
const NAME: &'static str = "sqr";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1 * v1
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1 * v1
|
||||
}
|
||||
const KERNEL_F32: &'static str = "usqr_f32";
|
||||
const KERNEL_F64: &'static str = "usqr_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Sqrt {
|
||||
const NAME: &'static str = "sqrt";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.sqrt()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.sqrt()
|
||||
}
|
||||
const KERNEL_F32: &'static str = "usqrt_f32";
|
||||
const KERNEL_F64: &'static str = "usqrt_f64";
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
pub fn device(&self) -> Device {
|
||||
match self {
|
||||
@ -173,7 +61,11 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||
pub(crate) fn unary_impl<B: op::UnaryOp>(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
// TODO: Different code path for the contiguous case?
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
@ -188,7 +80,7 @@ impl Storage {
|
||||
}
|
||||
|
||||
// TODO: Support broadcasting?
|
||||
fn binary_impl<B: BinaryOp>(
|
||||
pub(crate) fn binary_impl<B: op::BinaryOp>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
@ -218,58 +110,6 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn add_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
self.binary_impl::<Add>(rhs, shape, lhs_stride, rhs_stride)
|
||||
}
|
||||
|
||||
pub(crate) fn sub_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
self.binary_impl::<Sub>(rhs, shape, lhs_stride, rhs_stride)
|
||||
}
|
||||
|
||||
pub(crate) fn mul_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
self.binary_impl::<Mul>(rhs, shape, lhs_stride, rhs_stride)
|
||||
}
|
||||
|
||||
pub(crate) fn div_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
self.binary_impl::<Div>(rhs, shape, lhs_stride, rhs_stride)
|
||||
}
|
||||
|
||||
pub(crate) fn neg_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||
self.unary_impl::<Neg>(shape, stride)
|
||||
}
|
||||
|
||||
pub(crate) fn sqr_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||
self.unary_impl::<Sqr>(shape, stride)
|
||||
}
|
||||
|
||||
pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||
self.unary_impl::<Sqrt>(shape, stride)
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
|
@ -43,10 +43,12 @@ impl std::fmt::Debug for Tensor {
|
||||
}
|
||||
|
||||
macro_rules! unary_op {
|
||||
($fn_name:ident, $op_name:ident, $impl_name:ident) => {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name(&self) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
let storage = self.storage.$impl_name(self.shape(), self.stride())?;
|
||||
let storage = self
|
||||
.storage
|
||||
.unary_impl::<crate::op::$op_name>(self.shape(), self.stride())?;
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
@ -61,12 +63,15 @@ macro_rules! unary_op {
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $op_name:ident, $impl_name:ident) => {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
let storage =
|
||||
self.storage
|
||||
.$impl_name(&rhs.storage, shape, self.stride(), rhs.stride())?;
|
||||
let storage = self.storage.binary_impl::<crate::op::$op_name>(
|
||||
&rhs.storage,
|
||||
shape,
|
||||
self.stride(),
|
||||
rhs.stride(),
|
||||
)?;
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
@ -211,14 +216,14 @@ impl Tensor {
|
||||
|
||||
// TODO: Also make an inplace version or a pre-allocated? This could be tricky
|
||||
// if this can create cycles in the compute graph.
|
||||
binary_op!(add, Add, add_impl);
|
||||
binary_op!(mul, Mul, mul_impl);
|
||||
binary_op!(sub, Sub, sub_impl);
|
||||
binary_op!(div, Div, div_impl);
|
||||
binary_op!(add, Add);
|
||||
binary_op!(mul, Mul);
|
||||
binary_op!(sub, Sub);
|
||||
binary_op!(div, Div);
|
||||
|
||||
unary_op!(neg, Neg, neg_impl);
|
||||
unary_op!(sqr, Sqr, sqr_impl);
|
||||
unary_op!(sqrt, Sqrt, sqrt_impl);
|
||||
unary_op!(neg, Neg);
|
||||
unary_op!(sqr, Sqr);
|
||||
unary_op!(sqrt, Sqrt);
|
||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||
if self.rank() != 0 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
|
Reference in New Issue
Block a user