Op refactor (#208)

* Add the binary and unary op enums to factorize some code.

* Bugfix.
This commit is contained in:
Laurent Mazare
2023-07-20 13:28:45 +02:00
committed by GitHub
parent e9c052bf94
commit 2a8f28d687
8 changed files with 81 additions and 76 deletions

View File

@ -1,4 +1,4 @@
use crate::op::{CmpOp, ReduceOp};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
pub(crate) trait BackendStorage: Sized {
@ -25,10 +25,9 @@ pub(crate) trait BackendStorage: Sized {
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self>;
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>;
fn binary_impl<B: crate::op::BinaryOp>(&self, _: &Self, _: &Layout, _: &Layout)
-> Result<Self>;
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;

View File

@ -1,4 +1,4 @@
use crate::op::{Op, ReduceOp};
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
use crate::{Error, Result, Tensor, TensorId};
use std::collections::HashMap;
@ -39,10 +39,7 @@ impl Tensor {
kernel: rhs,
..
}
| Op::Add(lhs, rhs)
| Op::Mul(lhs, rhs)
| Op::Sub(lhs, rhs)
| Op::Div(lhs, rhs)
| Op::Binary(lhs, rhs, _)
| Op::Embedding(lhs, rhs)
| Op::Matmul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
@ -74,17 +71,8 @@ impl Tensor {
| Op::Transpose(node, _, _)
| Op::Narrow(node, _, _, _)
| Op::Softmax(node, _)
| Op::Sqr(node)
| Op::Sqrt(node)
| Op::Gelu(node)
| Op::Relu(node)
| Op::Elu(node, _)
| Op::Exp(node)
| Op::Log(node)
| Op::Sin(node)
| Op::Cos(node)
| Op::Abs(node)
| Op::Neg(node) => {
| Op::Unary(node, _)
| Op::Elu(node, _) => {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
@ -118,19 +106,19 @@ impl Tensor {
// this is out of scope.
if let Some(op) = node.op() {
match op {
Op::Add(lhs, rhs) => {
Op::Binary(lhs, rhs, BinaryOp::Add) => {
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
}
Op::Sub(lhs, rhs) => {
Op::Binary(lhs, rhs, BinaryOp::Sub) => {
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
}
Op::Mul(lhs, rhs) => {
Op::Binary(lhs, rhs, BinaryOp::Mul) => {
let lhs_grad = grad.mul(rhs)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
@ -138,7 +126,7 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Div(lhs, rhs) => {
Op::Binary(lhs, rhs, BinaryOp::Div) => {
let lhs_grad = grad.div(rhs)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
@ -221,24 +209,26 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Log(arg) => {
Op::Unary(arg, UnaryOp::Log) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad * *node)?)?
}
Op::Sin(arg) => {
Op::Unary(arg, UnaryOp::Sin) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
}
Op::Cos(arg) => {
Op::Unary(arg, UnaryOp::Cos) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
}
Op::Abs(_args) => return Err(Error::BackwardNotSupported { op: "abs" }),
Op::Exp(arg) => {
Op::Unary(_, UnaryOp::Abs) => {
return Err(Error::BackwardNotSupported { op: "abs" })
}
Op::Unary(arg, UnaryOp::Exp) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad / arg)?)?
}
Op::Neg(arg) => {
Op::Unary(arg, UnaryOp::Neg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&grad)?
}
@ -276,15 +266,19 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }),
Op::Relu(_) => return Err(Error::BackwardNotSupported { op: "relu" }),
Op::Unary(_, UnaryOp::Gelu) => {
return Err(Error::BackwardNotSupported { op: "gelu" })
}
Op::Unary(_, UnaryOp::Relu) => {
return Err(Error::BackwardNotSupported { op: "relu" })
}
Op::Elu(..) => return Err(Error::BackwardNotSupported { op: "elu" }),
Op::Sqr(arg) => {
Op::Unary(arg, UnaryOp::Sqr) => {
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Sqrt(arg) => {
Op::Unary(arg, UnaryOp::Sqrt) => {
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?

View File

@ -1,5 +1,5 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOp, CmpOp, ReduceOp, UnaryOp};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, Layout, Result, Shape, WithDType};
use half::{bf16, f16};
@ -1158,7 +1158,7 @@ impl BackendStorage for CpuStorage {
}
}
fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
match self {
Self::BF16(storage) => {
if B::BF16_VEC {
@ -1207,7 +1207,12 @@ impl BackendStorage for CpuStorage {
}
}
fn binary_impl<B: BinaryOp>(&self, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
fn binary_impl<B: BinaryOpT>(
&self,
rhs: &Self,
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
match (self, rhs) {
(Self::BF16(lhs), Self::BF16(rhs)) => {
let data = if B::BF16_VEC {

View File

@ -1,5 +1,5 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{CmpOp, ReduceOp};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
use candle_kernels as kernels;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
@ -573,7 +573,7 @@ impl<'a> Map1 for FastReduce<'a> {
}
}
impl<U: crate::op::UnaryOp> Map1 for U {
impl<U: UnaryOpT> Map1 for U {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
@ -716,7 +716,7 @@ impl<'a> Map2 for WhereCond<'a> {
}
}
impl<U: crate::op::BinaryOp> Map2 for U {
impl<U: crate::op::BinaryOpT> Map2 for U {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
lhs: &CudaSlice<T>,
@ -976,13 +976,13 @@ impl BackendStorage for CudaStorage {
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
}
fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
fn unary_impl<U: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device().clone();
let slice = U::V.map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
fn binary_impl<B: crate::op::BinaryOp>(
fn binary_impl<B: BinaryOpT>(
&self,
rhs: &Self,
lhs_l: &Layout,

View File

@ -1,5 +1,5 @@
#![allow(dead_code)]
use crate::op::{CmpOp, ReduceOp};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
#[derive(Debug, Clone)]
@ -57,16 +57,11 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self> {
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn binary_impl<B: crate::op::BinaryOp>(
&self,
_: &Self,
_: &Layout,
_: &Layout,
) -> Result<Self> {
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -19,12 +19,34 @@ pub enum ReduceOp {
Max,
}
// These ops return the same type as their input type.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinaryOp {
Add,
Mul,
Sub,
Div,
}
// Unary ops with no argument
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnaryOp {
Exp,
Log,
Sin,
Cos,
Abs,
Neg,
Sqr,
Sqrt,
Gelu,
Relu,
}
#[derive(Clone)]
pub(crate) enum Op {
Add(Tensor, Tensor),
Mul(Tensor, Tensor),
Sub(Tensor, Tensor),
Div(Tensor, Tensor),
Binary(Tensor, Tensor, BinaryOp),
Unary(Tensor, UnaryOp),
Cmp(Tensor, CmpOp),
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
@ -49,26 +71,16 @@ pub(crate) enum Op {
},
ToDType(Tensor),
Broadcast(Tensor),
Exp(Tensor),
Log(Tensor),
Sin(Tensor),
Cos(Tensor),
Abs(Tensor),
Narrow(Tensor, usize, usize, usize),
Neg(Tensor),
Reshape(Tensor),
Softmax(Tensor, usize),
Sqr(Tensor),
Sqrt(Tensor),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
Gelu(Tensor),
Relu(Tensor),
Elu(Tensor, f64),
// TODO: Support for custom ops.
}
pub(crate) trait UnaryOp {
pub(crate) trait UnaryOpT {
const NAME: &'static str;
const KERNEL: &'static str;
const V: Self;
@ -91,7 +103,7 @@ pub(crate) trait UnaryOp {
fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
}
pub(crate) trait BinaryOp {
pub(crate) trait BinaryOpT {
const NAME: &'static str;
const KERNEL: &'static str;
const V: Self;
@ -133,7 +145,7 @@ pub(crate) struct Relu;
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
impl BinaryOp for $op {
impl BinaryOpT for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("b", $name);
const V: Self = $op;
@ -187,7 +199,7 @@ bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => {
impl UnaryOp for $op {
impl UnaryOpT for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("u", $name);
const V: Self = $op;
@ -219,7 +231,7 @@ macro_rules! unary_op {
};
($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
impl UnaryOp for $op {
impl UnaryOpT for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("u", $name);
const V: Self = $op;
@ -277,7 +289,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
/// `gelu` operation
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
impl UnaryOp for Gelu {
impl UnaryOpT for Gelu {
const NAME: &'static str = "gelu";
const V: Self = Gelu;
#[inline(always)]
@ -343,7 +355,7 @@ impl UnaryOp for Gelu {
}
}
impl UnaryOp for Relu {
impl UnaryOpT for Relu {
const NAME: &'static str = "relu";
const KERNEL: &'static str = "urelu";
const V: Self = Relu;

View File

@ -147,7 +147,7 @@ impl Storage {
}
}
pub(crate) fn unary_impl<B: op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
// TODO: Different code path for the contiguous case?
match self {
Storage::Cpu(storage) => {
@ -161,7 +161,7 @@ impl Storage {
}
}
pub(crate) fn binary_impl<B: op::BinaryOp>(
pub(crate) fn binary_impl<B: op::BinaryOpT>(
&self,
rhs: &Self,
lhs_layout: &Layout,

View File

@ -1,5 +1,5 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{CmpOp, Op, ReduceOp};
use crate::op::{BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@ -80,7 +80,7 @@ macro_rules! unary_op {
.storage()
.unary_impl::<crate::op::$op_name>(self.layout())?;
let op = if self.track_op() {
Some(Op::$op_name(self.clone()))
Some(Op::Unary(self.clone(), UnaryOp::$op_name))
} else {
None
};
@ -99,7 +99,7 @@ macro_rules! binary_op {
rhs.layout(),
)?;
let op = if self.track_op() || rhs.track_op() {
Some(Op::$op_name(self.clone(), rhs.clone()))
Some(Op::Binary(self.clone(), rhs.clone(), BinaryOp::$op_name))
} else {
None
};