mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add an abstract backprop op type (#240)
* Start adding the backprop op type. * More backprop ops. * Finish the backprop op.
This commit is contained in:
@ -58,7 +58,7 @@ pub enum UnaryOp {
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) enum Op {
|
||||
pub enum Op {
|
||||
Binary(Tensor, Tensor, BinaryOp),
|
||||
Unary(Tensor, UnaryOp),
|
||||
Cmp(Tensor, CmpOp),
|
||||
@ -512,3 +512,63 @@ impl UnaryOpT for Relu {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
|
||||
/// properly checked when creating a new value
|
||||
#[derive(Clone)]
|
||||
pub struct BackpropOp(Option<Op>);
|
||||
|
||||
impl BackpropOp {
|
||||
pub(crate) fn none() -> Self {
|
||||
BackpropOp(None)
|
||||
}
|
||||
|
||||
pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
|
||||
let op = if arg.track_op() {
|
||||
Some(f(arg.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
|
||||
pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
|
||||
let op = if arg1.track_op() || arg2.track_op() {
|
||||
Some(f(arg1.clone(), arg2.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
|
||||
pub(crate) fn new3(
|
||||
arg1: &Tensor,
|
||||
arg2: &Tensor,
|
||||
arg3: &Tensor,
|
||||
f: impl Fn(Tensor, Tensor, Tensor) -> Op,
|
||||
) -> Self {
|
||||
let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
|
||||
Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
|
||||
pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
|
||||
let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
|
||||
let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
|
||||
Some(f(args))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for BackpropOp {
|
||||
type Target = Option<Op>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user