diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 958b07cf..cc4ffd49 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -2,6 +2,19 @@ use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp}; use crate::{Error, Result, Tensor, TensorId}; use std::collections::HashMap; +// arg has been reduced to node via reduce_dims, expand it back to arg. +// This has to handle keepdims. +fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result { + if arg.rank() == node.rank() { + // keepdim = true + node.broadcast_as(arg.shape()) + } else { + // keepdim = false + // first expand the reduced dims. + node.reshape(reduced_dims)?.broadcast_as(arg.shape()) + } +} + impl Tensor { /// Return all the nodes that lead to this value in a topologically sorted vec, the first /// elements having dependencies on the latter ones, e.g. the first element if any is the @@ -145,8 +158,26 @@ impl Tensor { *f_sum_grad = f_sum_grad.add(&f_grad)?; } Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?, - Op::IndexSelect(_lhs, _rhs, _) => { - Err(Error::BackwardNotSupported { op: "index-select" })? + Op::IndexSelect(arg, indexes, dim) => { + let dim = *dim; + let sum_grad = grads.or_insert(arg)?; + // TODO: This is very very very inefficient, have some dedicated kernel for this. + let indexes = indexes.to_vec1::()?; + for (dst_index, src_index) in indexes.iter().enumerate() { + let src_index = *src_index as usize; + let dst_grad_for_index = grad.narrow(dim, dst_index, 1)?; + let mut pre_dims = arg.dims().to_vec(); + pre_dims[dim] = src_index; + let pre_zeros = + Tensor::zeros(pre_dims, sum_grad.dtype(), sum_grad.device())?; + let mut post_dims = arg.dims().to_vec(); + post_dims[dim] = post_dims[dim] - src_index - 1; + let post_zeros = + Tensor::zeros(post_dims, sum_grad.dtype(), sum_grad.device())?; + let src_grad = + Tensor::cat(&[pre_zeros, dst_grad_for_index, post_zeros], dim)?; + *sum_grad = sum_grad.add(&src_grad)?; + } } Op::Embedding(_lhs, _rhs) => { Err(Error::BackwardNotSupported { op: "embedding" })? @@ -189,20 +220,32 @@ impl Tensor { } } - let arg_grad = grad.sum(sum_dims.as_slice())?; + let mut arg_grad = grad.sum_keepdim(sum_dims.as_slice())?; + for _i in 0..left_dims { + arg_grad = arg_grad.squeeze(0)? + } let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.broadcast_add(&arg_grad)? + *sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?; } - Op::Reduce(arg, ReduceOp::Sum, _) => { + Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => { + let grad = broadcast_back(arg, &grad, reduced_dims)?; let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.broadcast_add(&grad)? + *sum_grad = sum_grad.add(&grad)?; } Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }), - Op::Reduce(_args, ReduceOp::Max, _) => { - Err(Error::BackwardNotSupported { op: "max" })? + Op::Reduce(arg, ReduceOp::Max, reduced_dims) => { + let node = broadcast_back(arg, node, reduced_dims)?; + let grad = broadcast_back(arg, &grad, reduced_dims)?; + let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?; } - Op::Reduce(_args, ReduceOp::Min, _) => { - Err(Error::BackwardNotSupported { op: "min" })? + Op::Reduce(arg, ReduceOp::Min, reduced_dims) => { + let node = broadcast_back(arg, node, reduced_dims)?; + let grad = broadcast_back(arg, &grad, reduced_dims)?; + let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?; } Op::ToDType(arg) => { let sum_grad = grads.or_insert(arg)?; @@ -215,7 +258,7 @@ impl Tensor { } Op::Unary(arg, UnaryOp::Log) => { let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&(&grad * *node)?)? + *sum_grad = sum_grad.add(&(grad / arg)?)? } Op::Unary(arg, UnaryOp::Sin) => { let sum_grad = grads.or_insert(arg)?; @@ -228,7 +271,7 @@ impl Tensor { Op::Unary(_, UnaryOp::Abs) => Err(Error::BackwardNotSupported { op: "abs" })?, Op::Unary(arg, UnaryOp::Exp) => { let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&(&grad / arg)?)? + *sum_grad = sum_grad.add(&(&grad * *node)?)? } Op::Unary(arg, UnaryOp::Neg) => { let sum_grad = grads.or_insert(arg)?; diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index a1f942d4..7760e2c7 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -52,7 +52,7 @@ mod mkl; pub mod npy; mod op; pub mod safetensors; -mod shape; +pub mod shape; mod storage; mod strided_index; mod tensor; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 0fe56840..4686e57e 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -48,6 +48,7 @@ pub(crate) enum Op { Binary(Tensor, Tensor, BinaryOp), Unary(Tensor, UnaryOp), Cmp(Tensor, CmpOp), + // The third argument is the reduced shape with `keepdim=true`. Reduce(Tensor, ReduceOp, Vec), Matmul(Tensor, Tensor), Embedding(Tensor, Tensor), diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 28659061..f72404df 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -633,15 +633,15 @@ impl Tensor { let storage = self .storage() .reduce_op(ReduceOp::Max, self.layout(), &max_dims)?; - let op = if self.track_op() { - Some(Op::Reduce(self.clone(), ReduceOp::Max, max_dims.to_vec())) - } else { - None - }; let mut dims = self.dims().to_vec(); for &max_dim in max_dims.iter() { dims[max_dim] = 1 } + let op = if self.track_op() { + Some(Op::Reduce(self.clone(), ReduceOp::Max, dims.to_vec())) + } else { + None + }; let max = from_storage(storage, dims, op, false); if keepdim { Ok(max) @@ -655,15 +655,15 @@ impl Tensor { let storage = self .storage() .reduce_op(ReduceOp::Min, self.layout(), &min_dims)?; - let op = if self.track_op() { - Some(Op::Reduce(self.clone(), ReduceOp::Min, min_dims.to_vec())) - } else { - None - }; let mut dims = self.dims().to_vec(); for &min_dim in min_dims.iter() { dims[min_dim] = 1 } + let op = if self.track_op() { + Some(Op::Reduce(self.clone(), ReduceOp::Min, dims.to_vec())) + } else { + None + }; let min = from_storage(storage, dims, op, false); if keepdim { Ok(min) @@ -677,15 +677,15 @@ impl Tensor { let storage = self .storage() .reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?; - let op = if self.track_op() { - Some(Op::Reduce(self.clone(), ReduceOp::Sum, sum_dims.to_vec())) - } else { - None - }; let mut dims = self.dims().to_vec(); for &sum_dim in sum_dims.iter() { dims[sum_dim] = 1 } + let op = if self.track_op() { + Some(Op::Reduce(self.clone(), ReduceOp::Sum, dims.to_vec())) + } else { + None + }; let sum = from_storage(storage, dims, op, false); if keepdim { Ok(sum) diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index a7f6725d..0ceab1de 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -79,7 +79,42 @@ fn grad_descent(device: &Device) -> Result<()> { Ok(()) } +fn unary_grad(device: &Device) -> Result<()> { + let x = Var::new(&[3f32, 1., 4., 0.15], device)?; + let x = x.as_tensor(); + let y = (x.log()? + 1.)?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_vec1::()?, [2.0986123, 1.0, 2.3862944, -0.89712]); + assert_eq!(grad_x.to_vec1::()?, [0.33333334, 1.0, 0.25, 6.6666665]); + let y = x.exp()?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!( + y.to_vec1::()?, + [20.085537, 2.7182817, 54.59815, 1.1618342] + ); + assert_eq!( + grad_x.to_vec1::()?, + [20.085537, 2.7182817, 54.59815, 1.1618342] + ); + let y = x.exp()?.sqr()?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!( + y.to_vec1::()?, + [403.4288, 7.3890557, 2980.9578, 1.3498588] + ); + // exp(x)^2 = exp(2*x) + assert_eq!( + grad_x.to_vec1::()?, + [806.8576, 14.778111, 5961.9155, 2.6997175] + ); + Ok(()) +} + test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu); test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu); test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu); test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu); +test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu); diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs index df67f741..bf7385ac 100644 --- a/candle-examples/examples/simple-training/main.rs +++ b/candle-examples/examples/simple-training/main.rs @@ -3,11 +3,26 @@ extern crate intel_mkl_src; use anyhow::Result; -use candle::{DType, Var, D}; +use candle::{DType, Tensor, Var, D}; const IMAGE_DIM: usize = 784; const LABELS: usize = 10; +fn log_softmax(xs: &Tensor, d: D) -> candle::Result { + let d = d.to_index(xs.shape(), "log-softmax")?; + let max = xs.max_keepdim(d)?; + let diff = xs.broadcast_sub(&max)?; + let sum_exp = diff.exp()?.sum_keepdim(d)?; + let log_sm = diff.broadcast_sub(&sum_exp.log()?)?; + Ok(log_sm) +} + +// TODO: Once the index_select backprop is efficient enough, switch to using this. +fn _nll_loss(inp: &Tensor, target: &Tensor) -> candle::Result { + let b_sz = target.shape().r1()?; + inp.index_select(target, 0)?.sum_all()? / b_sz as f64 +} + pub fn main() -> Result<()> { let dev = candle::Device::cuda_if_available(0)?; let m = candle_nn::vision::mnist::load_dir("data")?; @@ -15,25 +30,50 @@ pub fn main() -> Result<()> { println!("train-labels: {:?}", m.train_labels.shape()); println!("test-images: {:?}", m.test_images.shape()); println!("test-labels: {:?}", m.test_labels.shape()); + let train_labels = m.train_labels; + let train_images = m.train_images; + let train_labels = train_labels.to_vec1::()?; + let train_label_mask = train_labels + .iter() + .flat_map(|l| (0..LABELS).map(|i| f32::from(i == *l as usize))) + .collect::>(); + let train_label_mask = Tensor::from_vec(train_label_mask, (train_labels.len(), LABELS), &dev)?; let ws = Var::zeros((IMAGE_DIM, LABELS), DType::F32, &dev)?; let bs = Var::zeros(LABELS, DType::F32, &dev)?; - let sgd = candle_nn::SGD::new(&[&ws, &bs], 0.1); + let sgd = candle_nn::SGD::new(&[&ws, &bs], 3e-1); + let test_images = m.test_images; + let test_labels = m.test_labels.to_vec1::()?; for epoch in 1..200 { - let logits = m.train_images.matmul(&ws)?.broadcast_add(&bs)?; - let loss = logits.softmax(D::Minus1)?; - // TODO: log_softmax + let loss = loss.nll_loss(&m.train_labels); + let logits = train_images.matmul(&ws)?.broadcast_add(&bs)?; + let log_sm = log_softmax(&logits, D::Minus1)?; + let loss = (&log_sm * &train_label_mask)? + .sum_all()? + .affine(-1f64 / train_images.dim(0)? as f64, 0f64)?; sgd.backward_step(&loss)?; - let _test_logits = m.test_images.matmul(&ws)?.broadcast_add(&bs)?; - /* TODO + let test_logits = test_images.matmul(&ws)?.broadcast_add(&bs)?; + /* TODO: Add argmax so that the following can be computed within candle. let test_accuracy = test_logits .argmax(Some(-1), false) - .eq_tensor(&m.test_labels) + .eq_tensor(&test_labels) .to_kind(Kind::Float) .mean(Kind::Float) .double_value(&[]); */ - let test_accuracy = 0.; + let test_logits = test_logits.to_vec2::()?; + let sum_ok = test_logits + .iter() + .zip(test_labels.iter()) + .map(|(logits, label)| { + let arg_max = logits + .iter() + .enumerate() + .max_by(|(_, v1), (_, v2)| v1.total_cmp(v2)) + .map(|(idx, _)| idx); + f64::from(arg_max == Some(*label as usize)) + }) + .sum::(); + let test_accuracy = sum_ok / test_labels.len() as f64; println!( "{epoch:4} train loss: {:8.5} test acc: {:5.2}%", loss.to_scalar::()?,