mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Merge pull request #258 from LaurentMazare/start_book
Starting the book.
This commit is contained in:
2
.github/workflows/book.yml
vendored
2
.github/workflows/book.yml
vendored
@ -24,6 +24,6 @@ jobs:
|
|||||||
curl -sSL $url | tar -xz --directory=bin
|
curl -sSL $url | tar -xz --directory=bin
|
||||||
echo "$(pwd)/bin" >> $GITHUB_PATH
|
echo "$(pwd)/bin" >> $GITHUB_PATH
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: cd candle-book && mdbook test
|
run: cd candle-book && cargo build && mdbook test -L ../target/debug/deps/
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,6 +48,8 @@ trunk serve --release --public-url /candle-llama2/ --port 8081
|
|||||||
And then browse to
|
And then browse to
|
||||||
[http://localhost:8081/candle-llama2](http://localhost:8081/candle-llama2).
|
[http://localhost:8081/candle-llama2](http://localhost:8081/candle-llama2).
|
||||||
|
|
||||||
|
<!--- ANCHOR: features --->
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- Simple syntax, looks and like PyTorch.
|
- Simple syntax, looks and like PyTorch.
|
||||||
@ -60,8 +62,11 @@ And then browse to
|
|||||||
- Embed user-defined ops/kernels, such as [flash-attention
|
- Embed user-defined ops/kernels, such as [flash-attention
|
||||||
v2](https://github.com/LaurentMazare/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
|
v2](https://github.com/LaurentMazare/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
|
||||||
|
|
||||||
|
<!--- ANCHOR_END: features --->
|
||||||
|
|
||||||
## How to use ?
|
## How to use ?
|
||||||
|
|
||||||
|
<!--- ANCHOR: cheatsheet --->
|
||||||
Cheatsheet:
|
Cheatsheet:
|
||||||
|
|
||||||
| | Using PyTorch | Using Candle |
|
| | Using PyTorch | Using Candle |
|
||||||
@ -76,6 +81,8 @@ Cheatsheet:
|
|||||||
| Saving | `torch.save({"A": A}, "model.bin")` | `tensor.save_safetensors("A", "model.safetensors")?` |
|
| Saving | `torch.save({"A": A}, "model.bin")` | `tensor.save_safetensors("A", "model.safetensors")?` |
|
||||||
| Loading | `weights = torch.load("model.bin")` | TODO (see the examples for now) |
|
| Loading | `weights = torch.load("model.bin")` | TODO (see the examples for now) |
|
||||||
|
|
||||||
|
<!--- ANCHOR_END: cheatsheet --->
|
||||||
|
|
||||||
|
|
||||||
## Structure
|
## Structure
|
||||||
|
|
||||||
|
@ -1 +1,6 @@
|
|||||||
# Introduction
|
# Introduction
|
||||||
|
|
||||||
|
{{#include ../../README.md:features}}
|
||||||
|
|
||||||
|
|
||||||
|
This book will introduce step by step how to use `candle`.
|
||||||
|
@ -6,13 +6,13 @@
|
|||||||
|
|
||||||
- [Installation](guide/installation.md)
|
- [Installation](guide/installation.md)
|
||||||
- [Hello World - MNIST](guide/hello_world.md)
|
- [Hello World - MNIST](guide/hello_world.md)
|
||||||
- [PyTorch cheatsheet](guide/hello_world.md)
|
- [PyTorch cheatsheet](guide/cheatsheet.md)
|
||||||
|
|
||||||
# Reference Guide
|
# Reference Guide
|
||||||
|
|
||||||
- [Running a model](inference/README.md)
|
- [Running a model](inference/README.md)
|
||||||
- [Serialization](inference/serialization.md)
|
|
||||||
- [Using the hub](inference/hub.md)
|
- [Using the hub](inference/hub.md)
|
||||||
|
- [Serialization](inference/serialization.md)
|
||||||
- [Advanced Cuda usage](inference/cuda/README.md)
|
- [Advanced Cuda usage](inference/cuda/README.md)
|
||||||
- [Writing a custom kernel](inference/cuda/writing.md)
|
- [Writing a custom kernel](inference/cuda/writing.md)
|
||||||
- [Porting a custom kernel](inference/cuda/porting.md)
|
- [Porting a custom kernel](inference/cuda/porting.md)
|
||||||
@ -24,3 +24,4 @@
|
|||||||
- [Training](training/README.md)
|
- [Training](training/README.md)
|
||||||
- [MNIST](training/mnist.md)
|
- [MNIST](training/mnist.md)
|
||||||
- [Fine-tuning](training/finetuning.md)
|
- [Fine-tuning](training/finetuning.md)
|
||||||
|
- [Using MKL](advanced/mkl.md)
|
||||||
|
1
candle-book/src/advanced/mkl.md
Normal file
1
candle-book/src/advanced/mkl.md
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Using MKL
|
3
candle-book/src/guide/cheatsheet.md
Normal file
3
candle-book/src/guide/cheatsheet.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# Pytorch cheatsheet
|
||||||
|
|
||||||
|
{{#include ../../../README.md:cheatsheet}}
|
@ -1 +1,195 @@
|
|||||||
# PyTorch cheatsheet
|
# Hello world!
|
||||||
|
|
||||||
|
We will now create the hello world of the ML world, building a model capable of solving MNIST dataset.
|
||||||
|
|
||||||
|
Open `src/main.rs` and fill in this content:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
# extern crate candle;
|
||||||
|
use candle::{DType, Device, Result, Tensor};
|
||||||
|
|
||||||
|
struct Model {
|
||||||
|
first: Tensor,
|
||||||
|
second: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||||
|
let x = image.matmul(&self.first)?;
|
||||||
|
let x = x.relu()?;
|
||||||
|
x.matmul(&self.second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
// Use Device::new_cuda(0)?; to use the GPU.
|
||||||
|
let device = Device::Cpu;
|
||||||
|
|
||||||
|
let first = Tensor::zeros((784, 100), DType::F32, &device)?;
|
||||||
|
let second = Tensor::zeros((100, 10), DType::F32, &device)?;
|
||||||
|
let model = Model { first, second };
|
||||||
|
|
||||||
|
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||||
|
|
||||||
|
let digit = model.forward(&dummy_image)?;
|
||||||
|
println!("Digit {digit:?} digit");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Everything should now run with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --release
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using a `Linear` layer.
|
||||||
|
|
||||||
|
Now that we have this, we might want to complexify things a bit, for instance by adding `bias` and creating
|
||||||
|
the classical `Linear` layer. We can do as such
|
||||||
|
|
||||||
|
```rust
|
||||||
|
# extern crate candle;
|
||||||
|
# use candle::{DType, Device, Result, Tensor};
|
||||||
|
struct Linear{
|
||||||
|
weight: Tensor,
|
||||||
|
bias: Tensor,
|
||||||
|
}
|
||||||
|
impl Linear{
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let x = x.matmul(&self.weight)?;
|
||||||
|
x.broadcast_add(&self.bias)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Model {
|
||||||
|
first: Linear,
|
||||||
|
second: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||||
|
let x = self.first.forward(image)?;
|
||||||
|
let x = x.relu()?;
|
||||||
|
self.second.forward(&x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This will change the model running code into a new function
|
||||||
|
|
||||||
|
```rust
|
||||||
|
# extern crate candle;
|
||||||
|
# use candle::{DType, Device, Result, Tensor};
|
||||||
|
# struct Linear{
|
||||||
|
# weight: Tensor,
|
||||||
|
# bias: Tensor,
|
||||||
|
# }
|
||||||
|
# impl Linear{
|
||||||
|
# fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
# let x = x.matmul(&self.weight)?;
|
||||||
|
# x.broadcast_add(&self.bias)
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# struct Model {
|
||||||
|
# first: Linear,
|
||||||
|
# second: Linear,
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# impl Model {
|
||||||
|
# fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||||
|
# let x = self.first.forward(image)?;
|
||||||
|
# let x = x.relu()?;
|
||||||
|
# self.second.forward(&x)
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
// Use Device::new_cuda(0)?; to use the GPU.
|
||||||
|
// Use Device::Cpu; to use the CPU.
|
||||||
|
let device = Device::cuda_if_available(0)?;
|
||||||
|
|
||||||
|
// Creating a dummy model
|
||||||
|
let weight = Tensor::zeros((784, 100), DType::F32, &device)?;
|
||||||
|
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
|
||||||
|
let first = Linear{weight, bias};
|
||||||
|
let weight = Tensor::zeros((100, 10), DType::F32, &device)?;
|
||||||
|
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
|
||||||
|
let second = Linear{weight, bias};
|
||||||
|
let model = Model { first, second };
|
||||||
|
|
||||||
|
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||||
|
|
||||||
|
// Inference on the model
|
||||||
|
let digit = model.forward(&dummy_image)?;
|
||||||
|
println!("Digit {digit:?} digit");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Now it works, it is a great way to create your own layers.
|
||||||
|
But most of the classical layers are already implemented in [candle-nn](https://github.com/LaurentMazare/candle/tree/main/candle-nn).
|
||||||
|
|
||||||
|
## Using `candle_nn`.
|
||||||
|
|
||||||
|
For instance [Linear](https://github.com/LaurentMazare/candle/blob/main/candle-nn/src/linear.rs) is already there.
|
||||||
|
This Linear is coded with PyTorch layout in mind, to reuse better existing models out there, so it uses the transpose of the weights and not the weights directly.
|
||||||
|
|
||||||
|
So instead we can simplify our example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo add --git https://github.com/LaurentMazare/candle.git candle-nn
|
||||||
|
```
|
||||||
|
|
||||||
|
And rewrite our examples using it
|
||||||
|
|
||||||
|
```rust
|
||||||
|
# extern crate candle;
|
||||||
|
# extern crate candle_nn;
|
||||||
|
use candle::{DType, Device, Result, Tensor};
|
||||||
|
use candle_nn::Linear;
|
||||||
|
|
||||||
|
struct Model {
|
||||||
|
first: Linear,
|
||||||
|
second: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||||
|
let x = self.first.forward(image)?;
|
||||||
|
let x = x.relu()?;
|
||||||
|
self.second.forward(&x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
// Use Device::new_cuda(0)?; to use the GPU.
|
||||||
|
let device = Device::Cpu;
|
||||||
|
|
||||||
|
// This has changed (784, 100) -> (100, 784) !
|
||||||
|
let weight = Tensor::zeros((100, 784), DType::F32, &device)?;
|
||||||
|
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
|
||||||
|
let first = Linear::new(weight, Some(bias));
|
||||||
|
let weight = Tensor::zeros((10, 100), DType::F32, &device)?;
|
||||||
|
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
|
||||||
|
let second = Linear::new(weight, Some(bias));
|
||||||
|
let model = Model { first, second };
|
||||||
|
|
||||||
|
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||||
|
|
||||||
|
let digit = model.forward(&dummy_image)?;
|
||||||
|
println!("Digit {digit:?} digit");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Feel free to modify this example to use `Conv2d` to create a classical convnet instead.
|
||||||
|
|
||||||
|
|
||||||
|
Now that we have the running dummy code we can get to more advanced topics:
|
||||||
|
|
||||||
|
- [For PyTorch users](./guide/cheatsheet.md)
|
||||||
|
- [Running existing models](./inference/README.md)
|
||||||
|
- [Training models](./training/README.md)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1 +1,24 @@
|
|||||||
# Installation
|
# Installation
|
||||||
|
|
||||||
|
Start by creating a new app:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo new myapp
|
||||||
|
cd myapp
|
||||||
|
cargo add --git https://github.com/LaurentMazare/candle.git candle
|
||||||
|
```
|
||||||
|
|
||||||
|
At this point, candle will be built **without** CUDA support.
|
||||||
|
To get CUDA support use the `cuda` feature
|
||||||
|
```bash
|
||||||
|
cargo add --git https://github.com/LaurentMazare/candle.git candle --features cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
You can check everything works properly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo build
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md)
|
||||||
|
@ -41,6 +41,12 @@ impl From<usize> for Shape {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<(usize,)> for Shape {
|
||||||
|
fn from(d1: (usize,)) -> Self {
|
||||||
|
Self(vec![d1.0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<(usize, usize)> for Shape {
|
impl From<(usize, usize)> for Shape {
|
||||||
fn from(d12: (usize, usize)) -> Self {
|
fn from(d12: (usize, usize)) -> Self {
|
||||||
Self(vec![d12.0, d12.1])
|
Self(vec![d12.0, d12.1])
|
||||||
|
Reference in New Issue
Block a user