mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
285 lines
8.4 KiB
Rust
285 lines
8.4 KiB
Rust
// This should reach 91.5% accuracy.
|
|
#[cfg(feature = "mkl")]
|
|
extern crate intel_mkl_src;
|
|
|
|
#[cfg(feature = "accelerate")]
|
|
extern crate accelerate_src;
|
|
|
|
use clap::{Parser, ValueEnum};
|
|
use rand::prelude::*;
|
|
use rand::rng;
|
|
|
|
use candle::{DType, Result, Tensor, D};
|
|
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
|
|
|
|
const IMAGE_DIM: usize = 784;
|
|
const LABELS: usize = 10;
|
|
|
|
fn linear_z(in_dim: usize, out_dim: usize, vs: VarBuilder) -> Result<Linear> {
|
|
let ws = vs.get_with_hints((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
|
|
let bs = vs.get_with_hints(out_dim, "bias", candle_nn::init::ZERO)?;
|
|
Ok(Linear::new(ws, Some(bs)))
|
|
}
|
|
|
|
trait Model: Sized {
|
|
fn new(vs: VarBuilder) -> Result<Self>;
|
|
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
|
}
|
|
|
|
struct LinearModel {
|
|
linear: Linear,
|
|
}
|
|
|
|
impl Model for LinearModel {
|
|
fn new(vs: VarBuilder) -> Result<Self> {
|
|
let linear = linear_z(IMAGE_DIM, LABELS, vs)?;
|
|
Ok(Self { linear })
|
|
}
|
|
|
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
self.linear.forward(xs)
|
|
}
|
|
}
|
|
|
|
struct Mlp {
|
|
ln1: Linear,
|
|
ln2: Linear,
|
|
}
|
|
|
|
impl Model for Mlp {
|
|
fn new(vs: VarBuilder) -> Result<Self> {
|
|
let ln1 = candle_nn::linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
|
|
let ln2 = candle_nn::linear(100, LABELS, vs.pp("ln2"))?;
|
|
Ok(Self { ln1, ln2 })
|
|
}
|
|
|
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
let xs = self.ln1.forward(xs)?;
|
|
let xs = xs.relu()?;
|
|
self.ln2.forward(&xs)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct ConvNet {
|
|
conv1: Conv2d,
|
|
conv2: Conv2d,
|
|
fc1: Linear,
|
|
fc2: Linear,
|
|
dropout: candle_nn::Dropout,
|
|
}
|
|
|
|
impl ConvNet {
|
|
fn new(vs: VarBuilder) -> Result<Self> {
|
|
let conv1 = candle_nn::conv2d(1, 32, 5, Default::default(), vs.pp("c1"))?;
|
|
let conv2 = candle_nn::conv2d(32, 64, 5, Default::default(), vs.pp("c2"))?;
|
|
let fc1 = candle_nn::linear(1024, 1024, vs.pp("fc1"))?;
|
|
let fc2 = candle_nn::linear(1024, LABELS, vs.pp("fc2"))?;
|
|
let dropout = candle_nn::Dropout::new(0.5);
|
|
Ok(Self {
|
|
conv1,
|
|
conv2,
|
|
fc1,
|
|
fc2,
|
|
dropout,
|
|
})
|
|
}
|
|
|
|
fn forward(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
|
|
let (b_sz, _img_dim) = xs.dims2()?;
|
|
let xs = xs
|
|
.reshape((b_sz, 1, 28, 28))?
|
|
.apply(&self.conv1)?
|
|
.max_pool2d(2)?
|
|
.apply(&self.conv2)?
|
|
.max_pool2d(2)?
|
|
.flatten_from(1)?
|
|
.apply(&self.fc1)?
|
|
.relu()?;
|
|
self.dropout.forward_t(&xs, train)?.apply(&self.fc2)
|
|
}
|
|
}
|
|
|
|
struct TrainingArgs {
|
|
learning_rate: f64,
|
|
load: Option<String>,
|
|
save: Option<String>,
|
|
epochs: usize,
|
|
}
|
|
|
|
fn training_loop_cnn(
|
|
m: candle_datasets::vision::Dataset,
|
|
args: &TrainingArgs,
|
|
) -> anyhow::Result<()> {
|
|
const BSIZE: usize = 64;
|
|
|
|
let dev = candle::Device::cuda_if_available(0)?;
|
|
|
|
let train_labels = m.train_labels;
|
|
let train_images = m.train_images.to_device(&dev)?;
|
|
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
|
|
|
let mut varmap = VarMap::new();
|
|
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
|
let model = ConvNet::new(vs.clone())?;
|
|
|
|
if let Some(load) = &args.load {
|
|
println!("loading weights from {load}");
|
|
varmap.load(load)?
|
|
}
|
|
|
|
let adamw_params = candle_nn::ParamsAdamW {
|
|
lr: args.learning_rate,
|
|
..Default::default()
|
|
};
|
|
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), adamw_params)?;
|
|
let test_images = m.test_images.to_device(&dev)?;
|
|
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
|
let n_batches = train_images.dim(0)? / BSIZE;
|
|
let mut batch_idxs = (0..n_batches).collect::<Vec<usize>>();
|
|
for epoch in 1..args.epochs {
|
|
let mut sum_loss = 0f32;
|
|
batch_idxs.shuffle(&mut rng());
|
|
for batch_idx in batch_idxs.iter() {
|
|
let train_images = train_images.narrow(0, batch_idx * BSIZE, BSIZE)?;
|
|
let train_labels = train_labels.narrow(0, batch_idx * BSIZE, BSIZE)?;
|
|
let logits = model.forward(&train_images, true)?;
|
|
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
|
let loss = loss::nll(&log_sm, &train_labels)?;
|
|
opt.backward_step(&loss)?;
|
|
sum_loss += loss.to_vec0::<f32>()?;
|
|
}
|
|
let avg_loss = sum_loss / n_batches as f32;
|
|
|
|
let test_logits = model.forward(&test_images, false)?;
|
|
let sum_ok = test_logits
|
|
.argmax(D::Minus1)?
|
|
.eq(&test_labels)?
|
|
.to_dtype(DType::F32)?
|
|
.sum_all()?
|
|
.to_scalar::<f32>()?;
|
|
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
|
println!(
|
|
"{epoch:4} train loss {:8.5} test acc: {:5.2}%",
|
|
avg_loss,
|
|
100. * test_accuracy
|
|
);
|
|
}
|
|
if let Some(save) = &args.save {
|
|
println!("saving trained weights in {save}");
|
|
varmap.save(save)?
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn training_loop<M: Model>(
|
|
m: candle_datasets::vision::Dataset,
|
|
args: &TrainingArgs,
|
|
) -> anyhow::Result<()> {
|
|
let dev = candle::Device::cuda_if_available(0)?;
|
|
|
|
let train_labels = m.train_labels;
|
|
let train_images = m.train_images.to_device(&dev)?;
|
|
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
|
|
|
let mut varmap = VarMap::new();
|
|
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
|
let model = M::new(vs.clone())?;
|
|
|
|
if let Some(load) = &args.load {
|
|
println!("loading weights from {load}");
|
|
varmap.load(load)?
|
|
}
|
|
|
|
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate)?;
|
|
let test_images = m.test_images.to_device(&dev)?;
|
|
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
|
for epoch in 1..args.epochs {
|
|
let logits = model.forward(&train_images)?;
|
|
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
|
let loss = loss::nll(&log_sm, &train_labels)?;
|
|
sgd.backward_step(&loss)?;
|
|
|
|
let test_logits = model.forward(&test_images)?;
|
|
let sum_ok = test_logits
|
|
.argmax(D::Minus1)?
|
|
.eq(&test_labels)?
|
|
.to_dtype(DType::F32)?
|
|
.sum_all()?
|
|
.to_scalar::<f32>()?;
|
|
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
|
println!(
|
|
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
|
loss.to_scalar::<f32>()?,
|
|
100. * test_accuracy
|
|
);
|
|
}
|
|
if let Some(save) = &args.save {
|
|
println!("saving trained weights in {save}");
|
|
varmap.save(save)?
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
#[derive(ValueEnum, Clone)]
|
|
enum WhichModel {
|
|
Linear,
|
|
Mlp,
|
|
Cnn,
|
|
}
|
|
|
|
#[derive(Parser)]
|
|
struct Args {
|
|
#[clap(value_enum, default_value_t = WhichModel::Linear)]
|
|
model: WhichModel,
|
|
|
|
#[arg(long)]
|
|
learning_rate: Option<f64>,
|
|
|
|
#[arg(long, default_value_t = 200)]
|
|
epochs: usize,
|
|
|
|
/// The file where to save the trained weights, in safetensors format.
|
|
#[arg(long)]
|
|
save: Option<String>,
|
|
|
|
/// The file where to load the trained weights from, in safetensors format.
|
|
#[arg(long)]
|
|
load: Option<String>,
|
|
|
|
/// The directory where to load the dataset from, in ubyte format.
|
|
#[arg(long)]
|
|
local_mnist: Option<String>,
|
|
}
|
|
|
|
pub fn main() -> anyhow::Result<()> {
|
|
let args = Args::parse();
|
|
// Load the dataset
|
|
let m = if let Some(directory) = args.local_mnist {
|
|
candle_datasets::vision::mnist::load_dir(directory)?
|
|
} else {
|
|
candle_datasets::vision::mnist::load()?
|
|
};
|
|
println!("train-images: {:?}", m.train_images.shape());
|
|
println!("train-labels: {:?}", m.train_labels.shape());
|
|
println!("test-images: {:?}", m.test_images.shape());
|
|
println!("test-labels: {:?}", m.test_labels.shape());
|
|
|
|
let default_learning_rate = match args.model {
|
|
WhichModel::Linear => 1.,
|
|
WhichModel::Mlp => 0.05,
|
|
WhichModel::Cnn => 0.001,
|
|
};
|
|
let training_args = TrainingArgs {
|
|
epochs: args.epochs,
|
|
learning_rate: args.learning_rate.unwrap_or(default_learning_rate),
|
|
load: args.load,
|
|
save: args.save,
|
|
};
|
|
match args.model {
|
|
WhichModel::Linear => training_loop::<LinearModel>(m, &training_args),
|
|
WhichModel::Mlp => training_loop::<Mlp>(m, &training_args),
|
|
WhichModel::Cnn => training_loop_cnn(m, &training_args),
|
|
}
|
|
}
|