mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add a convnet training example. (#661)
* Add a convnet example. * Dataset fix. * Randomize batches.
This commit is contained in:
@ -101,10 +101,10 @@ pub fn load() -> Result<crate::vision::Dataset> {
|
|||||||
);
|
);
|
||||||
let repo = api.repo(repo);
|
let repo = api.repo(repo);
|
||||||
let test_parquet_filename = repo
|
let test_parquet_filename = repo
|
||||||
.get("mnist/mnist-test.parquet")
|
.get("mnist/test/0000.parquet")
|
||||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||||
let train_parquet_filename = repo
|
let train_parquet_filename = repo
|
||||||
.get("mnist/mnist-train.parquet")
|
.get("mnist/train/0000.parquet")
|
||||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
|
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
|
||||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||||
|
@ -6,9 +6,10 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
use rand::prelude::*;
|
||||||
|
|
||||||
use candle::{DType, Result, Tensor, D};
|
use candle::{DType, Result, Tensor, D};
|
||||||
use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap};
|
use candle_nn::{loss, ops, Conv2d, Linear, Module, VarBuilder, VarMap};
|
||||||
|
|
||||||
const IMAGE_DIM: usize = 784;
|
const IMAGE_DIM: usize = 784;
|
||||||
const LABELS: usize = 10;
|
const LABELS: usize = 10;
|
||||||
@ -58,6 +59,40 @@ impl Model for Mlp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ConvNet {
|
||||||
|
conv1: Conv2d,
|
||||||
|
conv2: Conv2d,
|
||||||
|
fc1: Linear,
|
||||||
|
fc2: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model for ConvNet {
|
||||||
|
fn new(vs: VarBuilder) -> Result<Self> {
|
||||||
|
let conv1 = candle_nn::conv2d(1, 32, 5, Default::default(), vs.pp("c1"))?;
|
||||||
|
let conv2 = candle_nn::conv2d(32, 64, 5, Default::default(), vs.pp("c2"))?;
|
||||||
|
let fc1 = candle_nn::linear(1024, 1024, vs.pp("fc1"))?;
|
||||||
|
let fc2 = candle_nn::linear(1024, LABELS, vs.pp("fc2"))?;
|
||||||
|
Ok(Self {
|
||||||
|
conv1,
|
||||||
|
conv2,
|
||||||
|
fc1,
|
||||||
|
fc2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let (b_sz, _img_dim) = xs.dims2()?;
|
||||||
|
let xs = xs.reshape((b_sz, 1, 28, 28))?;
|
||||||
|
let xs = self.conv1.forward(&xs)?.max_pool2d((2, 2), (2, 2))?;
|
||||||
|
let xs = self.conv2.forward(&xs)?.max_pool2d((2, 2), (2, 2))?;
|
||||||
|
let xs = xs.flatten_from(1)?;
|
||||||
|
let xs = self.fc1.forward(&xs)?;
|
||||||
|
let xs = xs.relu()?;
|
||||||
|
self.fc2.forward(&xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct TrainingArgs {
|
struct TrainingArgs {
|
||||||
learning_rate: f64,
|
learning_rate: f64,
|
||||||
load: Option<String>,
|
load: Option<String>,
|
||||||
@ -65,6 +100,71 @@ struct TrainingArgs {
|
|||||||
epochs: usize,
|
epochs: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn training_loop_cnn(
|
||||||
|
m: candle_datasets::vision::Dataset,
|
||||||
|
args: &TrainingArgs,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
const BSIZE: usize = 64;
|
||||||
|
|
||||||
|
let dev = candle::Device::cuda_if_available(0)?;
|
||||||
|
|
||||||
|
let train_labels = m.train_labels;
|
||||||
|
let train_images = m.train_images.to_device(&dev)?;
|
||||||
|
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||||
|
|
||||||
|
let mut varmap = VarMap::new();
|
||||||
|
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||||
|
let model = ConvNet::new(vs.clone())?;
|
||||||
|
|
||||||
|
if let Some(load) = &args.load {
|
||||||
|
println!("loading weights from {load}");
|
||||||
|
varmap.load(load)?
|
||||||
|
}
|
||||||
|
|
||||||
|
let adamw_params = candle_nn::ParamsAdamW {
|
||||||
|
lr: args.learning_rate,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), adamw_params)?;
|
||||||
|
let test_images = m.test_images.to_device(&dev)?;
|
||||||
|
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||||
|
let n_batches = train_images.dim(0)? / BSIZE;
|
||||||
|
let mut batch_idxs = (0..n_batches).collect::<Vec<usize>>();
|
||||||
|
for epoch in 1..args.epochs {
|
||||||
|
let mut sum_loss = 0f32;
|
||||||
|
batch_idxs.shuffle(&mut thread_rng());
|
||||||
|
for batch_idx in batch_idxs.iter() {
|
||||||
|
let train_images = train_images.narrow(0, batch_idx * BSIZE, BSIZE)?;
|
||||||
|
let train_labels = train_labels.narrow(0, batch_idx * BSIZE, BSIZE)?;
|
||||||
|
let logits = model.forward(&train_images)?;
|
||||||
|
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||||
|
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||||
|
opt.backward_step(&loss)?;
|
||||||
|
sum_loss += loss.to_vec0::<f32>()?;
|
||||||
|
}
|
||||||
|
let avg_loss = sum_loss / n_batches as f32;
|
||||||
|
|
||||||
|
let test_logits = model.forward(&test_images)?;
|
||||||
|
let sum_ok = test_logits
|
||||||
|
.argmax(D::Minus1)?
|
||||||
|
.eq(&test_labels)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.sum_all()?
|
||||||
|
.to_scalar::<f32>()?;
|
||||||
|
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||||
|
println!(
|
||||||
|
"{epoch:4} train loss {:8.5} test acc: {:5.2}%",
|
||||||
|
avg_loss,
|
||||||
|
100. * test_accuracy
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if let Some(save) = &args.save {
|
||||||
|
println!("saving trained weights in {save}");
|
||||||
|
varmap.save(save)?
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn training_loop<M: Model>(
|
fn training_loop<M: Model>(
|
||||||
m: candle_datasets::vision::Dataset,
|
m: candle_datasets::vision::Dataset,
|
||||||
args: &TrainingArgs,
|
args: &TrainingArgs,
|
||||||
@ -118,6 +218,7 @@ fn training_loop<M: Model>(
|
|||||||
enum WhichModel {
|
enum WhichModel {
|
||||||
Linear,
|
Linear,
|
||||||
Mlp,
|
Mlp,
|
||||||
|
Cnn,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
@ -160,6 +261,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let default_learning_rate = match args.model {
|
let default_learning_rate = match args.model {
|
||||||
WhichModel::Linear => 1.,
|
WhichModel::Linear => 1.,
|
||||||
WhichModel::Mlp => 0.05,
|
WhichModel::Mlp => 0.05,
|
||||||
|
WhichModel::Cnn => 0.001,
|
||||||
};
|
};
|
||||||
let training_args = TrainingArgs {
|
let training_args = TrainingArgs {
|
||||||
epochs: args.epochs,
|
epochs: args.epochs,
|
||||||
@ -170,5 +272,6 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::Linear => training_loop::<LinearModel>(m, &training_args),
|
WhichModel::Linear => training_loop::<LinearModel>(m, &training_args),
|
||||||
WhichModel::Mlp => training_loop::<Mlp>(m, &training_args),
|
WhichModel::Mlp => training_loop::<Mlp>(m, &training_args),
|
||||||
|
WhichModel::Cnn => training_loop_cnn(m, &training_args),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user