mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
More realistic training setup. (#210)
* More realistic training setup. * Compute the model accuracy. * Very inefficient backprop for index select. * More backprop. * Fix some backprop issues. * Backprop fix. * Another broadcasting backprop fix. * Better backprop for reducing ops. * Training again. * Add some gradient tests. * Get the training to work.
This commit is contained in:
@ -2,6 +2,19 @@ use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
|
|||||||
use crate::{Error, Result, Tensor, TensorId};
|
use crate::{Error, Result, Tensor, TensorId};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
// arg has been reduced to node via reduce_dims, expand it back to arg.
|
||||||
|
// This has to handle keepdims.
|
||||||
|
fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result<Tensor> {
|
||||||
|
if arg.rank() == node.rank() {
|
||||||
|
// keepdim = true
|
||||||
|
node.broadcast_as(arg.shape())
|
||||||
|
} else {
|
||||||
|
// keepdim = false
|
||||||
|
// first expand the reduced dims.
|
||||||
|
node.reshape(reduced_dims)?.broadcast_as(arg.shape())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||||
@ -145,8 +158,26 @@ impl Tensor {
|
|||||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||||
}
|
}
|
||||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||||
Op::IndexSelect(_lhs, _rhs, _) => {
|
Op::IndexSelect(arg, indexes, dim) => {
|
||||||
Err(Error::BackwardNotSupported { op: "index-select" })?
|
let dim = *dim;
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
// TODO: This is very very very inefficient, have some dedicated kernel for this.
|
||||||
|
let indexes = indexes.to_vec1::<u32>()?;
|
||||||
|
for (dst_index, src_index) in indexes.iter().enumerate() {
|
||||||
|
let src_index = *src_index as usize;
|
||||||
|
let dst_grad_for_index = grad.narrow(dim, dst_index, 1)?;
|
||||||
|
let mut pre_dims = arg.dims().to_vec();
|
||||||
|
pre_dims[dim] = src_index;
|
||||||
|
let pre_zeros =
|
||||||
|
Tensor::zeros(pre_dims, sum_grad.dtype(), sum_grad.device())?;
|
||||||
|
let mut post_dims = arg.dims().to_vec();
|
||||||
|
post_dims[dim] = post_dims[dim] - src_index - 1;
|
||||||
|
let post_zeros =
|
||||||
|
Tensor::zeros(post_dims, sum_grad.dtype(), sum_grad.device())?;
|
||||||
|
let src_grad =
|
||||||
|
Tensor::cat(&[pre_zeros, dst_grad_for_index, post_zeros], dim)?;
|
||||||
|
*sum_grad = sum_grad.add(&src_grad)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Op::Embedding(_lhs, _rhs) => {
|
Op::Embedding(_lhs, _rhs) => {
|
||||||
Err(Error::BackwardNotSupported { op: "embedding" })?
|
Err(Error::BackwardNotSupported { op: "embedding" })?
|
||||||
@ -189,20 +220,32 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let arg_grad = grad.sum(sum_dims.as_slice())?;
|
let mut arg_grad = grad.sum_keepdim(sum_dims.as_slice())?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
for _i in 0..left_dims {
|
||||||
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
|
arg_grad = arg_grad.squeeze(0)?
|
||||||
}
|
}
|
||||||
Op::Reduce(arg, ReduceOp::Sum, _) => {
|
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.broadcast_add(&grad)?
|
*sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?;
|
||||||
|
}
|
||||||
|
Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
|
||||||
|
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
*sum_grad = sum_grad.add(&grad)?;
|
||||||
}
|
}
|
||||||
Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
|
Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
|
||||||
Op::Reduce(_args, ReduceOp::Max, _) => {
|
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
|
||||||
Err(Error::BackwardNotSupported { op: "max" })?
|
let node = broadcast_back(arg, node, reduced_dims)?;
|
||||||
|
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||||
|
let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
*sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
|
||||||
}
|
}
|
||||||
Op::Reduce(_args, ReduceOp::Min, _) => {
|
Op::Reduce(arg, ReduceOp::Min, reduced_dims) => {
|
||||||
Err(Error::BackwardNotSupported { op: "min" })?
|
let node = broadcast_back(arg, node, reduced_dims)?;
|
||||||
|
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||||
|
let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
*sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
|
||||||
}
|
}
|
||||||
Op::ToDType(arg) => {
|
Op::ToDType(arg) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
@ -215,7 +258,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
Op::Unary(arg, UnaryOp::Log) => {
|
Op::Unary(arg, UnaryOp::Log) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&(&grad * *node)?)?
|
*sum_grad = sum_grad.add(&(grad / arg)?)?
|
||||||
}
|
}
|
||||||
Op::Unary(arg, UnaryOp::Sin) => {
|
Op::Unary(arg, UnaryOp::Sin) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
@ -228,7 +271,7 @@ impl Tensor {
|
|||||||
Op::Unary(_, UnaryOp::Abs) => Err(Error::BackwardNotSupported { op: "abs" })?,
|
Op::Unary(_, UnaryOp::Abs) => Err(Error::BackwardNotSupported { op: "abs" })?,
|
||||||
Op::Unary(arg, UnaryOp::Exp) => {
|
Op::Unary(arg, UnaryOp::Exp) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&(&grad / arg)?)?
|
*sum_grad = sum_grad.add(&(&grad * *node)?)?
|
||||||
}
|
}
|
||||||
Op::Unary(arg, UnaryOp::Neg) => {
|
Op::Unary(arg, UnaryOp::Neg) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
@ -52,7 +52,7 @@ mod mkl;
|
|||||||
pub mod npy;
|
pub mod npy;
|
||||||
mod op;
|
mod op;
|
||||||
pub mod safetensors;
|
pub mod safetensors;
|
||||||
mod shape;
|
pub mod shape;
|
||||||
mod storage;
|
mod storage;
|
||||||
mod strided_index;
|
mod strided_index;
|
||||||
mod tensor;
|
mod tensor;
|
||||||
|
@ -48,6 +48,7 @@ pub(crate) enum Op {
|
|||||||
Binary(Tensor, Tensor, BinaryOp),
|
Binary(Tensor, Tensor, BinaryOp),
|
||||||
Unary(Tensor, UnaryOp),
|
Unary(Tensor, UnaryOp),
|
||||||
Cmp(Tensor, CmpOp),
|
Cmp(Tensor, CmpOp),
|
||||||
|
// The third argument is the reduced shape with `keepdim=true`.
|
||||||
Reduce(Tensor, ReduceOp, Vec<usize>),
|
Reduce(Tensor, ReduceOp, Vec<usize>),
|
||||||
Matmul(Tensor, Tensor),
|
Matmul(Tensor, Tensor),
|
||||||
Embedding(Tensor, Tensor),
|
Embedding(Tensor, Tensor),
|
||||||
|
@ -633,15 +633,15 @@ impl Tensor {
|
|||||||
let storage = self
|
let storage = self
|
||||||
.storage()
|
.storage()
|
||||||
.reduce_op(ReduceOp::Max, self.layout(), &max_dims)?;
|
.reduce_op(ReduceOp::Max, self.layout(), &max_dims)?;
|
||||||
let op = if self.track_op() {
|
|
||||||
Some(Op::Reduce(self.clone(), ReduceOp::Max, max_dims.to_vec()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let mut dims = self.dims().to_vec();
|
let mut dims = self.dims().to_vec();
|
||||||
for &max_dim in max_dims.iter() {
|
for &max_dim in max_dims.iter() {
|
||||||
dims[max_dim] = 1
|
dims[max_dim] = 1
|
||||||
}
|
}
|
||||||
|
let op = if self.track_op() {
|
||||||
|
Some(Op::Reduce(self.clone(), ReduceOp::Max, dims.to_vec()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
let max = from_storage(storage, dims, op, false);
|
let max = from_storage(storage, dims, op, false);
|
||||||
if keepdim {
|
if keepdim {
|
||||||
Ok(max)
|
Ok(max)
|
||||||
@ -655,15 +655,15 @@ impl Tensor {
|
|||||||
let storage = self
|
let storage = self
|
||||||
.storage()
|
.storage()
|
||||||
.reduce_op(ReduceOp::Min, self.layout(), &min_dims)?;
|
.reduce_op(ReduceOp::Min, self.layout(), &min_dims)?;
|
||||||
let op = if self.track_op() {
|
|
||||||
Some(Op::Reduce(self.clone(), ReduceOp::Min, min_dims.to_vec()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let mut dims = self.dims().to_vec();
|
let mut dims = self.dims().to_vec();
|
||||||
for &min_dim in min_dims.iter() {
|
for &min_dim in min_dims.iter() {
|
||||||
dims[min_dim] = 1
|
dims[min_dim] = 1
|
||||||
}
|
}
|
||||||
|
let op = if self.track_op() {
|
||||||
|
Some(Op::Reduce(self.clone(), ReduceOp::Min, dims.to_vec()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
let min = from_storage(storage, dims, op, false);
|
let min = from_storage(storage, dims, op, false);
|
||||||
if keepdim {
|
if keepdim {
|
||||||
Ok(min)
|
Ok(min)
|
||||||
@ -677,15 +677,15 @@ impl Tensor {
|
|||||||
let storage = self
|
let storage = self
|
||||||
.storage()
|
.storage()
|
||||||
.reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
|
.reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
|
||||||
let op = if self.track_op() {
|
|
||||||
Some(Op::Reduce(self.clone(), ReduceOp::Sum, sum_dims.to_vec()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let mut dims = self.dims().to_vec();
|
let mut dims = self.dims().to_vec();
|
||||||
for &sum_dim in sum_dims.iter() {
|
for &sum_dim in sum_dims.iter() {
|
||||||
dims[sum_dim] = 1
|
dims[sum_dim] = 1
|
||||||
}
|
}
|
||||||
|
let op = if self.track_op() {
|
||||||
|
Some(Op::Reduce(self.clone(), ReduceOp::Sum, dims.to_vec()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
let sum = from_storage(storage, dims, op, false);
|
let sum = from_storage(storage, dims, op, false);
|
||||||
if keepdim {
|
if keepdim {
|
||||||
Ok(sum)
|
Ok(sum)
|
||||||
|
@ -79,7 +79,42 @@ fn grad_descent(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn unary_grad(device: &Device) -> Result<()> {
|
||||||
|
let x = Var::new(&[3f32, 1., 4., 0.15], device)?;
|
||||||
|
let x = x.as_tensor();
|
||||||
|
let y = (x.log()? + 1.)?;
|
||||||
|
let grads = y.backward()?;
|
||||||
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
|
assert_eq!(y.to_vec1::<f32>()?, [2.0986123, 1.0, 2.3862944, -0.89712]);
|
||||||
|
assert_eq!(grad_x.to_vec1::<f32>()?, [0.33333334, 1.0, 0.25, 6.6666665]);
|
||||||
|
let y = x.exp()?;
|
||||||
|
let grads = y.backward()?;
|
||||||
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
|
assert_eq!(
|
||||||
|
y.to_vec1::<f32>()?,
|
||||||
|
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
grad_x.to_vec1::<f32>()?,
|
||||||
|
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
||||||
|
);
|
||||||
|
let y = x.exp()?.sqr()?;
|
||||||
|
let grads = y.backward()?;
|
||||||
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
|
assert_eq!(
|
||||||
|
y.to_vec1::<f32>()?,
|
||||||
|
[403.4288, 7.3890557, 2980.9578, 1.3498588]
|
||||||
|
);
|
||||||
|
// exp(x)^2 = exp(2*x)
|
||||||
|
assert_eq!(
|
||||||
|
grad_x.to_vec1::<f32>()?,
|
||||||
|
[806.8576, 14.778111, 5961.9155, 2.6997175]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
||||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
||||||
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
||||||
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
||||||
|
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
|
||||||
|
@ -3,11 +3,26 @@
|
|||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle::{DType, Var, D};
|
use candle::{DType, Tensor, Var, D};
|
||||||
|
|
||||||
const IMAGE_DIM: usize = 784;
|
const IMAGE_DIM: usize = 784;
|
||||||
const LABELS: usize = 10;
|
const LABELS: usize = 10;
|
||||||
|
|
||||||
|
fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> candle::Result<Tensor> {
|
||||||
|
let d = d.to_index(xs.shape(), "log-softmax")?;
|
||||||
|
let max = xs.max_keepdim(d)?;
|
||||||
|
let diff = xs.broadcast_sub(&max)?;
|
||||||
|
let sum_exp = diff.exp()?.sum_keepdim(d)?;
|
||||||
|
let log_sm = diff.broadcast_sub(&sum_exp.log()?)?;
|
||||||
|
Ok(log_sm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Once the index_select backprop is efficient enough, switch to using this.
|
||||||
|
fn _nll_loss(inp: &Tensor, target: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
let b_sz = target.shape().r1()?;
|
||||||
|
inp.index_select(target, 0)?.sum_all()? / b_sz as f64
|
||||||
|
}
|
||||||
|
|
||||||
pub fn main() -> Result<()> {
|
pub fn main() -> Result<()> {
|
||||||
let dev = candle::Device::cuda_if_available(0)?;
|
let dev = candle::Device::cuda_if_available(0)?;
|
||||||
let m = candle_nn::vision::mnist::load_dir("data")?;
|
let m = candle_nn::vision::mnist::load_dir("data")?;
|
||||||
@ -15,25 +30,50 @@ pub fn main() -> Result<()> {
|
|||||||
println!("train-labels: {:?}", m.train_labels.shape());
|
println!("train-labels: {:?}", m.train_labels.shape());
|
||||||
println!("test-images: {:?}", m.test_images.shape());
|
println!("test-images: {:?}", m.test_images.shape());
|
||||||
println!("test-labels: {:?}", m.test_labels.shape());
|
println!("test-labels: {:?}", m.test_labels.shape());
|
||||||
|
let train_labels = m.train_labels;
|
||||||
|
let train_images = m.train_images;
|
||||||
|
let train_labels = train_labels.to_vec1::<u8>()?;
|
||||||
|
let train_label_mask = train_labels
|
||||||
|
.iter()
|
||||||
|
.flat_map(|l| (0..LABELS).map(|i| f32::from(i == *l as usize)))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let train_label_mask = Tensor::from_vec(train_label_mask, (train_labels.len(), LABELS), &dev)?;
|
||||||
let ws = Var::zeros((IMAGE_DIM, LABELS), DType::F32, &dev)?;
|
let ws = Var::zeros((IMAGE_DIM, LABELS), DType::F32, &dev)?;
|
||||||
let bs = Var::zeros(LABELS, DType::F32, &dev)?;
|
let bs = Var::zeros(LABELS, DType::F32, &dev)?;
|
||||||
let sgd = candle_nn::SGD::new(&[&ws, &bs], 0.1);
|
let sgd = candle_nn::SGD::new(&[&ws, &bs], 3e-1);
|
||||||
|
let test_images = m.test_images;
|
||||||
|
let test_labels = m.test_labels.to_vec1::<u8>()?;
|
||||||
for epoch in 1..200 {
|
for epoch in 1..200 {
|
||||||
let logits = m.train_images.matmul(&ws)?.broadcast_add(&bs)?;
|
let logits = train_images.matmul(&ws)?.broadcast_add(&bs)?;
|
||||||
let loss = logits.softmax(D::Minus1)?;
|
let log_sm = log_softmax(&logits, D::Minus1)?;
|
||||||
// TODO: log_softmax + let loss = loss.nll_loss(&m.train_labels);
|
let loss = (&log_sm * &train_label_mask)?
|
||||||
|
.sum_all()?
|
||||||
|
.affine(-1f64 / train_images.dim(0)? as f64, 0f64)?;
|
||||||
sgd.backward_step(&loss)?;
|
sgd.backward_step(&loss)?;
|
||||||
|
|
||||||
let _test_logits = m.test_images.matmul(&ws)?.broadcast_add(&bs)?;
|
let test_logits = test_images.matmul(&ws)?.broadcast_add(&bs)?;
|
||||||
/* TODO
|
/* TODO: Add argmax so that the following can be computed within candle.
|
||||||
let test_accuracy = test_logits
|
let test_accuracy = test_logits
|
||||||
.argmax(Some(-1), false)
|
.argmax(Some(-1), false)
|
||||||
.eq_tensor(&m.test_labels)
|
.eq_tensor(&test_labels)
|
||||||
.to_kind(Kind::Float)
|
.to_kind(Kind::Float)
|
||||||
.mean(Kind::Float)
|
.mean(Kind::Float)
|
||||||
.double_value(&[]);
|
.double_value(&[]);
|
||||||
*/
|
*/
|
||||||
let test_accuracy = 0.;
|
let test_logits = test_logits.to_vec2::<f32>()?;
|
||||||
|
let sum_ok = test_logits
|
||||||
|
.iter()
|
||||||
|
.zip(test_labels.iter())
|
||||||
|
.map(|(logits, label)| {
|
||||||
|
let arg_max = logits
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.max_by(|(_, v1), (_, v2)| v1.total_cmp(v2))
|
||||||
|
.map(|(idx, _)| idx);
|
||||||
|
f64::from(arg_max == Some(*label as usize))
|
||||||
|
})
|
||||||
|
.sum::<f64>();
|
||||||
|
let test_accuracy = sum_ok / test_labels.len() as f64;
|
||||||
println!(
|
println!(
|
||||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||||
loss.to_scalar::<f32>()?,
|
loss.to_scalar::<f32>()?,
|
||||||
|
Reference in New Issue
Block a user