mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Op refactor (#208)
* Add the binary and unary op enums to factorize some code. * Bugfix.
This commit is contained in:
@ -1,4 +1,4 @@
|
|||||||
use crate::op::{CmpOp, ReduceOp};
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
|
|
||||||
pub(crate) trait BackendStorage: Sized {
|
pub(crate) trait BackendStorage: Sized {
|
||||||
@ -25,10 +25,9 @@ pub(crate) trait BackendStorage: Sized {
|
|||||||
|
|
||||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
|
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
|
||||||
|
|
||||||
fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self>;
|
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>;
|
||||||
|
|
||||||
fn binary_impl<B: crate::op::BinaryOp>(&self, _: &Self, _: &Layout, _: &Layout)
|
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
|
||||||
-> Result<Self>;
|
|
||||||
|
|
||||||
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
|
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::op::{Op, ReduceOp};
|
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
|
||||||
use crate::{Error, Result, Tensor, TensorId};
|
use crate::{Error, Result, Tensor, TensorId};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -39,10 +39,7 @@ impl Tensor {
|
|||||||
kernel: rhs,
|
kernel: rhs,
|
||||||
..
|
..
|
||||||
}
|
}
|
||||||
| Op::Add(lhs, rhs)
|
| Op::Binary(lhs, rhs, _)
|
||||||
| Op::Mul(lhs, rhs)
|
|
||||||
| Op::Sub(lhs, rhs)
|
|
||||||
| Op::Div(lhs, rhs)
|
|
||||||
| Op::Embedding(lhs, rhs)
|
| Op::Embedding(lhs, rhs)
|
||||||
| Op::Matmul(lhs, rhs) => {
|
| Op::Matmul(lhs, rhs) => {
|
||||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||||
@ -74,17 +71,8 @@ impl Tensor {
|
|||||||
| Op::Transpose(node, _, _)
|
| Op::Transpose(node, _, _)
|
||||||
| Op::Narrow(node, _, _, _)
|
| Op::Narrow(node, _, _, _)
|
||||||
| Op::Softmax(node, _)
|
| Op::Softmax(node, _)
|
||||||
| Op::Sqr(node)
|
| Op::Unary(node, _)
|
||||||
| Op::Sqrt(node)
|
| Op::Elu(node, _) => {
|
||||||
| Op::Gelu(node)
|
|
||||||
| Op::Relu(node)
|
|
||||||
| Op::Elu(node, _)
|
|
||||||
| Op::Exp(node)
|
|
||||||
| Op::Log(node)
|
|
||||||
| Op::Sin(node)
|
|
||||||
| Op::Cos(node)
|
|
||||||
| Op::Abs(node)
|
|
||||||
| Op::Neg(node) => {
|
|
||||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
nodes
|
nodes
|
||||||
@ -118,19 +106,19 @@ impl Tensor {
|
|||||||
// this is out of scope.
|
// this is out of scope.
|
||||||
if let Some(op) = node.op() {
|
if let Some(op) = node.op() {
|
||||||
match op {
|
match op {
|
||||||
Op::Add(lhs, rhs) => {
|
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||||
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
|
||||||
}
|
}
|
||||||
Op::Sub(lhs, rhs) => {
|
Op::Binary(lhs, rhs, BinaryOp::Sub) => {
|
||||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||||
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
|
*rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
|
||||||
}
|
}
|
||||||
Op::Mul(lhs, rhs) => {
|
Op::Binary(lhs, rhs, BinaryOp::Mul) => {
|
||||||
let lhs_grad = grad.mul(rhs)?;
|
let lhs_grad = grad.mul(rhs)?;
|
||||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||||
@ -138,7 +126,7 @@ impl Tensor {
|
|||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
Op::Div(lhs, rhs) => {
|
Op::Binary(lhs, rhs, BinaryOp::Div) => {
|
||||||
let lhs_grad = grad.div(rhs)?;
|
let lhs_grad = grad.div(rhs)?;
|
||||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||||
@ -221,24 +209,26 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Log(arg) => {
|
Op::Unary(arg, UnaryOp::Log) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&(&grad * *node)?)?
|
*sum_grad = sum_grad.add(&(&grad * *node)?)?
|
||||||
}
|
}
|
||||||
Op::Sin(arg) => {
|
Op::Unary(arg, UnaryOp::Sin) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
|
*sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
|
||||||
}
|
}
|
||||||
Op::Cos(arg) => {
|
Op::Unary(arg, UnaryOp::Cos) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
||||||
}
|
}
|
||||||
Op::Abs(_args) => return Err(Error::BackwardNotSupported { op: "abs" }),
|
Op::Unary(_, UnaryOp::Abs) => {
|
||||||
Op::Exp(arg) => {
|
return Err(Error::BackwardNotSupported { op: "abs" })
|
||||||
|
}
|
||||||
|
Op::Unary(arg, UnaryOp::Exp) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&(&grad / arg)?)?
|
*sum_grad = sum_grad.add(&(&grad / arg)?)?
|
||||||
}
|
}
|
||||||
Op::Neg(arg) => {
|
Op::Unary(arg, UnaryOp::Neg) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.sub(&grad)?
|
*sum_grad = sum_grad.sub(&grad)?
|
||||||
}
|
}
|
||||||
@ -276,15 +266,19 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }),
|
Op::Unary(_, UnaryOp::Gelu) => {
|
||||||
Op::Relu(_) => return Err(Error::BackwardNotSupported { op: "relu" }),
|
return Err(Error::BackwardNotSupported { op: "gelu" })
|
||||||
|
}
|
||||||
|
Op::Unary(_, UnaryOp::Relu) => {
|
||||||
|
return Err(Error::BackwardNotSupported { op: "relu" })
|
||||||
|
}
|
||||||
Op::Elu(..) => return Err(Error::BackwardNotSupported { op: "elu" }),
|
Op::Elu(..) => return Err(Error::BackwardNotSupported { op: "elu" }),
|
||||||
Op::Sqr(arg) => {
|
Op::Unary(arg, UnaryOp::Sqr) => {
|
||||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Sqrt(arg) => {
|
Op::Unary(arg, UnaryOp::Sqrt) => {
|
||||||
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
|
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
use crate::op::{BinaryOp, CmpOp, ReduceOp, UnaryOp};
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
use crate::{DType, Error, Layout, Result, Shape, WithDType};
|
use crate::{DType, Error, Layout, Result, Shape, WithDType};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
|
||||||
@ -1158,7 +1158,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||||
match self {
|
match self {
|
||||||
Self::BF16(storage) => {
|
Self::BF16(storage) => {
|
||||||
if B::BF16_VEC {
|
if B::BF16_VEC {
|
||||||
@ -1207,7 +1207,12 @@ impl BackendStorage for CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn binary_impl<B: BinaryOp>(&self, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
fn binary_impl<B: BinaryOpT>(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
) -> Result<Self> {
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Self::BF16(lhs), Self::BF16(rhs)) => {
|
(Self::BF16(lhs), Self::BF16(rhs)) => {
|
||||||
let data = if B::BF16_VEC {
|
let data = if B::BF16_VEC {
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
use crate::op::{CmpOp, ReduceOp};
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
|
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||||
use candle_kernels as kernels;
|
use candle_kernels as kernels;
|
||||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||||
@ -573,7 +573,7 @@ impl<'a> Map1 for FastReduce<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<U: crate::op::UnaryOp> Map1 for U {
|
impl<U: UnaryOpT> Map1 for U {
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
&self,
|
&self,
|
||||||
src: &CudaSlice<T>,
|
src: &CudaSlice<T>,
|
||||||
@ -716,7 +716,7 @@ impl<'a> Map2 for WhereCond<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<U: crate::op::BinaryOp> Map2 for U {
|
impl<U: crate::op::BinaryOpT> Map2 for U {
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
&self,
|
&self,
|
||||||
lhs: &CudaSlice<T>,
|
lhs: &CudaSlice<T>,
|
||||||
@ -976,13 +976,13 @@ impl BackendStorage for CudaStorage {
|
|||||||
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
|
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
fn unary_impl<U: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let slice = U::V.map(&self.slice, &device, layout)?;
|
let slice = U::V.map(&self.slice, &device, layout)?;
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn binary_impl<B: crate::op::BinaryOp>(
|
fn binary_impl<B: BinaryOpT>(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
lhs_l: &Layout,
|
lhs_l: &Layout,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
use crate::op::{CmpOp, ReduceOp};
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -57,16 +57,11 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self> {
|
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn binary_impl<B: crate::op::BinaryOp>(
|
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
&self,
|
|
||||||
_: &Self,
|
|
||||||
_: &Layout,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,12 +19,34 @@ pub enum ReduceOp {
|
|||||||
Max,
|
Max,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// These ops return the same type as their input type.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum BinaryOp {
|
||||||
|
Add,
|
||||||
|
Mul,
|
||||||
|
Sub,
|
||||||
|
Div,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unary ops with no argument
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum UnaryOp {
|
||||||
|
Exp,
|
||||||
|
Log,
|
||||||
|
Sin,
|
||||||
|
Cos,
|
||||||
|
Abs,
|
||||||
|
Neg,
|
||||||
|
Sqr,
|
||||||
|
Sqrt,
|
||||||
|
Gelu,
|
||||||
|
Relu,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) enum Op {
|
pub(crate) enum Op {
|
||||||
Add(Tensor, Tensor),
|
Binary(Tensor, Tensor, BinaryOp),
|
||||||
Mul(Tensor, Tensor),
|
Unary(Tensor, UnaryOp),
|
||||||
Sub(Tensor, Tensor),
|
|
||||||
Div(Tensor, Tensor),
|
|
||||||
Cmp(Tensor, CmpOp),
|
Cmp(Tensor, CmpOp),
|
||||||
Reduce(Tensor, ReduceOp, Vec<usize>),
|
Reduce(Tensor, ReduceOp, Vec<usize>),
|
||||||
Matmul(Tensor, Tensor),
|
Matmul(Tensor, Tensor),
|
||||||
@ -49,26 +71,16 @@ pub(crate) enum Op {
|
|||||||
},
|
},
|
||||||
ToDType(Tensor),
|
ToDType(Tensor),
|
||||||
Broadcast(Tensor),
|
Broadcast(Tensor),
|
||||||
Exp(Tensor),
|
|
||||||
Log(Tensor),
|
|
||||||
Sin(Tensor),
|
|
||||||
Cos(Tensor),
|
|
||||||
Abs(Tensor),
|
|
||||||
Narrow(Tensor, usize, usize, usize),
|
Narrow(Tensor, usize, usize, usize),
|
||||||
Neg(Tensor),
|
|
||||||
Reshape(Tensor),
|
Reshape(Tensor),
|
||||||
Softmax(Tensor, usize),
|
Softmax(Tensor, usize),
|
||||||
Sqr(Tensor),
|
|
||||||
Sqrt(Tensor),
|
|
||||||
ToDevice(Tensor),
|
ToDevice(Tensor),
|
||||||
Transpose(Tensor, usize, usize),
|
Transpose(Tensor, usize, usize),
|
||||||
Gelu(Tensor),
|
|
||||||
Relu(Tensor),
|
|
||||||
Elu(Tensor, f64),
|
Elu(Tensor, f64),
|
||||||
// TODO: Support for custom ops.
|
// TODO: Support for custom ops.
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) trait UnaryOp {
|
pub(crate) trait UnaryOpT {
|
||||||
const NAME: &'static str;
|
const NAME: &'static str;
|
||||||
const KERNEL: &'static str;
|
const KERNEL: &'static str;
|
||||||
const V: Self;
|
const V: Self;
|
||||||
@ -91,7 +103,7 @@ pub(crate) trait UnaryOp {
|
|||||||
fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
|
fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) trait BinaryOp {
|
pub(crate) trait BinaryOpT {
|
||||||
const NAME: &'static str;
|
const NAME: &'static str;
|
||||||
const KERNEL: &'static str;
|
const KERNEL: &'static str;
|
||||||
const V: Self;
|
const V: Self;
|
||||||
@ -133,7 +145,7 @@ pub(crate) struct Relu;
|
|||||||
|
|
||||||
macro_rules! bin_op {
|
macro_rules! bin_op {
|
||||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||||
impl BinaryOp for $op {
|
impl BinaryOpT for $op {
|
||||||
const NAME: &'static str = $name;
|
const NAME: &'static str = $name;
|
||||||
const KERNEL: &'static str = concat!("b", $name);
|
const KERNEL: &'static str = concat!("b", $name);
|
||||||
const V: Self = $op;
|
const V: Self = $op;
|
||||||
@ -187,7 +199,7 @@ bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
|
|||||||
|
|
||||||
macro_rules! unary_op {
|
macro_rules! unary_op {
|
||||||
($op: ident, $name: literal, $a: ident, $e: expr) => {
|
($op: ident, $name: literal, $a: ident, $e: expr) => {
|
||||||
impl UnaryOp for $op {
|
impl UnaryOpT for $op {
|
||||||
const NAME: &'static str = $name;
|
const NAME: &'static str = $name;
|
||||||
const KERNEL: &'static str = concat!("u", $name);
|
const KERNEL: &'static str = concat!("u", $name);
|
||||||
const V: Self = $op;
|
const V: Self = $op;
|
||||||
@ -219,7 +231,7 @@ macro_rules! unary_op {
|
|||||||
};
|
};
|
||||||
|
|
||||||
($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
|
($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
|
||||||
impl UnaryOp for $op {
|
impl UnaryOpT for $op {
|
||||||
const NAME: &'static str = $name;
|
const NAME: &'static str = $name;
|
||||||
const KERNEL: &'static str = concat!("u", $name);
|
const KERNEL: &'static str = concat!("u", $name);
|
||||||
const V: Self = $op;
|
const V: Self = $op;
|
||||||
@ -277,7 +289,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
|||||||
|
|
||||||
/// `gelu` operation
|
/// `gelu` operation
|
||||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||||
impl UnaryOp for Gelu {
|
impl UnaryOpT for Gelu {
|
||||||
const NAME: &'static str = "gelu";
|
const NAME: &'static str = "gelu";
|
||||||
const V: Self = Gelu;
|
const V: Self = Gelu;
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
@ -343,7 +355,7 @@ impl UnaryOp for Gelu {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UnaryOp for Relu {
|
impl UnaryOpT for Relu {
|
||||||
const NAME: &'static str = "relu";
|
const NAME: &'static str = "relu";
|
||||||
const KERNEL: &'static str = "urelu";
|
const KERNEL: &'static str = "urelu";
|
||||||
const V: Self = Relu;
|
const V: Self = Relu;
|
||||||
|
@ -147,7 +147,7 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn unary_impl<B: op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||||
// TODO: Different code path for the contiguous case?
|
// TODO: Different code path for the contiguous case?
|
||||||
match self {
|
match self {
|
||||||
Storage::Cpu(storage) => {
|
Storage::Cpu(storage) => {
|
||||||
@ -161,7 +161,7 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn binary_impl<B: op::BinaryOp>(
|
pub(crate) fn binary_impl<B: op::BinaryOpT>(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
lhs_layout: &Layout,
|
lhs_layout: &Layout,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
use crate::op::{CmpOp, Op, ReduceOp};
|
use crate::op::{BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
|
||||||
use crate::shape::{Dim, Dims};
|
use crate::shape::{Dim, Dims};
|
||||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
@ -80,7 +80,7 @@ macro_rules! unary_op {
|
|||||||
.storage()
|
.storage()
|
||||||
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::$op_name(self.clone()))
|
Some(Op::Unary(self.clone(), UnaryOp::$op_name))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@ -99,7 +99,7 @@ macro_rules! binary_op {
|
|||||||
rhs.layout(),
|
rhs.layout(),
|
||||||
)?;
|
)?;
|
||||||
let op = if self.track_op() || rhs.track_op() {
|
let op = if self.track_op() || rhs.track_op() {
|
||||||
Some(Op::$op_name(self.clone(), rhs.clone()))
|
Some(Op::Binary(self.clone(), rhs.clone(), BinaryOp::$op_name))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user