mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Remove one level of indirection for the binary and unary ops.
This commit is contained in:
@ -43,10 +43,12 @@ impl std::fmt::Debug for Tensor {
|
||||
}
|
||||
|
||||
macro_rules! unary_op {
|
||||
($fn_name:ident, $op_name:ident, $impl_name:ident) => {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name(&self) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
let storage = self.storage.$impl_name(self.shape(), self.stride())?;
|
||||
let storage = self
|
||||
.storage
|
||||
.unary_impl::<crate::op::$op_name>(self.shape(), self.stride())?;
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
@ -61,12 +63,15 @@ macro_rules! unary_op {
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $op_name:ident, $impl_name:ident) => {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
let storage =
|
||||
self.storage
|
||||
.$impl_name(&rhs.storage, shape, self.stride(), rhs.stride())?;
|
||||
let storage = self.storage.binary_impl::<crate::op::$op_name>(
|
||||
&rhs.storage,
|
||||
shape,
|
||||
self.stride(),
|
||||
rhs.stride(),
|
||||
)?;
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
@ -211,14 +216,14 @@ impl Tensor {
|
||||
|
||||
// TODO: Also make an inplace version or a pre-allocated? This could be tricky
|
||||
// if this can create cycles in the compute graph.
|
||||
binary_op!(add, Add, add_impl);
|
||||
binary_op!(mul, Mul, mul_impl);
|
||||
binary_op!(sub, Sub, sub_impl);
|
||||
binary_op!(div, Div, div_impl);
|
||||
binary_op!(add, Add);
|
||||
binary_op!(mul, Mul);
|
||||
binary_op!(sub, Sub);
|
||||
binary_op!(div, Div);
|
||||
|
||||
unary_op!(neg, Neg, neg_impl);
|
||||
unary_op!(sqr, Sqr, sqr_impl);
|
||||
unary_op!(sqrt, Sqrt, sqrt_impl);
|
||||
unary_op!(neg, Neg);
|
||||
unary_op!(sqr, Sqr);
|
||||
unary_op!(sqrt, Sqrt);
|
||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||
if self.rank() != 0 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
|
Reference in New Issue
Block a user