Compare commits

...

25 Commits

Author SHA1 Message Date
7f0f83a7c1 Rotating kv cache positions (#2901)
* Retrieve the current positions for rotating KV caches.

* Add the function to the kv cache too.

* More testing.
2025-04-15 23:09:26 +02:00
76e565c4ab Updated candle-book: Introduction, Installation, MNIST guide, and added CONTRIBUTING.md (#2897)
* added CONTRIBUTING.md to candle-book

* added description to candle-book introduction

* Updated formatting and added different features to candle-book installation

* mnist guide first draft candle-book

* updated mnist guide syntax and grammar for candle-book

* changed HelloWorld - Mnist to Tutorial - Mnist in SUMMARY.md

* updated intro to mnist guide in candle-book
2025-04-15 21:41:10 +02:00
e4e7b0b2da Use cudarc 0.16. (#2900)
* Use cudarc 0.16.

* Allow for disabling event tracking.

* Tweaks.

* Bump the ug version.

* And bump the candle version too.
2025-04-15 21:40:18 +02:00
b01ebbad8a Use cudarc 0.15.2. (#2896) 2025-04-14 20:47:52 +02:00
1d1d6d4fe6 Bump the crate version. (#2895) 2025-04-14 15:52:11 +02:00
2653002f29 Gumbel-Softmax sampling. (#2894)
* Gumbel-Softmax sampling.

* Add a sampling test.

* Share the gumbel-softmax bits.
2025-04-14 15:42:42 +02:00
a52b76ae82 Expose the cudnn algo in the conv ops. (#2892)
* Set the algo.

* Expose the cudnn preferred algo for conv ops.
2025-04-14 08:25:32 +02:00
fb660b8d43 Add a cudnn feature to candle-nn/candle-transformers. (#2890) 2025-04-13 17:43:41 +02:00
2f9606b187 Exclude candle-book to avoid some CI failures. (#2889)
* Exclude candle-book to avoid some CI failures.

* Remove the book CIs.
2025-04-13 17:11:41 +02:00
f3a73f80d1 Support for cudnn conv1d. (#2888)
* Support for cudnn conv1d.

* More conv1d work.

* Get the conv1d to work with cudnn.

* Cleanup.
2025-04-13 16:47:37 +02:00
b44d38de0e Add the Orpheus TTS. (#2886)
* Add the Orpheus TTS.

* Add a small readme.

* Token fix.

* Support more voices.

* Clippy fixes.
2025-04-13 12:02:17 +02:00
d9198deb37 Im2col cuda optimization. (#2885) 2025-04-13 10:07:53 +02:00
15ed0b11ce Optimize the batched matmul for the cpu backend. (#2884) 2025-04-12 21:40:40 +02:00
34505fdf3a Avoid using batched-matmul in nn::Linear. (#2883)
* Avoid using batched-matmul in nn::Linear.

* Also avoid batched matmul in conv1d.

* Also tweak the conv2d.

* Batched tests.

* Also cover conv2d.
2025-04-12 19:53:58 +02:00
d7b7ce16e4 Upgrade ug. (#2882) 2025-04-12 13:19:32 +02:00
19fb6dac1f Bump the crate version. (#2881) 2025-04-11 22:28:21 +02:00
acc5bd335f Cuda cleanup. (#2880)
* Cuda cleanup.

* More fixes.
2025-04-11 21:43:35 +02:00
eb478ece92 Implementing DistilBertForMaskedLM. (#2866)
* Initial commit: model weights working, prediciton incorrect

* moved distilbertformaskedlm into distilbert modeling file

* made maskedLM like bert example, still incorrect predictions

* finally not getting NaNs, fixed attention mask

* getting correct output sentences

* get top k predictions

* fixed output formatting slightly

* added default arg for model_id

* lint

* moved masked token example code from distilbertformaskedlm example to distilbert example

* lint

* removed distilbertformaskedlm example

* cleanup

* clippy

* removed embedding normalization from example

* made output and model dependent on args instead of prompt

* lint

* replaced or_ok anyhow error with anyhow context

* changed error message for mask token not found
2025-04-11 13:25:39 +02:00
d339b01726 Fix hardcoded f32 dtype for attention_mask. Use the model dtype for compatibility. (#2872) 2025-04-08 06:12:14 +02:00
2f3bf42bcb Support more snac variants. (#2871) 2025-04-07 08:23:47 +02:00
e3370c6316 Add the SNAC audio tokenizer. (#2869)
* Add the SNAC audio tokenizer.

* More snac.

* Again more snac.

* Add some example code for snac.

* Get the weights to load.

* Add to the snac model.

* Fixes.

* Get round-tripping to work.

* Save/load code files.

* Clippy fix.

* Fmt fix.
2025-04-06 22:15:36 +02:00
338f6a102e Clippy 1.86 fixes for cuda. (#2868) 2025-04-05 15:45:35 +02:00
bc33df77e1 Add the missing voices for CSM. (#2867) 2025-04-05 06:52:36 +02:00
cf9d7bf24c Add the CSM model. (#2862)
* Add the CSM model.

* Add some code to load the model.

* Load the text tokenizer.

* Add frame generation.

* Get the sampling to work.

* Rope fix.

* Autoregressive generation.

* Generate some audio file.

* Use the actual prompt.

* Support multiple turns.

* Add a very barebone readme.

* Move some of the shared bits to the model.
2025-04-04 06:48:03 +02:00
9d31361c4f Fix for clippy 1.86. (#2864)
* Fix for clippy 1.86.

* More clippy fixes.

* More fixes.
2025-04-03 19:38:27 +02:00
77 changed files with 4074 additions and 444 deletions

View File

@ -1,40 +0,0 @@
name: Deploy Rust book
on:
push:
branches:
- main
jobs:
deploy:
runs-on: ubuntu-latest
permissions:
contents: write # To push a branch
pull-requests: write # To create a PR from that branch
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Install latest mdbook
run: |
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
mkdir mdbook
curl -sSL $url | tar -xz --directory=./mdbook
echo `pwd`/mdbook >> $GITHUB_PATH
- name: Deploy GitHub Pages
run: |
# This assumes your book is in the root of your repository.
# Just add a `cd` here if you need to change to another directory.
cd candle-book
mdbook build
git worktree add gh-pages
git config user.name "Deploy from CI"
git config user.email ""
cd gh-pages
# Delete the ref to avoid keeping history.
git update-ref -d refs/heads/gh-pages
rm -rf *
mv ../book/* .
git add .
git commit -m "Deploy $GITHUB_SHA to gh-pages"
git push --force --set-upstream origin gh-pages

View File

@ -1,29 +0,0 @@
name: CI
on:
pull_request:
jobs:
test:
name: Test candle-book
runs-on: ubuntu-latest
permissions:
contents: write # To push a branch
pull-requests: write # To create a PR from that branch
steps:
- uses: actions/checkout@master
- name: Install Rust
run: |
rustup set profile minimal
rustup toolchain install stable
rustup default stable
- name: Install latest mdbook
run: |
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
mkdir bin
curl -sSL $url | tar -xz --directory=bin
echo "$(pwd)/bin" >> $GITHUB_PATH
- name: Run tests
run: cd candle-book && cargo build && mdbook test -L ../target/debug/deps/

View File

@ -3,7 +3,6 @@ members = [
"candle-core",
"candle-datasets",
"candle-examples",
"candle-book",
"candle-nn",
"candle-pyo3",
"candle-transformers",
@ -12,6 +11,7 @@ members = [
"tensor-tools",
]
exclude = [
"candle-book",
"candle-flash-attn",
"candle-kernels",
"candle-metal-kernels",
@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.9.0-alpha.1"
version = "0.9.0-alpha.4"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.1" }
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.1" }
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.1" }
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.1" }
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.1" }
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.4" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.4" }
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.4" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.4" }
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.4" }
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.4" }
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.4" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.14.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
cudarc = { version = "0.16.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.4.1"
@ -70,9 +70,9 @@ tokenizers = { version = "0.21.0", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
ug = "0.2.0"
ug-cuda = "0.2.0"
ug-metal = "0.2.0"
ug = "0.4.0"
ug-cuda = "0.4.0"
ug-metal = "0.4.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}

View File

@ -290,6 +290,8 @@ Cheatsheet:
### Why should I use Candle?
<!--- ANCHOR: goals --->
Candle's core goal is to *make serverless inference possible*. Full machine learning frameworks like PyTorch
are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight
binaries.
@ -299,6 +301,7 @@ and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-fut
Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers).
<!--- ANCHOR_END: goals --->
### Other ML frameworks

View File

@ -0,0 +1,13 @@
# Candle Book
The book uses [mdBook](https://github.com/rust-lang/mdBook) for building.
## Installation
To install mdBook, run `cargo install mdbook`. More instructions can be found [here](https://rust-lang.github.io/mdBook/guide/installation.html).
## Viewing the book
To view the book, run `mdbook serve --open candle-book`. More instructions can be found [here](https://rust-lang.github.io/mdBook/guide/creating.html).
The book is built automatically in github CI.

View File

@ -1,6 +1,7 @@
# Introduction
{{#include ../../README.md:goals}}
{{#include ../../README.md:features}}
This book will introduce step by step how to use `candle`.
This book will introduce step by step how to use `candle`.

View File

@ -5,7 +5,10 @@
# User Guide
- [Installation](guide/installation.md)
- [Hello World - MNIST](guide/hello_world.md)
- [Tutorial - MNIST](guide/mnist/intro.md)
- [Modeling](guide/mnist/modeling.md)
- [Training](guide/mnist/training.md)
- [Saving And Loading](guide/mnist/saving_loading.md)
- [PyTorch cheatsheet](guide/cheatsheet.md)
# Reference Guide

View File

@ -1,8 +1,23 @@
# Installation
**With Cuda support**:
## 1. Create a new rust app or library
1. First, make sure that Cuda is correctly installed.
```bash
cargo new myapp
cd myapp
```
## 2. Add the correct candle version
### Standard
```bash
cargo add --git https://github.com/huggingface/candle.git candle-core
```
### CUDA
First, make sure that Cuda is correctly installed.
- `nvcc --version` should print information about your Cuda compiler driver.
- `nvidia-smi --query-gpu=compute_cap --format=csv` should print your GPUs compute capability, e.g. something
like:
@ -17,43 +32,36 @@ You can also compile the Cuda kernels for a specific compute cap using the
If any of the above commands errors out, please make sure to update your Cuda version.
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
Start by creating a new cargo:
```bash
cargo new myapp
cd myapp
```
Make sure to add the `candle-core` crate with the cuda feature:
Add the `candle-core` crate with the cuda feature:
```bash
cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"
```
### MKL
You can also see the `mkl` feature which can get faster inference on CPU.
Add the `candle-core` crate with the mkl feature:
```bash
cargo add --git https://github.com/huggingface/candle.git candle-core --features "mkl"
```
### Metal
Metal is exclusive to MacOS.
Add the `candle-core` crate with the metal feature:
```bash
cargo add --git https://github.com/huggingface/candle.git candle-core --features "metal"
```
## 3. Building
Run `cargo build` to make sure everything can be correctly built.
```bash
cargo build
```
**Without Cuda support**:
Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) as follows:
```bash
cargo new myapp
cd myapp
cargo add --git https://github.com/huggingface/candle.git candle-core
```
Finally, run `cargo build` to make sure everything can be correctly built.
```bash
cargo build
```
**With mkl support**
You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md)

View File

@ -0,0 +1,17 @@
# Candle MNIST Tutorial
## Introduction
This tutorial provides an introduction to Candle by implementing and training a neural network for MNIST digit classification from scratch.
Throughout this tutorial, you will learn the basics of:
- Tensor operations and model construction
- Creating and implementing neural network layers
- Parameter initialization
- Training loop implementation
- Saving and loading trained models
## Getting Started
Before proceeding, please ensure that you have properly installed Candle by following the instructions in the [Installation](../installation.md) guide.

View File

@ -0,0 +1,172 @@
# Candle MNIST Tutorial
## Modeling
Open `src/main.rs` in your project folder and insert the following code:
```rust
use candle_core::{Device, Result, Tensor};
struct Model {
first: Tensor,
second: Tensor,
}
impl Model {
fn forward(&self, image: &Tensor) -> Result<Tensor> {
let x = image.matmul(&self.first)?;
let x = x.relu()?;
x.matmul(&self.second)
}
}
fn main() -> Result<()> {
// Use Device::new_cuda(0)?; to utilize GPU acceleration.
let device = Device::Cpu;
let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let model = Model { first, second };
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
Ok(())
}
```
Execute the program with:
```bash
$ cargo run --release
> Digit Tensor[dims 1, 10; f32] digit
```
Since random inputs are provided, expect an incoherent output.
## Implementing a `Linear` Layer
To create a more sophisticated layer type, add a `bias` to the weight to construct the standard `Linear` layer.
Replace the entire content of `src/main.rs` with:
```rust
use candle_core::{Device, Result, Tensor};
struct Linear {
weight: Tensor,
bias: Tensor,
}
impl Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.matmul(&self.weight)?;
x.broadcast_add(&self.bias)
}
}
struct Model {
first: Linear,
second: Linear,
}
impl Model {
fn forward(&self, image: &Tensor) -> Result<Tensor> {
let x = self.first.forward(image)?;
let x = x.relu()?;
self.second.forward(&x)
}
}
fn main() -> Result<()> {
// Use Device::new_cuda(0)?; for GPU acceleration.
// Use Device::Cpu; for CPU computation.
let device = Device::cuda_if_available(0)?;
// Initialize model parameters
let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let first = Linear { weight, bias };
let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let second = Linear { weight, bias };
let model = Model { first, second };
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
// Perform inference
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
Ok(())
}
```
Execute again with:
```bash
$ cargo run --release
> Digit Tensor[dims 1, 10; f32] digit
```
## Utilizing `candle_nn`
Many classical layers (such as [Linear](https://github.com/huggingface/candle/blob/main/candle-nn/src/linear.rs)) are already implemented in [candle-nn](https://github.com/huggingface/candle/tree/main/candle-nn).
This `Linear` implementation follows PyTorch conventions for improved compatibility with existing models, utilizing the transpose of weights rather than direct weights.
Let's simplify our implementation. First, add `candle-nn` as a dependency:
```bash
$ cargo add --git https://github.com/huggingface/candle.git candle-nn
```
Now, replace the entire content of `src/main.rs` with:
```rust
use candle_core::{Device, Result, Tensor};
use candle_nn::{Linear, Module};
struct Model {
first: Linear,
second: Linear,
}
impl Model {
fn forward(&self, image: &Tensor) -> Result<Tensor> {
let x = self.first.forward(image)?;
let x = x.relu()?;
self.second.forward(&x)
}
}
fn main() -> Result<()> {
// Use Device::new_cuda(0)?; for GPU acceleration.
let device = Device::Cpu;
// Note the dimension change: (784, 100) -> (100, 784)
let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?;
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let first = Linear::new(weight, Some(bias));
let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?;
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let second = Linear::new(weight, Some(bias));
let model = Model { first, second };
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
Ok(())
}
```
Execute the final version:
```bash
$ cargo run --release
> Digit Tensor[dims 1, 10; f32] digit
```

View File

@ -0,0 +1,158 @@
# Candle MNIST Tutorial
## Saving and Loading Models
After training a model, it is useful to save and subsequently load the model parameters. In Candle, this functionality is managed through the `VarMap` data structure, with parameters stored on disk using the [safetensors](https://huggingface.co/docs/safetensors/index) format.
### Saving Model Parameters
Let's modify our `training_loop` function to include functionality for saving weights:
```rust
fn training_loop(
m: candle_datasets::vision::Dataset,
) -> anyhow::Result<()> {
let dev = Device::cuda_if_available(0)?;
let train_labels = m.train_labels;
let train_images = m.train_images.to_device(&dev)?;
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
// Initialize a VarMap for trainable parameters
let varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
let model = Model::new(vs.clone())?;
let learning_rate = 0.05;
let epochs = 10;
// Initialize stochastic gradient descent optimizer
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
let test_images = m.test_images.to_device(&dev)?;
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
for epoch in 1..epochs {
// Standard MNIST forward pass
let logits = model.forward(&train_images)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
// Compute Negative Log Likelihood loss
let loss = loss::nll(&log_sm, &train_labels)?;
// Perform backward pass and update weights
sgd.backward_step(&loss)?;
// Evaluate model on test set
let test_logits = model.forward(&test_images)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&test_labels)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
println!(
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
loss.to_scalar::<f32>()?,
test_accuracy
);
}
// Save model weights to disk
varmap.save("model_weights.safetensors")?;
Ok(())
}
```
```bash
$ cargo run --release
> 1 train loss: 2.40485 test acc: 0.11%
> 2 train loss: 2.34161 test acc: 0.14%
> 3 train loss: 2.28841 test acc: 0.17%
> 4 train loss: 2.24158 test acc: 0.19%
> 5 train loss: 2.19898 test acc: 0.23%
> 6 train loss: 2.15927 test acc: 0.26%
> 7 train loss: 2.12161 test acc: 0.29%
> 8 train loss: 2.08549 test acc: 0.32%
> 9 train loss: 2.05053 test acc: 0.35%
```
### Loading Model Parameters
Now that we have saved our model parameters, we can modify the code to load them. The primary change required is to make the `varmap` variable mutable:
```rust
fn training_loop(
m: candle_datasets::vision::Dataset,
) -> anyhow::Result<()> {
let dev = Device::cuda_if_available(0)?;
let train_labels = m.train_labels;
let train_images = m.train_images.to_device(&dev)?;
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
// Create a mutable VarMap for trainable parameters
let mut varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
let model = Model::new(vs.clone())?;
// Load pre-trained weights from file
varmap.load("model_weights.safetensors")?;
let learning_rate = 0.05;
let epochs = 10;
// Initialize stochastic gradient descent optimizer
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
let test_images = m.test_images.to_device(&dev)?;
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
for epoch in 1..epochs {
// Standard MNIST forward pass
let logits = model.forward(&train_images)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
// Compute Negative Log Likelihood loss
let loss = loss::nll(&log_sm, &train_labels)?;
// Perform backward pass and update weights
sgd.backward_step(&loss)?;
// Evaluate model on test set
let test_logits = model.forward(&test_images)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&test_labels)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
println!(
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
loss.to_scalar::<f32>()?,
test_accuracy
);
}
// Save updated weights back to disk
varmap.save("model_weights.safetensors")?;
Ok(())
}
```
```bash
$ cargo run --release
> 1 train loss: 2.01645 test acc: 0.38%
> 2 train loss: 1.98300 test acc: 0.41%
> 3 train loss: 1.95008 test acc: 0.44%
> 4 train loss: 1.91754 test acc: 0.47%
> 5 train loss: 1.88534 test acc: 0.50%
> 6 train loss: 1.85349 test acc: 0.53%
> 7 train loss: 1.82198 test acc: 0.56%
> 8 train loss: 1.79077 test acc: 0.59%
> 9 train loss: 1.75989 test acc: 0.61%
```
Note that loading the weights will fail if the specified file does not exist or is incompatible with the current model architecture. Implementing file existence checks and appropriate error handling is left to the user.

View File

@ -0,0 +1,134 @@
# Candle MNIST Tutorial
## Training Implementation
First, let's create a utility function `make_linear` that accepts a `VarBuilder` and returns an initialized linear layer. The `VarBuilder` constructs a `VarMap`, which is the data structure that stores our trainable parameters.
```rust
use candle_core::{Device, Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder, VarMap};
fn make_linear(vs: VarBuilder, in_dim: usize, out_dim: usize) -> Result<Linear> {
let ws = vs.get_with_hints(
(out_dim, in_dim),
"weight",
candle_nn::init::DEFAULT_KAIMING_NORMAL,
)?;
let bound = 1. / (in_dim as f64).sqrt();
let bs = vs.get_with_hints(
out_dim,
"bias",
candle_nn::Init::Uniform {
lo: -bound,
up: bound,
},
)?;
Ok(Linear::new(ws, Some(bs)))
}
```
Next, let's implement a `new` method for our model class to accept a `VarBuilder` and initialize the model. We use `VarBuilder::pp` to "push prefix" so that the parameter names are organized hierarchically: the first layer weights as `first.weight` and `first.bias`, and the second layer weights as `second.weight` and `second.bias`.
```rust
impl Model {
fn new(vs: VarBuilder) -> Result<Self> {
const IMAGE_DIM: usize = 784;
const HIDDEN_DIM: usize = 100;
const LABELS: usize = 10;
let first = make_linear(vs.pp("first"), IMAGE_DIM, HIDDEN_DIM)?;
let second = make_linear(vs.pp("second"), HIDDEN_DIM, LABELS)?;
Ok(Self { first, second })
}
fn forward(&self, image: &Tensor) -> Result<Tensor> {
let x = self.first.forward(image)?;
let x = x.relu()?;
self.second.forward(&x)
}
}
```
Now, let's add the `candle-datasets` package to our project to access the MNIST dataset:
```bash
$ cargo add --git https://github.com/huggingface/candle.git candle-datasets
```
With the dataset available, we can implement our training loop:
```rust
use candle_core::{DType, Device, Result, Tensor, D};
use candle_nn::{loss, ops, Linear, Module, Optimizer, VarBuilder, VarMap};
fn training_loop(
m: candle_datasets::vision::Dataset,
) -> anyhow::Result<()> {
let dev = Device::cuda_if_available(0)?;
let train_labels = m.train_labels;
let train_images = m.train_images.to_device(&dev)?;
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
// Initialize a VarMap to store trainable parameters
let varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
let model = Model::new(vs.clone())?;
let learning_rate = 0.05;
let epochs = 10;
// Initialize a stochastic gradient descent optimizer to update parameters
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
let test_images = m.test_images.to_device(&dev)?;
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
for epoch in 1..epochs {
// Perform forward pass on MNIST data
let logits = model.forward(&train_images)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
// Compute Negative Log Likelihood loss
let loss = loss::nll(&log_sm, &train_labels)?;
// Perform backward pass and update weights
sgd.backward_step(&loss)?;
// Evaluate model on test set
let test_logits = model.forward(&test_images)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&test_labels)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
println!(
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
loss.to_scalar::<f32>()?,
test_accuracy
);
}
Ok(())
}
```
Finally, let's implement our main function:
```rust
pub fn main() -> anyhow::Result<()> {
let m = candle_datasets::vision::mnist::load()?;
return training_loop(m);
}
```
Let's execute the training process:
```bash
$ cargo run --release
> 1 train loss: 2.35449 test acc: 0.12%
> 2 train loss: 2.30760 test acc: 0.15%
> ...
```

View File

@ -56,3 +56,7 @@ harness = false
[[example]]
name = "metal_basics"
required-features = ["metal"]
[[example]]
name = "cuda_basics"
required-features = ["cuda"]

View File

@ -6,28 +6,18 @@ extern crate intel_mkl_src;
use anyhow::Result;
use candle_core::{Device, Tensor};
// xs: [1024, 64, 1924], c Tensor[dims 128, 64, 8; f32, cuda:0] Conv1dConfig { padding: 0, stride: 4, dilation: 1, groups: 1 }
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
.to_dtype(candle_core::DType::BF16)?;
candle_core::cuda::set_gemm_reduced_precision_f32(false);
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("fp32: {:?}", start_time.elapsed());
drop(_x1);
candle_core::cuda::set_gemm_reduced_precision_f32(true);
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("tf32: {:?}", start_time.elapsed());
let x = Tensor::randn(0f32, 1.0, (1024, 64, 1924), &device)?;
let c = Tensor::randn(0f32, 1.0, (128, 64, 8), &device)?;
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
drop(_x1);
for _ in 0..20 {
let start_time = std::time::Instant::now();
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
device.synchronize()?;
println!("conv1d: {:?}", start_time.elapsed());
}
Ok(())
}

View File

@ -14,6 +14,7 @@ pub struct ParamsConv1D {
pub(crate) padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>,
}
impl ParamsConv1D {
@ -54,7 +55,7 @@ impl ParamsConvTranspose1D {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CudnnFwdAlgo {
ImplicitGemm,
ImplicitPrecompGemm,
@ -151,6 +152,19 @@ impl Tensor {
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
self.conv1d_with_algo(kernel, padding, stride, dilation, groups, None)
}
/// Applies a 1D convolution over the input tensor.
pub fn conv1d_with_algo(
&self,
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = self.dims3()?;
@ -174,6 +188,7 @@ impl Tensor {
padding,
stride,
dilation,
cudnn_fwd_algo,
};
if groups == 1 {
self.conv1d_single_group(kernel, &params)
@ -278,6 +293,18 @@ impl Tensor {
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
self.conv2d_with_algo(kernel, padding, stride, dilation, groups, None)
}
pub fn conv2d_with_algo(
&self,
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
@ -297,7 +324,7 @@ impl Tensor {
padding,
stride,
dilation,
cudnn_fwd_algo: None,
cudnn_fwd_algo,
};
if groups == 1 {
self.conv2d_single_group(kernel, &params)

View File

@ -1289,6 +1289,15 @@ impl Map2 for MatMul {
} else {
Parallelism::None
};
let (b, m, n, k) = if b_skip == 0 && a_skip == m * k {
// a_skip and c_skip should be updated but step is always 0 so
// it wouldn't matter.
(1, b * m, n, k)
} else if a_skip == 0 && b_skip == n * k {
(1, m, b * n, k)
} else {
(b, m, n, k)
};
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];

View File

@ -122,3 +122,104 @@ pub(crate) fn launch_conv2d<
}
Ok(())
}
pub(crate) fn launch_conv1d<
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
Y: cudarc::cudnn::CudnnDataType,
>(
src: &CudaView<T>,
src_l: &crate::Layout,
filter: &CudaView<T>,
dst: &mut CudaSlice<T>,
params: &crate::conv::ParamsConv1D,
dev: &crate::cuda_backend::CudaDevice,
) -> crate::Result<()> {
use crate::conv::CudnnFwdAlgo as CandleAlgo;
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
let device_id = dev.id();
let cudnn = CUDNN.with(|cudnn| {
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
return Ok(cudnn.clone());
}
let c = Cudnn::new(dev.cuda_stream());
if let Ok(c) = &c {
cudnn.borrow_mut().insert(device_id, c.clone());
}
c
})?;
let conv = cudnn.create_conv2d::<Y>(
/* pad */ [params.padding as i32, 0],
/* stride */ [params.stride as i32, 1],
/* dilation */ [params.dilation as i32, 1],
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
)?;
// https://docs.nvidia.com/deeplearning/cudnn/backend/latest/api/cudnn-ops-library.html#cudnnsettensornddescriptor
// > Tensors are restricted to having at least 4 dimensions, and at most CUDNN_DIM_MAX
// > dimensions (defined in cudnn.h). When working with lower dimensional data, it is
// > recommended that the user create a 4D tensor, and set the size along unused dimensions
// > to 1.
let x_shape = [
params.b_size as i32,
params.c_in as i32,
params.l_in as i32,
1,
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
cudnn.create_4d_tensor::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
x_shape,
)?
} else {
let s = src_l.stride();
cudnn.create_4d_tensor_ex::<T>(x_shape, [s[0] as i32, s[1] as i32, s[2] as i32, 1i32])?
};
let w = cudnn.create_4d_filter::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[
params.c_out as i32,
params.c_in as i32,
params.k_size as i32,
1,
],
)?;
let l_out = params.l_out() as i32;
let y = cudnn.create_4d_tensor::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, l_out, 1],
)?;
let conv1d = ConvForward {
conv: &conv,
x: &x,
w: &w,
y: &y,
};
let alg = match params.cudnn_fwd_algo {
None => conv1d.pick_algorithm()?,
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
Some(CandleAlgo::ImplicitPrecompGemm) => {
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
}
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
};
let workspace_size = conv1d.get_workspace_size(alg)?;
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
unsafe {
conv1d.launch::<CudaSlice<u8>, _, _, _>(
alg,
Some(&mut workspace),
(T::one(), T::zero()),
src,
filter,
dst,
)?;
}
Ok(())
}

View File

@ -46,11 +46,61 @@ impl std::fmt::Debug for CudaDevice {
}
}
impl std::ops::Deref for CudaDevice {
type Target = Arc<cudarc::driver::CudaStream>;
impl CudaDevice {
#[allow(clippy::missing_safety_doc)]
pub unsafe fn alloc<T: cudarc::driver::DeviceRepr>(
&self,
len: usize,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.alloc::<T>(len).w()
}
fn deref(&self) -> &Self::Target {
&self.stream
pub fn alloc_zeros<T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits>(
&self,
len: usize,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.alloc_zeros::<T>(len).w()
}
pub fn memcpy_htod<
T: cudarc::driver::DeviceRepr,
Src: cudarc::driver::HostSlice<T> + ?Sized,
Dst: cudarc::driver::DevicePtrMut<T>,
>(
&self,
src: &Src,
dst: &mut Dst,
) -> Result<()> {
self.stream.memcpy_htod(src, dst).w()
}
pub fn memcpy_dtov<T: cudarc::driver::DeviceRepr, Src: cudarc::driver::DevicePtr<T>>(
&self,
src: &Src,
) -> Result<Vec<T>> {
self.stream.memcpy_dtov(src).w()
}
pub fn memcpy_dtod<
T,
Src: cudarc::driver::DevicePtr<T>,
Dst: cudarc::driver::DevicePtrMut<T>,
>(
&self,
src: &Src,
dst: &mut Dst,
) -> Result<()> {
self.stream.memcpy_dtod(src, dst).w()
}
pub fn memcpy_stod<
T: cudarc::driver::DeviceRepr,
Src: cudarc::driver::HostSlice<T> + ?Sized,
>(
&self,
src: &Src,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.memcpy_stod(src).w()
}
}
@ -94,6 +144,24 @@ impl CudaDevice {
self.stream.clone()
}
/// When turned on, all cuda tensors **created after calling this function** will
/// not track uses via cuda events.
///
/// # Safety
///
/// It is up to the user to ensure proper synchronization between multiple streams:
/// - Ensure that no tensor is freed before a use on another stream is finished.
/// - Ensure that a tensor is not used on another stream before allocation on the
/// allocating stream finishes.
/// - Ensure that a tensor is not written two concurrently by multiple streams.
pub unsafe fn disable_event_tracking(&self) {
self.context.disable_event_tracking()
}
pub fn is_event_tracking(&self) -> bool {
self.context.is_event_tracking()
}
#[cfg(not(target_arch = "wasm32"))]
pub fn compile(
&self,
@ -126,7 +194,7 @@ impl CudaDevice {
let slice = match dtype {
DType::U8 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
let data = unsafe { self.alloc::<u8>(elem_count)? };
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u8;
@ -138,7 +206,7 @@ impl CudaDevice {
}
DType::U32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
let data = unsafe { self.alloc::<u32>(elem_count)? };
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u32;
@ -150,7 +218,7 @@ impl CudaDevice {
}
DType::I64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
let data = unsafe { self.alloc::<i64>(elem_count)? };
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as i64;
@ -162,7 +230,7 @@ impl CudaDevice {
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
let data = unsafe { self.alloc::<bf16>(elem_count)? };
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = bf16::from_f64(v);
@ -174,7 +242,7 @@ impl CudaDevice {
}
DType::F16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
let data = unsafe { self.alloc::<f16>(elem_count)? };
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = f16::from_f64(v);
@ -186,7 +254,7 @@ impl CudaDevice {
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
let data = unsafe { self.alloc::<f32>(elem_count)? };
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as f32;
@ -198,7 +266,7 @@ impl CudaDevice {
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
let data = unsafe { self.alloc::<f64>(elem_count) }?;
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
builder.arg(&data);
@ -325,31 +393,31 @@ impl BackendDevice for CudaDevice {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc_zeros::<u8>(elem_count).w()?;
let data = self.alloc_zeros::<u8>(elem_count)?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc_zeros::<u32>(elem_count).w()?;
let data = self.alloc_zeros::<u32>(elem_count)?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc_zeros::<i64>(elem_count).w()?;
let data = self.alloc_zeros::<i64>(elem_count)?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
let data = self.alloc_zeros::<bf16>(elem_count)?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc_zeros::<f16>(elem_count).w()?;
let data = self.alloc_zeros::<f16>(elem_count)?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc_zeros::<f32>(elem_count).w()?;
let data = self.alloc_zeros::<f32>(elem_count)?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc_zeros::<f64>(elem_count).w()?;
let data = self.alloc_zeros::<f64>(elem_count)?;
CudaStorageSlice::F64(data)
}
};
@ -373,12 +441,12 @@ impl BackendDevice for CudaDevice {
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
let mut data = unsafe { self.alloc::<f32>(elem_count)? };
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
let mut data = unsafe { self.alloc::<f64>(elem_count)? };
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F64(data)
}
@ -417,7 +485,7 @@ impl BackendDevice for CudaDevice {
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
let mut data = unsafe { self.alloc::<f32>(elem_count_round)? };
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
@ -425,7 +493,7 @@ impl BackendDevice for CudaDevice {
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
let mut data = unsafe { self.alloc::<f64>(elem_count_round)? };
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
@ -444,31 +512,31 @@ impl BackendDevice for CudaDevice {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc::<u8>(elem_count).w()?;
let data = self.alloc::<u8>(elem_count)?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc::<u32>(elem_count).w()?;
let data = self.alloc::<u32>(elem_count)?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc::<i64>(elem_count).w()?;
let data = self.alloc::<i64>(elem_count)?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc::<bf16>(elem_count).w()?;
let data = self.alloc::<bf16>(elem_count)?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc::<f16>(elem_count).w()?;
let data = self.alloc::<f16>(elem_count)?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc::<f32>(elem_count).w()?;
let data = self.alloc::<f32>(elem_count)?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc::<f64>(elem_count).w()?;
let data = self.alloc::<f64>(elem_count)?;
CudaStorageSlice::F64(data)
}
};
@ -481,31 +549,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let slice = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => {
let data = self.memcpy_stod(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::U8(data)
}
CpuStorageRef::U32(storage) => {
let data = self.memcpy_stod(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::U32(data)
}
CpuStorageRef::I64(storage) => {
let data = self.memcpy_stod(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::I64(data)
}
CpuStorageRef::BF16(storage) => {
let data = self.memcpy_stod(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::BF16(data)
}
CpuStorageRef::F16(storage) => {
let data = self.memcpy_stod(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F16(data)
}
CpuStorageRef::F32(storage) => {
let data = self.memcpy_stod(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F32(data)
}
CpuStorageRef::F64(storage) => {
let data = self.memcpy_stod(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F64(data)
}
};
@ -518,31 +586,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.memcpy_stod(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.memcpy_stod(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.memcpy_stod(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.memcpy_stod(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.memcpy_stod(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.memcpy_stod(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.memcpy_stod(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F64(data)
}
};
@ -555,31 +623,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.memcpy_stod(&storage).w()?;
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.memcpy_stod(&storage).w()?;
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.memcpy_stod(&storage).w()?;
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.memcpy_stod(&storage).w()?;
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.memcpy_stod(&storage).w()?;
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.memcpy_stod(&storage).w()?;
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.memcpy_stod(&storage).w()?;
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::F64(data)
}
};

View File

@ -39,7 +39,7 @@ impl SlicePtrOrNull<usize> {
let ds = if l.is_contiguous() {
SlicePtrOrNull::Null
} else {
SlicePtrOrNull::Ptr(dev.memcpy_stod(&[l.dims(), l.stride()].concat()).w()?)
SlicePtrOrNull::Ptr(dev.memcpy_stod(&[l.dims(), l.stride()].concat())?)
};
Ok(ds)
}
@ -89,7 +89,7 @@ impl Map1 for Affine {
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), &kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let out = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -120,7 +120,7 @@ impl Map1 for Elu {
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), &kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let out = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -134,6 +134,7 @@ impl Map1 for Elu {
}
}
#[allow(unused)]
struct Im2Col1D {
l_k: usize,
stride: usize,
@ -142,6 +143,7 @@ struct Im2Col1D {
}
impl Im2Col1D {
#[allow(unused)]
fn l_out(&self, l: usize) -> usize {
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
}
@ -157,15 +159,15 @@ impl Map1 for Im2Col1D {
let shape = layout.shape();
let dims = shape.dims();
let l_out = self.l_out(dims[2]);
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat()).w()?;
let threads = dims[0] * l_out * dims[1];
let cfg = LaunchConfig::for_num_elems(threads as u32);
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), &kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let dst = unsafe { dev.alloc::<T>(threads * self.l_k)? };
let mut builder = func.builder();
barg!(builder, dst_el);
barg!(builder, threads);
barg!(builder, l_out);
barg!(builder, self.l_k);
barg!(builder, self.stride);
@ -210,11 +212,11 @@ impl Map1 for Im2Col {
let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat()).w()?;
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), &kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let dst = unsafe { dev.alloc::<T>(dst_el)? };
let mut builder = func.builder();
barg!(builder, dst_el);
barg!(builder, h_out);
@ -249,7 +251,7 @@ impl Map1 for Powf {
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), &kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let out = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -302,9 +304,7 @@ impl Map1Any for FastReduce<'_> {
block_dim: (block_dim as u32, 1, 1),
shared_mem_bytes: 0,
};
let ds = dev
.memcpy_stod(&[dims.as_slice(), stride.as_slice()].concat())
.w()?;
let ds = dev.memcpy_stod(&[dims.as_slice(), stride.as_slice()].concat())?;
let src = &src.slice(layout.start_offset()..);
let (name, check_empty, return_index) = match self.1 {
ReduceOp::Sum => ("fast_sum", false, false),
@ -319,7 +319,7 @@ impl Map1Any for FastReduce<'_> {
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::REDUCE)?;
if return_index {
// SAFETY: filled in by the follow up kernel.
let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
let out = unsafe { dev.alloc::<u32>(dst_el)? };
let mut builder = func.builder();
barg!(builder, src_el);
barg!(builder, el_to_sum_per_block);
@ -332,7 +332,7 @@ impl Map1Any for FastReduce<'_> {
Ok(S::U32(out))
} else {
// SAFETY: filled in by the follow up kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let out = unsafe { dev.alloc::<T>(dst_el)? };
let mut builder = func.builder();
barg!(builder, src_el);
barg!(builder, el_to_sum_per_block);
@ -362,7 +362,7 @@ impl<U: UnaryOpT> Map1 for U {
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), &kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let mut out = unsafe { dev.alloc::<T>(el_count) }.w()?;
let mut out = unsafe { dev.alloc::<T>(el_count)? };
let mut builder = func.builder();
barg!(builder, el_count);
barg!(builder, dims.len());
@ -403,7 +403,7 @@ impl Map1 for IndexSelect<'_> {
};
let ids_shape = ids_l.shape();
let ids_dims = ids_shape.dims();
let ds = dev.memcpy_stod(&[ids_dims, ids_l.stride()].concat()).w()?;
let ds = dev.memcpy_stod(&[ids_dims, ids_l.stride()].concat())?;
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
@ -416,7 +416,7 @@ impl Map1 for IndexSelect<'_> {
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let out = unsafe { dev.alloc::<T>(dst_el)? };
let mut builder = func.builder();
barg!(builder, dst_el);
barg!(builder, ids_dims.len());
@ -471,7 +471,7 @@ impl Map1 for Gather<'_> {
let ids_dim_sz = ids_l.dims()[dim];
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let out = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, ids);
@ -608,7 +608,7 @@ impl Map2 for Conv1D<'_> {
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), &kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let out = unsafe { dev.alloc::<T>(dst_el)? };
let ds = if dims.len() == 3 {
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
} else if dims.len() == 2 {
@ -616,7 +616,7 @@ impl Map2 for Conv1D<'_> {
} else {
crate::bail!("unexpected input shape for conv1d {dims:?}")
};
let ds = dev.memcpy_stod(&ds).w()?;
let ds = dev.memcpy_stod(&ds)?;
let mut builder = func.builder();
barg!(builder, el, l_out, p.stride, p.padding, p.dilation);
builder.arg(&ds);
@ -651,7 +651,7 @@ impl Map2 for Conv2D<'_> {
let el = shape.elem_count();
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let out = unsafe { dev.alloc::<T>(dst_el)? };
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv2d"), &kernels::CONV)?;
let ds = if dims.len() == 4 {
@ -659,7 +659,7 @@ impl Map2 for Conv2D<'_> {
} else {
crate::bail!("unexpected input shape for conv2d {dims:?}")
};
let ds = dev.memcpy_stod(&ds).w()?;
let ds = dev.memcpy_stod(&ds)?;
let mut builder = func.builder();
barg!(builder, el, out_w, out_h, p.stride, p.padding, p.dilation);
builder.arg(&ds);
@ -687,7 +687,7 @@ impl Map1 for Col2Im1D {
let stride = self.stride;
let l_out = (l_in - 1) * stride + k_size;
let dst_el = b_size * c_out * l_out;
let mut im = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let mut im = unsafe { dev.alloc::<T>(dst_el)? };
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), &kernels::CONV)?;
@ -722,7 +722,7 @@ impl Map2 for ConvTranspose1D<'_> {
let el = shape.elem_count();
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let out = unsafe { dev.alloc::<T>(dst_el)? };
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), &kernels::CONV)?;
let ds = if dims.len() == 3 {
@ -730,7 +730,7 @@ impl Map2 for ConvTranspose1D<'_> {
} else {
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
};
let ds = dev.memcpy_stod(&ds).w()?;
let ds = dev.memcpy_stod(&ds)?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, l_out);
@ -770,7 +770,7 @@ impl Map2 for ConvTranspose2D<'_> {
let el = shape.elem_count();
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let out = unsafe { dev.alloc::<T>(dst_el)? };
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), &kernels::CONV)?;
let ds = if dims.len() == 4 {
@ -778,7 +778,7 @@ impl Map2 for ConvTranspose2D<'_> {
} else {
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}")
};
let ds = dev.memcpy_stod(&ds).w()?;
let ds = dev.memcpy_stod(&ds)?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, out_w);
@ -837,8 +837,8 @@ impl Map1 for Pool2D {
};
let func = dev.get_or_load_func(&kernel_name::<T>(kname), &kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let ds = dev.memcpy_stod(&ds).w()?;
let out = unsafe { dev.alloc::<T>(dst_el)? };
let ds = dev.memcpy_stod(&ds)?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, self.w_k);
@ -876,8 +876,8 @@ impl Map1 for UpsampleNearest2D {
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("upsample_nearest2d"), &kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let ds = dev.memcpy_stod(&ds).w()?;
let out = unsafe { dev.alloc::<T>(dst_el)? };
let ds = dev.memcpy_stod(&ds)?;
let scale_w = dims[2] as f64 / out_w as f64;
let scale_h = dims[3] as f64 / out_h as f64;
let mut builder = func.builder();
@ -930,13 +930,12 @@ impl Map2 for WhereCond<'_> {
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev
.memcpy_stod(&[dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())
.w()?;
.memcpy_stod(&[dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?;
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let out = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -967,16 +966,13 @@ impl<U: crate::op::BinaryOpT> Map2 for U {
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
SlicePtrOrNull::Null
} else {
SlicePtrOrNull::Ptr(
dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())
.w()?,
)
SlicePtrOrNull::Ptr(dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())?)
};
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), &kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let out = unsafe { dev.alloc::<T>(elem_count)? };
let mut builder = func.builder();
barg!(builder, elem_count);
barg!(builder, dims.len());
@ -1007,10 +1003,7 @@ impl Map2Any for Cmp {
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
SlicePtrOrNull::Null
} else {
SlicePtrOrNull::Ptr(
dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())
.w()?,
)
SlicePtrOrNull::Ptr(dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())?)
};
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
@ -1024,7 +1017,7 @@ impl Map2Any for Cmp {
};
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
let out = unsafe { dev.alloc::<u8>(elem_count)? };
let mut builder = func.builder();
barg!(builder, elem_count);
barg!(builder, dims.len());
@ -1208,7 +1201,6 @@ fn gemm_config<T>(
mnk: (m, n, k),
})?,
};
Ok(StridedBatchedConfig {
batch_size: b as i32,
gemm,
@ -1269,7 +1261,7 @@ impl BackendStorage for CudaStorage {
let func = dev.get_or_load_func(&kernel_name, &kernels::CAST)?;
let slice = match dtype {
DType::U8 => {
let out = unsafe { dev.alloc::<u8>(el) }.w()?;
let out = unsafe { dev.alloc::<u8>(el)? };
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -1280,7 +1272,7 @@ impl BackendStorage for CudaStorage {
CudaStorageSlice::U8(out)
}
DType::U32 => {
let out = unsafe { dev.alloc::<u32>(el) }.w()?;
let out = unsafe { dev.alloc::<u32>(el)? };
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -1291,7 +1283,7 @@ impl BackendStorage for CudaStorage {
CudaStorageSlice::U32(out)
}
DType::I64 => {
let out = unsafe { dev.alloc::<i64>(el) }.w()?;
let out = unsafe { dev.alloc::<i64>(el)? };
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -1302,7 +1294,7 @@ impl BackendStorage for CudaStorage {
CudaStorageSlice::I64(out)
}
DType::BF16 => {
let out = unsafe { dev.alloc::<bf16>(el) }.w()?;
let out = unsafe { dev.alloc::<bf16>(el)? };
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -1313,7 +1305,7 @@ impl BackendStorage for CudaStorage {
CudaStorageSlice::BF16(out)
}
DType::F16 => {
let out = unsafe { dev.alloc::<f16>(el) }.w()?;
let out = unsafe { dev.alloc::<f16>(el)? };
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -1324,7 +1316,7 @@ impl BackendStorage for CudaStorage {
CudaStorageSlice::F16(out)
}
DType::F32 => {
let out = unsafe { dev.alloc::<f32>(el) }.w()?;
let out = unsafe { dev.alloc::<f32>(el)? };
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -1335,7 +1327,7 @@ impl BackendStorage for CudaStorage {
CudaStorageSlice::F32(out)
}
DType::F64 => {
let out = unsafe { dev.alloc::<f64>(el) }.w()?;
let out = unsafe { dev.alloc::<f64>(el)? };
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -1445,6 +1437,7 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
#[cfg(not(feature = "cudnn"))]
fn conv1d(
&self,
l: &Layout,
@ -1473,12 +1466,11 @@ impl BackendStorage for CudaStorage {
let n = params.c_out;
let k = params.k_size * params.c_in;
let m = l_out;
let col_l = Layout::contiguous((b, m, k));
let col_l = Layout::contiguous((b * m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
let kernel_l =
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = unsafe {
@ -1486,10 +1478,9 @@ impl BackendStorage for CudaStorage {
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
};
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
let kernel_l =
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
@ -1497,6 +1488,72 @@ impl BackendStorage for CudaStorage {
Ok(res_t)
}
#[cfg(feature = "cudnn")]
fn conv1d(
&self,
inp_l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConv1D,
) -> Result<Self> {
let device = self.device().clone();
if !kernel_l.is_contiguous() {
let slice = Conv1D(params).map(&self.slice, inp_l, &kernel.slice, kernel_l, &device)?;
return Ok(Self { slice, device });
}
let l_out = params.l_out();
let dst_el = params.c_out * l_out * params.b_size;
let slice = match (&self.slice, &kernel.slice) {
(S::U8(inp), S::U8(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<u8>(dst_el)? };
crate::cudnn::launch_conv1d::<u8, u8>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::U8(out)
}
(S::BF16(inp), S::BF16(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<bf16>(dst_el)? };
// Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16"
// version.
// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88
crate::cudnn::launch_conv1d::<bf16, f32>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::BF16(out)
}
(S::F16(inp), S::F16(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f16>(dst_el)? };
crate::cudnn::launch_conv1d::<f16, f16>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F16(out)
}
(S::F32(inp), S::F32(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f32>(dst_el)? };
crate::cudnn::launch_conv1d::<f32, f32>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F32(out)
}
(S::F64(inp), S::F64(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f64>(dst_el)? };
crate::cudnn::launch_conv1d::<f64, f64>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F64(out)
}
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv1d does not support u32"))?,
(S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv1d does not support i64"))?,
_ => Err(CudaError::InternalError("dtype mismatch in conv1d"))?,
};
Ok(Self { slice, device })
}
fn conv_transpose1d(
&self,
l: &Layout,
@ -1587,12 +1644,11 @@ impl BackendStorage for CudaStorage {
let n = params.c_out;
let k = params.k_h * params.k_w * params.c_in;
let m = h_out * w_out;
let col_l = Layout::contiguous((b, m, k));
let col_l = Layout::contiguous((b * m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
let kernel_l =
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = unsafe {
@ -1600,10 +1656,9 @@ impl BackendStorage for CudaStorage {
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
};
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
let kernel_l =
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)?
@ -1632,7 +1687,7 @@ impl BackendStorage for CudaStorage {
(S::U8(inp), S::U8(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<u8>(dst_el) }.w()?;
let mut out = unsafe { device.alloc::<u8>(dst_el)? };
crate::cudnn::launch_conv2d::<u8, u8>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::U8(out)
@ -1640,7 +1695,7 @@ impl BackendStorage for CudaStorage {
(S::BF16(inp), S::BF16(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<bf16>(dst_el) }.w()?;
let mut out = unsafe { device.alloc::<bf16>(dst_el)? };
// Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16"
// version.
// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88
@ -1651,7 +1706,7 @@ impl BackendStorage for CudaStorage {
(S::F16(inp), S::F16(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f16>(dst_el) }.w()?;
let mut out = unsafe { device.alloc::<f16>(dst_el)? };
crate::cudnn::launch_conv2d::<f16, f16>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F16(out)
@ -1659,7 +1714,7 @@ impl BackendStorage for CudaStorage {
(S::F32(inp), S::F32(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f32>(dst_el) }.w()?;
let mut out = unsafe { device.alloc::<f32>(dst_el)? };
crate::cudnn::launch_conv2d::<f32, f32>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F32(out)
@ -1667,7 +1722,7 @@ impl BackendStorage for CudaStorage {
(S::F64(inp), S::F64(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f64>(dst_el) }.w()?;
let mut out = unsafe { device.alloc::<f64>(dst_el)? };
crate::cudnn::launch_conv2d::<f64, f64>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F64(out)
@ -1783,7 +1838,7 @@ impl BackendStorage for CudaStorage {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
let mut out = unsafe { dev.alloc::<bf16>(elem_count)? };
unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, &mut out) }
.w()?;
CudaStorageSlice::BF16(out)
@ -1792,7 +1847,7 @@ impl BackendStorage for CudaStorage {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
let mut out = unsafe { dev.alloc::<f16>(elem_count)? };
unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, &mut out) }
.w()?;
CudaStorageSlice::F16(out)
@ -1801,7 +1856,7 @@ impl BackendStorage for CudaStorage {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
let mut out = unsafe { dev.alloc::<f32>(elem_count)? };
unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) }
.w()?;
CudaStorageSlice::F32(out)
@ -1810,7 +1865,7 @@ impl BackendStorage for CudaStorage {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<f64>(elem_count) }.w()?;
let mut out = unsafe { dev.alloc::<f64>(elem_count)? };
unsafe {
self.device
.blas
@ -1883,7 +1938,7 @@ impl BackendStorage for CudaStorage {
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.memcpy_dtod(&src, &mut dst).w()?
dev.memcpy_dtod(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_bf16", &kernels::UNARY)?;
let mut builder = func.builder();
@ -1899,7 +1954,7 @@ impl BackendStorage for CudaStorage {
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.memcpy_dtod(&src, &mut dst).w()?
dev.memcpy_dtod(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_f16", &kernels::UNARY)?;
let mut builder = func.builder();
@ -1915,7 +1970,7 @@ impl BackendStorage for CudaStorage {
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.memcpy_dtod(&src, &mut dst).w()?
dev.memcpy_dtod(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_f32", &kernels::UNARY)?;
let mut builder = func.builder();
@ -1931,7 +1986,7 @@ impl BackendStorage for CudaStorage {
(CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.memcpy_dtod(&src, &mut dst).w()?
dev.memcpy_dtod(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_u8", &kernels::UNARY)?;
let mut builder = func.builder();
@ -1947,7 +2002,7 @@ impl BackendStorage for CudaStorage {
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.memcpy_dtod(&src, &mut dst).w()?
dev.memcpy_dtod(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_u32", &kernels::UNARY)?;
let mut builder = func.builder();
@ -1963,7 +2018,7 @@ impl BackendStorage for CudaStorage {
(CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.memcpy_dtod(&src, &mut dst).w()?
dev.memcpy_dtod(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_i64", &kernels::UNARY)?;
let mut builder = func.builder();
@ -1979,7 +2034,7 @@ impl BackendStorage for CudaStorage {
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.memcpy_dtod(&src, &mut dst).w()?
dev.memcpy_dtod(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_f64", &kernels::UNARY)?;
let mut builder = func.builder();

View File

@ -816,7 +816,7 @@ impl PthTensors {
/// # Arguments
/// * `path` - Path to the pth file.
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
/// contains multiple objects and the state_dict is the one we are interested in.
/// contains multiple objects and the state_dict is the one we are interested in.
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
path: P,
key: Option<&str>,

View File

@ -73,7 +73,7 @@ fn dequantize_f32(
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
let nb = (elem_count + 255) / 256;
let nb = elem_count.div_ceil(256);
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
@ -99,7 +99,7 @@ fn dequantize_f32(
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
let dst = unsafe { dev.alloc::<f32>(elem_count)? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
@ -133,7 +133,7 @@ fn dequantize_f16(
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
let nb = (elem_count + 255) / 256;
let nb = elem_count.div_ceil(256);
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
@ -159,7 +159,7 @@ fn dequantize_f16(
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
let dst = unsafe { dev.alloc::<f16>(elem_count)? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
@ -216,7 +216,7 @@ fn dequantize_mul_mat_vec(
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
let dst = unsafe { dev.alloc::<f32>(nrows)? };
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (block_num_y as u32, 1, 1),
@ -256,7 +256,7 @@ fn mul_mat_vec_via_q8_1(
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
let y_size_in_bytes =
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
let kernel_name = match dtype {
@ -274,12 +274,12 @@ fn mul_mat_vec_via_q8_1(
};
let kernel_name = format!("{kernel_name}{b_size}");
let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
let dst = unsafe { dev.alloc::<f32>(nrows * b_size)? };
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
let (nblocks, nwarps) = match b_size {
1 => (nrows as u32, 4),
2..=4 => ((nrows as u32 + 1) / 2, 4),
5..=8 => ((nrows as u32 + 1) / 2, 2),
2..=4 => ((nrows as u32).div_ceil(2), 4),
5..=8 => ((nrows as u32).div_ceil(2), 2),
_ => crate::bail!("unexpected bsize {b_size}"),
};
let cfg = cudarc::driver::LaunchConfig {
@ -329,7 +329,7 @@ fn mul_mat_via_q8_1(
let k_padded = pad(k, MATRIX_ROW_PADDING);
let y_size_in_bytes =
k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
let (kernel_name, mmq_x, mmq_y) = match dtype {
@ -346,7 +346,7 @@ fn mul_mat_via_q8_1(
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols)? };
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (
ceil_div(x_rows, mmq_y) as u32,
@ -378,7 +378,7 @@ impl QCudaStorage {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
let padded_size_in_bytes =
ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size();
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes).w()?;
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes)?;
Ok(QCudaStorage {
data: PaddedCudaSlice {
inner,
@ -425,8 +425,7 @@ impl QCudaStorage {
let buffer = self
.device
.memcpy_dtov(&self.data.inner.slice(..self.data.len))
.w()?;
.memcpy_dtov(&self.data.inner.slice(..self.data.len))?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
@ -457,9 +456,7 @@ impl QCudaStorage {
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
// Run the quantization on cpu.
let src = match &src.slice {
crate::cuda_backend::CudaStorageSlice::F32(data) => {
self.device.memcpy_dtov(data).w()?
}
crate::cuda_backend::CudaStorageSlice::F32(data) => self.device.memcpy_dtov(data)?,
_ => crate::bail!("only f32 can be quantized"),
};
let src_len = src.len();
@ -469,10 +466,9 @@ impl QCudaStorage {
let data = qcpu_storage.data()?;
let padded_len =
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
let mut inner = unsafe { self.device.alloc::<u8>(padded_len)? };
self.device
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))
.w()?;
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))?;
self.data = PaddedCudaSlice {
inner,
len: data.len(),
@ -606,10 +602,8 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
};
let dtype = T::DTYPE;
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
device
.memcpy_htod(data, &mut inner.slice_mut(..data.len()))
.w()?;
let mut inner = unsafe { device.alloc::<u8>(padded_len)? };
device.memcpy_htod(data, &mut inner.slice_mut(..data.len()))?;
Ok(QStorage::Cuda(QCudaStorage {
data: PaddedCudaSlice {
inner,
@ -631,9 +625,9 @@ mod test {
let el_padded = pad(el, MATRIX_ROW_PADDING);
let y_size_in_bytes =
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
let y = dev.memcpy_stod(&vs).w()?;
let y = dev.memcpy_stod(&vs)?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
Ok(())
}
@ -643,7 +637,7 @@ mod test {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
let y = dev.memcpy_stod(&vs).w()?;
let y = dev.memcpy_stod(&vs)?;
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_vec_via_q8_1(
@ -656,7 +650,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
let vs = dev.memcpy_dtov(&vs.slice(..))?;
assert_eq!(vs.len(), 1);
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
// Q8 means 1/256 precision.
@ -671,7 +665,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
let vs = dev.memcpy_dtov(&vs.slice(..))?;
assert_eq!(vs.len(), 1);
assert_eq!(vs[0], 5561851.0);
Ok(())
@ -682,7 +676,7 @@ mod test {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
let y = dev.memcpy_stod(&vs).w()?;
let y = dev.memcpy_stod(&vs)?;
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
@ -696,7 +690,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
let vs = dev.memcpy_dtov(&vs.slice(..))?;
/*
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
@ -723,7 +717,7 @@ mod test {
let dev = CudaDevice::new(0)?;
let (x_rows, ncols, y_cols) = (4, 16, 2048);
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
let y = dev.memcpy_stod(&vs).w()?;
let y = dev.memcpy_stod(&vs)?;
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
@ -737,7 +731,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let _vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
let _vs = dev.memcpy_dtov(&vs.slice(..))?;
Ok(())
}
}

View File

@ -76,7 +76,7 @@ mod cuda {
Some((o1, o2)) => src.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
let dst = unsafe { dev.alloc::<u32>(elem_count)? };
let func = if self.asc {
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
} else {

View File

@ -53,6 +53,20 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
let res = {
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?
};
assert_eq!(res.dims(), [3, 2, 5]);
// Same as pytorch default padding: use zeros.
assert_eq!(
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
);
assert_eq!(
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
let w = w.transpose(0, 1)?;
// The CPU kernels applied in the contiguous and non contiguous cases are different.
@ -163,6 +177,22 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = {
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
t.conv2d(&w, 0, 1, 1, 1)?
};
assert_eq!(res.dims(), [3, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
);
assert_eq!(
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
[
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;

View File

@ -60,7 +60,7 @@ bindgen_cuda = { version = "0.1.1", optional = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
cudnn = ["candle/cudnn"]
cudnn = ["candle/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
@ -69,6 +69,7 @@ metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal", "rubato"]
encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"]
snac = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
[[example]]
@ -107,6 +108,10 @@ required-features = ["candle-datasets"]
name = "mimi"
required-features = ["mimi"]
[[example]]
name = "snac"
required-features = ["snac"]
[[example]]
name = "encodec"
required-features = ["encodec"]

View File

@ -0,0 +1,14 @@
# Conversational Speech Model (CSM)
CSM is a speech generation model from Sesame,
[SesameAILabs/csm](https://github.com/SesameAILabs/csm).
It can generate a conversational speech between two different speakers.
The speakers turn are delimited by the `|` character in the prompt.
```bash
cargo run --example csm --features cuda -r -- \
--voices candle-examples/examples/csm/voices.safetensors \
--prompt "Hey how are you doing?|Pretty good, pretty good. How about you?"
```

View File

@ -0,0 +1,243 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::csm::{Config, Model};
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "1b")]
Csm1b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
/// The prompt to be used for the generation, use a | to separate the speakers.
#[arg(long, default_value = "Hey how are you doing today?")]
prompt: String,
/// The voices to be used, in safetensors format.
#[arg(long)]
voices: String,
/// The output file using the wav format.
#[arg(long, default_value = "out.wav")]
out_file: String,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.7)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
/// The model size to use.
#[arg(long, default_value = "1b")]
which: Which,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
config: Option<String>,
#[arg(long)]
weights: Option<String>,
/// The mimi model weight file, in safetensor format.
#[arg(long)]
mimi_weights: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => {
let name = match args.which {
Which::Csm1b => "sesame/csm-1b",
};
name.to_string()
}
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let filenames = match args.weights {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => vec![repo.get("model.safetensors")?],
};
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("meta-llama/Llama-3.2-1B".to_string())
.get("tokenizer.json")?,
};
let mimi_filename = match args.mimi_weights {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.model("kyutai/mimi".to_string())
.get("model.safetensors")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config: Config = match args.config {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
};
let device = candle_examples::device(args.cpu)?;
let (mut model, device) = {
let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
(model, device)
};
let mut mimi_model = {
use candle_transformers::models::mimi;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[mimi_filename], DType::F32, &device)? };
let config = mimi::Config::v0_1(Some(32));
mimi::Model::new(config, vb)?
};
let cb = config.audio_num_codebooks;
println!("loaded the model in {:?}", start.elapsed());
let voices = candle::safetensors::load(args.voices, &device)?;
let mut lp = candle_transformers::generation::LogitsProcessor::new(
args.seed,
Some(args.temperature),
None,
);
let tokens = voices
.get("tokens")
.expect("no tokens in prompt")
.to_dtype(DType::U32)?;
let mask = voices.get("mask").expect("no mask in prompt").clone();
let mut pos = 0;
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
pos += tokens.dim(1)?;
let mut all_pcms = vec![];
for (turn_idx, prompt) in args.prompt.split('|').enumerate() {
println!("{prompt:?}");
let speaker_idx = turn_idx % 2;
let prompt = format!("[{speaker_idx}]{}<|end_of_text|>", prompt);
let prompt = tokenizer.encode(prompt, true).map_err(E::msg)?;
let (mut tokens, mut mask) = model.text_tokens_and_mask(prompt.get_ids())?;
let mut generated_tokens = vec![];
loop {
let frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
pos += tokens.dim(1)?;
let is_done = frame.iter().all(|&x| x == 0);
(tokens, mask) = model.audio_tokens_and_mask(frame)?;
print!("\rframe {pos}");
if is_done {
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
pos += tokens.dim(1)?;
break;
}
generated_tokens.push(tokens.clone());
}
println!();
let generated_tokens = Tensor::cat(&generated_tokens, 1)?.narrow(2, 0, cb)?.t()?;
let pcm = mimi_model.decode(&generated_tokens)?;
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
all_pcms.push(pcm);
}
let pcm = Tensor::cat(&all_pcms, 0)?;
let pcm = pcm.to_vec1::<f32>()?;
println!("writing output file {}", args.out_file);
let mut output = std::fs::File::create(args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
Ok(())
}

Binary file not shown.

View File

@ -68,7 +68,7 @@ impl CustomOp1 for LayerNorm {
Some((o1, o2)) => slice.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
let dst = unsafe { dev.alloc::<f32>(elem_count) }?;
let func =
dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?;
let cfg = LaunchConfig {

View File

@ -8,7 +8,7 @@ DistilBert is used to compute the sentence embeddings for a prompt. The model we
are downloaded from the hub on the first run.
```bash
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
$ cargo run --example distilbert --release -- --prompt "Here is a test sentence"
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
@ -20,3 +20,25 @@ cargo run --example distilbert --release -- --prompt "Here is a test sentence"
> Tensor[[1, 7, 768], f32]
```
## Masked Token
DistilBert is used to compute the top K choices for a masked token.
```bash
$ cargo run --example distilbert -- --prompt "The capital of France is [MASK]." --top-k 10
> Input: The capital of France is [MASK].
> Predictions for [MASK] at position 6:
> 1: marseille (probability: 12.14%)
> 2: paris (probability: 10.84%)
> 3: toulouse (probability: 8.57%)
> 4: lyon (probability: 7.61%)
> 5: montpellier (probability: 5.18%)
> 6: bordeaux (probability: 4.88%)
> 7: nantes (probability: 4.82%)
> 8: lille (probability: 4.07%)
> 9: strasbourg (probability: 3.12%)
> 10: cannes (probability: 3.04%)
```

View File

@ -3,15 +3,48 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
use candle_transformers::models::distilbert::{
Config, DistilBertForMaskedLM, DistilBertModel, DTYPE,
};
use anyhow::{Error as E, Result};
use anyhow::{Context, Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::path::PathBuf;
use tokenizers::Tokenizer;
enum ModelType {
Masked(DistilBertForMaskedLM),
UnMasked(DistilBertModel),
}
impl ModelType {
fn device(&self) -> &Device {
match self {
ModelType::Masked(model) => &model.bert.device,
ModelType::UnMasked(model) => &model.device,
}
}
fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
match self {
ModelType::Masked(model) => Ok(model.forward(input_ids, attention_mask)?),
ModelType::UnMasked(model) => Ok(model.forward(input_ids, attention_mask)?),
}
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "distilbert")]
DistilBert,
#[value(name = "distilbertformaskedlm")]
DistilbertForMaskedLM,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -23,10 +56,14 @@ struct Args {
#[arg(long)]
tracing: bool,
#[arg(long, default_value = "distilbert")]
model: Which,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
/// Revision or branch
#[arg(long)]
revision: Option<String>,
@ -42,94 +79,246 @@ struct Args {
#[arg(long, default_value = "1")]
n: usize,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
/// Number of top predictions to show for each mask
#[arg(long, default_value = "5")]
top_k: usize,
}
impl Args {
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
fn build_model_and_tokenizer(&self) -> Result<(ModelType, Tokenizer)> {
let device = candle_examples::device(self.cpu)?;
let (model_id, revision) = self.resolve_model_and_revision();
let (config_path, tokenizer_path, weights_path) =
self.download_model_files(&model_id, &revision)?;
let config = std::fs::read_to_string(config_path)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
let vb = self.load_variables(&weights_path, &device)?;
let model = self.create_model(&config, vb)?;
Ok((model, tokenizer))
}
fn resolve_model_and_revision(&self) -> (String, String) {
let default_model = "distilbert-base-uncased".to_string();
let default_revision = "main".to_string();
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
match (self.model_id.clone(), self.revision.clone()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(Some(model_id), None) => (model_id, default_revision),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
}
}
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = if self.use_pth {
api.get("pytorch_model.bin")?
} else {
api.get("model.safetensors")?
};
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
fn download_model_files(
&self,
model_id: &str,
revision: &str,
) -> Result<(PathBuf, PathBuf, PathBuf)> {
let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string());
let api = Api::new()?;
let api = api.repo(repo);
let vb = if self.use_pth {
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = if self.use_pth {
api.get("pytorch_model.bin")?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
api.get("model.safetensors")?
};
let model = DistilBertModel::load(vb, &config)?;
Ok((model, tokenizer))
Ok((config, tokenizer, weights))
}
fn load_variables(&self, weights_path: &PathBuf, device: &Device) -> Result<VarBuilder> {
if self.use_pth {
Ok(VarBuilder::from_pth(weights_path, DTYPE, device)?)
} else {
Ok(unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, device)? })
}
}
fn create_model(&self, config: &Config, vb: VarBuilder) -> Result<ModelType> {
match self.model {
Which::DistilbertForMaskedLM => {
Ok(ModelType::Masked(DistilBertForMaskedLM::load(vb, config)?))
}
Which::DistilBert => Ok(ModelType::UnMasked(DistilBertModel::load(vb, config)?)),
}
}
}
fn get_mask(size: usize, device: &Device) -> Tensor {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Tensor::from_slice(&mask, (size, size), device).unwrap()
fn main() -> Result<()> {
let args = Args::parse();
let _guard = setup_tracing(&args);
let (model, tokenizer) = args.build_model_and_tokenizer()?;
let device = model.device();
let (token_ids, mask) = prepare_inputs(&args, &tokenizer, device)?;
let output = model.forward(&token_ids, &mask)?;
process_output(&model, &output, &token_ids, &tokenizer, &args)?;
Ok(())
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
fn setup_tracing(args: &Args) -> Option<impl Drop> {
if args.tracing {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
}
}
let tokenizer = tokenizer
fn prepare_inputs(args: &Args, tokenizer: &Tokenizer, device: &Device) -> Result<(Tensor, Tensor)> {
let mut binding = tokenizer.clone();
let tokenizer_configured = binding
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(args.prompt, true)
let tokens = tokenizer_configured
.encode(args.prompt.clone(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mask = get_mask(tokens.len(), device);
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
println!("mask: {:?}", mask.to_vec2::<u8>());
let mask = match args.model {
Which::DistilbertForMaskedLM => attention_mask_maskedlm(tokenizer, &args.prompt, device)?,
Which::DistilBert => attention_mask(tokens.len(), device)?,
};
let ys = model.forward(&token_ids, &mask)?;
println!("{ys}");
println!("token_ids: {:?}", token_ids.to_vec2::<u32>()?);
Ok((token_ids, mask))
}
fn process_output(
model: &ModelType,
output: &Tensor,
token_ids: &Tensor,
tokenizer: &Tokenizer,
args: &Args,
) -> Result<()> {
match model {
ModelType::UnMasked(_) => {
println!("embeddings");
println!("{output}");
}
ModelType::Masked(_) => {
process_masked_output(output, token_ids, tokenizer, args)?;
}
}
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
fn process_masked_output(
output: &Tensor,
token_ids: &Tensor,
tokenizer: &Tokenizer,
args: &Args,
) -> Result<()> {
let input_ids_vec = token_ids.to_vec2::<u32>()?;
let mask_token_id = tokenizer
.token_to_id("[MASK]")
.context("Mask token, \"[MASK]\", not found in tokenizer.")?;
println!("\nInput: {}", args.prompt);
for (token_idx, &token_id) in input_ids_vec[0].iter().enumerate() {
if token_id == mask_token_id {
println!("Predictions for [MASK] at position {}:", token_idx);
let pos_logits = output.get(0)?.get(token_idx)?;
let probs = candle_nn::ops::softmax(&pos_logits, 0)?;
let (top_values, top_indices) = get_top_k(&probs, args.top_k)?;
let values = top_values.to_vec1::<f32>()?;
let indices = top_indices.to_vec1::<u32>()?;
for (i, (&token_id, &prob)) in indices.iter().zip(values.iter()).enumerate() {
let token = tokenizer.decode(&[token_id], false).map_err(E::msg)?;
println!(
" {}: {:15} (probability: {:.2}%)",
i + 1,
token,
prob * 100.0
);
}
}
}
Ok(())
}
fn get_top_k(tensor: &Tensor, k: usize) -> Result<(Tensor, Tensor)> {
let n = tensor.dims().iter().product::<usize>();
let k = std::cmp::min(k, n);
let values = tensor.to_vec1::<f32>()?;
let mut value_indices: Vec<(f32, usize)> = values
.into_iter()
.enumerate()
.map(|(idx, val)| (val, idx))
.collect();
value_indices.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let top_k_values: Vec<f32> = value_indices.iter().take(k).map(|(val, _)| *val).collect();
let top_k_indices: Vec<u32> = value_indices
.iter()
.take(k)
.map(|(_, idx)| *idx as u32)
.collect();
let device = tensor.device();
let top_values = Tensor::from_vec(top_k_values, (k,), device)?;
let top_indices = Tensor::from_vec(top_k_indices, (k,), device)?;
Ok((top_values, top_indices))
}
fn attention_mask(size: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Ok(Tensor::from_slice(&mask, (size, size), device)?)
}
fn attention_mask_maskedlm(tokenizer: &Tokenizer, input: &str, device: &Device) -> Result<Tensor> {
let tokens = tokenizer.encode(input, true).map_err(E::msg)?;
let seq_len = tokens.get_attention_mask().to_vec().len();
let mask_token_id = tokenizer
.token_to_id("[MASK]")
.context("Mask token, \"[MASK]\", not found in tokenizer.")?;
let mut attention_mask_vec = Vec::with_capacity(seq_len * seq_len);
let ids = tokens.get_ids();
for _ in 0..seq_len {
for id in ids.iter() {
let mask_value = if id == &mask_token_id { 1u8 } else { 0u8 };
attention_mask_vec.push(mask_value);
}
}
let shape = (1, 1, seq_len, seq_len);
let mask = Tensor::from_vec(attention_mask_vec, shape, device)?;
Ok(mask)
}

View File

@ -46,7 +46,7 @@ impl TextGeneration {
Sampling::ArgMax
} else {
match (top_k, top_p) {
(None, None) => Sampling::All { temperature },
(None, None) => Sampling::GumbelSoftmax { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },

View File

@ -256,6 +256,12 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let tokenizer = common_args.tokenizer()?;
let device = candle_examples::device(common_args.cpu)?;
#[cfg(feature = "cuda")]
if let candle::Device::Cuda(d) = &device {
unsafe {
d.disable_event_tracking();
}
};
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
let is_safetensors = config_path

View File

@ -21,7 +21,7 @@ impl Config {
}
fn dt_rank(&self) -> usize {
(self.d_model + 15) / 16
self.d_model.div_ceil(16)
}
fn d_conv(&self) -> usize {

View File

@ -0,0 +1,14 @@
# Orpheus
Orpheus is a 3B text-to-speech model based on Llama.
- Weights on HuggingFace
[canopylabs/orpheus-3b-0.1-ft](https://huggingface.co/canopylabs/orpheus-3b-0.1-ft).
- Code on GitHub [canopyai/Orpheus-TTS](https://github.com/canopyai/Orpheus-TTS).
```bash
cargo run --example orpheus --features cuda -r
```

View File

@ -0,0 +1,329 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle::{DType, Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::llama::{Cache, Llama, LlamaConfig};
use candle_transformers::models::snac::{Config as SnacConfig, Model as SnacModel};
use tokenizers::Tokenizer;
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/realtime_streaming_example/main.py#L43
const STOP_TOKEN_ID: u32 = 128258;
#[derive(Parser)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long, default_value = "Hey, how are you doing today?")]
prompt: String,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.6)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
model_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// The output wav file.
#[arg(long, default_value = "out.wav")]
out_file: String,
#[arg(long, default_value = "3b-0.1-ft")]
which: Which,
#[arg(long, default_value = "tara")]
voice: Voice,
#[arg(long)]
use_flash_attn: bool,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Voice {
#[value(name = "tara")]
Tara,
#[value(name = "leah")]
Leah,
#[value(name = "jess")]
Jess,
#[value(name = "leo")]
Leo,
#[value(name = "dan")]
Dan,
#[value(name = "mia")]
Mia,
#[value(name = "zac")]
Zac,
#[value(name = "zoe")]
Zoe,
}
impl Voice {
fn as_str(&self) -> &'static str {
match self {
Voice::Tara => "tara",
Voice::Leah => "leah",
Voice::Jess => "jess",
Voice::Leo => "leo",
Voice::Dan => "dan",
Voice::Mia => "mia",
Voice::Zac => "zac",
Voice::Zoe => "zoe",
}
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "3b-0.1-ft")]
ThreeB0_1Ft,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let prompt = args.prompt.clone();
let mut model = Model::load(args)?;
model.run(&prompt)?;
Ok(())
}
struct Model {
model: Llama,
tokenizer: Tokenizer,
logits_processor: candle_transformers::generation::LogitsProcessor,
cache: Cache,
device: Device,
verbose_prompt: bool,
snac: SnacModel,
out_file: String,
voice: Voice,
}
fn load_snac(device: &Device) -> Result<SnacModel> {
let api = hf_hub::api::sync::Api::new()?;
let m = api.model("hubertsiuzdak/snac_24khz".to_string());
let config = m.get("config.json")?;
let config: SnacConfig = serde_json::from_reader(std::fs::File::open(config)?)?;
let m = api.model("lmz/candle-snac".to_string());
let model = m.get("snac_24khz.safetensors")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, device)? };
let model = SnacModel::new(&config, vb)?;
Ok(model)
}
impl Model {
fn load(args: Args) -> Result<Self> {
let start = std::time::Instant::now();
let api = hf_hub::api::sync::Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::ThreeB0_1Ft => "canopylabs/orpheus-3b-0.1-ft".to_string(),
},
};
let revision = match args.revision {
Some(r) => r,
None => "main".to_string(),
};
let repo = api.repo(hf_hub::Repo::with_revision(
model_id,
hf_hub::RepoType::Model,
revision,
));
let model_files = match args.model_file {
Some(m) => vec![m.into()],
None => match args.which {
Which::ThreeB0_1Ft => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
};
let config = match args.config_file {
Some(m) => m.into(),
None => repo.get("config.json")?,
};
let tokenizer = match args.tokenizer_file {
Some(m) => m.into(),
None => repo.get("tokenizer.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, dtype, &device)? };
let config: LlamaConfig = serde_json::from_reader(std::fs::File::open(config)?)?;
let config = config.into_config(args.use_flash_attn);
let model = Llama::load(vb, &config)?;
let logits_processor = {
use candle_transformers::generation::{LogitsProcessor, Sampling};
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k.as_ref(), args.top_p.as_ref()) {
(None, None) => Sampling::All { temperature },
(Some(&k), None) => Sampling::TopK { k, temperature },
(None, Some(&p)) => Sampling::TopP { p, temperature },
(Some(&k), Some(&p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
println!("loaded the model in {:?}", start.elapsed());
let cache = Cache::new(true, dtype, &config, &device)?;
let snac = load_snac(&device)?;
Ok(Self {
model,
tokenizer,
logits_processor,
cache,
device,
verbose_prompt: args.verbose_prompt,
snac,
voice: args.voice,
out_file: args.out_file,
})
}
fn run(&mut self, prompt: &str) -> Result<()> {
println!("running the model on '{}'", prompt);
let device = &self.device;
let prompt = format!("{voice}: {prompt}", voice = self.voice.as_str());
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/engine_class.py#L82
let mut tokens = [
&[128259],
tokens.get_ids(),
&[128009, 128260, 128261, 128257],
]
.concat();
if self.verbose_prompt {
println!("{:?}", tokens);
}
let mut cache = self.cache.clone();
println!("starting the inference loop");
let mut index_pos = 0;
let mut audio_tokens = vec![];
for index in 0..2000 {
let (context_size, context_index) = if index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, context_index, &mut cache)?;
let logits = logits.squeeze(0)?;
index_pos += ctxt.len();
let next_token = self.logits_processor.sample(&logits)?;
if let Some(tok) = self.tokenizer.id_to_token(next_token) {
match tok.strip_prefix("<custom_token_") {
Some(tok) => match tok.strip_suffix('>') {
Some(tok) => {
let tok = tok.parse::<u32>()?;
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/decoder.py#L86C35-L86C63
let tok = tok - 10 - ((audio_tokens.len() as u32 % 7) * 4096);
audio_tokens.push(tok);
}
None => {
println!("{index}: unexpected custom token {next_token} {tok}");
}
},
None => {
println!("{index}: unexpected token {next_token} {tok}");
}
}
}
if next_token == STOP_TOKEN_ID {
println!("reached stop token");
break;
}
tokens.push(next_token);
}
println!("generated {} audio tokens", audio_tokens.len());
let mut codes0 = vec![];
let mut codes1 = vec![];
let mut codes2 = vec![];
for audio_tokens in audio_tokens.chunks_exact(7) {
codes0.push(audio_tokens[0]);
for i in [1, 4] {
codes1.push(audio_tokens[i]);
}
for i in [2, 3, 5, 6] {
codes2.push(audio_tokens[i]);
}
}
let codes0 = Tensor::new(codes0, device)?.unsqueeze(0)?;
let codes1 = Tensor::new(codes1, device)?.unsqueeze(0)?;
let codes2 = Tensor::new(codes2, device)?.unsqueeze(0)?;
let pcm = self.snac.decode(&[&codes0, &codes1, &codes2])?;
println!("decoded to pcm {pcm:?}");
let mut output = std::fs::File::create(&self.out_file)?;
let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24000)?;
Ok(())
}
}

View File

@ -0,0 +1,275 @@
use anyhow::{Context, Result};
use std::sync::{Arc, Mutex};
pub const SAMPLE_RATE: usize = 24_000;
pub(crate) struct AudioOutputData_ {
resampled_data: std::collections::VecDeque<f32>,
resampler: rubato::FastFixedIn<f32>,
output_buffer: Vec<f32>,
input_buffer: Vec<f32>,
input_len: usize,
}
impl AudioOutputData_ {
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
use rubato::Resampler;
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
let resampler = rubato::FastFixedIn::new(
resample_ratio,
f64::max(resample_ratio, 1.0),
rubato::PolynomialDegree::Septic,
1024,
1,
)?;
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
Ok(Self {
resampled_data,
resampler,
input_buffer,
output_buffer,
input_len: 0,
})
}
pub fn reset(&mut self) {
use rubato::Resampler;
self.output_buffer.fill(0.);
self.input_buffer.fill(0.);
self.resampler.reset();
self.resampled_data.clear();
}
pub(crate) fn take_all(&mut self) -> Vec<f32> {
let mut data = Vec::with_capacity(self.resampled_data.len());
while let Some(elem) = self.resampled_data.pop_back() {
data.push(elem);
}
data
}
pub(crate) fn is_empty(&self) -> bool {
self.resampled_data.is_empty()
}
// Assumes that the input buffer is large enough.
fn push_input_buffer(&mut self, samples: &[f32]) {
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
self.input_len += samples.len()
}
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
use rubato::Resampler;
let mut pos_in = 0;
loop {
let rem = self.input_buffer.len() - self.input_len;
let pos_end = usize::min(pos_in + rem, samples.len());
self.push_input_buffer(&samples[pos_in..pos_end]);
pos_in = pos_end;
if self.input_len < self.input_buffer.len() {
break;
}
let (_, out_len) = self.resampler.process_into_buffer(
&[&self.input_buffer],
&mut [&mut self.output_buffer],
None,
)?;
for &elem in self.output_buffer[..out_len].iter() {
self.resampled_data.push_front(elem)
}
self.input_len = 0;
}
Ok(())
}
}
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio output stream!");
let host = cpal::default_host();
let device = host
.default_output_device()
.context("no output device available")?;
let mut supported_configs_range = device.supported_output_configs()?;
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
// On macOS, it's commonly the case that there are only stereo outputs.
None => device
.supported_output_configs()?
.next()
.context("no audio output available")?,
Some(config_range) => config_range,
};
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
let channels = config.channels as usize;
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
SAMPLE_RATE,
config.sample_rate.0 as usize,
)?));
let ad = audio_data.clone();
let stream = device.build_output_stream(
&config,
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
data.fill(0.);
let mut ad = ad.lock().unwrap();
let mut last_elem = 0f32;
for (idx, elem) in data.iter_mut().enumerate() {
if idx % channels == 0 {
match ad.resampled_data.pop_back() {
None => break,
Some(v) => {
last_elem = v;
*elem = v
}
}
} else {
*elem = last_elem
}
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio input stream!");
let host = cpal::default_host();
let device = host
.default_input_device()
.context("no input device available")?;
let mut supported_configs_range = device.supported_input_configs()?;
let config_range = supported_configs_range
.find(|c| c.channels() == 1)
.context("no audio input available")?;
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
config.sample_rate.0 as usize,
SAMPLE_RATE,
)?));
let ad = audio_data.clone();
let stream = device.build_input_stream(
&config,
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let mut ad = ad.lock().unwrap();
if let Err(err) = ad.push_samples(data) {
eprintln!("error processing audio input {err:?}")
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
pub(crate) fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result<Vec<f32>> {
use rubato::Resampler;
let mut pcm_out =
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
let mut resampler =
rubato::FftFixedInOut::<f32>::new(sr_in as usize, sr_out as usize, 1024, 1)?;
let mut output_buffer = resampler.output_buffer_allocate(true);
let mut pos_in = 0;
while pos_in + resampler.input_frames_next() < pcm_in.len() {
let (in_len, out_len) =
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
pos_in += in_len;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
if pos_in < pcm_in.len() {
let (_in_len, out_len) = resampler.process_partial_into_buffer(
Some(&[&pcm_in[pos_in..]]),
&mut output_buffer,
None,
)?;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
Ok(pcm_out)
}

View File

@ -0,0 +1,197 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::snac::{Config, Model};
use clap::{Parser, ValueEnum};
use hf_hub::api::sync::Api;
mod audio_io;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Action {
AudioToAudio,
AudioToCode,
CodeToAudio,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "24khz")]
S24khz,
#[value(name = "32khz")]
S32khz,
#[value(name = "44khz")]
S44khz,
}
impl Which {
fn sample_rate(&self) -> u32 {
match self {
Which::S24khz => 24000,
Which::S32khz => 32000,
Which::S44khz => 44000,
}
}
fn config_repo(&self) -> &'static str {
match self {
Which::S24khz => "hubertsiuzdak/snac_24khz",
Which::S32khz => "hubertsiuzdak/snac_32khz",
Which::S44khz => "hubertsiuzdak/snac_44khz",
}
}
fn model_file(&self) -> &'static str {
match self {
Which::S24khz => "snac_24khz.safetensors",
Which::S32khz => "snac_32khz.safetensors",
Which::S44khz => "snac_44khz.safetensors",
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The action to be performed, specifies the format for the input and output data.
action: Action,
/// The input file, either an audio file or some snac tokens stored as safetensors.
in_file: String,
/// The output file, either a wave audio file or some snac tokens stored as safetensors.
out_file: String,
/// The model size to use.
#[arg(long, default_value = "24khz")]
which: Which,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The model weight file, in safetensor format.
#[arg(long)]
model: Option<String>,
/// The config file, in safetensor format.
#[arg(long)]
config: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let model_sample_rate = args.which.sample_rate();
let config = match args.config {
Some(c) => std::path::PathBuf::from(c),
None => Api::new()?
.model(args.which.config_repo().to_string())
.get("config.json")?,
};
let config: Config = serde_json::from_slice(&std::fs::read(config)?)?;
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.model("lmz/candle-snac".to_string())
.get(args.which.model_file())?,
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let model = Model::new(&config, vb)?;
let codes = match args.action {
Action::CodeToAudio => {
let codes = candle::safetensors::load(args.in_file, &device)?;
let num_codebooks = model.num_codebooks();
(0..num_codebooks)
.map(|i| {
codes
.get(&format!("codes-{i}"))
.expect("no codes in input file")
.clone()
})
.collect::<Vec<_>>()
}
Action::AudioToCode | Action::AudioToAudio => {
let pcm = if args.in_file == "-" {
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
let (stream, input_audio) = audio_io::setup_input_stream()?;
let mut pcms = vec![];
let stdin = std::thread::spawn(|| {
let mut s = String::new();
std::io::stdin().read_line(&mut s)
});
while !stdin.is_finished() {
let input = input_audio.lock().unwrap().take_all();
if input.is_empty() {
std::thread::sleep(std::time::Duration::from_millis(100));
continue;
}
pcms.push(input)
}
drop(stream);
pcms.concat()
} else {
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
if sample_rate != model_sample_rate {
println!("WARNING: snac uses a {model_sample_rate} sample rate, input uses {sample_rate}, resampling...");
audio_io::resample(&pcm, sample_rate, model_sample_rate)?
} else {
pcm
}
};
let pcm_len = pcm.len();
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
println!("input pcm shape: {:?}", pcm.shape());
model.encode(&pcm)?
}
};
for codes in codes.iter() {
println!("codes shape: {:?}", codes.shape());
}
match args.action {
Action::AudioToCode => {
let mut tensors = std::collections::HashMap::new();
for (i, codes) in codes.iter().enumerate() {
tensors.insert(format!("codes-{i}"), codes.clone());
}
candle::safetensors::save(&tensors, "codes.safetensors")?;
}
Action::AudioToAudio | Action::CodeToAudio => {
let codes = codes.iter().collect::<Vec<_>>();
let pcm = model.decode(&codes)?;
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, model_sample_rate, true)?;
let pcm = pcm.to_vec1::<f32>()?;
if args.out_file == "-" {
let (stream, ad) = audio_io::setup_output_stream()?;
{
let mut ad = ad.lock().unwrap();
ad.push_samples(&pcm)?;
}
loop {
let ad = ad.lock().unwrap();
if ad.is_empty() {
break;
}
// That's very weird, calling thread::sleep here triggers the stream to stop
// playing (the callback doesn't seem to be called anymore).
// std::thread::sleep(std::time::Duration::from_millis(100));
}
drop(stream)
} else {
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, model_sample_rate)?;
}
}
}
Ok(())
}

View File

@ -133,6 +133,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
padding,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
let conv = if bias {
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?

View File

@ -92,6 +92,7 @@ impl ConvBlock {
stride,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.9.0-alpha.1"
version = "0.9.0-alpha.4"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,14 +11,17 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.1" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.4" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
bindgen_cuda = "0.1.1"
anyhow = { version = "1", features = ["backtrace"] }
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { path = "../candle-nn", features = ["cuda"] }
[features]
default = []
cudnn = ["candle/cudnn"]

View File

@ -2,7 +2,6 @@ mod ffi;
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
use half::{bf16, f16};
@ -142,10 +141,8 @@ impl FlashAttn {
let seqlen_k_rounded = round_multiple(seqlen_k, 128);
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev
.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
.w()?;
let dst = unsafe { dev.alloc::<T>(elem_count)? };
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)?;
let is_bf16 = if is_bf16 { 1 } else { 0 };
@ -607,8 +604,8 @@ impl FlashAttnVarLen {
let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128);
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q).w()?;
let dst = unsafe { dev.alloc::<f16>(elem_count)? };
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q)?;
let is_bf16 = if is_bf16 { 1 } else { 0 };

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.9.0-alpha.1"
version = "0.9.0-alpha.4"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -53,7 +53,7 @@ __device__ void conv1d(
template <typename T>
__device__ void im2col1d(
const size_t dst_numel,
const size_t numel,
const size_t l_out,
const size_t l_k,
const size_t stride,
@ -63,10 +63,10 @@ __device__ void im2col1d(
const T *src,
T *dst
) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
const size_t thread_i = blockIdx.x * blockDim.x + threadIdx.x;
// dst: (b_size, l_out, c_in, l_k)
// src: (b_size, c_in, l_in)
if (dst_i >= dst_numel) {
if (thread_i >= numel) {
return;
}
const size_t *src_dims = info;
@ -74,26 +74,26 @@ __device__ void im2col1d(
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];
const size_t dst_s2 = l_k;
const size_t dst_s1 = c_in * dst_s2;
const size_t dst_s1 = c_in;
const size_t dst_s0 = l_out * dst_s1;
size_t tmp_dst_i = dst_i;
size_t tmp_dst_i = thread_i;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t l_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= l_idx * dst_s1;
const size_t c_idx = tmp_dst_i / dst_s2;
tmp_dst_i -= c_idx * dst_s2;
const size_t l_k_idx = tmp_dst_i;
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
dst[dst_i] = static_cast<T>(0);
}
else {
src_l_idx -= padding;
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
dst[dst_i] = src[src_i];
const size_t c_idx = tmp_dst_i;
for (size_t l_k_idx = 0; l_k_idx < l_k; ++l_k_idx) {
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
size_t dst_i = thread_i * l_k + l_k_idx;
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
dst[dst_i] = static_cast<T>(0);
}
else {
src_l_idx -= padding;
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
dst[dst_i] = src[src_i];
}
}
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.9.0-alpha.1"
version = "0.9.0-alpha.4"
edition = "2021"
description = "Metal kernels for Candle"

View File

@ -33,6 +33,7 @@ criterion = { workspace = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
cudnn = ["candle/cudnn"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]

View File

@ -1,6 +1,6 @@
//! Convolution Layers.
use crate::BatchNorm;
use candle::{Result, Tensor};
use candle::{conv::CudnnFwdAlgo, Result, Tensor};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Conv1dConfig {
@ -8,6 +8,7 @@ pub struct Conv1dConfig {
pub stride: usize,
pub dilation: usize,
pub groups: usize,
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
}
impl Default for Conv1dConfig {
@ -17,6 +18,7 @@ impl Default for Conv1dConfig {
stride: 1,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
}
}
}
@ -52,12 +54,13 @@ impl Conv1d {
impl crate::Module for Conv1d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv1d(
let x = x.conv1d_with_algo(
&self.weight,
self.config.padding,
self.config.stride,
self.config.dilation,
self.config.groups,
self.config.cudnn_fwd_algo,
)?;
match &self.bias {
None => Ok(x),
@ -147,6 +150,7 @@ pub struct Conv2dConfig {
pub stride: usize,
pub dilation: usize,
pub groups: usize,
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
}
impl Default for Conv2dConfig {
@ -156,6 +160,7 @@ impl Default for Conv2dConfig {
stride: 1,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
}
}
}
@ -211,12 +216,13 @@ impl Conv2d {
impl crate::Module for Conv2d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv2d(
let x = x.conv2d_with_algo(
&self.weight,
self.config.padding,
self.config.stride,
self.config.dilation,
self.config.groups,
self.config.cudnn_fwd_algo,
)?;
match &self.bias {
None => Ok(x),

View File

@ -294,6 +294,27 @@ impl RotatingCache {
Tensor::from_slice(&mask, (size1, size2), device)
}
/// Returns the positions corresponding to all the elements that will be retured
/// *after* adding `seq_len` to the cache.
pub fn positions(&self, seq_len: usize) -> Vec<usize> {
if seq_len <= self.max_seq_len {
let upd_offset = (self.offset + seq_len) % self.max_seq_len;
let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len);
(0..cache_out_len)
.map(|i| {
let pos_cache = self.current_seq_len + seq_len + i - upd_offset;
if i < upd_offset {
pos_cache
} else {
pos_cache - self.max_seq_len
}
})
.collect()
} else {
(self.current_seq_len..(self.current_seq_len + seq_len)).collect()
}
}
/// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
let mask = if seq_len == 1 {
@ -362,10 +383,17 @@ impl RotatingKvCache {
self.k.current_seq_len()
}
/// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
self.k.attn_mask(seq_len, device)
}
/// Returns the positions corresponding to all the elements that will be retured
/// *after* adding `seq_len` to the cache.
pub fn positions(&self, seq_len: usize) -> Vec<usize> {
self.k.positions(seq_len)
}
pub fn reset(&mut self) {
self.k.reset();
self.v.reset();

View File

@ -31,6 +31,7 @@ pub mod ops;
pub mod optim;
pub mod rnn;
pub mod rotary_emb;
pub mod sampling;
pub mod sequential;
pub mod var_builder;
pub mod var_map;

View File

@ -41,12 +41,36 @@ impl Linear {
impl super::Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let w = match *x.dims() {
[b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
// When possible, we avoid using a broadcasted matmul as it is much slower
// than the standard matmul for the cuda and cpu backends.
let x = match *x.dims() {
[b1, b2, m, k] => {
if x.is_contiguous() {
let w = self.weight.t()?;
x.reshape((b1 * b2 * m, k))?
.matmul(&w)?
.reshape((b1, b2, m, ()))?
} else {
let w = self.weight.broadcast_left((b1, b2))?.t()?;
x.matmul(&w)?
}
}
[bsize, m, k] => {
if x.is_contiguous() {
let w = self.weight.t()?;
x.reshape((bsize * m, k))?
.matmul(&w)?
.reshape((bsize, m, ()))?
} else {
let w = self.weight.broadcast_left(bsize)?.t()?;
x.matmul(&w)?
}
}
_ => {
let w = self.weight.t()?;
x.matmul(&w)?
}
};
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),

View File

@ -7,7 +7,7 @@ use candle::{Result, Tensor};
/// Arguments
///
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
/// of categories. This is expected to contain log probabilities.
/// of categories. This is expected to contain log probabilities.
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
///
/// The resulting tensor is a scalar containing the average value over the batch.
@ -34,7 +34,7 @@ pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
/// Arguments
///
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
/// of categories. This is expected to raw logits.
/// of categories. This is expected to raw logits.
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
///
/// The resulting tensor is a scalar containing the average value over the batch.
@ -56,9 +56,9 @@ pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
/// Arguments
///
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
/// of categories. This is expected to raw logits.
/// of categories. This is expected to raw logits.
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
/// of categories.
/// of categories.
///
/// The resulting tensor is a scalar containing the average value over the batch.
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {

View File

@ -112,7 +112,7 @@ impl candle::CustomOp1 for Sigmoid {
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
let out = unsafe { dev.alloc::<T>(el_count)? };
let mut builder = func.builder();
candle::builder_arg!(builder, el_count, dims.len());
@ -373,7 +373,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
};
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let dst = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&dst);
@ -561,7 +561,7 @@ impl candle::CustomOp2 for RmsNorm {
};
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let dst = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&dst);
@ -800,7 +800,7 @@ impl candle::CustomOp3 for LayerNorm {
let func =
dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let dst = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&dst);

View File

@ -119,7 +119,7 @@ impl candle::CustomOp3 for RotaryEmbI {
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let dst = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&cos);
@ -369,7 +369,7 @@ impl candle::CustomOp3 for RotaryEmb {
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let dst = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&cos);
@ -620,7 +620,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let dst = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&cos);

20
candle-nn/src/sampling.rs Normal file
View File

@ -0,0 +1,20 @@
use candle::{Result, Tensor};
/// Sample according to the Gumbel-Softmax distribution.
pub fn gumbel_softmax<D: candle::shape::Dim>(
logits: &Tensor,
temperature: f64,
dim: D,
) -> Result<Tensor> {
if temperature <= 0.0 {
logits.argmax(dim)
} else if temperature == 1.0 {
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
let sampled = (logits - minus_g)?.argmax(dim)?;
Ok(sampled)
} else {
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
Ok(sampled)
}
}

View File

@ -39,9 +39,16 @@ fn rotating_kv_cache() -> Result<()> {
assert_eq!(cache.current_seq_len(), 0);
let data = cache.current_data()?;
assert!(data.is_none());
assert_eq!(cache.positions(1), &[0]);
assert_eq!(cache.positions(2), &[0, 1]);
let t = Tensor::new(&[1., 2., 3.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3.]);
assert_eq!(cache.positions(0), &[0, 1, 2]);
assert_eq!(cache.positions(1), &[0, 1, 2, 3]);
assert_eq!(cache.positions(2), &[0, 1, 2, 3, 4]);
assert_eq!(cache.positions(3), &[0, 1, 2, 3, 4, 5]);
assert_eq!(cache.positions(4), &[6, 1, 2, 3, 4, 5]);
let t = Tensor::new(&[4.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3., 4.]);
@ -79,11 +86,17 @@ fn rotating_kv_cache() -> Result<()> {
mask.to_vec2::<u8>()?,
&[[0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0]],
);
assert_eq!(cache.positions(0), &[12, 7, 8, 9, 10, 11]);
assert_eq!(cache.positions(2), &[12, 13, 14, 9, 10, 11]);
assert_eq!(cache.positions(3), &[12, 13, 14, 15, 10, 11]);
assert_eq!(cache.positions(8), &[13, 14, 15, 16, 17, 18, 19, 20]);
let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]);
assert_eq!(cache.current_seq_len(), 22);
assert_eq!(cache.offset(), 0);
assert_eq!(cache.positions(0), &[16, 17, 18, 19, 20, 21]);
assert_eq!(cache.positions(1), &[22, 17, 18, 19, 20, 21]);
let mask = cache.attn_mask(1, &Device::Cpu)?;
assert!(mask.is_none());

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.9.0-alpha.1"
version = "0.9.0-alpha.4"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.1" }
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.4" }
prost = "0.12.1"
[build-dependencies]

View File

@ -29,6 +29,7 @@ tracing = { workspace = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda"]
cudnn = ["candle/cudnn", "candle-nn/cudnn"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
metal = ["candle/metal", "candle-nn/metal"]

View File

@ -13,6 +13,8 @@ pub enum Sampling {
TopK { k: usize, temperature: f64 },
TopP { p: f64, temperature: f64 },
TopKThenTopP { k: usize, p: f64, temperature: f64 },
// Note that the rng is not used for the Gumbel-Softmax sampling.
GumbelSoftmax { temperature: f64 },
}
pub struct LogitsProcessor {
@ -49,6 +51,11 @@ impl LogitsProcessor {
Ok(next_token)
}
fn sample_gumbel_softmax(&mut self, logits: &Tensor, temperature: f64) -> Result<u32> {
let sampled = candle_nn::sampling::gumbel_softmax(logits, temperature, candle::D::Minus1)?;
sampled.to_vec0::<u32>()
}
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
let next_token = distr.sample(&mut self.rng) as u32;
@ -127,6 +134,9 @@ impl LogitsProcessor {
let next_token = match &self.sampling {
Sampling::ArgMax => self.sample_argmax(logits)?,
Sampling::GumbelSoftmax { temperature } => {
self.sample_gumbel_softmax(&logits, *temperature)?
}
Sampling::All { temperature } => {
let prs = prs(*temperature)?;
self.sample_multinomial(&prs)?

View File

@ -504,8 +504,9 @@ impl BertModel {
Some(attention_mask) => attention_mask.clone(),
None => input_ids.ones_like()?,
};
let dtype = embedding_output.dtype();
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
let attention_mask = get_extended_attention_mask(&attention_mask, dtype)?;
let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
Ok(sequence_output)
}
@ -519,8 +520,11 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
};
let attention_mask = attention_mask.to_dtype(dtype)?;
// torch.finfo(dtype).min
(attention_mask.ones_like()? - &attention_mask)?
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
(attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
&Tensor::try_from(f32::MIN)?
.to_device(attention_mask.device())?
.to_dtype(dtype)?,
)
}
//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766

View File

@ -514,8 +514,9 @@ impl ChineseClipTextTransformer {
Some(attention_mask) => attention_mask.clone(),
None => input_ids.ones_like()?,
};
let dtype = embedding_output.dtype();
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
let attention_mask = get_extended_attention_mask(&attention_mask, dtype)?;
let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?;
let encoder_output = encoder_outputs.i((.., 0, ..))?;
let pooled_output = match &self.pooler {
@ -535,6 +536,9 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
};
let attention_mask = attention_mask.to_dtype(dtype)?;
// torch.finfo(dtype).min
(attention_mask.ones_like()? - &attention_mask)?
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
(attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
&Tensor::try_from(f32::MIN)?
.to_device(attention_mask.device())?
.to_dtype(dtype)?,
)
}

View File

@ -0,0 +1,533 @@
//! Implementation of the Conversational Speech Model (CSM) from Sesame
//!
//! See: [CSM](Conversational Speech Model)
//!
/// CSM (Conversational Speech Model) is a speech generation model from Sesame that generates RVQ
/// audio codes from text and audio inputs. The model architecture employs a Llama backbone and a
/// smaller audio decoder that produces Mimi audio codes.
///
use crate::generation::LogitsProcessor;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{embedding, linear_b, Embedding, Linear, RmsNorm, VarBuilder};
use std::sync::Arc;
#[derive(serde::Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub enum Flavor {
#[serde(rename = "llama-1B")]
Llama1B,
#[serde(rename = "llama-100M")]
Llama100M,
}
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
pub audio_num_codebooks: usize,
pub audio_vocab_size: usize,
pub backbone_flavor: Flavor,
pub decoder_flavor: Flavor,
pub text_vocab_size: usize,
}
#[allow(unused)]
#[derive(Debug, Clone)]
pub struct LlamaConfig {
vocab_size: usize,
num_layers: usize,
num_heads: usize,
num_kv_heads: usize,
embed_dim: usize,
max_seq_len: usize,
intermediate_dim: usize,
norm_eps: f64,
rope_base: f32,
scale_factor: usize,
}
impl LlamaConfig {
pub fn from_flavor(flavor: Flavor) -> Self {
match flavor {
Flavor::Llama1B => Self {
vocab_size: 128256,
num_layers: 16,
num_heads: 32,
num_kv_heads: 8,
embed_dim: 2048,
max_seq_len: 2048,
intermediate_dim: 8192,
norm_eps: 1e-5,
rope_base: 500_000.,
scale_factor: 32,
},
Flavor::Llama100M => Self {
vocab_size: 128256,
num_layers: 4,
num_heads: 8,
num_kv_heads: 2,
embed_dim: 1024,
max_seq_len: 2048,
intermediate_dim: 8192,
norm_eps: 1e-5,
rope_base: 500_000.,
scale_factor: 32,
},
}
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
fn calculate_default_inv_freq(cfg: &LlamaConfig) -> Vec<f32> {
let head_dim = cfg.embed_dim / cfg.num_heads;
(0..head_dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_base.powf(i as f32 / head_dim as f32))
.collect()
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &LlamaConfig, dev: &Device) -> Result<Self> {
let low_freq_factor = 1.0;
let high_freq_factor = 4.0;
let original_max_position_embeddings = 8192;
let scale_factor = cfg.scale_factor as f32;
let theta = {
let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
calculate_default_inv_freq(cfg)
.into_iter()
.map(|freq| {
let wavelen = 2. * std::f32::consts::PI / freq;
if wavelen < high_freq_wavelen {
freq
} else if wavelen > low_freq_wavelen {
freq / scale_factor
} else {
let smooth = (original_max_position_embeddings as f32 / wavelen
- low_freq_factor)
/ (high_freq_factor - low_freq_factor);
(1. - smooth) * freq / scale_factor + smooth * freq
}
})
.collect::<Vec<_>>()
};
let theta = Tensor::new(theta, dev)?;
let idx_theta = Tensor::arange(0, cfg.max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((cfg.max_seq_len, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self { cos, sin })
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope_i(q, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope_i(k, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
fn rms_norm(hidden_size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
let weight = vb.get((hidden_size,), "scale")?;
Ok(RmsNorm::new(weight, eps))
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
num_heads: usize,
head_dim: usize,
num_kv_heads: usize,
num_kv_groups: usize,
}
impl Attention {
fn new(cfg: &LlamaConfig, rotary_emb: Arc<RotaryEmbedding>, vb: VarBuilder) -> Result<Self> {
let head_dim = cfg.embed_dim / cfg.num_heads;
let kv_dim = cfg.num_kv_heads * head_dim;
let q_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("q_proj"))?;
let k_proj = linear_b(cfg.embed_dim, kv_dim, false, vb.pp("k_proj"))?;
let v_proj = linear_b(cfg.embed_dim, kv_dim, false, vb.pp("v_proj"))?;
let o_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("output_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
rotary_emb,
kv_cache: None,
num_heads: cfg.num_heads,
num_kv_heads: cfg.num_kv_heads,
num_kv_groups: cfg.num_heads / cfg.num_kv_heads,
head_dim,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let (query_states, key_states) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&value_states)?
};
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, self.num_heads * self.head_dim))?
.apply(&self.o_proj)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
struct Mlp {
w1: Linear,
w2: Linear,
w3: Linear,
}
impl Mlp {
fn new(cfg: &LlamaConfig, vb: VarBuilder) -> Result<Self> {
let w1 = linear_b(cfg.embed_dim, cfg.intermediate_dim, false, vb.pp("w1"))?;
let w2 = linear_b(cfg.intermediate_dim, cfg.embed_dim, false, vb.pp("w2"))?;
let w3 = linear_b(cfg.embed_dim, cfg.intermediate_dim, false, vb.pp("w3"))?;
Ok(Self { w1, w2, w3 })
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let lhs = xs.apply(&self.w1)?.silu()?;
let rhs = xs.apply(&self.w3)?;
(lhs * rhs)?.apply(&self.w2)
}
}
#[derive(Debug, Clone)]
struct Layer {
mlp_norm: RmsNorm,
sa_norm: RmsNorm,
attn: Attention,
mlp: Mlp,
}
impl Layer {
fn new(cfg: &LlamaConfig, rotary_emb: Arc<RotaryEmbedding>, vb: VarBuilder) -> Result<Self> {
let mlp_norm = rms_norm(cfg.embed_dim, cfg.norm_eps, vb.pp("mlp_norm"))?;
let sa_norm = rms_norm(cfg.embed_dim, cfg.norm_eps, vb.pp("sa_norm"))?;
let attn = Attention::new(cfg, rotary_emb, vb.pp("attn"))?;
let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
Ok(Self {
mlp_norm,
sa_norm,
attn,
mlp,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.sa_norm.forward(xs)?;
let xs = self.attn.forward(&xs, attention_mask, seqlen_offset)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.mlp_norm)?.apply(&self.mlp)?;
residual + xs
}
fn clear_kv_cache(&mut self) {
self.attn.clear_kv_cache()
}
}
#[derive(Debug, Clone)]
pub struct LlamaModel {
layers: Vec<Layer>,
norm: RmsNorm,
device: Device,
dtype: DType,
}
impl LlamaModel {
pub fn new(cfg: &LlamaConfig, vb: VarBuilder) -> Result<Self> {
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
let mut layers = Vec::with_capacity(cfg.num_layers);
let vb_l = vb.pp("layers");
for layer_idx in 0..cfg.num_layers {
let layer = Layer::new(cfg, rotary_emb.clone(), vb_l.pp(layer_idx))?;
layers.push(layer);
}
let norm = rms_norm(cfg.embed_dim, cfg.norm_eps, vb.pp("norm"))?;
Ok(Self {
layers,
norm,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
fn prepare_decoder_attention_mask(
&self,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, seq_len, _embed_dim) = xs.dims3()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?;
Some(mask)
};
let mut xs = xs.clone();
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?;
}
let ys = xs.narrow(1, seq_len - 1, 1)?.apply(&self.norm)?;
Ok(ys)
}
}
#[derive(Debug, Clone)]
pub struct Model {
backbone: LlamaModel,
decoder: LlamaModel,
codebook0_head: Linear,
audio_embeddings: Embedding,
text_embeddings: Embedding,
projection: Linear,
audio_head: Tensor,
config: Config,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let backbone_cfg = LlamaConfig::from_flavor(cfg.backbone_flavor);
let backbone = LlamaModel::new(&backbone_cfg, vb.pp("backbone"))?;
let decoder_cfg = LlamaConfig::from_flavor(cfg.decoder_flavor);
let decoder = LlamaModel::new(&decoder_cfg, vb.pp("decoder"))?;
let backbone_dim = backbone_cfg.embed_dim;
let decoder_dim = decoder_cfg.embed_dim;
let audio_embeddings = embedding(
cfg.audio_vocab_size * cfg.audio_num_codebooks,
backbone_dim,
vb.pp("audio_embeddings"),
)?;
let text_embeddings =
embedding(cfg.text_vocab_size, backbone_dim, vb.pp("text_embeddings"))?;
let projection = linear_b(backbone_dim, decoder_dim, false, vb.pp("projection"))?;
let codebook0_head = linear_b(
backbone_dim,
cfg.audio_vocab_size,
false,
vb.pp("codebook0_head"),
)?;
let audio_head = vb.get(
(
cfg.audio_num_codebooks - 1,
decoder_dim,
cfg.audio_vocab_size,
),
"audio_head",
)?;
Ok(Self {
backbone,
decoder,
codebook0_head,
audio_embeddings,
text_embeddings,
projection,
audio_head,
config: cfg.clone(),
})
}
pub fn clear_kv_cache(&mut self) {
self.backbone.clear_kv_cache();
self.decoder.clear_kv_cache();
}
pub fn generate_frame(
&mut self,
tokens: &Tensor,
tokens_mask: &Tensor,
input_pos: usize,
lp: &mut LogitsProcessor,
) -> Result<Vec<u32>> {
let (b_sz, seq_len, _cb_plus_one) = tokens.dims3()?;
let audio_tokens = tokens.narrow(2, 0, self.config.audio_num_codebooks)?;
let text_tokens = tokens.narrow(2, self.config.audio_num_codebooks, 1)?;
let text_embeds = self.text_embeddings.forward(&text_tokens)?;
let arange = (Tensor::arange(
0u32,
self.config.audio_num_codebooks as u32,
&self.decoder.device,
)? * self.config.audio_vocab_size as f64)?;
let audio_tokens = audio_tokens.broadcast_add(&arange.reshape((1, 1, ()))?)?;
let audio_embeds = self.audio_embeddings.forward(&audio_tokens)?.reshape((
b_sz,
seq_len,
self.config.audio_num_codebooks,
(),
))?;
let embeds = Tensor::cat(&[&audio_embeds, &text_embeds], D::Minus2)?;
let embeds = embeds.broadcast_mul(
&tokens_mask
.to_dtype(self.backbone.dtype)?
.unsqueeze(D::Minus1)?,
)?;
let embeds = embeds.sum(2)?;
let h = self.backbone.forward(&embeds, input_pos)?;
let c0_logits = h.apply(&self.codebook0_head)?;
let c0_sample = lp.sample(&c0_logits.i((0, 0))?)?;
let mut all_samples = vec![c0_sample];
let c0_sample = Tensor::from_slice(&[c0_sample], (1, 1), &self.decoder.device)?;
let c0_embed = self.audio_embeddings.forward(&c0_sample)?;
let mut curr_h = Tensor::cat(&[h, c0_embed], 1)?;
self.decoder.clear_kv_cache();
let mut decoder_pos = 0;
for i in 1..self.config.audio_num_codebooks {
let proj_h = curr_h.apply(&self.projection)?;
let decoder_h = self.decoder.forward(&proj_h, decoder_pos)?;
decoder_pos += curr_h.dim(1)?;
let ci_logits = decoder_h.broadcast_matmul(&self.audio_head.get(i - 1)?)?;
let ci_sample = lp.sample(&ci_logits.i((0, 0))?)?;
all_samples.push(ci_sample);
let ci_sample = Tensor::from_slice(
&[ci_sample + (i * self.config.audio_vocab_size) as u32],
(1, 1),
&self.decoder.device,
)?;
let ci_embed = self.audio_embeddings.forward(&ci_sample)?;
curr_h = ci_embed
}
Ok(all_samples)
}
pub fn audio_tokens_and_mask(&self, mut frame: Vec<u32>) -> Result<(Tensor, Tensor)> {
let cb = self.config.audio_num_codebooks;
let device = &self.backbone.device;
let mut mask = vec![1u8; cb];
mask.push(0);
let mask = Tensor::from_vec(mask, (1, 1, cb + 1), device)?;
frame.push(0);
let tokens = Tensor::from_vec(frame, (1, 1, cb + 1), device)?;
Ok((tokens, mask))
}
pub fn text_tokens_and_mask(&self, ids: &[u32]) -> Result<(Tensor, Tensor)> {
let cb = self.config.audio_num_codebooks;
let device = &self.backbone.device;
let mut tokens = vec![];
let mut mask = vec![];
for &v in ids.iter() {
let mut token = vec![0; cb];
token.push(v);
let token = Tensor::from_vec(token, (1, 1, cb + 1), device)?;
tokens.push(token);
let mut m = vec![0u8; cb];
m.push(1);
let m = Tensor::from_vec(m, (1, 1, cb + 1), device)?;
mask.push(m);
}
let tokens = Tensor::cat(&tokens, 1)?;
let mask = Tensor::cat(&mask, 1)?;
Ok((tokens, mask))
}
}

View File

@ -104,7 +104,7 @@ impl EncoderBlock {
let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
let cfg1 = Conv1dConfig {
stride,
padding: (stride + 1) / 2,
padding: stride.div_ceil(2),
..Default::default()
};
let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
@ -196,7 +196,7 @@ impl DecoderBlock {
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
let cfg = ConvTranspose1dConfig {
stride,
padding: (stride + 1) / 2,
padding: stride.div_ceil(2),
..Default::default()
};
let conv_tr1 = encodec::conv_transpose1d_weight_norm(
@ -330,6 +330,7 @@ impl ResidualVectorQuantizer {
Ok(Self { quantizers })
}
#[allow(clippy::wrong_self_convention)]
pub fn from_codes(&self, codes: &Tensor) -> Result<Tensor> {
let mut sum = None;
for (idx, quantizer) in self.quantizers.iter().enumerate() {

View File

@ -124,6 +124,7 @@ impl ResidualConvUnit {
stride: 1,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
};
let conv1 = conv2d(
conf.num_features,
@ -208,6 +209,7 @@ impl FeatureFusionBlock {
stride: 1,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
};
let output_conv = conv2d(
conf.num_features,
@ -258,6 +260,7 @@ impl Scratch {
stride: 1,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
};
let layer1_rn = conv2d_no_bias(
@ -319,6 +322,7 @@ impl Scratch {
stride: 1,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
};
let output_conv1 = conv2d(
conf.num_features,
@ -425,6 +429,7 @@ impl DPTHead {
stride: 2,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
},
vb.pp("resize_layers").pp("3"),
)?),

View File

@ -19,7 +19,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")]
enum HiddenAct {
pub enum HiddenAct {
Gelu,
Relu,
}
@ -49,22 +49,22 @@ impl Module for HiddenActLayer {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
enum PositionEmbeddingType {
pub enum PositionEmbeddingType {
#[default]
Absolute,
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
vocab_size: usize,
dim: usize,
pub vocab_size: usize,
pub dim: usize,
n_layers: usize,
n_heads: usize,
hidden_dim: usize,
activation: HiddenAct,
max_position_embeddings: usize,
initializer_range: f64,
pad_token_id: usize,
pub pad_token_id: usize,
#[serde(default)]
position_embedding_type: PositionEmbeddingType,
#[serde(default)]
@ -345,3 +345,107 @@ impl DistilBertModel {
Ok(sequence_output)
}
}
struct DistilBertPredictionHeadTransform {
dense: Linear,
activation: HiddenActLayer,
layer_norm: LayerNorm,
}
impl DistilBertPredictionHeadTransform {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let dense = linear(config.dim, config.dim, vb.pp("vocab_transform"))?;
let activation = HiddenActLayer::new(config.activation);
let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("vocab_layer_norm"))?;
Ok(Self {
dense,
activation,
layer_norm,
})
}
}
impl Module for DistilBertPredictionHeadTransform {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let hidden_states = self
.activation
.forward(&self.dense.forward(hidden_states)?)?;
self.layer_norm.forward(&hidden_states)
}
}
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1
pub struct DistilBertLMPredictionHead {
transform: DistilBertPredictionHeadTransform,
decoder: Linear,
}
impl DistilBertLMPredictionHead {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let transform = DistilBertPredictionHeadTransform::load(vb.clone(), config)?;
// distil_bert_uncased uses the word embeddings for the vocab projector weight, but has a seperate vocab_projector bias
let vocab_projector_weight_vb = vb.pp("distilbert.embeddings.word_embeddings");
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
let ws = vocab_projector_weight_vb.get_with_hints(
(config.vocab_size, config.dim),
"weight",
init_ws,
)?;
let bound = 1. / (config.dim as f64).sqrt();
let init_bs = candle_nn::Init::Uniform {
lo: -bound,
up: bound,
};
let vocab_projector_bias_vb = vb.pp("vocab_projector");
let bs = vocab_projector_bias_vb.get_with_hints(config.vocab_size, "bias", init_bs)?;
let decoder = Linear::from_weights(ws, Some(bs));
Ok(Self { transform, decoder })
}
}
impl Module for DistilBertLMPredictionHead {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
self.decoder
.forward(&self.transform.forward(hidden_states)?)
}
}
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792
pub struct DistilBertOnlyMLMHead {
predictions: DistilBertLMPredictionHead,
}
impl DistilBertOnlyMLMHead {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let predictions = DistilBertLMPredictionHead::load(vb.clone(), config)?;
Ok(Self { predictions })
}
}
impl Module for DistilBertOnlyMLMHead {
fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> {
self.predictions.forward(sequence_output)
}
}
pub struct DistilBertForMaskedLM {
pub bert: DistilBertModel,
cls: DistilBertOnlyMLMHead,
}
impl DistilBertForMaskedLM {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let bert = DistilBertModel::load(vb.pp("distilbert"), config)?;
let cls = DistilBertOnlyMLMHead::load(vb.clone(), config)?;
Ok(Self { bert, cls })
}
pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let sequence_output = self.bert.forward(input_ids, attention_mask)?;
self.cls.forward(&sequence_output)
}
}

View File

@ -141,6 +141,20 @@ pub fn conv1d_weight_norm(
Ok(Conv1d::new(weight, Some(bias), config))
}
pub fn conv1d_weight_norm_no_bias(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: candle_nn::Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
Ok(Conv1d::new(weight, None, config))
}
pub fn conv_transpose1d_weight_norm(
in_c: usize,
out_c: usize,
@ -454,6 +468,7 @@ impl EncodecConv1d {
stride,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
},
vb.pp("conv"),
)?,

View File

@ -6,8 +6,8 @@ pub fn get_noise(
width: usize,
device: &Device,
) -> Result<Tensor> {
let height = (height + 15) / 16 * 2;
let width = (width + 15) / 16 * 2;
let height = height.div_ceil(16) * 2;
let width = width.div_ceil(16) * 2;
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
}
@ -84,8 +84,8 @@ pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec<f
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
let (b, _h_w, c_ph_pw) = xs.dims3()?;
let height = (height + 15) / 16;
let width = (width + 15) / 16;
let height = height.div_ceil(16);
let width = width.div_ceil(16);
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
.reshape((b, c_ph_pw / 4, height * 2, width * 2))

View File

@ -27,7 +27,7 @@ impl Config {
}
fn dt_rank(&self) -> usize {
(self.d_model + 15) / 16
self.d_model.div_ceil(16)
}
fn d_inner(&self) -> usize {

View File

@ -716,7 +716,7 @@ pub mod transformer {
None => {
let hidden_dim = self.dim * 4;
let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize;
(n_hidden + 255) / 256 * 256
n_hidden.div_ceil(256) * 256
}
}
}

View File

@ -267,6 +267,7 @@ impl StreamableConv1d {
stride,
dilation,
groups,
cudnn_fwd_algo: None,
};
let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
if k_size < stride {

View File

@ -27,6 +27,7 @@ pub mod codegeex4_9b;
pub mod colpali;
pub mod convmixer;
pub mod convnext;
pub mod csm;
pub mod dac;
pub mod debertav2;
pub mod deepseek2;
@ -103,6 +104,7 @@ pub mod rwkv_v6;
pub mod segformer;
pub mod segment_anything;
pub mod siglip;
pub mod snac;
pub mod stable_diffusion;
pub mod stable_lm;
pub mod starcoder2;

View File

@ -0,0 +1,814 @@
#![allow(unused)]
//! Implementation of the Multi-Scale Neural Audio Codec (SNAC)
//!
//! See: [SNAC](https://github.com/hubertsiuzdak/snac)
//!
/// Multi-Scale Neural Audio Codec (SNAC) compresses audio into discrete codes at a low bitrate.
/// For more information, read the paper: https://arxiv.org/abs/2410.14411
///
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{
linear_b, Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig, LayerNorm, Linear,
VarBuilder,
};
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
pub sampling_rate: usize,
pub encoder_dim: usize,
pub encoder_rates: Vec<usize>,
pub decoder_dim: usize,
pub decoder_rates: Vec<usize>,
pub attn_window_size: Option<usize>,
pub codebook_size: usize,
pub codebook_dim: usize,
pub vq_strides: Vec<usize>,
pub noise: bool,
pub depthwise: bool,
}
// Equivalent to torch.repeat_interleave
pub fn repeat_interleave<D: candle::shape::Dim>(
img: &Tensor,
repeats: usize,
dim: D,
) -> Result<Tensor> {
if repeats == 1 {
return Ok(img.clone());
}
let dim = dim.to_index(img.shape(), "chunk")?;
let img = img.unsqueeze(dim + 1)?;
let mut dims = img.dims().to_vec();
dims[dim + 1] = repeats;
img.broadcast_as(dims)?.flatten(dim, dim + 1)
}
pub fn conv1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: candle_nn::Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight_g = vb.get((out_c, 1, 1), "parametrizations.weight.original0")?;
let weight_v = {
let name = "parametrizations.weight.original1";
match vb.get((out_c, in_c, kernel_size), name) {
Ok(v) => v,
Err(_) => vb.get((out_c, 1, kernel_size), name)?,
}
};
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = vb.get(out_c, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
}
pub fn conv1d_weight_norm_no_bias(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: candle_nn::Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight_g = vb.get((out_c, 1, 1), "parametrizations.weight.original0")?;
let weight_v = {
let name = "parametrizations.weight.original1";
match vb.get((out_c, in_c, kernel_size), name) {
Ok(v) => v,
Err(_) => vb.get((out_c, 1, kernel_size), name)?,
}
};
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
Ok(Conv1d::new(weight, None, config))
}
pub fn conv_transpose1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
bias: bool,
config: candle_nn::ConvTranspose1dConfig,
vb: VarBuilder,
) -> Result<ConvTranspose1d> {
let weight_g = vb.get((in_c, 1, 1), "parametrizations.weight.original0")?;
let weight_v = vb.get(
(in_c, out_c, kernel_size),
"parametrizations.weight.original1",
)?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = if bias {
Some(vb.get(out_c, "bias")?)
} else {
None
};
Ok(ConvTranspose1d::new(weight, bias, config))
}
// https://github.com/hubertsiuzdak/snac/blob/main/snac/attention.py
#[allow(unused)]
#[derive(Debug, Clone)]
struct SinusoidalEmbeddings {
inv_freq: Tensor,
scale: Tensor,
scale_base: f32,
use_xpos: bool,
}
impl SinusoidalEmbeddings {
fn new(dim: usize, scale_base: f32, use_xpos: bool, dev: &Device) -> Result<Self> {
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / 10_000f32.powf(i as f32 / dim as f32))
.collect();
let len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, len, dev)?.to_dtype(DType::F32)?;
let scale: Vec<_> = (0..dim)
.step_by(2)
.map(|i| (i as f32 + 0.4 * dim as f32) / (1.4 * dim as f32))
.collect();
let scale = Tensor::from_vec(scale, len, dev)?.to_dtype(DType::F32)?;
Ok(Self {
inv_freq,
scale,
scale_base,
use_xpos,
})
}
}
#[allow(unused)]
#[derive(Debug, Clone)]
struct LocalMHA {
norm: LayerNorm,
to_qkv: Linear,
to_out: Linear,
num_heads: usize,
head_dim: usize,
rel_pos: Option<SinusoidalEmbeddings>,
}
impl LocalMHA {
fn new(
dim: usize,
window_size: usize,
dim_head: usize,
use_rotary_pos_emb: bool,
vb: VarBuilder,
) -> Result<Self> {
let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
let to_qkv = linear_b(dim, dim * 3, false, vb.pp("to_qkv"))?;
let to_out = linear_b(dim, dim, false, vb.pp("to_out"))?;
let rel_pos = if use_rotary_pos_emb {
let rel_pos =
SinusoidalEmbeddings::new(dim_head, window_size as f32 / 2.0, false, vb.device())?;
Some(rel_pos)
} else {
None
};
Ok(Self {
norm,
to_qkv,
to_out,
rel_pos,
num_heads: dim / dim_head,
head_dim: dim_head,
})
}
}
impl Module for LocalMHA {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b, c, t) = xs.dims3()?;
let residual = xs.clone();
let xs = xs.transpose(1, 2)?.apply(&self.norm)?;
let qkv = xs.apply(&self.to_qkv)?;
let q = qkv.narrow(D::Minus1, 0, c)?;
let k = qkv.narrow(D::Minus1, c, c)?;
let v = qkv.narrow(D::Minus1, 2 * c, c)?;
let q = q
.reshape((b, t, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b, t, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let v = v
.reshape((b, t, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let (q, k) = match self.rel_pos {
Some(_) => todo!(),
None => (q, k),
};
let out = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
// Non-causal attention
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&v)?
};
let out = out
.transpose(1, 2)?
.reshape((b, t, self.num_heads * self.head_dim))?
.apply(&self.to_out)?;
out.transpose(1, 2)? + residual
}
}
#[derive(Debug, Clone)]
struct Snake1d {
alpha: Tensor,
}
impl Snake1d {
pub fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
let alpha = vb.get((1, channels, 1), "alpha")?;
Ok(Self { alpha })
}
}
impl Module for Snake1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs_shape = xs.shape();
let xs = xs.flatten_from(2)?;
let sin = self.alpha.broadcast_mul(&xs)?.sin()?;
let sin = (&sin * &sin)?;
(xs + (&self.alpha + 1e-9)?.recip()?.broadcast_mul(&sin)?)?.reshape(xs_shape)
}
}
#[derive(Debug, Clone)]
struct ResidualUnit {
snake1: Snake1d,
conv1: Conv1d,
snake2: Snake1d,
conv2: Conv1d,
}
impl ResidualUnit {
fn new(
dim: usize,
dilation: usize,
kernel: usize,
groups: usize,
vb: VarBuilder,
) -> Result<Self> {
let pad = ((kernel - 1) * dilation) / 2;
let vb = vb.pp("block");
let snake1 = Snake1d::new(dim, vb.pp(0))?;
let cfg1 = Conv1dConfig {
dilation,
padding: pad,
groups,
..Default::default()
};
let conv1 = conv1d_weight_norm(dim, dim, 7, cfg1, vb.pp(1))?;
let snake2 = Snake1d::new(dim, vb.pp(2))?;
let conv2 = conv1d_weight_norm(dim, dim, 1, Default::default(), vb.pp(3))?;
Ok(Self {
snake1,
conv1,
snake2,
conv2,
})
}
}
impl Module for ResidualUnit {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let ys = xs
.apply(&self.snake1)?
.apply(&self.conv1)?
.apply(&self.snake2)?
.apply(&self.conv2)?;
let pad = (xs.dim(D::Minus1)? - ys.dim(D::Minus1)?) / 2;
if pad > 0 {
&ys + xs.narrow(D::Minus1, pad, ys.dim(D::Minus1)?)
} else {
ys + xs
}
}
}
#[derive(Debug, Clone)]
struct NoiseBlock {
linear: Conv1d,
}
impl NoiseBlock {
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let linear = conv1d_weight_norm_no_bias(dim, dim, 1, Default::default(), vb.pp("linear"))?;
Ok(Self { linear })
}
}
impl Module for NoiseBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b, _c, t) = xs.dims3()?;
let noise = Tensor::randn(0f32, 1f32, (b, 1, t), xs.device())?;
let h = xs.apply(&self.linear)?;
let n = noise.broadcast_mul(&h)?;
let xs = (xs + n)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
struct DecoderBlock {
snake1: Snake1d,
conv_tr1: ConvTranspose1d,
noise: Option<NoiseBlock>,
res1: ResidualUnit,
res2: ResidualUnit,
res3: ResidualUnit,
}
impl DecoderBlock {
fn new(
in_dim: usize,
out_dim: usize,
stride: usize,
noise: bool,
groups: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb = vb.pp("block");
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
let cfg = ConvTranspose1dConfig {
stride,
padding: stride.div_ceil(2),
output_padding: stride % 2,
..Default::default()
};
let conv_tr1 =
conv_transpose1d_weight_norm(in_dim, out_dim, 2 * stride, true, cfg, vb.pp(1))?;
let (n, noise) = if noise {
let noise = NoiseBlock::new(out_dim, vb.pp(2))?;
(1, Some(noise))
} else {
(0, None)
};
let res1 = ResidualUnit::new(out_dim, 1, 7, groups, vb.pp(2 + n))?;
let res2 = ResidualUnit::new(out_dim, 3, 7, groups, vb.pp(3 + n))?;
let res3 = ResidualUnit::new(out_dim, 9, 7, groups, vb.pp(4 + n))?;
Ok(Self {
snake1,
conv_tr1,
noise,
res1,
res2,
res3,
})
}
}
impl Module for DecoderBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.snake1)?
.apply(&self.conv_tr1)?
.apply(&self.noise.as_ref())?
.apply(&self.res1)?
.apply(&self.res2)?
.apply(&self.res3)
}
}
#[derive(Debug, Clone)]
struct EncoderBlock {
res1: ResidualUnit,
res2: ResidualUnit,
res3: ResidualUnit,
snake1: Snake1d,
conv1: Conv1d,
}
impl EncoderBlock {
fn new(
out_dim: usize,
in_dim: Option<usize>,
stride: usize,
groups: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb = vb.pp("block");
let in_dim = in_dim.unwrap_or(out_dim / 2);
let res1 = ResidualUnit::new(in_dim, 1, 7, groups, vb.pp(0))?;
let res2 = ResidualUnit::new(in_dim, 3, 7, groups, vb.pp(1))?;
let res3 = ResidualUnit::new(in_dim, 9, 7, groups, vb.pp(2))?;
let snake1 = Snake1d::new(in_dim, vb.pp(3))?;
let cfg1 = Conv1dConfig {
stride,
padding: stride.div_ceil(2),
..Default::default()
};
let conv1 = conv1d_weight_norm(in_dim, out_dim, 2 * stride, cfg1, vb.pp(4))?;
Ok(Self {
res1,
res2,
res3,
snake1,
conv1,
})
}
}
impl candle::Module for EncoderBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.res1)?
.apply(&self.res2)?
.apply(&self.res3)?
.apply(&self.snake1)?
.apply(&self.conv1)
}
}
#[derive(Debug, Clone)]
pub struct Encoder {
conv1: Conv1d,
blocks: Vec<EncoderBlock>,
local_mha: Option<LocalMHA>,
conv2: Conv1d,
}
impl candle::Module for Encoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.conv1)?;
for block in self.blocks.iter() {
xs = xs.apply(block)?
}
xs.apply(&self.conv2)
}
}
impl Encoder {
fn new(
mut d_model: usize,
strides: &[usize],
depthwise: bool,
attn_window_size: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let vb = vb.pp("block");
let mut idx = 0;
let cfg1 = Conv1dConfig {
padding: 3,
..Default::default()
};
let conv1 = conv1d_weight_norm(1, d_model, 7, cfg1, vb.pp(idx))?;
idx += 1;
let mut blocks = Vec::with_capacity(strides.len());
for &stride in strides.iter() {
d_model *= 2;
let groups = if depthwise { d_model / 2 } else { 1 };
let block = EncoderBlock::new(d_model, None, stride, groups, vb.pp(idx))?;
idx += 1;
blocks.push(block)
}
let local_mha = match attn_window_size {
Some(w) => {
let mha = LocalMHA::new(d_model, w, 64, true, vb.pp(idx))?;
idx += 1;
Some(mha)
}
None => None,
};
let groups = if depthwise { d_model } else { 1 };
let cfg2 = Conv1dConfig {
padding: 3,
groups,
..Default::default()
};
let conv2 = conv1d_weight_norm(d_model, d_model, 7, cfg2, vb.pp(idx))?;
idx += 1;
Ok(Self {
conv1,
blocks,
local_mha,
conv2,
})
}
}
#[derive(Debug, Clone)]
enum ConvInit {
Depthwise(Conv1d, Conv1d),
Standard(Conv1d),
}
#[derive(Debug, Clone)]
pub struct Decoder {
conv1: ConvInit,
local_mha: Option<LocalMHA>,
blocks: Vec<DecoderBlock>,
snake1: Snake1d,
conv2: Conv1d,
}
impl Decoder {
#[allow(clippy::too_many_arguments)]
fn new(
in_c: usize,
mut channels: usize,
rates: &[usize],
noise: bool,
depthwise: bool,
attn_window_size: Option<usize>,
d_out: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb = vb.pp("model");
let mut idx = 0;
let pad3 = Conv1dConfig {
padding: 3,
..Default::default()
};
let conv1 = if depthwise {
let cfg1 = Conv1dConfig {
padding: 3,
groups: in_c,
..Default::default()
};
let conv1 = conv1d_weight_norm(in_c, in_c, 7, cfg1, vb.pp(idx))?;
idx += 1;
let conv2 = conv1d_weight_norm(in_c, channels, 1, Default::default(), vb.pp(idx))?;
idx += 1;
ConvInit::Depthwise(conv1, conv2)
} else {
let conv1 = conv1d_weight_norm(in_c, channels, 7, pad3, vb.pp(idx))?;
idx += 1;
ConvInit::Standard(conv1)
};
let mut blocks = Vec::with_capacity(rates.len());
let local_mha = match attn_window_size {
Some(w) => {
let mha = LocalMHA::new(channels, w, 64, true, vb.pp(idx))?;
idx += 1;
Some(mha)
}
None => None,
};
for stride in rates.iter() {
let groups = if depthwise { channels / 2 } else { 1 };
let block =
DecoderBlock::new(channels, channels / 2, *stride, noise, groups, vb.pp(idx))?;
idx += 1;
channels /= 2;
blocks.push(block)
}
let snake1 = Snake1d::new(channels, vb.pp(idx))?;
idx += 1;
let conv2 = conv1d_weight_norm(channels, d_out, 7, pad3, vb.pp(idx))?;
idx += 1;
Ok(Self {
conv1,
local_mha,
blocks,
snake1,
conv2,
})
}
}
impl candle::Module for Decoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = match &self.conv1 {
ConvInit::Standard(c) => xs.apply(c)?,
ConvInit::Depthwise(c1, c2) => xs.apply(c1)?.apply(c2)?,
};
for block in self.blocks.iter() {
xs = xs.apply(block)?
}
xs.apply(&self.snake1)?.apply(&self.conv2)
}
}
fn normalize(v: &Tensor) -> Result<Tensor> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}
// https://github.com/hubertsiuzdak/snac/blob/main/snac/vq.py
#[allow(unused)]
#[derive(Clone, Debug)]
struct VectorQuantizer {
in_proj: Conv1d,
out_proj: Conv1d,
codebook: candle_nn::Embedding,
stride: usize,
}
impl VectorQuantizer {
fn new(
in_dim: usize,
cb_size: usize,
cb_dim: usize,
stride: usize,
vb: VarBuilder,
) -> Result<Self> {
let in_proj = conv1d_weight_norm(in_dim, cb_dim, 1, Default::default(), vb.pp("in_proj"))?;
let out_proj =
conv1d_weight_norm(cb_dim, in_dim, 1, Default::default(), vb.pp("out_proj"))?;
let codebook = candle_nn::embedding(cb_size, cb_dim, vb.pp("codebook"))?;
Ok(Self {
in_proj,
out_proj,
codebook,
stride,
})
}
fn decode_latents(&self, latents: &Tensor) -> Result<(Tensor, Tensor)> {
let (b, d, t) = latents.dims3()?;
let encodings = latents.transpose(1, 2)?.reshape((b * t, d))?;
let encodings = normalize(&encodings)?;
let codebook = normalize(self.codebook.embeddings())?;
let dist = (encodings
.sqr()?
.sum_keepdim(1)?
.broadcast_sub(&encodings.matmul(&codebook.t()?)?)?
* 2.0)?
.broadcast_add(&codebook.sqr()?.sum_keepdim(1)?.t()?)?;
let indices = dist.argmin(1)?.reshape((b, ()))?;
let z_q = self.decode_code(&indices)?;
Ok((z_q, indices))
}
fn encode(&self, z: &Tensor) -> Result<(Tensor, Tensor)> {
let z = if self.stride > 1 {
let (b, c, t) = z.dims3()?;
z.reshape((b, c, 1, t))?
.avg_pool2d((1, self.stride))?
.squeeze(2)?
} else {
z.clone()
};
let z_e = z.apply(&self.in_proj)?;
let (z_q, indices) = self.decode_latents(&z_e)?;
let z_q = z_q.apply(&self.out_proj)?;
let z_q = if self.stride > 1 {
repeat_interleave(&z_q, self.stride, D::Minus1)?
} else {
z_q
};
Ok((z_q, indices))
}
fn embed_code(&self, embed_id: &Tensor) -> Result<Tensor> {
embed_id.apply(&self.codebook)
}
fn decode_code(&self, embed_id: &Tensor) -> Result<Tensor> {
self.embed_code(embed_id)?.transpose(1, 2)
}
}
#[derive(Clone, Debug)]
pub struct ResidualVectorQuantizer {
quantizers: Vec<VectorQuantizer>,
}
impl ResidualVectorQuantizer {
fn new(
input_dim: usize,
cb_size: usize,
cb_dim: usize,
vq_strides: &[usize],
vb: VarBuilder,
) -> Result<Self> {
let vb = &vb.pp("quantizers");
let quantizers = vq_strides
.iter()
.enumerate()
.map(|(i, stride)| VectorQuantizer::new(input_dim, cb_size, cb_dim, *stride, vb.pp(i)))
.collect::<Result<Vec<_>>>()?;
Ok(Self { quantizers })
}
fn encode(&self, z: &Tensor) -> Result<(Tensor, Vec<Tensor>)> {
let mut residual = z.clone();
let mut z_q = z.zeros_like()?;
let mut codes = Vec::with_capacity(self.quantizers.len());
for quantizer in self.quantizers.iter() {
let (z_q_i, indices_i) = quantizer.encode(&residual)?;
z_q = (z_q + &z_q_i)?;
residual = (residual - &z_q_i)?;
codes.push(indices_i)
}
Ok((z_q, codes))
}
#[allow(clippy::wrong_self_convention)]
fn from_codes(&self, codes: &[&Tensor]) -> Result<Tensor> {
let mut sum = None;
for (quantizer, codes) in self.quantizers.iter().zip(codes.iter()) {
let z_p_i = quantizer.decode_code(codes)?;
let z_q_i = z_p_i.apply(&quantizer.out_proj)?;
let z_q_i = repeat_interleave(&z_q_i, quantizer.stride, D::Minus1)?;
let s = match sum {
None => z_q_i,
Some(s) => (s + z_q_i)?,
};
sum = Some(s)
}
match sum {
Some(s) => Ok(s),
None => candle::bail!("empty codebooks"),
}
}
}
fn gcd(mut a: usize, mut b: usize) -> usize {
while b != 0 {
let t = b;
b = a % b;
a = t;
}
a
}
fn lcm(a: usize, b: usize) -> usize {
a / gcd(a, b) * b
}
// https://github.com/hubertsiuzdak/snac/blob/main/snac/snac.py
#[derive(Debug, Clone)]
pub struct Model {
pub encoder: Encoder,
pub quantizer: ResidualVectorQuantizer,
pub decoder: Decoder,
pub hop_length: usize,
pub config: Config,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let encoder = Encoder::new(
cfg.encoder_dim,
&cfg.encoder_rates,
cfg.depthwise,
cfg.attn_window_size,
vb.pp("encoder"),
)?;
let latent_dim = cfg.encoder_dim * 2usize.pow(cfg.encoder_rates.len() as u32);
let quantizer = ResidualVectorQuantizer::new(
latent_dim,
cfg.codebook_size,
cfg.codebook_dim,
&cfg.vq_strides,
vb.pp("quantizer"),
)?;
let decoder = Decoder::new(
latent_dim,
cfg.decoder_dim,
&cfg.decoder_rates,
cfg.noise,
cfg.depthwise,
cfg.attn_window_size,
/* d_out */ 1,
vb.pp("decoder"),
)?;
let hop_length = cfg.encoder_rates.iter().product::<usize>();
Ok(Self {
encoder,
decoder,
quantizer,
config: cfg.clone(),
hop_length,
})
}
fn preprocess(&self, audio_data: &Tensor) -> Result<Tensor> {
let len = audio_data.dim(D::Minus1)?;
let lcm = lcm(
self.config.vq_strides[0],
self.config.attn_window_size.unwrap_or(1),
);
let pad_to = self.hop_length * lcm;
let right_pad = len.div_ceil(pad_to) * pad_to - len;
let audio_data = audio_data.pad_with_zeros(D::Minus1, 0, right_pad)?;
Ok(audio_data)
}
pub fn encode(&self, audio_data: &Tensor) -> Result<Vec<Tensor>> {
let audio_data = self.preprocess(audio_data)?;
let z = self.encoder.forward(&audio_data)?;
let (_, codes) = self.quantizer.encode(&z)?;
Ok(codes)
}
pub fn decode(&self, audio_codes: &[&Tensor]) -> Result<Tensor> {
let audio_values = self.quantizer.from_codes(audio_codes)?;
audio_values.apply(&self.decoder)
}
pub fn config(&self) -> &Config {
&self.config
}
pub fn num_codebooks(&self) -> usize {
self.quantizer.quantizers.len()
}
}

View File

@ -68,6 +68,7 @@ impl ResnetBlock2D {
padding: 1,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
@ -83,6 +84,7 @@ impl ResnetBlock2D {
padding: 0,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
Some(conv2d(
in_channels,

View File

@ -198,7 +198,7 @@ pub fn log_mel_spectrogram_<T: Float>(
let samples = {
let mut samples_padded = samples.to_vec();
let to_add = n_len * fft_step - samples.len();
samples_padded.extend(std::iter::repeat(zero).take(to_add));
samples_padded.extend(std::iter::repeat_n(zero, to_add));
samples_padded
};

View File

@ -248,12 +248,14 @@ impl AudioEncoder {
stride: 1,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
let cfg2 = Conv1dConfig {
padding: 1,
stride: 2,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;

View File

@ -244,12 +244,14 @@ impl AudioEncoder {
stride: 1,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
let cfg2 = Conv1dConfig {
padding: 1,
stride: 2,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;

View File

@ -54,3 +54,25 @@ fn sample_with_top_k() -> Result<()> {
assert_eq!(token, 2);
Ok(())
}
#[test]
fn sample_gumbel() -> Result<()> {
let mut logits_process = LogitsProcessor::from_sampling(
42,
candle_transformers::generation::Sampling::GumbelSoftmax { temperature: 1.0 },
);
let logits = Tensor::new(&[-1.0, 0.0, 0.2, 1.0], &Device::Cpu)?;
let sm = candle_nn::ops::softmax(&logits, 0)?.to_vec1::<f64>()?;
let mut counts = vec![0f64; 4];
let samples = 100000;
for _ in 0..samples {
let token = logits_process.sample(&logits)?;
counts[token as usize] += 1f64 / samples as f64;
}
for i in 0..4 {
if (counts[i] - sm[i]).abs() > 0.05 {
panic!("pr mismatch {counts:?} {sm:?}");
}
}
Ok(())
}

View File

@ -177,7 +177,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
let samples = {
let mut samples_padded = samples.to_vec();
let to_add = n_len * fft_step - samples.len();
samples_padded.extend(std::iter::repeat(zero).take(to_add));
samples_padded.extend(std::iter::repeat_n(zero, to_add));
samples_padded
};

View File

@ -98,6 +98,7 @@ impl ConvBlock {
stride,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;