diff --git a/.github/workflows/book.yml b/.github/workflows/book.yml index 895a68db..bb4d0494 100644 --- a/.github/workflows/book.yml +++ b/.github/workflows/book.yml @@ -24,6 +24,6 @@ jobs: curl -sSL $url | tar -xz --directory=bin echo "$(pwd)/bin" >> $GITHUB_PATH - name: Run tests - run: cd candle-book && mdbook test + run: cd candle-book && cargo build && mdbook test -L ../target/debug/deps/ diff --git a/README.md b/README.md index 5f39d1fc..b6a30c17 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,8 @@ trunk serve --release --public-url /candle-llama2/ --port 8081 And then browse to [http://localhost:8081/candle-llama2](http://localhost:8081/candle-llama2). + + ## Features - Simple syntax, looks and like PyTorch. @@ -60,8 +62,11 @@ And then browse to - Embed user-defined ops/kernels, such as [flash-attention v2](https://github.com/LaurentMazare/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152). + + ## How to use ? + Cheatsheet: | | Using PyTorch | Using Candle | @@ -76,6 +81,8 @@ Cheatsheet: | Saving | `torch.save({"A": A}, "model.bin")` | `tensor.save_safetensors("A", "model.safetensors")?` | | Loading | `weights = torch.load("model.bin")` | TODO (see the examples for now) | + + ## Structure diff --git a/candle-book/src/README.md b/candle-book/src/README.md index e10b99d0..be352dc1 100644 --- a/candle-book/src/README.md +++ b/candle-book/src/README.md @@ -1 +1,6 @@ # Introduction + +{{#include ../../README.md:features}} + + +This book will introduce step by step how to use `candle`. diff --git a/candle-book/src/SUMMARY.md b/candle-book/src/SUMMARY.md index 24e2b25a..ddd6e916 100644 --- a/candle-book/src/SUMMARY.md +++ b/candle-book/src/SUMMARY.md @@ -6,13 +6,13 @@ - [Installation](guide/installation.md) - [Hello World - MNIST](guide/hello_world.md) -- [PyTorch cheatsheet](guide/hello_world.md) +- [PyTorch cheatsheet](guide/cheatsheet.md) # Reference Guide - [Running a model](inference/README.md) - - [Serialization](inference/serialization.md) - [Using the hub](inference/hub.md) + - [Serialization](inference/serialization.md) - [Advanced Cuda usage](inference/cuda/README.md) - [Writing a custom kernel](inference/cuda/writing.md) - [Porting a custom kernel](inference/cuda/porting.md) @@ -24,3 +24,4 @@ - [Training](training/README.md) - [MNIST](training/mnist.md) - [Fine-tuning](training/finetuning.md) +- [Using MKL](advanced/mkl.md) diff --git a/candle-book/src/advanced/mkl.md b/candle-book/src/advanced/mkl.md new file mode 100644 index 00000000..f4dfa8ae --- /dev/null +++ b/candle-book/src/advanced/mkl.md @@ -0,0 +1 @@ +# Using MKL diff --git a/candle-book/src/guide/cheatsheet.md b/candle-book/src/guide/cheatsheet.md new file mode 100644 index 00000000..d0893ee0 --- /dev/null +++ b/candle-book/src/guide/cheatsheet.md @@ -0,0 +1,3 @@ +# Pytorch cheatsheet + +{{#include ../../../README.md:cheatsheet}} diff --git a/candle-book/src/guide/hello_world.md b/candle-book/src/guide/hello_world.md index c370cdd3..b1d24d85 100644 --- a/candle-book/src/guide/hello_world.md +++ b/candle-book/src/guide/hello_world.md @@ -1 +1,195 @@ -# PyTorch cheatsheet +# Hello world! + +We will now create the hello world of the ML world, building a model capable of solving MNIST dataset. + +Open `src/main.rs` and fill in this content: + +```rust +# extern crate candle; +use candle::{DType, Device, Result, Tensor}; + +struct Model { + first: Tensor, + second: Tensor, +} + +impl Model { + fn forward(&self, image: &Tensor) -> Result { + let x = image.matmul(&self.first)?; + let x = x.relu()?; + x.matmul(&self.second) + } +} + +fn main() -> Result<()> { + // Use Device::new_cuda(0)?; to use the GPU. + let device = Device::Cpu; + + let first = Tensor::zeros((784, 100), DType::F32, &device)?; + let second = Tensor::zeros((100, 10), DType::F32, &device)?; + let model = Model { first, second }; + + let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?; + + let digit = model.forward(&dummy_image)?; + println!("Digit {digit:?} digit"); + Ok(()) +} +``` + +Everything should now run with: + +```bash +cargo run --release +``` + +## Using a `Linear` layer. + +Now that we have this, we might want to complexify things a bit, for instance by adding `bias` and creating +the classical `Linear` layer. We can do as such + +```rust +# extern crate candle; +# use candle::{DType, Device, Result, Tensor}; +struct Linear{ + weight: Tensor, + bias: Tensor, +} +impl Linear{ + fn forward(&self, x: &Tensor) -> Result { + let x = x.matmul(&self.weight)?; + x.broadcast_add(&self.bias) + } +} + +struct Model { + first: Linear, + second: Linear, +} + +impl Model { + fn forward(&self, image: &Tensor) -> Result { + let x = self.first.forward(image)?; + let x = x.relu()?; + self.second.forward(&x) + } +} +``` + +This will change the model running code into a new function + +```rust +# extern crate candle; +# use candle::{DType, Device, Result, Tensor}; +# struct Linear{ +# weight: Tensor, +# bias: Tensor, +# } +# impl Linear{ +# fn forward(&self, x: &Tensor) -> Result { +# let x = x.matmul(&self.weight)?; +# x.broadcast_add(&self.bias) +# } +# } +# +# struct Model { +# first: Linear, +# second: Linear, +# } +# +# impl Model { +# fn forward(&self, image: &Tensor) -> Result { +# let x = self.first.forward(image)?; +# let x = x.relu()?; +# self.second.forward(&x) +# } +# } +fn main() -> Result<()> { + // Use Device::new_cuda(0)?; to use the GPU. + // Use Device::Cpu; to use the CPU. + let device = Device::cuda_if_available(0)?; + + // Creating a dummy model + let weight = Tensor::zeros((784, 100), DType::F32, &device)?; + let bias = Tensor::zeros((100, ), DType::F32, &device)?; + let first = Linear{weight, bias}; + let weight = Tensor::zeros((100, 10), DType::F32, &device)?; + let bias = Tensor::zeros((10, ), DType::F32, &device)?; + let second = Linear{weight, bias}; + let model = Model { first, second }; + + let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?; + + // Inference on the model + let digit = model.forward(&dummy_image)?; + println!("Digit {digit:?} digit"); + Ok(()) +} +``` + +Now it works, it is a great way to create your own layers. +But most of the classical layers are already implemented in [candle-nn](https://github.com/LaurentMazare/candle/tree/main/candle-nn). + +## Using `candle_nn`. + +For instance [Linear](https://github.com/LaurentMazare/candle/blob/main/candle-nn/src/linear.rs) is already there. +This Linear is coded with PyTorch layout in mind, to reuse better existing models out there, so it uses the transpose of the weights and not the weights directly. + +So instead we can simplify our example: + +```bash +cargo add --git https://github.com/LaurentMazare/candle.git candle-nn +``` + +And rewrite our examples using it + +```rust +# extern crate candle; +# extern crate candle_nn; +use candle::{DType, Device, Result, Tensor}; +use candle_nn::Linear; + +struct Model { + first: Linear, + second: Linear, +} + +impl Model { + fn forward(&self, image: &Tensor) -> Result { + let x = self.first.forward(image)?; + let x = x.relu()?; + self.second.forward(&x) + } +} + +fn main() -> Result<()> { + // Use Device::new_cuda(0)?; to use the GPU. + let device = Device::Cpu; + + // This has changed (784, 100) -> (100, 784) ! + let weight = Tensor::zeros((100, 784), DType::F32, &device)?; + let bias = Tensor::zeros((100, ), DType::F32, &device)?; + let first = Linear::new(weight, Some(bias)); + let weight = Tensor::zeros((10, 100), DType::F32, &device)?; + let bias = Tensor::zeros((10, ), DType::F32, &device)?; + let second = Linear::new(weight, Some(bias)); + let model = Model { first, second }; + + let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?; + + let digit = model.forward(&dummy_image)?; + println!("Digit {digit:?} digit"); + Ok(()) +} +``` + +Feel free to modify this example to use `Conv2d` to create a classical convnet instead. + + +Now that we have the running dummy code we can get to more advanced topics: + +- [For PyTorch users](./guide/cheatsheet.md) +- [Running existing models](./inference/README.md) +- [Training models](./training/README.md) + + diff --git a/candle-book/src/guide/installation.md b/candle-book/src/guide/installation.md index 25267fe2..c909a5df 100644 --- a/candle-book/src/guide/installation.md +++ b/candle-book/src/guide/installation.md @@ -1 +1,24 @@ # Installation + +Start by creating a new app: + +```bash +cargo new myapp +cd myapp +cargo add --git https://github.com/LaurentMazare/candle.git candle +``` + +At this point, candle will be built **without** CUDA support. +To get CUDA support use the `cuda` feature +```bash +cargo add --git https://github.com/LaurentMazare/candle.git candle --features cuda +``` + +You can check everything works properly: + +```bash +cargo build +``` + + +You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md) diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index b016ead5..a5e21aad 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -41,6 +41,12 @@ impl From for Shape { } } +impl From<(usize,)> for Shape { + fn from(d1: (usize,)) -> Self { + Self(vec![d1.0]) + } +} + impl From<(usize, usize)> for Shape { fn from(d12: (usize, usize)) -> Self { Self(vec![d12.0, d12.1])