mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
9 Commits
0.9.0-alph
...
metal-prec
Author | SHA1 | Date | |
---|---|---|---|
3b24f8f302 | |||
9954981327 | |||
7f0f83a7c1 | |||
76e565c4ab | |||
e4e7b0b2da | |||
b01ebbad8a | |||
1d1d6d4fe6 | |||
2653002f29 | |||
a52b76ae82 |
26
Cargo.toml
26
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.9.0-alpha.2"
|
||||
version = "0.9.0-alpha.4"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.2" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.2" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.2" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.2" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.2" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.2" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.2" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.2" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.4" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.4" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.4" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.4" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.4" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.4" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.4" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.15.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
cudarc = { version = "0.16.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.4.1"
|
||||
@ -70,9 +70,9 @@ tokenizers = { version = "0.21.0", default-features = false }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
ug = "0.3.1"
|
||||
ug-cuda = "0.3.1"
|
||||
ug-metal = "0.3.1"
|
||||
ug = "0.4.0"
|
||||
ug-cuda = "0.4.0"
|
||||
ug-metal = "0.4.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "1.1.1", default-features = false }
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
|
@ -290,6 +290,8 @@ Cheatsheet:
|
||||
|
||||
### Why should I use Candle?
|
||||
|
||||
<!--- ANCHOR: goals --->
|
||||
|
||||
Candle's core goal is to *make serverless inference possible*. Full machine learning frameworks like PyTorch
|
||||
are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight
|
||||
binaries.
|
||||
@ -299,6 +301,7 @@ and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-fut
|
||||
|
||||
Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers).
|
||||
|
||||
<!--- ANCHOR_END: goals --->
|
||||
|
||||
### Other ML frameworks
|
||||
|
||||
|
13
candle-book/CONTRIBUTING.md
Normal file
13
candle-book/CONTRIBUTING.md
Normal file
@ -0,0 +1,13 @@
|
||||
# Candle Book
|
||||
|
||||
The book uses [mdBook](https://github.com/rust-lang/mdBook) for building.
|
||||
|
||||
## Installation
|
||||
|
||||
To install mdBook, run `cargo install mdbook`. More instructions can be found [here](https://rust-lang.github.io/mdBook/guide/installation.html).
|
||||
|
||||
## Viewing the book
|
||||
|
||||
To view the book, run `mdbook serve --open candle-book`. More instructions can be found [here](https://rust-lang.github.io/mdBook/guide/creating.html).
|
||||
|
||||
The book is built automatically in github CI.
|
@ -1,6 +1,7 @@
|
||||
# Introduction
|
||||
|
||||
{{#include ../../README.md:goals}}
|
||||
|
||||
{{#include ../../README.md:features}}
|
||||
|
||||
|
||||
This book will introduce step by step how to use `candle`.
|
||||
This book will introduce step by step how to use `candle`.
|
@ -5,7 +5,10 @@
|
||||
# User Guide
|
||||
|
||||
- [Installation](guide/installation.md)
|
||||
- [Hello World - MNIST](guide/hello_world.md)
|
||||
- [Tutorial - MNIST](guide/mnist/intro.md)
|
||||
- [Modeling](guide/mnist/modeling.md)
|
||||
- [Training](guide/mnist/training.md)
|
||||
- [Saving And Loading](guide/mnist/saving_loading.md)
|
||||
- [PyTorch cheatsheet](guide/cheatsheet.md)
|
||||
|
||||
# Reference Guide
|
||||
|
@ -1,8 +1,23 @@
|
||||
# Installation
|
||||
|
||||
**With Cuda support**:
|
||||
## 1. Create a new rust app or library
|
||||
|
||||
1. First, make sure that Cuda is correctly installed.
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
```
|
||||
|
||||
## 2. Add the correct candle version
|
||||
|
||||
### Standard
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core
|
||||
```
|
||||
|
||||
### CUDA
|
||||
|
||||
First, make sure that Cuda is correctly installed.
|
||||
- `nvcc --version` should print information about your Cuda compiler driver.
|
||||
- `nvidia-smi --query-gpu=compute_cap --format=csv` should print your GPUs compute capability, e.g. something
|
||||
like:
|
||||
@ -17,43 +32,36 @@ You can also compile the Cuda kernels for a specific compute cap using the
|
||||
|
||||
If any of the above commands errors out, please make sure to update your Cuda version.
|
||||
|
||||
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
|
||||
|
||||
Start by creating a new cargo:
|
||||
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
```
|
||||
|
||||
Make sure to add the `candle-core` crate with the cuda feature:
|
||||
Add the `candle-core` crate with the cuda feature:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"
|
||||
```
|
||||
|
||||
### MKL
|
||||
|
||||
You can also see the `mkl` feature which can get faster inference on CPU.
|
||||
|
||||
Add the `candle-core` crate with the mkl feature:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features "mkl"
|
||||
```
|
||||
|
||||
### Metal
|
||||
|
||||
Metal is exclusive to MacOS.
|
||||
|
||||
Add the `candle-core` crate with the metal feature:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features "metal"
|
||||
```
|
||||
|
||||
## 3. Building
|
||||
|
||||
Run `cargo build` to make sure everything can be correctly built.
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
|
||||
**Without Cuda support**:
|
||||
|
||||
Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) as follows:
|
||||
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core
|
||||
```
|
||||
|
||||
Finally, run `cargo build` to make sure everything can be correctly built.
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
|
||||
**With mkl support**
|
||||
|
||||
You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md)
|
||||
|
17
candle-book/src/guide/mnist/intro.md
Normal file
17
candle-book/src/guide/mnist/intro.md
Normal file
@ -0,0 +1,17 @@
|
||||
# Candle MNIST Tutorial
|
||||
|
||||
## Introduction
|
||||
|
||||
This tutorial provides an introduction to Candle by implementing and training a neural network for MNIST digit classification from scratch.
|
||||
|
||||
Throughout this tutorial, you will learn the basics of:
|
||||
|
||||
- Tensor operations and model construction
|
||||
- Creating and implementing neural network layers
|
||||
- Parameter initialization
|
||||
- Training loop implementation
|
||||
- Saving and loading trained models
|
||||
|
||||
## Getting Started
|
||||
|
||||
Before proceeding, please ensure that you have properly installed Candle by following the instructions in the [Installation](../installation.md) guide.
|
172
candle-book/src/guide/mnist/modeling.md
Normal file
172
candle-book/src/guide/mnist/modeling.md
Normal file
@ -0,0 +1,172 @@
|
||||
# Candle MNIST Tutorial
|
||||
|
||||
## Modeling
|
||||
|
||||
Open `src/main.rs` in your project folder and insert the following code:
|
||||
|
||||
```rust
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
|
||||
struct Model {
|
||||
first: Tensor,
|
||||
second: Tensor,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = image.matmul(&self.first)?;
|
||||
let x = x.relu()?;
|
||||
x.matmul(&self.second)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; to utilize GPU acceleration.
|
||||
let device = Device::Cpu;
|
||||
|
||||
let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
|
||||
let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Execute the program with:
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> Digit Tensor[dims 1, 10; f32] digit
|
||||
```
|
||||
|
||||
Since random inputs are provided, expect an incoherent output.
|
||||
|
||||
## Implementing a `Linear` Layer
|
||||
|
||||
To create a more sophisticated layer type, add a `bias` to the weight to construct the standard `Linear` layer.
|
||||
|
||||
Replace the entire content of `src/main.rs` with:
|
||||
|
||||
```rust
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
|
||||
struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.matmul(&self.weight)?;
|
||||
x.broadcast_add(&self.bias)
|
||||
}
|
||||
}
|
||||
|
||||
struct Model {
|
||||
first: Linear,
|
||||
second: Linear,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = self.first.forward(image)?;
|
||||
let x = x.relu()?;
|
||||
self.second.forward(&x)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; for GPU acceleration.
|
||||
// Use Device::Cpu; for CPU computation.
|
||||
let device = Device::cuda_if_available(0)?;
|
||||
|
||||
// Initialize model parameters
|
||||
let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
|
||||
let first = Linear { weight, bias };
|
||||
let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
|
||||
let second = Linear { weight, bias };
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
|
||||
// Perform inference
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Execute again with:
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> Digit Tensor[dims 1, 10; f32] digit
|
||||
```
|
||||
|
||||
## Utilizing `candle_nn`
|
||||
|
||||
Many classical layers (such as [Linear](https://github.com/huggingface/candle/blob/main/candle-nn/src/linear.rs)) are already implemented in [candle-nn](https://github.com/huggingface/candle/tree/main/candle-nn).
|
||||
|
||||
This `Linear` implementation follows PyTorch conventions for improved compatibility with existing models, utilizing the transpose of weights rather than direct weights.
|
||||
|
||||
Let's simplify our implementation. First, add `candle-nn` as a dependency:
|
||||
|
||||
```bash
|
||||
$ cargo add --git https://github.com/huggingface/candle.git candle-nn
|
||||
```
|
||||
|
||||
Now, replace the entire content of `src/main.rs` with:
|
||||
|
||||
```rust
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use candle_nn::{Linear, Module};
|
||||
|
||||
struct Model {
|
||||
first: Linear,
|
||||
second: Linear,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = self.first.forward(image)?;
|
||||
let x = x.relu()?;
|
||||
self.second.forward(&x)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; for GPU acceleration.
|
||||
let device = Device::Cpu;
|
||||
|
||||
// Note the dimension change: (784, 100) -> (100, 784)
|
||||
let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
|
||||
let first = Linear::new(weight, Some(bias));
|
||||
let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
|
||||
let second = Linear::new(weight, Some(bias));
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Execute the final version:
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> Digit Tensor[dims 1, 10; f32] digit
|
||||
```
|
158
candle-book/src/guide/mnist/saving_loading.md
Normal file
158
candle-book/src/guide/mnist/saving_loading.md
Normal file
@ -0,0 +1,158 @@
|
||||
# Candle MNIST Tutorial
|
||||
|
||||
## Saving and Loading Models
|
||||
|
||||
After training a model, it is useful to save and subsequently load the model parameters. In Candle, this functionality is managed through the `VarMap` data structure, with parameters stored on disk using the [safetensors](https://huggingface.co/docs/safetensors/index) format.
|
||||
|
||||
### Saving Model Parameters
|
||||
|
||||
Let's modify our `training_loop` function to include functionality for saving weights:
|
||||
|
||||
```rust
|
||||
fn training_loop(
|
||||
m: candle_datasets::vision::Dataset,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images.to_device(&dev)?;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
// Initialize a VarMap for trainable parameters
|
||||
let varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||
let model = Model::new(vs.clone())?;
|
||||
|
||||
let learning_rate = 0.05;
|
||||
let epochs = 10;
|
||||
|
||||
// Initialize stochastic gradient descent optimizer
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
for epoch in 1..epochs {
|
||||
// Standard MNIST forward pass
|
||||
let logits = model.forward(&train_images)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
|
||||
// Compute Negative Log Likelihood loss
|
||||
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||
|
||||
// Perform backward pass and update weights
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
// Evaluate model on test set
|
||||
let test_logits = model.forward(&test_images)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
test_accuracy
|
||||
);
|
||||
}
|
||||
|
||||
// Save model weights to disk
|
||||
varmap.save("model_weights.safetensors")?;
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> 1 train loss: 2.40485 test acc: 0.11%
|
||||
> 2 train loss: 2.34161 test acc: 0.14%
|
||||
> 3 train loss: 2.28841 test acc: 0.17%
|
||||
> 4 train loss: 2.24158 test acc: 0.19%
|
||||
> 5 train loss: 2.19898 test acc: 0.23%
|
||||
> 6 train loss: 2.15927 test acc: 0.26%
|
||||
> 7 train loss: 2.12161 test acc: 0.29%
|
||||
> 8 train loss: 2.08549 test acc: 0.32%
|
||||
> 9 train loss: 2.05053 test acc: 0.35%
|
||||
```
|
||||
|
||||
### Loading Model Parameters
|
||||
|
||||
Now that we have saved our model parameters, we can modify the code to load them. The primary change required is to make the `varmap` variable mutable:
|
||||
|
||||
```rust
|
||||
fn training_loop(
|
||||
m: candle_datasets::vision::Dataset,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images.to_device(&dev)?;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
// Create a mutable VarMap for trainable parameters
|
||||
let mut varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||
let model = Model::new(vs.clone())?;
|
||||
|
||||
// Load pre-trained weights from file
|
||||
varmap.load("model_weights.safetensors")?;
|
||||
|
||||
let learning_rate = 0.05;
|
||||
let epochs = 10;
|
||||
|
||||
// Initialize stochastic gradient descent optimizer
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
for epoch in 1..epochs {
|
||||
// Standard MNIST forward pass
|
||||
let logits = model.forward(&train_images)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
|
||||
// Compute Negative Log Likelihood loss
|
||||
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||
|
||||
// Perform backward pass and update weights
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
// Evaluate model on test set
|
||||
let test_logits = model.forward(&test_images)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
test_accuracy
|
||||
);
|
||||
}
|
||||
|
||||
// Save updated weights back to disk
|
||||
varmap.save("model_weights.safetensors")?;
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> 1 train loss: 2.01645 test acc: 0.38%
|
||||
> 2 train loss: 1.98300 test acc: 0.41%
|
||||
> 3 train loss: 1.95008 test acc: 0.44%
|
||||
> 4 train loss: 1.91754 test acc: 0.47%
|
||||
> 5 train loss: 1.88534 test acc: 0.50%
|
||||
> 6 train loss: 1.85349 test acc: 0.53%
|
||||
> 7 train loss: 1.82198 test acc: 0.56%
|
||||
> 8 train loss: 1.79077 test acc: 0.59%
|
||||
> 9 train loss: 1.75989 test acc: 0.61%
|
||||
```
|
||||
|
||||
Note that loading the weights will fail if the specified file does not exist or is incompatible with the current model architecture. Implementing file existence checks and appropriate error handling is left to the user.
|
134
candle-book/src/guide/mnist/training.md
Normal file
134
candle-book/src/guide/mnist/training.md
Normal file
@ -0,0 +1,134 @@
|
||||
# Candle MNIST Tutorial
|
||||
|
||||
## Training Implementation
|
||||
|
||||
First, let's create a utility function `make_linear` that accepts a `VarBuilder` and returns an initialized linear layer. The `VarBuilder` constructs a `VarMap`, which is the data structure that stores our trainable parameters.
|
||||
|
||||
```rust
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use candle_nn::{Linear, Module, VarBuilder, VarMap};
|
||||
|
||||
fn make_linear(vs: VarBuilder, in_dim: usize, out_dim: usize) -> Result<Linear> {
|
||||
let ws = vs.get_with_hints(
|
||||
(out_dim, in_dim),
|
||||
"weight",
|
||||
candle_nn::init::DEFAULT_KAIMING_NORMAL,
|
||||
)?;
|
||||
let bound = 1. / (in_dim as f64).sqrt();
|
||||
let bs = vs.get_with_hints(
|
||||
out_dim,
|
||||
"bias",
|
||||
candle_nn::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
},
|
||||
)?;
|
||||
Ok(Linear::new(ws, Some(bs)))
|
||||
}
|
||||
```
|
||||
|
||||
Next, let's implement a `new` method for our model class to accept a `VarBuilder` and initialize the model. We use `VarBuilder::pp` to "push prefix" so that the parameter names are organized hierarchically: the first layer weights as `first.weight` and `first.bias`, and the second layer weights as `second.weight` and `second.bias`.
|
||||
|
||||
```rust
|
||||
impl Model {
|
||||
fn new(vs: VarBuilder) -> Result<Self> {
|
||||
const IMAGE_DIM: usize = 784;
|
||||
const HIDDEN_DIM: usize = 100;
|
||||
const LABELS: usize = 10;
|
||||
|
||||
let first = make_linear(vs.pp("first"), IMAGE_DIM, HIDDEN_DIM)?;
|
||||
let second = make_linear(vs.pp("second"), HIDDEN_DIM, LABELS)?;
|
||||
|
||||
Ok(Self { first, second })
|
||||
}
|
||||
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = self.first.forward(image)?;
|
||||
let x = x.relu()?;
|
||||
self.second.forward(&x)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Now, let's add the `candle-datasets` package to our project to access the MNIST dataset:
|
||||
|
||||
```bash
|
||||
$ cargo add --git https://github.com/huggingface/candle.git candle-datasets
|
||||
```
|
||||
|
||||
With the dataset available, we can implement our training loop:
|
||||
|
||||
```rust
|
||||
use candle_core::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{loss, ops, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||
|
||||
fn training_loop(
|
||||
m: candle_datasets::vision::Dataset,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images.to_device(&dev)?;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
// Initialize a VarMap to store trainable parameters
|
||||
let varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||
let model = Model::new(vs.clone())?;
|
||||
|
||||
let learning_rate = 0.05;
|
||||
let epochs = 10;
|
||||
|
||||
// Initialize a stochastic gradient descent optimizer to update parameters
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
for epoch in 1..epochs {
|
||||
// Perform forward pass on MNIST data
|
||||
let logits = model.forward(&train_images)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
|
||||
// Compute Negative Log Likelihood loss
|
||||
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||
|
||||
// Perform backward pass and update weights
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
// Evaluate model on test set
|
||||
let test_logits = model.forward(&test_images)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
test_accuracy
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Finally, let's implement our main function:
|
||||
|
||||
```rust
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let m = candle_datasets::vision::mnist::load()?;
|
||||
return training_loop(m);
|
||||
}
|
||||
```
|
||||
|
||||
Let's execute the training process:
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> 1 train loss: 2.35449 test acc: 0.12%
|
||||
> 2 train loss: 2.30760 test acc: 0.15%
|
||||
> ...
|
||||
```
|
@ -6,28 +6,18 @@ extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
// xs: [1024, 64, 1924], c Tensor[dims 128, 64, 8; f32, cuda:0] Conv1dConfig { padding: 0, stride: 4, dilation: 1, groups: 1 }
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.matmul(&x)?;
|
||||
device.synchronize()?;
|
||||
println!("fp32: {:?}", start_time.elapsed());
|
||||
drop(_x1);
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
||||
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.matmul(&x)?;
|
||||
device.synchronize()?;
|
||||
println!("tf32: {:?}", start_time.elapsed());
|
||||
let x = Tensor::randn(0f32, 1.0, (1024, 64, 1924), &device)?;
|
||||
let c = Tensor::randn(0f32, 1.0, (128, 64, 8), &device)?;
|
||||
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
|
||||
drop(_x1);
|
||||
for _ in 0..20 {
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
|
||||
device.synchronize()?;
|
||||
println!("conv1d: {:?}", start_time.elapsed());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ impl ParamsConvTranspose1D {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum CudnnFwdAlgo {
|
||||
ImplicitGemm,
|
||||
ImplicitPrecompGemm,
|
||||
@ -152,6 +152,19 @@ impl Tensor {
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
self.conv1d_with_algo(kernel, padding, stride, dilation, groups, None)
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d_with_algo(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
@ -175,7 +188,7 @@ impl Tensor {
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
cudnn_fwd_algo: None,
|
||||
cudnn_fwd_algo,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv1d_single_group(kernel, ¶ms)
|
||||
@ -280,6 +293,18 @@ impl Tensor {
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
self.conv2d_with_algo(kernel, padding, stride, dilation, groups, None)
|
||||
}
|
||||
|
||||
pub fn conv2d_with_algo(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
||||
@ -299,7 +324,7 @@ impl Tensor {
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
cudnn_fwd_algo: None,
|
||||
cudnn_fwd_algo,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv2d_single_group(kernel, ¶ms)
|
||||
|
@ -144,6 +144,24 @@ impl CudaDevice {
|
||||
self.stream.clone()
|
||||
}
|
||||
|
||||
/// When turned on, all cuda tensors **created after calling this function** will
|
||||
/// not track uses via cuda events.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// It is up to the user to ensure proper synchronization between multiple streams:
|
||||
/// - Ensure that no tensor is freed before a use on another stream is finished.
|
||||
/// - Ensure that a tensor is not used on another stream before allocation on the
|
||||
/// allocating stream finishes.
|
||||
/// - Ensure that a tensor is not written two concurrently by multiple streams.
|
||||
pub unsafe fn disable_event_tracking(&self) {
|
||||
self.context.disable_event_tracking()
|
||||
}
|
||||
|
||||
pub fn is_event_tracking(&self) -> bool {
|
||||
self.context.is_event_tracking()
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn compile(
|
||||
&self,
|
||||
|
@ -3,7 +3,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
|
||||
use crate::scalar::TensorOrScalar;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::shape::{Dim, Dims, ShapeWithOneHole};
|
||||
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
@ -452,17 +452,13 @@ impl Tensor {
|
||||
Self::from_vec_impl(data, len, device, false)
|
||||
}
|
||||
|
||||
pub(crate) fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
|
||||
pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let buffer_size = data.len();
|
||||
if buffer_size != shape.elem_count() {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let shape = shape.into_shape(data.len())?;
|
||||
let storage = device.storage_owned(data)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
@ -481,7 +477,7 @@ impl Tensor {
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
|
||||
pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
@ -502,17 +498,12 @@ impl Tensor {
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||
pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||
array: &[D],
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let n: usize = shape.elem_count();
|
||||
let buffer_size: usize = array.len();
|
||||
if buffer_size != n {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let shape = shape.into_shape(array.len())?;
|
||||
let storage = device.storage_from_slice(array)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, shape, none, false))
|
||||
@ -2197,7 +2188,7 @@ impl Tensor {
|
||||
///
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
|
||||
pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
|
||||
let shape = s.into_shape(self.elem_count())?;
|
||||
if shape.elem_count() != self.elem_count() {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
|
@ -46,7 +46,7 @@ impl TextGeneration {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (top_k, top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(None, None) => Sampling::GumbelSoftmax { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
|
@ -256,6 +256,12 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let tokenizer = common_args.tokenizer()?;
|
||||
|
||||
let device = candle_examples::device(common_args.cpu)?;
|
||||
#[cfg(feature = "cuda")]
|
||||
if let candle::Device::Cuda(d) = &device {
|
||||
unsafe {
|
||||
d.disable_event_tracking();
|
||||
}
|
||||
};
|
||||
|
||||
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
|
||||
let is_safetensors = config_path
|
||||
|
@ -133,6 +133,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
||||
padding,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv = if bias {
|
||||
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
|
||||
|
@ -92,6 +92,7 @@ impl ConvBlock {
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.9.0-alpha.2"
|
||||
version = "0.9.0-alpha.4"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.2" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.4" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.9.0-alpha.2"
|
||||
version = "0.9.0-alpha.4"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.9.0-alpha.2"
|
||||
version = "0.9.0-alpha.4"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
|
123
candle-metal-kernels/build.rs
Normal file
123
candle-metal-kernels/build.rs
Normal file
@ -0,0 +1,123 @@
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
use std::{env, str};
|
||||
|
||||
const METAL_SOURCES: [&str; 1] = ["reduce"];
|
||||
|
||||
enum Platform {
|
||||
MacOS,
|
||||
IOS,
|
||||
}
|
||||
|
||||
impl Platform {
|
||||
fn sdk(&self) -> &str {
|
||||
match self {
|
||||
Platform::MacOS => "macosx",
|
||||
Platform::IOS => "iphoneos",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn compile(platform: Platform) -> Result<(), String> {
|
||||
println!("cargo::rerun-if-changed=src/reduce.metal");
|
||||
println!("cargo::rerun-if-changed=src/utils.metal");
|
||||
println!("cargo::rerun-if-changed=build.rs");
|
||||
|
||||
let current_dir = env::current_dir().expect("Failed to get current directory");
|
||||
let out_dir = PathBuf::from(std::env::var("OUT_DIR").map_err(|_| "OUT_DIR not set")?);
|
||||
let working_directory = out_dir.to_string_lossy().to_string();
|
||||
let sources = current_dir.join("src");
|
||||
|
||||
// Compile metal to air
|
||||
let mut compile_air_cmd = Command::new("xcrun");
|
||||
compile_air_cmd
|
||||
.arg("--sdk")
|
||||
.arg(platform.sdk())
|
||||
.arg("metal")
|
||||
.arg(format!("-working-directory={working_directory}"))
|
||||
.arg("-Wall")
|
||||
.arg("-Wextra")
|
||||
.arg("-O3")
|
||||
.arg("-c")
|
||||
.arg("-w");
|
||||
for metal_file in METAL_SOURCES {
|
||||
compile_air_cmd.arg(sources.join(format!("{metal_file}.metal")));
|
||||
}
|
||||
compile_air_cmd.arg(sources.join("utils.metal"));
|
||||
compile_air_cmd.spawn().expect("Failed to compile air");
|
||||
|
||||
let mut child = compile_air_cmd.spawn().expect("Failed to compile air");
|
||||
|
||||
match child.try_wait() {
|
||||
Ok(Some(status)) => {
|
||||
if !status.success() {
|
||||
panic!(
|
||||
"Compiling metal -> air failed. Exit with status: {}",
|
||||
status
|
||||
)
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
let status = child
|
||||
.wait()
|
||||
.expect("Compiling metal -> air failed while waiting for result");
|
||||
if !status.success() {
|
||||
panic!(
|
||||
"Compiling metal -> air failed. Exit with status: {}",
|
||||
status
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(e) => panic!("Compiling metal -> air failed: {:?}", e),
|
||||
}
|
||||
|
||||
// Compile air to metallib
|
||||
let lib_name = match platform {
|
||||
Platform::MacOS => "candle.metallib",
|
||||
Platform::IOS => "candle_ios.metallib",
|
||||
};
|
||||
let metallib = out_dir.join(lib_name);
|
||||
let mut compile_metallib_cmd = Command::new("xcrun");
|
||||
compile_metallib_cmd.arg("metal").arg("-o").arg(&metallib);
|
||||
|
||||
for metal_file in METAL_SOURCES {
|
||||
compile_metallib_cmd.arg(out_dir.join(format!("{metal_file}.air")));
|
||||
}
|
||||
compile_metallib_cmd.arg(out_dir.join("utils.air"));
|
||||
|
||||
let mut child = compile_metallib_cmd
|
||||
.spawn()
|
||||
.expect("Failed to compile air -> metallib");
|
||||
|
||||
match child.try_wait() {
|
||||
Ok(Some(status)) => {
|
||||
if !status.success() {
|
||||
panic!(
|
||||
"Compiling air -> metallib failed. Exit with status: {}",
|
||||
status
|
||||
)
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
let status = child
|
||||
.wait()
|
||||
.expect("Compiling air -> metallib failed while waiting for result");
|
||||
if !status.success() {
|
||||
panic!(
|
||||
"Compiling air -> metallib failed. Exit with status: {}",
|
||||
status
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(e) => panic!("Compiling air -> metallib failed: {:?}", e),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<(), String> {
|
||||
compile(Platform::MacOS)?;
|
||||
compile(Platform::IOS)?;
|
||||
|
||||
Ok(())
|
||||
}
|
@ -13,6 +13,11 @@ pub use sort::{call_arg_sort, call_mlx_arg_sort};
|
||||
pub use utils::BufferOffset;
|
||||
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
const CANDLE: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/candle.metallib"));
|
||||
#[cfg(target_os = "ios")]
|
||||
const CANDLE: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/candle_ios.metallib"));
|
||||
|
||||
const AFFINE: &str = include_str!("affine.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
@ -23,7 +28,6 @@ const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
|
||||
const MLX_SORT: &str = include_str!("mlx_sort.metal");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const SORT: &str = include_str!("sort.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
@ -54,6 +58,7 @@ impl DType {
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
Candle,
|
||||
Affine,
|
||||
Binary,
|
||||
Cast,
|
||||
@ -64,7 +69,6 @@ pub enum Source {
|
||||
MlxSort,
|
||||
Quantized,
|
||||
Random,
|
||||
Reduce,
|
||||
Sort,
|
||||
Ternary,
|
||||
Unary,
|
||||
@ -288,11 +292,11 @@ impl Kernels {
|
||||
Source::MlxSort => MLX_SORT,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Random => RANDOM,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Sort => SORT,
|
||||
Source::Ternary => TERNARY,
|
||||
Source::Unary => UNARY,
|
||||
Source::Sdpa => SDPA,
|
||||
Source::Candle => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -307,11 +311,21 @@ impl Kernels {
|
||||
if let Some(lib) = libraries.get(&source) {
|
||||
Ok(lib.clone())
|
||||
} else {
|
||||
let lib = {
|
||||
let source_content = self.get_library_source(source);
|
||||
device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||
let lib = match source {
|
||||
Source::Candle => {
|
||||
let source_data = CANDLE;
|
||||
device.new_library_with_data(source_data).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load candle metal library: {e}"
|
||||
))
|
||||
})?
|
||||
}
|
||||
source => {
|
||||
let source_content = self.get_library_source(source);
|
||||
device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||
}
|
||||
};
|
||||
libraries.insert(source, lib.clone());
|
||||
Ok(lib)
|
||||
@ -641,7 +655,7 @@ pub fn call_reduce_contiguous(
|
||||
let length = shape.iter().product::<usize>();
|
||||
let num_dims = shape.len();
|
||||
let work_per_threadgroup = length / out_length;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
@ -697,7 +711,7 @@ pub fn call_reduce_strided(
|
||||
let length: usize = shape.iter().product();
|
||||
let num_dims = shape.len();
|
||||
let work_per_threadgroup = length / out_length;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
|
||||
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
@ -752,7 +766,7 @@ pub fn call_last_softmax(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let work_per_threadgroup = elements;
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -801,7 +815,7 @@ pub fn call_rms_norm(
|
||||
alpha_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -862,7 +876,7 @@ pub fn call_layer_norm(
|
||||
beta_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -923,7 +937,7 @@ pub fn call_rope_i(
|
||||
sin_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -966,7 +980,7 @@ pub fn call_rope_thd(
|
||||
sin_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -1010,7 +1024,7 @@ pub fn call_rope(
|
||||
sin_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Candle, kernel_name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
@ -1,51 +1,8 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_limits>
|
||||
#include "utils.metal"
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint nonzero(uint n) {
|
||||
return n == 0 ? 1 : n;
|
||||
}
|
||||
|
||||
template<uint N>
|
||||
constexpr uint nonzero() {
|
||||
return N == 0 ? 1 : N;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr ushort granularity() {
|
||||
return nonzero<vec_elements<T>::value>();
|
||||
}
|
||||
|
||||
METAL_FUNC uint next_p2(uint x) {
|
||||
return 1 << (32 - clz(x - 1));
|
||||
}
|
||||
|
||||
METAL_FUNC uint prev_p2(uint x) {
|
||||
return 1 << (31 - clz(x));
|
||||
}
|
||||
|
||||
constant uint MAX_SHARED_MEM = 32767;
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC uint max_shared_mem(uint n) {
|
||||
return min(n, prev_p2(MAX_SHARED_MEM / sizeof(T)));
|
||||
}
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant const uint &num_dims,
|
||||
constant const size_t *dims,
|
||||
constant const size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
struct Divide {
|
||||
template<typename T>
|
||||
METAL_FUNC T operator()(T a, T b) { return a / b; }
|
||||
|
@ -1,6 +1,6 @@
|
||||
//! Convolution Layers.
|
||||
use crate::BatchNorm;
|
||||
use candle::{Result, Tensor};
|
||||
use candle::{conv::CudnnFwdAlgo, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct Conv1dConfig {
|
||||
@ -8,6 +8,7 @@ pub struct Conv1dConfig {
|
||||
pub stride: usize,
|
||||
pub dilation: usize,
|
||||
pub groups: usize,
|
||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
}
|
||||
|
||||
impl Default for Conv1dConfig {
|
||||
@ -17,6 +18,7 @@ impl Default for Conv1dConfig {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -52,12 +54,13 @@ impl Conv1d {
|
||||
|
||||
impl crate::Module for Conv1d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv1d(
|
||||
let x = x.conv1d_with_algo(
|
||||
&self.weight,
|
||||
self.config.padding,
|
||||
self.config.stride,
|
||||
self.config.dilation,
|
||||
self.config.groups,
|
||||
self.config.cudnn_fwd_algo,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
@ -147,6 +150,7 @@ pub struct Conv2dConfig {
|
||||
pub stride: usize,
|
||||
pub dilation: usize,
|
||||
pub groups: usize,
|
||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
}
|
||||
|
||||
impl Default for Conv2dConfig {
|
||||
@ -156,6 +160,7 @@ impl Default for Conv2dConfig {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -211,12 +216,13 @@ impl Conv2d {
|
||||
|
||||
impl crate::Module for Conv2d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv2d(
|
||||
let x = x.conv2d_with_algo(
|
||||
&self.weight,
|
||||
self.config.padding,
|
||||
self.config.stride,
|
||||
self.config.dilation,
|
||||
self.config.groups,
|
||||
self.config.cudnn_fwd_algo,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
|
@ -294,6 +294,27 @@ impl RotatingCache {
|
||||
Tensor::from_slice(&mask, (size1, size2), device)
|
||||
}
|
||||
|
||||
/// Returns the positions corresponding to all the elements that will be retured
|
||||
/// *after* adding `seq_len` to the cache.
|
||||
pub fn positions(&self, seq_len: usize) -> Vec<usize> {
|
||||
if seq_len <= self.max_seq_len {
|
||||
let upd_offset = (self.offset + seq_len) % self.max_seq_len;
|
||||
let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len);
|
||||
(0..cache_out_len)
|
||||
.map(|i| {
|
||||
let pos_cache = self.current_seq_len + seq_len + i - upd_offset;
|
||||
if i < upd_offset {
|
||||
pos_cache
|
||||
} else {
|
||||
pos_cache - self.max_seq_len
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
(self.current_seq_len..(self.current_seq_len + seq_len)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
|
||||
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
|
||||
let mask = if seq_len == 1 {
|
||||
@ -362,10 +383,17 @@ impl RotatingKvCache {
|
||||
self.k.current_seq_len()
|
||||
}
|
||||
|
||||
/// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
|
||||
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
|
||||
self.k.attn_mask(seq_len, device)
|
||||
}
|
||||
|
||||
/// Returns the positions corresponding to all the elements that will be retured
|
||||
/// *after* adding `seq_len` to the cache.
|
||||
pub fn positions(&self, seq_len: usize) -> Vec<usize> {
|
||||
self.k.positions(seq_len)
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.k.reset();
|
||||
self.v.reset();
|
||||
|
@ -31,6 +31,7 @@ pub mod ops;
|
||||
pub mod optim;
|
||||
pub mod rnn;
|
||||
pub mod rotary_emb;
|
||||
pub mod sampling;
|
||||
pub mod sequential;
|
||||
pub mod var_builder;
|
||||
pub mod var_map;
|
||||
|
20
candle-nn/src/sampling.rs
Normal file
20
candle-nn/src/sampling.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
/// Sample according to the Gumbel-Softmax distribution.
|
||||
pub fn gumbel_softmax<D: candle::shape::Dim>(
|
||||
logits: &Tensor,
|
||||
temperature: f64,
|
||||
dim: D,
|
||||
) -> Result<Tensor> {
|
||||
if temperature <= 0.0 {
|
||||
logits.argmax(dim)
|
||||
} else if temperature == 1.0 {
|
||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
||||
let sampled = (logits - minus_g)?.argmax(dim)?;
|
||||
Ok(sampled)
|
||||
} else {
|
||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
||||
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
|
||||
Ok(sampled)
|
||||
}
|
||||
}
|
@ -39,9 +39,16 @@ fn rotating_kv_cache() -> Result<()> {
|
||||
assert_eq!(cache.current_seq_len(), 0);
|
||||
let data = cache.current_data()?;
|
||||
assert!(data.is_none());
|
||||
assert_eq!(cache.positions(1), &[0]);
|
||||
assert_eq!(cache.positions(2), &[0, 1]);
|
||||
let t = Tensor::new(&[1., 2., 3.], &Device::Cpu)?;
|
||||
let data = cache.append(&t)?;
|
||||
assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3.]);
|
||||
assert_eq!(cache.positions(0), &[0, 1, 2]);
|
||||
assert_eq!(cache.positions(1), &[0, 1, 2, 3]);
|
||||
assert_eq!(cache.positions(2), &[0, 1, 2, 3, 4]);
|
||||
assert_eq!(cache.positions(3), &[0, 1, 2, 3, 4, 5]);
|
||||
assert_eq!(cache.positions(4), &[6, 1, 2, 3, 4, 5]);
|
||||
let t = Tensor::new(&[4.], &Device::Cpu)?;
|
||||
let data = cache.append(&t)?;
|
||||
assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3., 4.]);
|
||||
@ -79,11 +86,17 @@ fn rotating_kv_cache() -> Result<()> {
|
||||
mask.to_vec2::<u8>()?,
|
||||
&[[0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0]],
|
||||
);
|
||||
assert_eq!(cache.positions(0), &[12, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(cache.positions(2), &[12, 13, 14, 9, 10, 11]);
|
||||
assert_eq!(cache.positions(3), &[12, 13, 14, 15, 10, 11]);
|
||||
assert_eq!(cache.positions(8), &[13, 14, 15, 16, 17, 18, 19, 20]);
|
||||
let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?;
|
||||
let data = cache.append(&t)?;
|
||||
assert_eq!(data.to_vec1::<f64>()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]);
|
||||
assert_eq!(cache.current_seq_len(), 22);
|
||||
assert_eq!(cache.offset(), 0);
|
||||
assert_eq!(cache.positions(0), &[16, 17, 18, 19, 20, 21]);
|
||||
assert_eq!(cache.positions(1), &[22, 17, 18, 19, 20, 21]);
|
||||
|
||||
let mask = cache.attn_mask(1, &Device::Cpu)?;
|
||||
assert!(mask.is_none());
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.9.0-alpha.2"
|
||||
version = "0.9.0-alpha.4"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.2" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.2" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.4" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -13,6 +13,8 @@ pub enum Sampling {
|
||||
TopK { k: usize, temperature: f64 },
|
||||
TopP { p: f64, temperature: f64 },
|
||||
TopKThenTopP { k: usize, p: f64, temperature: f64 },
|
||||
// Note that the rng is not used for the Gumbel-Softmax sampling.
|
||||
GumbelSoftmax { temperature: f64 },
|
||||
}
|
||||
|
||||
pub struct LogitsProcessor {
|
||||
@ -49,6 +51,11 @@ impl LogitsProcessor {
|
||||
Ok(next_token)
|
||||
}
|
||||
|
||||
fn sample_gumbel_softmax(&mut self, logits: &Tensor, temperature: f64) -> Result<u32> {
|
||||
let sampled = candle_nn::sampling::gumbel_softmax(logits, temperature, candle::D::Minus1)?;
|
||||
sampled.to_vec0::<u32>()
|
||||
}
|
||||
|
||||
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
||||
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||
let next_token = distr.sample(&mut self.rng) as u32;
|
||||
@ -127,6 +134,9 @@ impl LogitsProcessor {
|
||||
|
||||
let next_token = match &self.sampling {
|
||||
Sampling::ArgMax => self.sample_argmax(logits)?,
|
||||
Sampling::GumbelSoftmax { temperature } => {
|
||||
self.sample_gumbel_softmax(&logits, *temperature)?
|
||||
}
|
||||
Sampling::All { temperature } => {
|
||||
let prs = prs(*temperature)?;
|
||||
self.sample_multinomial(&prs)?
|
||||
|
@ -124,6 +124,7 @@ impl ResidualConvUnit {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv1 = conv2d(
|
||||
conf.num_features,
|
||||
@ -208,6 +209,7 @@ impl FeatureFusionBlock {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let output_conv = conv2d(
|
||||
conf.num_features,
|
||||
@ -258,6 +260,7 @@ impl Scratch {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
|
||||
let layer1_rn = conv2d_no_bias(
|
||||
@ -319,6 +322,7 @@ impl Scratch {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let output_conv1 = conv2d(
|
||||
conf.num_features,
|
||||
@ -425,6 +429,7 @@ impl DPTHead {
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
},
|
||||
vb.pp("resize_layers").pp("3"),
|
||||
)?),
|
||||
|
@ -468,6 +468,7 @@ impl EncodecConv1d {
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
|
@ -267,6 +267,7 @@ impl StreamableConv1d {
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
|
||||
if k_size < stride {
|
||||
|
@ -68,6 +68,7 @@ impl ResnetBlock2D {
|
||||
padding: 1,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
|
||||
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
|
||||
@ -83,6 +84,7 @@ impl ResnetBlock2D {
|
||||
padding: 0,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
Some(conv2d(
|
||||
in_channels,
|
||||
|
@ -248,12 +248,14 @@ impl AudioEncoder {
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
|
@ -244,12 +244,14 @@ impl AudioEncoder {
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
|
@ -54,3 +54,25 @@ fn sample_with_top_k() -> Result<()> {
|
||||
assert_eq!(token, 2);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sample_gumbel() -> Result<()> {
|
||||
let mut logits_process = LogitsProcessor::from_sampling(
|
||||
42,
|
||||
candle_transformers::generation::Sampling::GumbelSoftmax { temperature: 1.0 },
|
||||
);
|
||||
let logits = Tensor::new(&[-1.0, 0.0, 0.2, 1.0], &Device::Cpu)?;
|
||||
let sm = candle_nn::ops::softmax(&logits, 0)?.to_vec1::<f64>()?;
|
||||
let mut counts = vec![0f64; 4];
|
||||
let samples = 100000;
|
||||
for _ in 0..samples {
|
||||
let token = logits_process.sample(&logits)?;
|
||||
counts[token as usize] += 1f64 / samples as f64;
|
||||
}
|
||||
for i in 0..4 {
|
||||
if (counts[i] - sm[i]).abs() > 0.05 {
|
||||
panic!("pr mismatch {counts:?} {sm:?}");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -98,6 +98,7 @@ impl ConvBlock {
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||
|
Reference in New Issue
Block a user