From 1f2604269350cbba2409e54d413d68718ef36858 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 22 Jul 2023 14:25:11 +0200 Subject: [PATCH] Move some shared functions to the nn module. (#221) --- .../examples/simple-training/main.rs | 23 ++++--------------- candle-nn/src/lib.rs | 2 ++ candle-nn/src/loss.rs | 8 +++++++ candle-nn/src/ops.rs | 10 ++++++++ 4 files changed, 24 insertions(+), 19 deletions(-) create mode 100644 candle-nn/src/loss.rs create mode 100644 candle-nn/src/ops.rs diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs index 60f2281b..edec2e92 100644 --- a/candle-examples/examples/simple-training/main.rs +++ b/candle-examples/examples/simple-training/main.rs @@ -3,27 +3,12 @@ extern crate intel_mkl_src; use anyhow::Result; -use candle::{DType, Tensor, Var, D}; +use candle::{DType, Var, D}; +use candle_nn::{loss, ops}; const IMAGE_DIM: usize = 784; const LABELS: usize = 10; -fn log_softmax(xs: &Tensor, d: D) -> candle::Result { - let d = d.to_index(xs.shape(), "log-softmax")?; - let max = xs.max_keepdim(d)?; - let diff = xs.broadcast_sub(&max)?; - let sum_exp = diff.exp()?.sum_keepdim(d)?; - let log_sm = diff.broadcast_sub(&sum_exp.log()?)?; - Ok(log_sm) -} - -fn nll_loss(inp: &Tensor, target: &Tensor) -> candle::Result { - let b_sz = target.dim(0)?; - inp.gather(target, 1)? - .sum_all()? - .affine(-1f64 / b_sz as f64, 0.) -} - pub fn main() -> Result<()> { let dev = candle::Device::cuda_if_available(0)?; let m = candle_nn::vision::mnist::load_dir("data")?; @@ -41,8 +26,8 @@ pub fn main() -> Result<()> { let test_labels = m.test_labels.to_dtype(DType::U32)?; for epoch in 1..200 { let logits = train_images.matmul(&ws)?.broadcast_add(&bs)?; - let log_sm = log_softmax(&logits, D::Minus1)?; - let loss = nll_loss(&log_sm, &train_labels)?; + let log_sm = ops::log_softmax(&logits, D::Minus1)?; + let loss = loss::nll(&log_sm, &train_labels)?; sgd.backward_step(&loss)?; let test_logits = test_images.matmul(&ws)?.broadcast_add(&bs)?; diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 2738e95f..db01b067 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -6,6 +6,8 @@ pub mod embedding; pub mod init; pub mod layer_norm; pub mod linear; +pub mod loss; +pub mod ops; pub mod optim; pub mod var_builder; pub mod vision; diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs new file mode 100644 index 00000000..d388af6c --- /dev/null +++ b/candle-nn/src/loss.rs @@ -0,0 +1,8 @@ +use candle::{Result, Tensor}; + +pub fn nll(inp: &Tensor, target: &Tensor) -> Result { + let b_sz = target.dim(0)?; + inp.gather(target, 1)? + .sum_all()? + .affine(-1f64 / b_sz as f64, 0.) +} diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs new file mode 100644 index 00000000..88196aa7 --- /dev/null +++ b/candle-nn/src/ops.rs @@ -0,0 +1,10 @@ +use candle::{Result, Tensor}; + +pub fn log_softmax(xs: &Tensor, d: D) -> Result { + let d = d.to_index(xs.shape(), "log-softmax")?; + let max = xs.max_keepdim(d)?; + let diff = xs.broadcast_sub(&max)?; + let sum_exp = diff.exp()?.sum_keepdim(d)?; + let log_sm = diff.broadcast_sub(&sum_exp.log()?)?; + Ok(log_sm) +}