mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 01:48:08 +00:00
Updated candle-book: Introduction, Installation, MNIST guide, and added CONTRIBUTING.md (#2897)
* added CONTRIBUTING.md to candle-book * added description to candle-book introduction * Updated formatting and added different features to candle-book installation * mnist guide first draft candle-book * updated mnist guide syntax and grammar for candle-book * changed HelloWorld - Mnist to Tutorial - Mnist in SUMMARY.md * updated intro to mnist guide in candle-book
This commit is contained in:
@ -290,6 +290,8 @@ Cheatsheet:
|
||||
|
||||
### Why should I use Candle?
|
||||
|
||||
<!--- ANCHOR: goals --->
|
||||
|
||||
Candle's core goal is to *make serverless inference possible*. Full machine learning frameworks like PyTorch
|
||||
are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight
|
||||
binaries.
|
||||
@ -299,6 +301,7 @@ and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-fut
|
||||
|
||||
Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers).
|
||||
|
||||
<!--- ANCHOR_END: goals --->
|
||||
|
||||
### Other ML frameworks
|
||||
|
||||
|
13
candle-book/CONTRIBUTING.md
Normal file
13
candle-book/CONTRIBUTING.md
Normal file
@ -0,0 +1,13 @@
|
||||
# Candle Book
|
||||
|
||||
The book uses [mdBook](https://github.com/rust-lang/mdBook) for building.
|
||||
|
||||
## Installation
|
||||
|
||||
To install mdBook, run `cargo install mdbook`. More instructions can be found [here](https://rust-lang.github.io/mdBook/guide/installation.html).
|
||||
|
||||
## Viewing the book
|
||||
|
||||
To view the book, run `mdbook serve --open candle-book`. More instructions can be found [here](https://rust-lang.github.io/mdBook/guide/creating.html).
|
||||
|
||||
The book is built automatically in github CI.
|
@ -1,6 +1,7 @@
|
||||
# Introduction
|
||||
|
||||
{{#include ../../README.md:goals}}
|
||||
|
||||
{{#include ../../README.md:features}}
|
||||
|
||||
|
||||
This book will introduce step by step how to use `candle`.
|
||||
This book will introduce step by step how to use `candle`.
|
@ -5,7 +5,10 @@
|
||||
# User Guide
|
||||
|
||||
- [Installation](guide/installation.md)
|
||||
- [Hello World - MNIST](guide/hello_world.md)
|
||||
- [Tutorial - MNIST](guide/mnist/intro.md)
|
||||
- [Modeling](guide/mnist/modeling.md)
|
||||
- [Training](guide/mnist/training.md)
|
||||
- [Saving And Loading](guide/mnist/saving_loading.md)
|
||||
- [PyTorch cheatsheet](guide/cheatsheet.md)
|
||||
|
||||
# Reference Guide
|
||||
|
@ -1,8 +1,23 @@
|
||||
# Installation
|
||||
|
||||
**With Cuda support**:
|
||||
## 1. Create a new rust app or library
|
||||
|
||||
1. First, make sure that Cuda is correctly installed.
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
```
|
||||
|
||||
## 2. Add the correct candle version
|
||||
|
||||
### Standard
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core
|
||||
```
|
||||
|
||||
### CUDA
|
||||
|
||||
First, make sure that Cuda is correctly installed.
|
||||
- `nvcc --version` should print information about your Cuda compiler driver.
|
||||
- `nvidia-smi --query-gpu=compute_cap --format=csv` should print your GPUs compute capability, e.g. something
|
||||
like:
|
||||
@ -17,43 +32,36 @@ You can also compile the Cuda kernels for a specific compute cap using the
|
||||
|
||||
If any of the above commands errors out, please make sure to update your Cuda version.
|
||||
|
||||
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
|
||||
|
||||
Start by creating a new cargo:
|
||||
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
```
|
||||
|
||||
Make sure to add the `candle-core` crate with the cuda feature:
|
||||
Add the `candle-core` crate with the cuda feature:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"
|
||||
```
|
||||
|
||||
### MKL
|
||||
|
||||
You can also see the `mkl` feature which can get faster inference on CPU.
|
||||
|
||||
Add the `candle-core` crate with the mkl feature:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features "mkl"
|
||||
```
|
||||
|
||||
### Metal
|
||||
|
||||
Metal is exclusive to MacOS.
|
||||
|
||||
Add the `candle-core` crate with the metal feature:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features "metal"
|
||||
```
|
||||
|
||||
## 3. Building
|
||||
|
||||
Run `cargo build` to make sure everything can be correctly built.
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
|
||||
**Without Cuda support**:
|
||||
|
||||
Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) as follows:
|
||||
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core
|
||||
```
|
||||
|
||||
Finally, run `cargo build` to make sure everything can be correctly built.
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
|
||||
**With mkl support**
|
||||
|
||||
You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md)
|
||||
|
17
candle-book/src/guide/mnist/intro.md
Normal file
17
candle-book/src/guide/mnist/intro.md
Normal file
@ -0,0 +1,17 @@
|
||||
# Candle MNIST Tutorial
|
||||
|
||||
## Introduction
|
||||
|
||||
This tutorial provides an introduction to Candle by implementing and training a neural network for MNIST digit classification from scratch.
|
||||
|
||||
Throughout this tutorial, you will learn the basics of:
|
||||
|
||||
- Tensor operations and model construction
|
||||
- Creating and implementing neural network layers
|
||||
- Parameter initialization
|
||||
- Training loop implementation
|
||||
- Saving and loading trained models
|
||||
|
||||
## Getting Started
|
||||
|
||||
Before proceeding, please ensure that you have properly installed Candle by following the instructions in the [Installation](../installation.md) guide.
|
172
candle-book/src/guide/mnist/modeling.md
Normal file
172
candle-book/src/guide/mnist/modeling.md
Normal file
@ -0,0 +1,172 @@
|
||||
# Candle MNIST Tutorial
|
||||
|
||||
## Modeling
|
||||
|
||||
Open `src/main.rs` in your project folder and insert the following code:
|
||||
|
||||
```rust
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
|
||||
struct Model {
|
||||
first: Tensor,
|
||||
second: Tensor,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = image.matmul(&self.first)?;
|
||||
let x = x.relu()?;
|
||||
x.matmul(&self.second)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; to utilize GPU acceleration.
|
||||
let device = Device::Cpu;
|
||||
|
||||
let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
|
||||
let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Execute the program with:
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> Digit Tensor[dims 1, 10; f32] digit
|
||||
```
|
||||
|
||||
Since random inputs are provided, expect an incoherent output.
|
||||
|
||||
## Implementing a `Linear` Layer
|
||||
|
||||
To create a more sophisticated layer type, add a `bias` to the weight to construct the standard `Linear` layer.
|
||||
|
||||
Replace the entire content of `src/main.rs` with:
|
||||
|
||||
```rust
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
|
||||
struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.matmul(&self.weight)?;
|
||||
x.broadcast_add(&self.bias)
|
||||
}
|
||||
}
|
||||
|
||||
struct Model {
|
||||
first: Linear,
|
||||
second: Linear,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = self.first.forward(image)?;
|
||||
let x = x.relu()?;
|
||||
self.second.forward(&x)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; for GPU acceleration.
|
||||
// Use Device::Cpu; for CPU computation.
|
||||
let device = Device::cuda_if_available(0)?;
|
||||
|
||||
// Initialize model parameters
|
||||
let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
|
||||
let first = Linear { weight, bias };
|
||||
let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
|
||||
let second = Linear { weight, bias };
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
|
||||
// Perform inference
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Execute again with:
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> Digit Tensor[dims 1, 10; f32] digit
|
||||
```
|
||||
|
||||
## Utilizing `candle_nn`
|
||||
|
||||
Many classical layers (such as [Linear](https://github.com/huggingface/candle/blob/main/candle-nn/src/linear.rs)) are already implemented in [candle-nn](https://github.com/huggingface/candle/tree/main/candle-nn).
|
||||
|
||||
This `Linear` implementation follows PyTorch conventions for improved compatibility with existing models, utilizing the transpose of weights rather than direct weights.
|
||||
|
||||
Let's simplify our implementation. First, add `candle-nn` as a dependency:
|
||||
|
||||
```bash
|
||||
$ cargo add --git https://github.com/huggingface/candle.git candle-nn
|
||||
```
|
||||
|
||||
Now, replace the entire content of `src/main.rs` with:
|
||||
|
||||
```rust
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use candle_nn::{Linear, Module};
|
||||
|
||||
struct Model {
|
||||
first: Linear,
|
||||
second: Linear,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = self.first.forward(image)?;
|
||||
let x = x.relu()?;
|
||||
self.second.forward(&x)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; for GPU acceleration.
|
||||
let device = Device::Cpu;
|
||||
|
||||
// Note the dimension change: (784, 100) -> (100, 784)
|
||||
let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
|
||||
let first = Linear::new(weight, Some(bias));
|
||||
let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
|
||||
let second = Linear::new(weight, Some(bias));
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Execute the final version:
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> Digit Tensor[dims 1, 10; f32] digit
|
||||
```
|
158
candle-book/src/guide/mnist/saving_loading.md
Normal file
158
candle-book/src/guide/mnist/saving_loading.md
Normal file
@ -0,0 +1,158 @@
|
||||
# Candle MNIST Tutorial
|
||||
|
||||
## Saving and Loading Models
|
||||
|
||||
After training a model, it is useful to save and subsequently load the model parameters. In Candle, this functionality is managed through the `VarMap` data structure, with parameters stored on disk using the [safetensors](https://huggingface.co/docs/safetensors/index) format.
|
||||
|
||||
### Saving Model Parameters
|
||||
|
||||
Let's modify our `training_loop` function to include functionality for saving weights:
|
||||
|
||||
```rust
|
||||
fn training_loop(
|
||||
m: candle_datasets::vision::Dataset,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images.to_device(&dev)?;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
// Initialize a VarMap for trainable parameters
|
||||
let varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||
let model = Model::new(vs.clone())?;
|
||||
|
||||
let learning_rate = 0.05;
|
||||
let epochs = 10;
|
||||
|
||||
// Initialize stochastic gradient descent optimizer
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
for epoch in 1..epochs {
|
||||
// Standard MNIST forward pass
|
||||
let logits = model.forward(&train_images)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
|
||||
// Compute Negative Log Likelihood loss
|
||||
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||
|
||||
// Perform backward pass and update weights
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
// Evaluate model on test set
|
||||
let test_logits = model.forward(&test_images)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
test_accuracy
|
||||
);
|
||||
}
|
||||
|
||||
// Save model weights to disk
|
||||
varmap.save("model_weights.safetensors")?;
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> 1 train loss: 2.40485 test acc: 0.11%
|
||||
> 2 train loss: 2.34161 test acc: 0.14%
|
||||
> 3 train loss: 2.28841 test acc: 0.17%
|
||||
> 4 train loss: 2.24158 test acc: 0.19%
|
||||
> 5 train loss: 2.19898 test acc: 0.23%
|
||||
> 6 train loss: 2.15927 test acc: 0.26%
|
||||
> 7 train loss: 2.12161 test acc: 0.29%
|
||||
> 8 train loss: 2.08549 test acc: 0.32%
|
||||
> 9 train loss: 2.05053 test acc: 0.35%
|
||||
```
|
||||
|
||||
### Loading Model Parameters
|
||||
|
||||
Now that we have saved our model parameters, we can modify the code to load them. The primary change required is to make the `varmap` variable mutable:
|
||||
|
||||
```rust
|
||||
fn training_loop(
|
||||
m: candle_datasets::vision::Dataset,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images.to_device(&dev)?;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
// Create a mutable VarMap for trainable parameters
|
||||
let mut varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||
let model = Model::new(vs.clone())?;
|
||||
|
||||
// Load pre-trained weights from file
|
||||
varmap.load("model_weights.safetensors")?;
|
||||
|
||||
let learning_rate = 0.05;
|
||||
let epochs = 10;
|
||||
|
||||
// Initialize stochastic gradient descent optimizer
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
for epoch in 1..epochs {
|
||||
// Standard MNIST forward pass
|
||||
let logits = model.forward(&train_images)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
|
||||
// Compute Negative Log Likelihood loss
|
||||
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||
|
||||
// Perform backward pass and update weights
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
// Evaluate model on test set
|
||||
let test_logits = model.forward(&test_images)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
test_accuracy
|
||||
);
|
||||
}
|
||||
|
||||
// Save updated weights back to disk
|
||||
varmap.save("model_weights.safetensors")?;
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> 1 train loss: 2.01645 test acc: 0.38%
|
||||
> 2 train loss: 1.98300 test acc: 0.41%
|
||||
> 3 train loss: 1.95008 test acc: 0.44%
|
||||
> 4 train loss: 1.91754 test acc: 0.47%
|
||||
> 5 train loss: 1.88534 test acc: 0.50%
|
||||
> 6 train loss: 1.85349 test acc: 0.53%
|
||||
> 7 train loss: 1.82198 test acc: 0.56%
|
||||
> 8 train loss: 1.79077 test acc: 0.59%
|
||||
> 9 train loss: 1.75989 test acc: 0.61%
|
||||
```
|
||||
|
||||
Note that loading the weights will fail if the specified file does not exist or is incompatible with the current model architecture. Implementing file existence checks and appropriate error handling is left to the user.
|
134
candle-book/src/guide/mnist/training.md
Normal file
134
candle-book/src/guide/mnist/training.md
Normal file
@ -0,0 +1,134 @@
|
||||
# Candle MNIST Tutorial
|
||||
|
||||
## Training Implementation
|
||||
|
||||
First, let's create a utility function `make_linear` that accepts a `VarBuilder` and returns an initialized linear layer. The `VarBuilder` constructs a `VarMap`, which is the data structure that stores our trainable parameters.
|
||||
|
||||
```rust
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use candle_nn::{Linear, Module, VarBuilder, VarMap};
|
||||
|
||||
fn make_linear(vs: VarBuilder, in_dim: usize, out_dim: usize) -> Result<Linear> {
|
||||
let ws = vs.get_with_hints(
|
||||
(out_dim, in_dim),
|
||||
"weight",
|
||||
candle_nn::init::DEFAULT_KAIMING_NORMAL,
|
||||
)?;
|
||||
let bound = 1. / (in_dim as f64).sqrt();
|
||||
let bs = vs.get_with_hints(
|
||||
out_dim,
|
||||
"bias",
|
||||
candle_nn::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
},
|
||||
)?;
|
||||
Ok(Linear::new(ws, Some(bs)))
|
||||
}
|
||||
```
|
||||
|
||||
Next, let's implement a `new` method for our model class to accept a `VarBuilder` and initialize the model. We use `VarBuilder::pp` to "push prefix" so that the parameter names are organized hierarchically: the first layer weights as `first.weight` and `first.bias`, and the second layer weights as `second.weight` and `second.bias`.
|
||||
|
||||
```rust
|
||||
impl Model {
|
||||
fn new(vs: VarBuilder) -> Result<Self> {
|
||||
const IMAGE_DIM: usize = 784;
|
||||
const HIDDEN_DIM: usize = 100;
|
||||
const LABELS: usize = 10;
|
||||
|
||||
let first = make_linear(vs.pp("first"), IMAGE_DIM, HIDDEN_DIM)?;
|
||||
let second = make_linear(vs.pp("second"), HIDDEN_DIM, LABELS)?;
|
||||
|
||||
Ok(Self { first, second })
|
||||
}
|
||||
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = self.first.forward(image)?;
|
||||
let x = x.relu()?;
|
||||
self.second.forward(&x)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Now, let's add the `candle-datasets` package to our project to access the MNIST dataset:
|
||||
|
||||
```bash
|
||||
$ cargo add --git https://github.com/huggingface/candle.git candle-datasets
|
||||
```
|
||||
|
||||
With the dataset available, we can implement our training loop:
|
||||
|
||||
```rust
|
||||
use candle_core::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{loss, ops, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||
|
||||
fn training_loop(
|
||||
m: candle_datasets::vision::Dataset,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images.to_device(&dev)?;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
// Initialize a VarMap to store trainable parameters
|
||||
let varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||
let model = Model::new(vs.clone())?;
|
||||
|
||||
let learning_rate = 0.05;
|
||||
let epochs = 10;
|
||||
|
||||
// Initialize a stochastic gradient descent optimizer to update parameters
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
for epoch in 1..epochs {
|
||||
// Perform forward pass on MNIST data
|
||||
let logits = model.forward(&train_images)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
|
||||
// Compute Negative Log Likelihood loss
|
||||
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||
|
||||
// Perform backward pass and update weights
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
// Evaluate model on test set
|
||||
let test_logits = model.forward(&test_images)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
test_accuracy
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Finally, let's implement our main function:
|
||||
|
||||
```rust
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let m = candle_datasets::vision::mnist::load()?;
|
||||
return training_loop(m);
|
||||
}
|
||||
```
|
||||
|
||||
Let's execute the training process:
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> 1 train loss: 2.35449 test acc: 0.12%
|
||||
> 2 train loss: 2.30760 test acc: 0.15%
|
||||
> ...
|
||||
```
|
Reference in New Issue
Block a user