mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Op refactor (#208)
* Add the binary and unary op enums to factorize some code. * Bugfix.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use crate::op::{CmpOp, ReduceOp};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
|
||||
pub(crate) trait BackendStorage: Sized {
|
||||
@ -25,10 +25,9 @@ pub(crate) trait BackendStorage: Sized {
|
||||
|
||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
|
||||
|
||||
fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self>;
|
||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>;
|
||||
|
||||
fn binary_impl<B: crate::op::BinaryOp>(&self, _: &Self, _: &Layout, _: &Layout)
|
||||
-> Result<Self>;
|
||||
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
|
||||
|
||||
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::op::{Op, ReduceOp};
|
||||
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
|
||||
use crate::{Error, Result, Tensor, TensorId};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -39,10 +39,7 @@ impl Tensor {
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::Add(lhs, rhs)
|
||||
| Op::Mul(lhs, rhs)
|
||||
| Op::Sub(lhs, rhs)
|
||||
| Op::Div(lhs, rhs)
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::Embedding(lhs, rhs)
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
@ -74,17 +71,8 @@ impl Tensor {
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Narrow(node, _, _, _)
|
||||
| Op::Softmax(node, _)
|
||||
| Op::Sqr(node)
|
||||
| Op::Sqrt(node)
|
||||
| Op::Gelu(node)
|
||||
| Op::Relu(node)
|
||||
| Op::Elu(node, _)
|
||||
| Op::Exp(node)
|
||||
| Op::Log(node)
|
||||
| Op::Sin(node)
|
||||
| Op::Cos(node)
|
||||
| Op::Abs(node)
|
||||
| Op::Neg(node) => {
|
||||
| Op::Unary(node, _)
|
||||
| Op::Elu(node, _) => {
|
||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
@ -118,19 +106,19 @@ impl Tensor {
|
||||
// this is out of scope.
|
||||
if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::Add(lhs, rhs) => {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
|
||||
}
|
||||
Op::Sub(lhs, rhs) => {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Sub) => {
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
|
||||
}
|
||||
Op::Mul(lhs, rhs) => {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Mul) => {
|
||||
let lhs_grad = grad.mul(rhs)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
@ -138,7 +126,7 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Div(lhs, rhs) => {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Div) => {
|
||||
let lhs_grad = grad.div(rhs)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
@ -221,24 +209,26 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Log(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Log) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * *node)?)?
|
||||
}
|
||||
Op::Sin(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Sin) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
|
||||
}
|
||||
Op::Cos(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Cos) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
||||
}
|
||||
Op::Abs(_args) => return Err(Error::BackwardNotSupported { op: "abs" }),
|
||||
Op::Exp(arg) => {
|
||||
Op::Unary(_, UnaryOp::Abs) => {
|
||||
return Err(Error::BackwardNotSupported { op: "abs" })
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Exp) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad / arg)?)?
|
||||
}
|
||||
Op::Neg(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Neg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&grad)?
|
||||
}
|
||||
@ -276,15 +266,19 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }),
|
||||
Op::Relu(_) => return Err(Error::BackwardNotSupported { op: "relu" }),
|
||||
Op::Unary(_, UnaryOp::Gelu) => {
|
||||
return Err(Error::BackwardNotSupported { op: "gelu" })
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Relu) => {
|
||||
return Err(Error::BackwardNotSupported { op: "relu" })
|
||||
}
|
||||
Op::Elu(..) => return Err(Error::BackwardNotSupported { op: "elu" }),
|
||||
Op::Sqr(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Sqr) => {
|
||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Sqrt(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Sqrt) => {
|
||||
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOp, CmpOp, ReduceOp, UnaryOp};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{DType, Error, Layout, Result, Shape, WithDType};
|
||||
use half::{bf16, f16};
|
||||
|
||||
@ -1158,7 +1158,7 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
match self {
|
||||
Self::BF16(storage) => {
|
||||
if B::BF16_VEC {
|
||||
@ -1207,7 +1207,12 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOp>(&self, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
||||
fn binary_impl<B: BinaryOpT>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
match (self, rhs) {
|
||||
(Self::BF16(lhs), Self::BF16(rhs)) => {
|
||||
let data = if B::BF16_VEC {
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{CmpOp, ReduceOp};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||
use candle_kernels as kernels;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
@ -573,7 +573,7 @@ impl<'a> Map1 for FastReduce<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<U: crate::op::UnaryOp> Map1 for U {
|
||||
impl<U: UnaryOpT> Map1 for U {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
@ -716,7 +716,7 @@ impl<'a> Map2 for WhereCond<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<U: crate::op::BinaryOp> Map2 for U {
|
||||
impl<U: crate::op::BinaryOpT> Map2 for U {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
lhs: &CudaSlice<T>,
|
||||
@ -976,13 +976,13 @@ impl BackendStorage for CudaStorage {
|
||||
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
|
||||
}
|
||||
|
||||
fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||
fn unary_impl<U: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = U::V.map(&self.slice, &device, layout)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn binary_impl<B: crate::op::BinaryOp>(
|
||||
fn binary_impl<B: BinaryOpT>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
lhs_l: &Layout,
|
||||
|
@ -1,5 +1,5 @@
|
||||
#![allow(dead_code)]
|
||||
use crate::op::{CmpOp, ReduceOp};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -57,16 +57,11 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self> {
|
||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn binary_impl<B: crate::op::BinaryOp>(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Layout,
|
||||
) -> Result<Self> {
|
||||
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
|
@ -19,12 +19,34 @@ pub enum ReduceOp {
|
||||
Max,
|
||||
}
|
||||
|
||||
// These ops return the same type as their input type.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum BinaryOp {
|
||||
Add,
|
||||
Mul,
|
||||
Sub,
|
||||
Div,
|
||||
}
|
||||
|
||||
// Unary ops with no argument
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum UnaryOp {
|
||||
Exp,
|
||||
Log,
|
||||
Sin,
|
||||
Cos,
|
||||
Abs,
|
||||
Neg,
|
||||
Sqr,
|
||||
Sqrt,
|
||||
Gelu,
|
||||
Relu,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) enum Op {
|
||||
Add(Tensor, Tensor),
|
||||
Mul(Tensor, Tensor),
|
||||
Sub(Tensor, Tensor),
|
||||
Div(Tensor, Tensor),
|
||||
Binary(Tensor, Tensor, BinaryOp),
|
||||
Unary(Tensor, UnaryOp),
|
||||
Cmp(Tensor, CmpOp),
|
||||
Reduce(Tensor, ReduceOp, Vec<usize>),
|
||||
Matmul(Tensor, Tensor),
|
||||
@ -49,26 +71,16 @@ pub(crate) enum Op {
|
||||
},
|
||||
ToDType(Tensor),
|
||||
Broadcast(Tensor),
|
||||
Exp(Tensor),
|
||||
Log(Tensor),
|
||||
Sin(Tensor),
|
||||
Cos(Tensor),
|
||||
Abs(Tensor),
|
||||
Narrow(Tensor, usize, usize, usize),
|
||||
Neg(Tensor),
|
||||
Reshape(Tensor),
|
||||
Softmax(Tensor, usize),
|
||||
Sqr(Tensor),
|
||||
Sqrt(Tensor),
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
Gelu(Tensor),
|
||||
Relu(Tensor),
|
||||
Elu(Tensor, f64),
|
||||
// TODO: Support for custom ops.
|
||||
}
|
||||
|
||||
pub(crate) trait UnaryOp {
|
||||
pub(crate) trait UnaryOpT {
|
||||
const NAME: &'static str;
|
||||
const KERNEL: &'static str;
|
||||
const V: Self;
|
||||
@ -91,7 +103,7 @@ pub(crate) trait UnaryOp {
|
||||
fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
|
||||
}
|
||||
|
||||
pub(crate) trait BinaryOp {
|
||||
pub(crate) trait BinaryOpT {
|
||||
const NAME: &'static str;
|
||||
const KERNEL: &'static str;
|
||||
const V: Self;
|
||||
@ -133,7 +145,7 @@ pub(crate) struct Relu;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||
impl BinaryOp for $op {
|
||||
impl BinaryOpT for $op {
|
||||
const NAME: &'static str = $name;
|
||||
const KERNEL: &'static str = concat!("b", $name);
|
||||
const V: Self = $op;
|
||||
@ -187,7 +199,7 @@ bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
|
||||
|
||||
macro_rules! unary_op {
|
||||
($op: ident, $name: literal, $a: ident, $e: expr) => {
|
||||
impl UnaryOp for $op {
|
||||
impl UnaryOpT for $op {
|
||||
const NAME: &'static str = $name;
|
||||
const KERNEL: &'static str = concat!("u", $name);
|
||||
const V: Self = $op;
|
||||
@ -219,7 +231,7 @@ macro_rules! unary_op {
|
||||
};
|
||||
|
||||
($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
|
||||
impl UnaryOp for $op {
|
||||
impl UnaryOpT for $op {
|
||||
const NAME: &'static str = $name;
|
||||
const KERNEL: &'static str = concat!("u", $name);
|
||||
const V: Self = $op;
|
||||
@ -277,7 +289,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||
|
||||
/// `gelu` operation
|
||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||
impl UnaryOp for Gelu {
|
||||
impl UnaryOpT for Gelu {
|
||||
const NAME: &'static str = "gelu";
|
||||
const V: Self = Gelu;
|
||||
#[inline(always)]
|
||||
@ -343,7 +355,7 @@ impl UnaryOp for Gelu {
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOp for Relu {
|
||||
impl UnaryOpT for Relu {
|
||||
const NAME: &'static str = "relu";
|
||||
const KERNEL: &'static str = "urelu";
|
||||
const V: Self = Relu;
|
||||
|
@ -147,7 +147,7 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
// TODO: Different code path for the contiguous case?
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
@ -161,7 +161,7 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn binary_impl<B: op::BinaryOp>(
|
||||
pub(crate) fn binary_impl<B: op::BinaryOpT>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
lhs_layout: &Layout,
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{CmpOp, Op, ReduceOp};
|
||||
use crate::op::{BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
@ -80,7 +80,7 @@ macro_rules! unary_op {
|
||||
.storage()
|
||||
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::$op_name(self.clone()))
|
||||
Some(Op::Unary(self.clone(), UnaryOp::$op_name))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@ -99,7 +99,7 @@ macro_rules! binary_op {
|
||||
rhs.layout(),
|
||||
)?;
|
||||
let op = if self.track_op() || rhs.track_op() {
|
||||
Some(Op::$op_name(self.clone(), rhs.clone()))
|
||||
Some(Op::Binary(self.clone(), rhs.clone(), BinaryOp::$op_name))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
Reference in New Issue
Block a user