Add the affine transformation.

This commit is contained in:
laurent
2023-06-20 21:51:35 +01:00
parent e04d0ac1ac
commit f1f372b13e
3 changed files with 71 additions and 7 deletions

View File

@ -209,6 +209,26 @@ impl Tensor {
}
}
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
let shape = self.shape();
let storage = self
.storage
.affine_impl(self.shape(), self.stride(), mul, add)?;
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape: shape.clone(),
stride: shape.stride_contiguous(),
op: Some(Op::Affine {
arg: self.clone(),
mul,
add,
}),
is_variable: false,
};
Ok(Self(Arc::new(tensor_)))
}
pub(crate) fn strided_index(&self) -> crate::storage::StridedIndex {
crate::storage::StridedIndex::new(self.dims(), self.stride())
}
@ -307,6 +327,15 @@ impl Tensor {
track_grad |= tg;
nodes
}
Op::Affine { arg, mul, .. } => {
if *mul == 0. {
nodes
} else {
let (tg, nodes) = walk(arg, nodes, already_seen);
track_grad |= tg;
nodes
}
}
Op::Sqr(node) | Op::Sqrt(node) => {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
@ -357,16 +386,19 @@ impl Tensor {
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Sqr(_arg) => {
todo!()
// TODO: Add scaling by a constant to enable the following.
// let arg_grad = 2 * arg * grad;
// let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
// *sum_grad = sum_grad.add(arg_grad)?
Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?;
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Sqr(arg) => {
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Sqrt(_arg) => {
todo!()
// TODO: Add scaling by a constant and divide to enable the following.
// TODO: Add div to enable the following.
// let arg_grad = grad / (2 * arg)
// let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
// *sum_grad = sum_grad.add(arg_grad)?