mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Custom ops with a single argument (#214)
* Add the CustomOp1 trait. * Add an example of custom op. * Polish the custom op example. * Add some backward pass test for custom ops.
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
|
||||
pub(crate) trait BackendStorage: Sized {
|
||||
pub trait BackendStorage: Sized {
|
||||
type Device: BackendDevice;
|
||||
|
||||
fn try_clone(&self, _: &Layout) -> Result<Self>;
|
||||
@ -53,7 +53,7 @@ pub(crate) trait BackendStorage: Sized {
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
|
||||
}
|
||||
|
||||
pub(crate) trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
type Storage: BackendStorage;
|
||||
|
||||
// TODO: Make the usize generic and part of a generic DeviceLocation.
|
||||
|
@ -86,7 +86,8 @@ impl Tensor {
|
||||
| Op::Narrow(node, _, _, _)
|
||||
| Op::Softmax(node, _)
|
||||
| Op::Unary(node, _)
|
||||
| Op::Elu(node, _) => {
|
||||
| Op::Elu(node, _)
|
||||
| Op::CustomOp1(node, _) => {
|
||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
@ -319,6 +320,11 @@ impl Tensor {
|
||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||
Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?,
|
||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||
Op::CustomOp1(arg, c) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let arg_grad = c.bwd(arg, node, &grad)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Sqr) => {
|
||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -18,7 +18,7 @@ pub enum CpuStorage {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CpuDevice;
|
||||
|
||||
trait Map1 {
|
||||
pub trait Map1 {
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
||||
|
||||
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
||||
@ -33,7 +33,7 @@ trait Map1 {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map1Any {
|
||||
pub trait Map1Any {
|
||||
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
|
||||
&self,
|
||||
vs: &[T],
|
||||
@ -54,7 +54,7 @@ trait Map1Any {
|
||||
}
|
||||
|
||||
type C = CpuStorage;
|
||||
trait Map2 {
|
||||
pub trait Map2 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
||||
|
||||
@ -82,7 +82,7 @@ trait Map2 {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map2U8 {
|
||||
pub trait Map2U8 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||
|
||||
@ -348,7 +348,11 @@ impl<'a> Map1 for Reduce<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> {
|
||||
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
mut f: F,
|
||||
) -> Vec<U> {
|
||||
match layout.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
||||
[start_offset..start_offset + len]
|
||||
@ -380,7 +384,7 @@ fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut
|
||||
}
|
||||
}
|
||||
|
||||
fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
||||
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
mut f: F,
|
||||
|
@ -33,13 +33,13 @@
|
||||
//!
|
||||
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
|
||||
|
||||
mod backend;
|
||||
pub mod backend;
|
||||
mod backprop;
|
||||
mod conv;
|
||||
mod convert;
|
||||
mod cpu_backend;
|
||||
pub mod cpu_backend;
|
||||
#[cfg(feature = "cuda")]
|
||||
mod cuda_backend;
|
||||
pub mod cuda_backend;
|
||||
mod device;
|
||||
pub mod display;
|
||||
mod dtype;
|
||||
@ -65,6 +65,7 @@ pub use dtype::{DType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use indexer::IndexOp;
|
||||
pub use layout::Layout;
|
||||
pub use op::CustomOp1;
|
||||
pub use shape::{Shape, D};
|
||||
pub use storage::Storage;
|
||||
pub use strided_index::{StridedBlocks, StridedIndex};
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::Tensor;
|
||||
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
|
||||
use half::{bf16, f16};
|
||||
use num_traits::float::Float;
|
||||
|
||||
@ -93,10 +93,35 @@ pub(crate) enum Op {
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
Elu(Tensor, f64),
|
||||
// TODO: Support for custom ops.
|
||||
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>),
|
||||
}
|
||||
|
||||
pub(crate) trait UnaryOpT {
|
||||
/// Unary ops that can be defined in user-land.
|
||||
pub trait CustomOp1: Send + Sync {
|
||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(&self, _: &CudaStorage, _: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||
/// The function should return the gradient of the argument.
|
||||
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Tensor> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
pub trait UnaryOpT {
|
||||
const NAME: &'static str;
|
||||
const KERNEL: &'static str;
|
||||
const V: Self;
|
||||
@ -119,7 +144,7 @@ pub(crate) trait UnaryOpT {
|
||||
fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
|
||||
}
|
||||
|
||||
pub(crate) trait BinaryOpT {
|
||||
pub trait BinaryOpT {
|
||||
const NAME: &'static str;
|
||||
const KERNEL: &'static str;
|
||||
const V: Self;
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::op::{self, CmpOp, ReduceOp};
|
||||
use crate::op::{self, CmpOp, CustomOp1, ReduceOp};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
@ -147,6 +147,19 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let (storage, shape) = c.cpu_fwd(storage, l)?;
|
||||
Ok((Self::Cpu(storage), shape))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
||||
Ok((Self::Cuda(storage), shape))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
// TODO: Different code path for the contiguous case?
|
||||
match self {
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
|
||||
use crate::op::{BinaryOp, CmpOp, CustomOp1, Op, ReduceOp, UnaryOp};
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
@ -1688,6 +1688,23 @@ impl Tensor {
|
||||
let rhs: &RwLock<Storage> = rhs.storage.as_ref();
|
||||
std::ptr::eq(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Applies a unary custom op.
|
||||
pub fn custom_op1_arc(&self, c: Arc<Box<dyn CustomOp1>>) -> Result<Self> {
|
||||
let (storage, shape) = self
|
||||
.storage()
|
||||
.custom_op1(self.layout(), c.as_ref().as_ref())?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::CustomOp1(self.clone(), c))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn custom_op1<C: 'static + CustomOp1>(&self, c: C) -> Result<Self> {
|
||||
self.custom_op1_arc(Arc::new(Box::new(c)))
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
157
candle-core/tests/custom_op_tests.rs
Normal file
157
candle-core/tests/custom_op_tests.rs
Normal file
@ -0,0 +1,157 @@
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cpu_backend;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
|
||||
use half::{bf16, f16};
|
||||
|
||||
mod test_utils;
|
||||
use test_utils::to_vec1_round;
|
||||
|
||||
fn fwd<T: num_traits::Float>(v: T, alpha: T) -> T {
|
||||
if v.is_sign_positive() {
|
||||
v
|
||||
} else {
|
||||
(v.exp() - T::one()) * alpha
|
||||
}
|
||||
}
|
||||
|
||||
struct Elu {
|
||||
alpha: f64,
|
||||
}
|
||||
|
||||
impl CustomOp1 for Elu {
|
||||
fn name(&self) -> &'static str {
|
||||
"elu"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
use CpuStorage::*;
|
||||
|
||||
// In this example, we pattern match over the different dtypes. Some helper functions and
|
||||
// traits from the `cpu_backend` module can be used to avoid this in some common cases, see
|
||||
// e.g. `Map1`.
|
||||
let storage = match s {
|
||||
BF16(s) => {
|
||||
let alpha = bf16::from_f64(self.alpha);
|
||||
let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha));
|
||||
BF16(data)
|
||||
}
|
||||
F16(s) => {
|
||||
let alpha = f16::from_f64(self.alpha);
|
||||
let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha));
|
||||
F16(data)
|
||||
}
|
||||
F32(s) => {
|
||||
let alpha = self.alpha as f32;
|
||||
let data = cpu_backend::unary_map(s, l, |v| fwd(v, alpha));
|
||||
F32(data)
|
||||
}
|
||||
F64(s) => {
|
||||
let data = cpu_backend::unary_map(s, l, |v| fwd(v, self.alpha));
|
||||
F64(data)
|
||||
}
|
||||
_ => Err(Error::UnsupportedDTypeForOp(s.dtype(), "elu").bt())?,
|
||||
};
|
||||
Ok((storage, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_op1_no_backward() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
|
||||
let t = (t - 5.)?;
|
||||
let elu_t = t.custom_op1(Elu { alpha: 1. })?;
|
||||
assert_eq!(
|
||||
to_vec1_round(&elu_t, 4)?,
|
||||
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Define a similar struct as Elu but with backward support.
|
||||
fn bwd<T: num_traits::Float>(v: T, alpha: T) -> T {
|
||||
if v.is_sign_positive() {
|
||||
T::one()
|
||||
} else {
|
||||
v.exp() * alpha
|
||||
}
|
||||
}
|
||||
|
||||
struct EluBackward {
|
||||
alpha: f64,
|
||||
}
|
||||
|
||||
impl CustomOp1 for EluBackward {
|
||||
fn name(&self) -> &'static str {
|
||||
"elu-bwd"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
use CpuStorage::*;
|
||||
|
||||
// In this example, we pattern match over the different dtypes. Some helper functions and
|
||||
// traits from the `cpu_backend` module can be used to avoid this in some common cases, see
|
||||
// e.g. `Map1`.
|
||||
let storage = match s {
|
||||
BF16(s) => {
|
||||
let alpha = bf16::from_f64(self.alpha);
|
||||
let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha));
|
||||
BF16(data)
|
||||
}
|
||||
F16(s) => {
|
||||
let alpha = f16::from_f64(self.alpha);
|
||||
let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha));
|
||||
F16(data)
|
||||
}
|
||||
F32(s) => {
|
||||
let alpha = self.alpha as f32;
|
||||
let data = cpu_backend::unary_map(s, l, |v| bwd(v, alpha));
|
||||
F32(data)
|
||||
}
|
||||
F64(s) => {
|
||||
let data = cpu_backend::unary_map(s, l, |v| bwd(v, self.alpha));
|
||||
F64(data)
|
||||
}
|
||||
_ => Err(Error::UnsupportedDTypeForOp(s.dtype(), "elu").bt())?,
|
||||
};
|
||||
Ok((storage, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
struct EluWithBackward(Elu);
|
||||
|
||||
impl EluWithBackward {
|
||||
fn new(alpha: f64) -> Self {
|
||||
Self(Elu { alpha })
|
||||
}
|
||||
}
|
||||
|
||||
impl CustomOp1 for EluWithBackward {
|
||||
fn name(&self) -> &'static str {
|
||||
"elu"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
self.0.cpu_fwd(s, l)
|
||||
}
|
||||
|
||||
fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Tensor> {
|
||||
let alpha = self.0.alpha;
|
||||
let bwd = arg.custom_op1(EluBackward { alpha })?;
|
||||
grad_res.mul(&bwd)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_op1_with_backward() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let t = candle::Var::new(&[-2f32, 0f32, 2f32], cpu)?;
|
||||
let elu_t = t.custom_op1(EluWithBackward::new(2.))?;
|
||||
assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]);
|
||||
|
||||
let grads = elu_t.backward()?;
|
||||
let grad_x = grads.get(&t).unwrap();
|
||||
assert_eq!(to_vec1_round(grad_x, 4)?, [0.2707, 1.0, 1.0]);
|
||||
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user