mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add the affine transformation.
This commit is contained in:
@ -2,6 +2,12 @@ use crate::Tensor;
|
|||||||
|
|
||||||
pub(crate) enum Op {
|
pub(crate) enum Op {
|
||||||
Add(Tensor, Tensor),
|
Add(Tensor, Tensor),
|
||||||
|
#[allow(dead_code)] // add is currently unused.
|
||||||
|
Affine {
|
||||||
|
arg: Tensor,
|
||||||
|
mul: f64,
|
||||||
|
add: f64,
|
||||||
|
},
|
||||||
Mul(Tensor, Tensor),
|
Mul(Tensor, Tensor),
|
||||||
Sqr(Tensor),
|
Sqr(Tensor),
|
||||||
Sqrt(Tensor),
|
Sqrt(Tensor),
|
||||||
|
@ -172,6 +172,32 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn affine_impl(
|
||||||
|
&self,
|
||||||
|
shape: &Shape,
|
||||||
|
stride: &[usize],
|
||||||
|
mul: f64,
|
||||||
|
add: f64,
|
||||||
|
) -> Result<Self> {
|
||||||
|
// TODO: Different code path for the contiguous case?
|
||||||
|
match self {
|
||||||
|
Storage::Cpu(storage) => match storage {
|
||||||
|
CpuStorage::F32(storage) => {
|
||||||
|
let index = StridedIndex::new(shape.dims(), stride);
|
||||||
|
let mul = mul as f32;
|
||||||
|
let add = add as f32;
|
||||||
|
let data = index.map(|i| storage[i] * mul + add).collect();
|
||||||
|
Ok(Storage::Cpu(CpuStorage::F32(data)))
|
||||||
|
}
|
||||||
|
CpuStorage::F64(storage) => {
|
||||||
|
let index = StridedIndex::new(shape.dims(), stride);
|
||||||
|
let data = index.map(|i| storage[i] * mul + add).collect();
|
||||||
|
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||||
// TODO: Different code path for the contiguous case?
|
// TODO: Different code path for the contiguous case?
|
||||||
match self {
|
match self {
|
||||||
|
@ -209,6 +209,26 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||||
|
let shape = self.shape();
|
||||||
|
let storage = self
|
||||||
|
.storage
|
||||||
|
.affine_impl(self.shape(), self.stride(), mul, add)?;
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
id: TensorId::new(),
|
||||||
|
storage,
|
||||||
|
shape: shape.clone(),
|
||||||
|
stride: shape.stride_contiguous(),
|
||||||
|
op: Some(Op::Affine {
|
||||||
|
arg: self.clone(),
|
||||||
|
mul,
|
||||||
|
add,
|
||||||
|
}),
|
||||||
|
is_variable: false,
|
||||||
|
};
|
||||||
|
Ok(Self(Arc::new(tensor_)))
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn strided_index(&self) -> crate::storage::StridedIndex {
|
pub(crate) fn strided_index(&self) -> crate::storage::StridedIndex {
|
||||||
crate::storage::StridedIndex::new(self.dims(), self.stride())
|
crate::storage::StridedIndex::new(self.dims(), self.stride())
|
||||||
}
|
}
|
||||||
@ -307,6 +327,15 @@ impl Tensor {
|
|||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
nodes
|
nodes
|
||||||
}
|
}
|
||||||
|
Op::Affine { arg, mul, .. } => {
|
||||||
|
if *mul == 0. {
|
||||||
|
nodes
|
||||||
|
} else {
|
||||||
|
let (tg, nodes) = walk(arg, nodes, already_seen);
|
||||||
|
track_grad |= tg;
|
||||||
|
nodes
|
||||||
|
}
|
||||||
|
}
|
||||||
Op::Sqr(node) | Op::Sqrt(node) => {
|
Op::Sqr(node) | Op::Sqrt(node) => {
|
||||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
@ -357,16 +386,19 @@ impl Tensor {
|
|||||||
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
Op::Sqr(_arg) => {
|
Op::Affine { arg, mul, .. } => {
|
||||||
todo!()
|
let arg_grad = grad.affine(*mul, 0.)?;
|
||||||
// TODO: Add scaling by a constant to enable the following.
|
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
||||||
// let arg_grad = 2 * arg * grad;
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
// let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
}
|
||||||
// *sum_grad = sum_grad.add(arg_grad)?
|
Op::Sqr(arg) => {
|
||||||
|
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||||
|
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
||||||
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Sqrt(_arg) => {
|
Op::Sqrt(_arg) => {
|
||||||
todo!()
|
todo!()
|
||||||
// TODO: Add scaling by a constant and divide to enable the following.
|
// TODO: Add div to enable the following.
|
||||||
// let arg_grad = grad / (2 * arg)
|
// let arg_grad = grad / (2 * arg)
|
||||||
// let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
// let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
||||||
// *sum_grad = sum_grad.add(arg_grad)?
|
// *sum_grad = sum_grad.add(arg_grad)?
|
||||||
|
Reference in New Issue
Block a user