mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Compare commits
65 Commits
Author | SHA1 | Date | |
---|---|---|---|
17313a4226 | |||
0224a749f0 | |||
cd7b877d6b | |||
5aed817f1b | |||
1a183c988a | |||
cac51fe16a | |||
61ddb9535e | |||
9a62c91643 | |||
92106c8762 | |||
9ce4fe6194 | |||
450a49ed1a | |||
6bd61727bc | |||
485ddf2996 | |||
36508a2c93 | |||
3d05f5cf3d | |||
637473cb5e | |||
e27b4700ad | |||
1fdfb58de5 | |||
cd96fa80da | |||
8a19bb7df2 | |||
38fc86621c | |||
5029ac52bb | |||
de23d34a28 | |||
d4bac37a61 | |||
e98754fc5a | |||
e3db30021f | |||
6e0646c208 | |||
fbaf0b0e32 | |||
a2e925462c | |||
3827685524 | |||
3aeb9575c7 | |||
6ff0a6999c | |||
82def7ae38 | |||
99bd69f383 | |||
a4c56a958e | |||
b2904a830b | |||
21055b5697 | |||
9dbaf958dc | |||
ce5f8dd129 | |||
9954981327 | |||
7f0f83a7c1 | |||
76e565c4ab | |||
e4e7b0b2da | |||
b01ebbad8a | |||
1d1d6d4fe6 | |||
2653002f29 | |||
a52b76ae82 | |||
fb660b8d43 | |||
2f9606b187 | |||
f3a73f80d1 | |||
b44d38de0e | |||
d9198deb37 | |||
15ed0b11ce | |||
34505fdf3a | |||
d7b7ce16e4 | |||
19fb6dac1f | |||
acc5bd335f | |||
eb478ece92 | |||
d339b01726 | |||
2f3bf42bcb | |||
e3370c6316 | |||
338f6a102e | |||
bc33df77e1 | |||
cf9d7bf24c | |||
9d31361c4f |
40
.github/workflows/book-cd.yml
vendored
40
.github/workflows/book-cd.yml
vendored
@ -1,40 +0,0 @@
|
||||
name: Deploy Rust book
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write # To push a branch
|
||||
pull-requests: write # To create a PR from that branch
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Install latest mdbook
|
||||
run: |
|
||||
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
|
||||
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
|
||||
mkdir mdbook
|
||||
curl -sSL $url | tar -xz --directory=./mdbook
|
||||
echo `pwd`/mdbook >> $GITHUB_PATH
|
||||
- name: Deploy GitHub Pages
|
||||
run: |
|
||||
# This assumes your book is in the root of your repository.
|
||||
# Just add a `cd` here if you need to change to another directory.
|
||||
cd candle-book
|
||||
mdbook build
|
||||
git worktree add gh-pages
|
||||
git config user.name "Deploy from CI"
|
||||
git config user.email ""
|
||||
cd gh-pages
|
||||
# Delete the ref to avoid keeping history.
|
||||
git update-ref -d refs/heads/gh-pages
|
||||
rm -rf *
|
||||
mv ../book/* .
|
||||
git add .
|
||||
git commit -m "Deploy $GITHUB_SHA to gh-pages"
|
||||
git push --force --set-upstream origin gh-pages
|
29
.github/workflows/book.yml
vendored
29
.github/workflows/book.yml
vendored
@ -1,29 +0,0 @@
|
||||
name: CI
|
||||
on:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Test candle-book
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write # To push a branch
|
||||
pull-requests: write # To create a PR from that branch
|
||||
steps:
|
||||
- uses: actions/checkout@master
|
||||
- name: Install Rust
|
||||
run: |
|
||||
rustup set profile minimal
|
||||
rustup toolchain install stable
|
||||
rustup default stable
|
||||
- name: Install latest mdbook
|
||||
run: |
|
||||
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
|
||||
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
|
||||
mkdir bin
|
||||
curl -sSL $url | tar -xz --directory=bin
|
||||
echo "$(pwd)/bin" >> $GITHUB_PATH
|
||||
- name: Run tests
|
||||
run: cd candle-book && cargo build && mdbook test -L ../target/debug/deps/
|
||||
|
||||
|
28
Cargo.toml
28
Cargo.toml
@ -3,7 +3,6 @@ members = [
|
||||
"candle-core",
|
||||
"candle-datasets",
|
||||
"candle-examples",
|
||||
"candle-book",
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
"candle-transformers",
|
||||
@ -12,6 +11,7 @@ members = [
|
||||
"tensor-tools",
|
||||
]
|
||||
exclude = [
|
||||
"candle-book",
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
"candle-metal-kernels",
|
||||
@ -20,7 +20,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.9.0-alpha.1"
|
||||
version = "0.9.1"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.1" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.1" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.1" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.1" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.1" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.1" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.1" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.1" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.9.1" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.1" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.9.1" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.1" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.9.1" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.9.1" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.9.1" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.14.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
cudarc = { version = "0.16.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.4.1"
|
||||
@ -70,9 +70,9 @@ tokenizers = { version = "0.21.0", default-features = false }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
ug = "0.2.0"
|
||||
ug-cuda = "0.2.0"
|
||||
ug-metal = "0.2.0"
|
||||
ug = "0.4.0"
|
||||
ug-cuda = "0.4.0"
|
||||
ug-metal = "0.4.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "1.1.1", default-features = false }
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
|
@ -290,6 +290,8 @@ Cheatsheet:
|
||||
|
||||
### Why should I use Candle?
|
||||
|
||||
<!--- ANCHOR: goals --->
|
||||
|
||||
Candle's core goal is to *make serverless inference possible*. Full machine learning frameworks like PyTorch
|
||||
are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight
|
||||
binaries.
|
||||
@ -299,6 +301,7 @@ and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-fut
|
||||
|
||||
Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers).
|
||||
|
||||
<!--- ANCHOR_END: goals --->
|
||||
|
||||
### Other ML frameworks
|
||||
|
||||
|
13
candle-book/CONTRIBUTING.md
Normal file
13
candle-book/CONTRIBUTING.md
Normal file
@ -0,0 +1,13 @@
|
||||
# Candle Book
|
||||
|
||||
The book uses [mdBook](https://github.com/rust-lang/mdBook) for building.
|
||||
|
||||
## Installation
|
||||
|
||||
To install mdBook, run `cargo install mdbook`. More instructions can be found [here](https://rust-lang.github.io/mdBook/guide/installation.html).
|
||||
|
||||
## Viewing the book
|
||||
|
||||
To view the book, run `mdbook serve --open candle-book`. More instructions can be found [here](https://rust-lang.github.io/mdBook/guide/creating.html).
|
||||
|
||||
The book is built automatically in github CI.
|
@ -1,6 +1,7 @@
|
||||
# Introduction
|
||||
|
||||
{{#include ../../README.md:goals}}
|
||||
|
||||
{{#include ../../README.md:features}}
|
||||
|
||||
|
||||
This book will introduce step by step how to use `candle`.
|
||||
This book will introduce step by step how to use `candle`.
|
@ -5,7 +5,10 @@
|
||||
# User Guide
|
||||
|
||||
- [Installation](guide/installation.md)
|
||||
- [Hello World - MNIST](guide/hello_world.md)
|
||||
- [Tutorial - MNIST](guide/mnist/intro.md)
|
||||
- [Modeling](guide/mnist/modeling.md)
|
||||
- [Training](guide/mnist/training.md)
|
||||
- [Saving And Loading](guide/mnist/saving_loading.md)
|
||||
- [PyTorch cheatsheet](guide/cheatsheet.md)
|
||||
|
||||
# Reference Guide
|
||||
@ -13,6 +16,7 @@
|
||||
- [Running a model](inference/inference.md)
|
||||
- [Using the hub](inference/hub.md)
|
||||
- [Error management](error_manage.md)
|
||||
- [Tracing](tracing.md)
|
||||
- [Training](training/training.md)
|
||||
- [Simplified](training/simplified.md)
|
||||
- [MNIST](training/mnist.md)
|
||||
|
@ -1,8 +1,23 @@
|
||||
# Installation
|
||||
|
||||
**With Cuda support**:
|
||||
## 1. Create a new rust app or library
|
||||
|
||||
1. First, make sure that Cuda is correctly installed.
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
```
|
||||
|
||||
## 2. Add the correct candle version
|
||||
|
||||
### Standard
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core
|
||||
```
|
||||
|
||||
### CUDA
|
||||
|
||||
First, make sure that Cuda is correctly installed.
|
||||
- `nvcc --version` should print information about your Cuda compiler driver.
|
||||
- `nvidia-smi --query-gpu=compute_cap --format=csv` should print your GPUs compute capability, e.g. something
|
||||
like:
|
||||
@ -17,43 +32,36 @@ You can also compile the Cuda kernels for a specific compute cap using the
|
||||
|
||||
If any of the above commands errors out, please make sure to update your Cuda version.
|
||||
|
||||
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
|
||||
|
||||
Start by creating a new cargo:
|
||||
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
```
|
||||
|
||||
Make sure to add the `candle-core` crate with the cuda feature:
|
||||
Add the `candle-core` crate with the cuda feature:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"
|
||||
```
|
||||
|
||||
### MKL
|
||||
|
||||
You can also see the `mkl` feature which can get faster inference on CPU.
|
||||
|
||||
Add the `candle-core` crate with the mkl feature:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features "mkl"
|
||||
```
|
||||
|
||||
### Metal
|
||||
|
||||
Metal is exclusive to MacOS.
|
||||
|
||||
Add the `candle-core` crate with the metal feature:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features "metal"
|
||||
```
|
||||
|
||||
## 3. Building
|
||||
|
||||
Run `cargo build` to make sure everything can be correctly built.
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
|
||||
**Without Cuda support**:
|
||||
|
||||
Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) as follows:
|
||||
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core
|
||||
```
|
||||
|
||||
Finally, run `cargo build` to make sure everything can be correctly built.
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
|
||||
**With mkl support**
|
||||
|
||||
You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md)
|
||||
|
17
candle-book/src/guide/mnist/intro.md
Normal file
17
candle-book/src/guide/mnist/intro.md
Normal file
@ -0,0 +1,17 @@
|
||||
# Candle MNIST Tutorial
|
||||
|
||||
## Introduction
|
||||
|
||||
This tutorial provides an introduction to Candle by implementing and training a neural network for MNIST digit classification from scratch.
|
||||
|
||||
Throughout this tutorial, you will learn the basics of:
|
||||
|
||||
- Tensor operations and model construction
|
||||
- Creating and implementing neural network layers
|
||||
- Parameter initialization
|
||||
- Training loop implementation
|
||||
- Saving and loading trained models
|
||||
|
||||
## Getting Started
|
||||
|
||||
Before proceeding, please ensure that you have properly installed Candle by following the instructions in the [Installation](../installation.md) guide.
|
172
candle-book/src/guide/mnist/modeling.md
Normal file
172
candle-book/src/guide/mnist/modeling.md
Normal file
@ -0,0 +1,172 @@
|
||||
# Candle MNIST Tutorial
|
||||
|
||||
## Modeling
|
||||
|
||||
Open `src/main.rs` in your project folder and insert the following code:
|
||||
|
||||
```rust
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
|
||||
struct Model {
|
||||
first: Tensor,
|
||||
second: Tensor,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = image.matmul(&self.first)?;
|
||||
let x = x.relu()?;
|
||||
x.matmul(&self.second)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; to utilize GPU acceleration.
|
||||
let device = Device::Cpu;
|
||||
|
||||
let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
|
||||
let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Execute the program with:
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> Digit Tensor[dims 1, 10; f32] digit
|
||||
```
|
||||
|
||||
Since random inputs are provided, expect an incoherent output.
|
||||
|
||||
## Implementing a `Linear` Layer
|
||||
|
||||
To create a more sophisticated layer type, add a `bias` to the weight to construct the standard `Linear` layer.
|
||||
|
||||
Replace the entire content of `src/main.rs` with:
|
||||
|
||||
```rust
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
|
||||
struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.matmul(&self.weight)?;
|
||||
x.broadcast_add(&self.bias)
|
||||
}
|
||||
}
|
||||
|
||||
struct Model {
|
||||
first: Linear,
|
||||
second: Linear,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = self.first.forward(image)?;
|
||||
let x = x.relu()?;
|
||||
self.second.forward(&x)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; for GPU acceleration.
|
||||
// Use Device::Cpu; for CPU computation.
|
||||
let device = Device::cuda_if_available(0)?;
|
||||
|
||||
// Initialize model parameters
|
||||
let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
|
||||
let first = Linear { weight, bias };
|
||||
let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
|
||||
let second = Linear { weight, bias };
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
|
||||
// Perform inference
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Execute again with:
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> Digit Tensor[dims 1, 10; f32] digit
|
||||
```
|
||||
|
||||
## Utilizing `candle_nn`
|
||||
|
||||
Many classical layers (such as [Linear](https://github.com/huggingface/candle/blob/main/candle-nn/src/linear.rs)) are already implemented in [candle-nn](https://github.com/huggingface/candle/tree/main/candle-nn).
|
||||
|
||||
This `Linear` implementation follows PyTorch conventions for improved compatibility with existing models, utilizing the transpose of weights rather than direct weights.
|
||||
|
||||
Let's simplify our implementation. First, add `candle-nn` as a dependency:
|
||||
|
||||
```bash
|
||||
$ cargo add --git https://github.com/huggingface/candle.git candle-nn
|
||||
```
|
||||
|
||||
Now, replace the entire content of `src/main.rs` with:
|
||||
|
||||
```rust
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use candle_nn::{Linear, Module};
|
||||
|
||||
struct Model {
|
||||
first: Linear,
|
||||
second: Linear,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = self.first.forward(image)?;
|
||||
let x = x.relu()?;
|
||||
self.second.forward(&x)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; for GPU acceleration.
|
||||
let device = Device::Cpu;
|
||||
|
||||
// Note the dimension change: (784, 100) -> (100, 784)
|
||||
let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
|
||||
let first = Linear::new(weight, Some(bias));
|
||||
let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
|
||||
let second = Linear::new(weight, Some(bias));
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Execute the final version:
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> Digit Tensor[dims 1, 10; f32] digit
|
||||
```
|
158
candle-book/src/guide/mnist/saving_loading.md
Normal file
158
candle-book/src/guide/mnist/saving_loading.md
Normal file
@ -0,0 +1,158 @@
|
||||
# Candle MNIST Tutorial
|
||||
|
||||
## Saving and Loading Models
|
||||
|
||||
After training a model, it is useful to save and subsequently load the model parameters. In Candle, this functionality is managed through the `VarMap` data structure, with parameters stored on disk using the [safetensors](https://huggingface.co/docs/safetensors/index) format.
|
||||
|
||||
### Saving Model Parameters
|
||||
|
||||
Let's modify our `training_loop` function to include functionality for saving weights:
|
||||
|
||||
```rust
|
||||
fn training_loop(
|
||||
m: candle_datasets::vision::Dataset,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images.to_device(&dev)?;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
// Initialize a VarMap for trainable parameters
|
||||
let varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||
let model = Model::new(vs.clone())?;
|
||||
|
||||
let learning_rate = 0.05;
|
||||
let epochs = 10;
|
||||
|
||||
// Initialize stochastic gradient descent optimizer
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
for epoch in 1..epochs {
|
||||
// Standard MNIST forward pass
|
||||
let logits = model.forward(&train_images)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
|
||||
// Compute Negative Log Likelihood loss
|
||||
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||
|
||||
// Perform backward pass and update weights
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
// Evaluate model on test set
|
||||
let test_logits = model.forward(&test_images)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
test_accuracy
|
||||
);
|
||||
}
|
||||
|
||||
// Save model weights to disk
|
||||
varmap.save("model_weights.safetensors")?;
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> 1 train loss: 2.40485 test acc: 0.11%
|
||||
> 2 train loss: 2.34161 test acc: 0.14%
|
||||
> 3 train loss: 2.28841 test acc: 0.17%
|
||||
> 4 train loss: 2.24158 test acc: 0.19%
|
||||
> 5 train loss: 2.19898 test acc: 0.23%
|
||||
> 6 train loss: 2.15927 test acc: 0.26%
|
||||
> 7 train loss: 2.12161 test acc: 0.29%
|
||||
> 8 train loss: 2.08549 test acc: 0.32%
|
||||
> 9 train loss: 2.05053 test acc: 0.35%
|
||||
```
|
||||
|
||||
### Loading Model Parameters
|
||||
|
||||
Now that we have saved our model parameters, we can modify the code to load them. The primary change required is to make the `varmap` variable mutable:
|
||||
|
||||
```rust
|
||||
fn training_loop(
|
||||
m: candle_datasets::vision::Dataset,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images.to_device(&dev)?;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
// Create a mutable VarMap for trainable parameters
|
||||
let mut varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||
let model = Model::new(vs.clone())?;
|
||||
|
||||
// Load pre-trained weights from file
|
||||
varmap.load("model_weights.safetensors")?;
|
||||
|
||||
let learning_rate = 0.05;
|
||||
let epochs = 10;
|
||||
|
||||
// Initialize stochastic gradient descent optimizer
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
for epoch in 1..epochs {
|
||||
// Standard MNIST forward pass
|
||||
let logits = model.forward(&train_images)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
|
||||
// Compute Negative Log Likelihood loss
|
||||
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||
|
||||
// Perform backward pass and update weights
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
// Evaluate model on test set
|
||||
let test_logits = model.forward(&test_images)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
test_accuracy
|
||||
);
|
||||
}
|
||||
|
||||
// Save updated weights back to disk
|
||||
varmap.save("model_weights.safetensors")?;
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> 1 train loss: 2.01645 test acc: 0.38%
|
||||
> 2 train loss: 1.98300 test acc: 0.41%
|
||||
> 3 train loss: 1.95008 test acc: 0.44%
|
||||
> 4 train loss: 1.91754 test acc: 0.47%
|
||||
> 5 train loss: 1.88534 test acc: 0.50%
|
||||
> 6 train loss: 1.85349 test acc: 0.53%
|
||||
> 7 train loss: 1.82198 test acc: 0.56%
|
||||
> 8 train loss: 1.79077 test acc: 0.59%
|
||||
> 9 train loss: 1.75989 test acc: 0.61%
|
||||
```
|
||||
|
||||
Note that loading the weights will fail if the specified file does not exist or is incompatible with the current model architecture. Implementing file existence checks and appropriate error handling is left to the user.
|
134
candle-book/src/guide/mnist/training.md
Normal file
134
candle-book/src/guide/mnist/training.md
Normal file
@ -0,0 +1,134 @@
|
||||
# Candle MNIST Tutorial
|
||||
|
||||
## Training Implementation
|
||||
|
||||
First, let's create a utility function `make_linear` that accepts a `VarBuilder` and returns an initialized linear layer. The `VarBuilder` constructs a `VarMap`, which is the data structure that stores our trainable parameters.
|
||||
|
||||
```rust
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use candle_nn::{Linear, Module, VarBuilder, VarMap};
|
||||
|
||||
fn make_linear(vs: VarBuilder, in_dim: usize, out_dim: usize) -> Result<Linear> {
|
||||
let ws = vs.get_with_hints(
|
||||
(out_dim, in_dim),
|
||||
"weight",
|
||||
candle_nn::init::DEFAULT_KAIMING_NORMAL,
|
||||
)?;
|
||||
let bound = 1. / (in_dim as f64).sqrt();
|
||||
let bs = vs.get_with_hints(
|
||||
out_dim,
|
||||
"bias",
|
||||
candle_nn::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
},
|
||||
)?;
|
||||
Ok(Linear::new(ws, Some(bs)))
|
||||
}
|
||||
```
|
||||
|
||||
Next, let's implement a `new` method for our model class to accept a `VarBuilder` and initialize the model. We use `VarBuilder::pp` to "push prefix" so that the parameter names are organized hierarchically: the first layer weights as `first.weight` and `first.bias`, and the second layer weights as `second.weight` and `second.bias`.
|
||||
|
||||
```rust
|
||||
impl Model {
|
||||
fn new(vs: VarBuilder) -> Result<Self> {
|
||||
const IMAGE_DIM: usize = 784;
|
||||
const HIDDEN_DIM: usize = 100;
|
||||
const LABELS: usize = 10;
|
||||
|
||||
let first = make_linear(vs.pp("first"), IMAGE_DIM, HIDDEN_DIM)?;
|
||||
let second = make_linear(vs.pp("second"), HIDDEN_DIM, LABELS)?;
|
||||
|
||||
Ok(Self { first, second })
|
||||
}
|
||||
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = self.first.forward(image)?;
|
||||
let x = x.relu()?;
|
||||
self.second.forward(&x)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Now, let's add the `candle-datasets` package to our project to access the MNIST dataset:
|
||||
|
||||
```bash
|
||||
$ cargo add --git https://github.com/huggingface/candle.git candle-datasets
|
||||
```
|
||||
|
||||
With the dataset available, we can implement our training loop:
|
||||
|
||||
```rust
|
||||
use candle_core::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{loss, ops, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||
|
||||
fn training_loop(
|
||||
m: candle_datasets::vision::Dataset,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images.to_device(&dev)?;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
// Initialize a VarMap to store trainable parameters
|
||||
let varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||
let model = Model::new(vs.clone())?;
|
||||
|
||||
let learning_rate = 0.05;
|
||||
let epochs = 10;
|
||||
|
||||
// Initialize a stochastic gradient descent optimizer to update parameters
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
for epoch in 1..epochs {
|
||||
// Perform forward pass on MNIST data
|
||||
let logits = model.forward(&train_images)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
|
||||
// Compute Negative Log Likelihood loss
|
||||
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||
|
||||
// Perform backward pass and update weights
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
// Evaluate model on test set
|
||||
let test_logits = model.forward(&test_images)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
test_accuracy
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Finally, let's implement our main function:
|
||||
|
||||
```rust
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let m = candle_datasets::vision::mnist::load()?;
|
||||
return training_loop(m);
|
||||
}
|
||||
```
|
||||
|
||||
Let's execute the training process:
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> 1 train loss: 2.35449 test acc: 0.12%
|
||||
> 2 train loss: 2.30760 test acc: 0.15%
|
||||
> ...
|
||||
```
|
68
candle-book/src/tracing.md
Normal file
68
candle-book/src/tracing.md
Normal file
@ -0,0 +1,68 @@
|
||||
# Tracing
|
||||
|
||||
Tracing is a powerful tool for identifying performance issues and bottlenecks in code.
|
||||
|
||||
> Profiling on GPUs is trickier due to asynchronous execution, see the [GPU section](#gpu).
|
||||
|
||||
## Overview
|
||||
|
||||
Candle uses the [tracing](https://docs.rs/tracing/latest/tracing/) crate for instrumentation.
|
||||
|
||||
To try it out, run an example in `candle-examples` with the `--tracing` flag.
|
||||
This generates a trace file, typically named `trace-<timestamp>.json`.
|
||||
You can view the trace in Chrome by navigating to `chrome://tracing/`, clicking **Load**, and selecting the generated trace file.
|
||||
|
||||
## Adding Tracing
|
||||
|
||||
Candle includes built-in tracing for many internal operations, using [spans](https://docs.rs/tracing/latest/tracing/struct.Span.html) to mark key points of execution.
|
||||
|
||||
To add custom tracing in your code, you can define a span like this:
|
||||
|
||||
```rust
|
||||
let span = tracing::span!(tracing::Level::TRACE, name);
|
||||
```
|
||||
|
||||
Then, to record the span during execution, create a guard:
|
||||
|
||||
```rust
|
||||
let _enter = span.enter();
|
||||
```
|
||||
|
||||
This guard will record the span's duration, from when it is created to when it is dropped, into a global data structure managed by the tracing crate.
|
||||
|
||||
## Recording and Saving a Trace
|
||||
|
||||
To capture and save trace data, you need to configure the tracing system with an output format. Candle uses the [tracing_subscriber](https://docs.rs/tracing-subscriber/latest/tracing_subscriber/) and [tracing_chrome](https://docs.rs/tracing-chrome/latest/tracing_chrome/) crates.
|
||||
|
||||
The snippet below sets up a Chrome compatible recorder that logs all tracing activity between creation and drop of the guard:
|
||||
|
||||
```rust
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let _guard = {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
guard
|
||||
};
|
||||
```
|
||||
|
||||
## GPU
|
||||
|
||||
When using CUDA, Metal, or other asynchronous GPU backends, tracing may produce misleading timing data because operations are queued rather than executed immediately.
|
||||
|
||||
### CUDA
|
||||
|
||||
For CUDA-specific profiling, you have two options:
|
||||
|
||||
1. Set the environment variable `CUDA_LAUNCH_BLOCKING=1` which forces synchronous execution. This makes trace timings more accurate, at the cost of reduced performance.
|
||||
2. Use [NVIDIA's Nsight Systems](https://developer.nvidia.com/nsight-systems) (`nsys profile` and `nsys-ui`) which are designed specifically for profiling asynchronous CUDA executions.
|
||||
|
||||
We recommend using NVIDIA's Nsight Systems when possible, as it offers accurate performance data without altering typical execution patterns. In contrast, setting the `CUDA_LAUNCH_BLOCKING` environment variable forces synchronous execution, which can significantly alter execution behavior.
|
||||
|
||||
#### Performance Profiling with NVIDIA Nsight Systems
|
||||
|
||||
1. Generate an `.nsys-rep` file containing performance data ([docs](https://docs.nvidia.com/nsight-systems/UserGuide/index.html#example-single-command-lines))
|
||||
- Run `nsys profile --trace cuda,nvtx,osrt --gpu-metrics-device=all --output profile_run ./target/debug/... --prompt "whatever "`
|
||||
1. Open the generated `.nsys-rep` report file in Nsight Systems GUI
|
||||
- File > Open
|
@ -56,3 +56,7 @@ harness = false
|
||||
[[example]]
|
||||
name = "metal_basics"
|
||||
required-features = ["metal"]
|
||||
|
||||
[[example]]
|
||||
name = "cuda_basics"
|
||||
required-features = ["cuda"]
|
||||
|
@ -4,11 +4,12 @@ use criterion::criterion_main;
|
||||
|
||||
criterion_main!(
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::copy::benches,
|
||||
benchmarks::conv_transpose2d::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::qmatmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::reduce::benches,
|
||||
benchmarks::unary::benches,
|
||||
benchmarks::where_cond::benches,
|
||||
benchmarks::conv_transpose2d::benches,
|
||||
benchmarks::qmatmul::benches,
|
||||
benchmarks::unary::benches
|
||||
);
|
||||
|
38
candle-core/benches/benchmarks/copy.rs
Normal file
38
candle-core/benches/benchmarks/copy.rs
Normal file
@ -0,0 +1,38 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{Device, Tensor, WithDType};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run_copy_mask_benchmark<D: WithDType>(c: &mut Criterion, device: &Device, name: &str) {
|
||||
let batch_size = 128;
|
||||
let in_seq_len = 1;
|
||||
let kv_seq_len = 1024;
|
||||
|
||||
let attn_mask = vec![vec![vec![D::zero(); kv_seq_len]; in_seq_len]; batch_size];
|
||||
let size_in_bytes = batch_size * in_seq_len * kv_seq_len * D::DTYPE.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(size_in_bytes as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let attn_masks = vec![attn_mask.clone(); iters as usize];
|
||||
let start = Instant::now();
|
||||
for attn_mask in attn_masks.into_iter() {
|
||||
let tensor = Tensor::new(black_box(attn_mask), device).unwrap();
|
||||
black_box(tensor);
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_copy_mask_benchmark::<f32>(c, &device, "copy_mask");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,5 +1,6 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod conv_transpose2d;
|
||||
pub(crate) mod copy;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod qmatmul;
|
||||
pub(crate) mod random;
|
||||
|
@ -6,28 +6,18 @@ extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
// xs: [1024, 64, 1924], c Tensor[dims 128, 64, 8; f32, cuda:0] Conv1dConfig { padding: 0, stride: 4, dilation: 1, groups: 1 }
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.matmul(&x)?;
|
||||
device.synchronize()?;
|
||||
println!("fp32: {:?}", start_time.elapsed());
|
||||
drop(_x1);
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
||||
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.matmul(&x)?;
|
||||
device.synchronize()?;
|
||||
println!("tf32: {:?}", start_time.elapsed());
|
||||
let x = Tensor::randn(0f32, 1.0, (1024, 64, 1924), &device)?;
|
||||
let c = Tensor::randn(0f32, 1.0, (128, 64, 8), &device)?;
|
||||
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
|
||||
drop(_x1);
|
||||
for _ in 0..20 {
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
|
||||
device.synchronize()?;
|
||||
println!("conv1d: {:?}", start_time.elapsed());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -71,15 +71,27 @@ pub trait BackendStorage: Sized {
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
|
||||
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn scatter_add(
|
||||
&self,
|
||||
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self>;
|
||||
) -> Result<()>;
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<()>;
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn index_add(
|
||||
&self,
|
||||
@ -113,6 +125,8 @@ pub trait BackendStorage: Sized {
|
||||
_src_offset: usize,
|
||||
_dst_offset: usize,
|
||||
) -> Result<()>;
|
||||
|
||||
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()>;
|
||||
}
|
||||
|
||||
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
@ -127,8 +141,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
|
||||
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
/// # Safety
|
||||
/// This function is unsafe as it doesn't initialize the underlying data store.
|
||||
/// The caller should ensure that the data is properly initialized as early as possible
|
||||
|
@ -53,6 +53,7 @@ impl Tensor {
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::IndexAdd(t1, t2, t3, _)
|
||||
| Op::Scatter(t1, t2, t3, _)
|
||||
| Op::ScatterAdd(t1, t2, t3, _)
|
||||
| Op::CustomOp3(t1, t2, t3, _)
|
||||
| Op::WhereCond(t1, t2, t3) => {
|
||||
@ -419,7 +420,7 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||
}
|
||||
Op::ScatterAdd(init, indexes, src, dim) => {
|
||||
Op::Scatter(init, indexes, src, dim) => {
|
||||
let init_sum_grad = grads.or_insert(init)?;
|
||||
*init_sum_grad = init_sum_grad.add(&grad)?;
|
||||
|
||||
@ -427,6 +428,16 @@ impl Tensor {
|
||||
let src_sum_grad = grads.or_insert(src)?;
|
||||
*src_sum_grad = src_sum_grad.add(&src_grad)?;
|
||||
}
|
||||
Op::ScatterAdd(init, indexes, src, dim) => {
|
||||
let init_sum_grad = grads.or_insert(init)?;
|
||||
let mask = init.ones_like()?;
|
||||
let mask = mask.scatter(indexes, &mask.zeros_like()?, *dim)?;
|
||||
*init_sum_grad = init_sum_grad.add(&grad.mul(&mask)?)?;
|
||||
|
||||
let src_grad = grad.gather(indexes, *dim)?;
|
||||
let src_sum_grad = grads.or_insert(src)?;
|
||||
*src_sum_grad = src_sum_grad.add(&src_grad)?;
|
||||
}
|
||||
Op::IndexAdd(init, indexes, src, dim) => {
|
||||
let init_sum_grad = grads.or_insert(init)?;
|
||||
*init_sum_grad = init_sum_grad.add(&grad)?;
|
||||
|
@ -14,6 +14,7 @@ pub struct ParamsConv1D {
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
}
|
||||
|
||||
impl ParamsConv1D {
|
||||
@ -54,7 +55,7 @@ impl ParamsConvTranspose1D {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum CudnnFwdAlgo {
|
||||
ImplicitGemm,
|
||||
ImplicitPrecompGemm,
|
||||
@ -151,6 +152,19 @@ impl Tensor {
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
self.conv1d_with_algo(kernel, padding, stride, dilation, groups, None)
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d_with_algo(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
@ -174,6 +188,7 @@ impl Tensor {
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
cudnn_fwd_algo,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv1d_single_group(kernel, ¶ms)
|
||||
@ -278,6 +293,18 @@ impl Tensor {
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
self.conv2d_with_algo(kernel, padding, stride, dilation, groups, None)
|
||||
}
|
||||
|
||||
pub fn conv2d_with_algo(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
||||
@ -297,7 +324,7 @@ impl Tensor {
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
cudnn_fwd_algo: None,
|
||||
cudnn_fwd_algo,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv2d_single_group(kernel, ¶ms)
|
||||
|
@ -7,7 +7,7 @@ use rayon::prelude::*;
|
||||
|
||||
mod utils;
|
||||
pub use utils::{
|
||||
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
|
||||
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
|
||||
};
|
||||
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
@ -483,17 +483,22 @@ impl<I: IntDType> Map1 for Gather<'_, I> {
|
||||
let start_dst_idx = start_dst_idx + i * dst_right_len;
|
||||
for right_i in 0..dst_right_len {
|
||||
let dst_idx = start_dst_idx + right_i;
|
||||
let index = ids[dst_idx].as_usize();
|
||||
if index >= src_dim_len {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: src_dim_len,
|
||||
op: "gather",
|
||||
let index = ids[dst_idx];
|
||||
if index == I::max_value() {
|
||||
dst[dst_idx] = T::zero();
|
||||
} else {
|
||||
let index = index.as_usize();
|
||||
if index >= src_dim_len {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: src_dim_len,
|
||||
op: "gather",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
.bt())?
|
||||
let src_idx = start_src_idx + index * src_right_len + right_i;
|
||||
dst[dst_idx] = src[src_idx]
|
||||
}
|
||||
let src_idx = start_src_idx + index * src_right_len + right_i;
|
||||
dst[dst_idx] = src[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -535,45 +540,89 @@ impl<I: IntDType> Map1 for IndexSelect<'_, I> {
|
||||
let start_src_idx = left_i * right_len * src_dim;
|
||||
let start_dst_idx = left_i * right_len * n_ids;
|
||||
for i in 0..n_ids {
|
||||
let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
|
||||
if index >= src_dim {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: src_dim,
|
||||
op: "index-select",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let start_src_idx = start_src_idx + index * right_len;
|
||||
let start_dst_idx = start_dst_idx + i * right_len;
|
||||
dst[start_dst_idx..start_dst_idx + right_len]
|
||||
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
|
||||
let index = self.ids[self.ids_l.start_offset() + stride_ids * i];
|
||||
if index == I::max_value() {
|
||||
dst[start_dst_idx..start_dst_idx + right_len].fill(T::zero());
|
||||
} else {
|
||||
let index = index.as_usize();
|
||||
if index >= src_dim {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: src_dim,
|
||||
op: "index-select",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let start_src_idx = start_src_idx + index * right_len;
|
||||
dst[start_dst_idx..start_dst_idx + right_len]
|
||||
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct ScatterAdd<'a, I: IntDType> {
|
||||
trait ElemUpdate {
|
||||
fn f<T: WithDType>(dst: &mut T, src: T);
|
||||
}
|
||||
|
||||
struct Set;
|
||||
struct Add;
|
||||
|
||||
impl ElemUpdate for Set {
|
||||
fn f<T: WithDType>(dst: &mut T, src: T) {
|
||||
*dst = src
|
||||
}
|
||||
}
|
||||
|
||||
impl ElemUpdate for Add {
|
||||
fn f<T: WithDType>(dst: &mut T, src: T) {
|
||||
*dst += src
|
||||
}
|
||||
}
|
||||
|
||||
struct Scatter<'a, I: IntDType, M: ElemUpdate> {
|
||||
ids: &'a [I],
|
||||
ids_l: &'a Layout,
|
||||
dim: usize,
|
||||
_phantom: std::marker::PhantomData<M>,
|
||||
}
|
||||
|
||||
impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
|
||||
const OP: &'static str = "scatter-add";
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let dst_len = l1.shape().elem_count();
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> {
|
||||
fn new(ids: &'a [I], ids_l: &'a Layout, dim: usize) -> Self {
|
||||
Self {
|
||||
ids,
|
||||
ids_l,
|
||||
dim,
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
|
||||
const OP: &'static str = "scatter";
|
||||
fn f<T: WithDType>(
|
||||
&self,
|
||||
dst: &mut [T],
|
||||
dst_l: &Layout,
|
||||
src: &[T],
|
||||
src_l: &Layout,
|
||||
) -> Result<()> {
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
Some((o1, o2)) => &mut dst[o1..o2],
|
||||
};
|
||||
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
|
||||
let dim = self.dim;
|
||||
let ids_dims = self.ids_l.dims();
|
||||
let dst_dims = l1.dims();
|
||||
let dst_dims = dst_l.dims();
|
||||
let dst_dim_len = dst_dims[dim];
|
||||
let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
|
||||
|
||||
@ -592,7 +641,11 @@ impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
|
||||
let start_ids_idx = start_ids_idx + i * ids_right_len;
|
||||
for right_i in 0..dst_right_len {
|
||||
let ids_idx = start_ids_idx + right_i;
|
||||
let index = ids[ids_idx].as_usize();
|
||||
let index = ids[ids_idx];
|
||||
if index == I::max_value() {
|
||||
continue;
|
||||
}
|
||||
let index = index.as_usize();
|
||||
if index >= dst_dim_len {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
@ -602,12 +655,12 @@ impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
|
||||
.bt())?
|
||||
}
|
||||
let dst_idx = start_dst_idx + index * dst_right_len + right_i;
|
||||
dst[dst_idx] += src[ids_idx]
|
||||
M::f(&mut dst[dst_idx], src[ids_idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(dst)
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@ -635,6 +688,9 @@ impl<I: IntDType> Map2 for IndexAdd<'_, I> {
|
||||
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
|
||||
if dim == 0 {
|
||||
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
||||
if *dst_idx == I::max_value() {
|
||||
continue;
|
||||
}
|
||||
let dst_idx = dst_idx.as_usize();
|
||||
if dst_idx >= max_idx {
|
||||
Err(Error::InvalidIndex {
|
||||
@ -653,6 +709,9 @@ impl<I: IntDType> Map2 for IndexAdd<'_, I> {
|
||||
}
|
||||
} else {
|
||||
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
||||
if *dst_idx == I::max_value() {
|
||||
continue;
|
||||
}
|
||||
let dst_idx = dst_idx.as_usize();
|
||||
if dst_idx >= max_idx {
|
||||
Err(Error::InvalidIndex {
|
||||
@ -1289,6 +1348,15 @@ impl Map2 for MatMul {
|
||||
} else {
|
||||
Parallelism::None
|
||||
};
|
||||
let (b, m, n, k) = if b_skip == 0 && a_skip == m * k {
|
||||
// a_skip and c_skip should be updated but step is always 0 so
|
||||
// it wouldn't matter.
|
||||
(1, b * m, n, k)
|
||||
} else if a_skip == 0 && b_skip == n * k {
|
||||
(1, m, b * n, k)
|
||||
} else {
|
||||
(b, m, n, k)
|
||||
};
|
||||
for step in 0..b {
|
||||
let lhs_p = &lhs[step * a_skip..];
|
||||
let rhs_p = &rhs[step * b_skip..];
|
||||
@ -2372,19 +2440,36 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
match ids {
|
||||
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::I64(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter").bt()),
|
||||
}
|
||||
}
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<()> {
|
||||
match ids {
|
||||
Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
|
||||
}
|
||||
}
|
||||
@ -2445,6 +2530,48 @@ impl BackendStorage for CpuStorage {
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Ok(self.clone())
|
||||
}
|
||||
|
||||
fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
|
||||
use crate::scalar::Scalar;
|
||||
fn set<T: crate::WithDType>(src: &mut [T], l: &Layout, s: T) {
|
||||
match l.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
src[start_offset..start_offset + len].fill(s)
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len: 1,
|
||||
} => {
|
||||
for src_index in block_start_index {
|
||||
src[src_index] = s
|
||||
}
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
for src_index in block_start_index {
|
||||
src[src_index..src_index + block_len].fill(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
match (self, s) {
|
||||
(Self::BF16(storage), Scalar::BF16(v)) => set(storage, l, v),
|
||||
(Self::F16(storage), Scalar::F16(v)) => set(storage, l, v),
|
||||
(Self::F32(storage), Scalar::F32(v)) => set(storage, l, v),
|
||||
(Self::F64(storage), Scalar::F64(v)) => set(storage, l, v),
|
||||
(Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
|
||||
(Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
|
||||
(Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
|
||||
(st, s) => crate::bail!(
|
||||
"const_set dtype mismatch, expected {:?} but got {:?}",
|
||||
st.dtype(),
|
||||
s
|
||||
),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for CpuDevice {
|
||||
@ -2619,20 +2746,6 @@ impl BackendDevice for CpuDevice {
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let storage = match dtype {
|
||||
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
|
||||
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
|
||||
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
|
||||
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
|
||||
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
|
||||
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
|
||||
DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
|
||||
};
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let storage = match dtype {
|
||||
|
@ -58,6 +58,30 @@ pub trait Map2 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2InPlace {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &mut [T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<()>;
|
||||
|
||||
fn map(&self, v1: &mut C, l1: &Layout, v2: &C, l2: &Layout) -> Result<()> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(v1, v2) => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}
|
||||
.bt())?,
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2U8 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||
|
@ -122,3 +122,104 @@ pub(crate) fn launch_conv2d<
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn launch_conv1d<
|
||||
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
|
||||
Y: cudarc::cudnn::CudnnDataType,
|
||||
>(
|
||||
src: &CudaView<T>,
|
||||
src_l: &crate::Layout,
|
||||
filter: &CudaView<T>,
|
||||
dst: &mut CudaSlice<T>,
|
||||
params: &crate::conv::ParamsConv1D,
|
||||
dev: &crate::cuda_backend::CudaDevice,
|
||||
) -> crate::Result<()> {
|
||||
use crate::conv::CudnnFwdAlgo as CandleAlgo;
|
||||
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
|
||||
|
||||
let device_id = dev.id();
|
||||
let cudnn = CUDNN.with(|cudnn| {
|
||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||
return Ok(cudnn.clone());
|
||||
}
|
||||
let c = Cudnn::new(dev.cuda_stream());
|
||||
if let Ok(c) = &c {
|
||||
cudnn.borrow_mut().insert(device_id, c.clone());
|
||||
}
|
||||
c
|
||||
})?;
|
||||
let conv = cudnn.create_conv2d::<Y>(
|
||||
/* pad */ [params.padding as i32, 0],
|
||||
/* stride */ [params.stride as i32, 1],
|
||||
/* dilation */ [params.dilation as i32, 1],
|
||||
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
|
||||
)?;
|
||||
// https://docs.nvidia.com/deeplearning/cudnn/backend/latest/api/cudnn-ops-library.html#cudnnsettensornddescriptor
|
||||
// > Tensors are restricted to having at least 4 dimensions, and at most CUDNN_DIM_MAX
|
||||
// > dimensions (defined in cudnn.h). When working with lower dimensional data, it is
|
||||
// > recommended that the user create a 4D tensor, and set the size along unused dimensions
|
||||
// > to 1.
|
||||
let x_shape = [
|
||||
params.b_size as i32,
|
||||
params.c_in as i32,
|
||||
params.l_in as i32,
|
||||
1,
|
||||
];
|
||||
// Note that `src` already starts at the proper offset.
|
||||
let x = if src_l.is_contiguous() {
|
||||
cudnn.create_4d_tensor::<T>(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
x_shape,
|
||||
)?
|
||||
} else {
|
||||
let s = src_l.stride();
|
||||
cudnn.create_4d_tensor_ex::<T>(x_shape, [s[0] as i32, s[1] as i32, s[2] as i32, 1i32])?
|
||||
};
|
||||
let w = cudnn.create_4d_filter::<T>(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[
|
||||
params.c_out as i32,
|
||||
params.c_in as i32,
|
||||
params.k_size as i32,
|
||||
1,
|
||||
],
|
||||
)?;
|
||||
let l_out = params.l_out() as i32;
|
||||
let y = cudnn.create_4d_tensor::<T>(
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[params.b_size as i32, params.c_out as i32, l_out, 1],
|
||||
)?;
|
||||
let conv1d = ConvForward {
|
||||
conv: &conv,
|
||||
x: &x,
|
||||
w: &w,
|
||||
y: &y,
|
||||
};
|
||||
let alg = match params.cudnn_fwd_algo {
|
||||
None => conv1d.pick_algorithm()?,
|
||||
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
Some(CandleAlgo::ImplicitPrecompGemm) => {
|
||||
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
|
||||
}
|
||||
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||
};
|
||||
let workspace_size = conv1d.get_workspace_size(alg)?;
|
||||
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
|
||||
unsafe {
|
||||
conv1d.launch::<CudaSlice<u8>, _, _, _>(
|
||||
alg,
|
||||
Some(&mut workspace),
|
||||
(T::one(), T::zero()),
|
||||
src,
|
||||
filter,
|
||||
dst,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ use crate::backend::BackendDevice;
|
||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
|
||||
use cudarc::driver::CudaFunction;
|
||||
use half::{bf16, f16};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
@ -46,11 +46,61 @@ impl std::fmt::Debug for CudaDevice {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CudaDevice {
|
||||
type Target = Arc<cudarc::driver::CudaStream>;
|
||||
impl CudaDevice {
|
||||
#[allow(clippy::missing_safety_doc)]
|
||||
pub unsafe fn alloc<T: cudarc::driver::DeviceRepr>(
|
||||
&self,
|
||||
len: usize,
|
||||
) -> Result<cudarc::driver::CudaSlice<T>> {
|
||||
self.stream.alloc::<T>(len).w()
|
||||
}
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.stream
|
||||
pub fn alloc_zeros<T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits>(
|
||||
&self,
|
||||
len: usize,
|
||||
) -> Result<cudarc::driver::CudaSlice<T>> {
|
||||
self.stream.alloc_zeros::<T>(len).w()
|
||||
}
|
||||
|
||||
pub fn memcpy_htod<
|
||||
T: cudarc::driver::DeviceRepr,
|
||||
Src: cudarc::driver::HostSlice<T> + ?Sized,
|
||||
Dst: cudarc::driver::DevicePtrMut<T>,
|
||||
>(
|
||||
&self,
|
||||
src: &Src,
|
||||
dst: &mut Dst,
|
||||
) -> Result<()> {
|
||||
self.stream.memcpy_htod(src, dst).w()
|
||||
}
|
||||
|
||||
pub fn memcpy_dtov<T: cudarc::driver::DeviceRepr, Src: cudarc::driver::DevicePtr<T>>(
|
||||
&self,
|
||||
src: &Src,
|
||||
) -> Result<Vec<T>> {
|
||||
self.stream.memcpy_dtov(src).w()
|
||||
}
|
||||
|
||||
pub fn memcpy_dtod<
|
||||
T,
|
||||
Src: cudarc::driver::DevicePtr<T>,
|
||||
Dst: cudarc::driver::DevicePtrMut<T>,
|
||||
>(
|
||||
&self,
|
||||
src: &Src,
|
||||
dst: &mut Dst,
|
||||
) -> Result<()> {
|
||||
self.stream.memcpy_dtod(src, dst).w()
|
||||
}
|
||||
|
||||
pub fn memcpy_stod<
|
||||
T: cudarc::driver::DeviceRepr,
|
||||
Src: cudarc::driver::HostSlice<T> + ?Sized,
|
||||
>(
|
||||
&self,
|
||||
src: &Src,
|
||||
) -> Result<cudarc::driver::CudaSlice<T>> {
|
||||
self.stream.memcpy_stod(src).w()
|
||||
}
|
||||
}
|
||||
|
||||
@ -94,6 +144,24 @@ impl CudaDevice {
|
||||
self.stream.clone()
|
||||
}
|
||||
|
||||
/// When turned on, all cuda tensors **created after calling this function** will
|
||||
/// not track uses via cuda events.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// It is up to the user to ensure proper synchronization between multiple streams:
|
||||
/// - Ensure that no tensor is freed before a use on another stream is finished.
|
||||
/// - Ensure that a tensor is not used on another stream before allocation on the
|
||||
/// allocating stream finishes.
|
||||
/// - Ensure that a tensor is not written two concurrently by multiple streams.
|
||||
pub unsafe fn disable_event_tracking(&self) {
|
||||
self.context.disable_event_tracking()
|
||||
}
|
||||
|
||||
pub fn is_event_tracking(&self) -> bool {
|
||||
self.context.is_event_tracking()
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn compile(
|
||||
&self,
|
||||
@ -120,100 +188,6 @@ impl CudaDevice {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u8;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as i64;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = bf16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = f16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as f32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_custom_func(
|
||||
&self,
|
||||
fn_name: &str,
|
||||
@ -325,31 +299,31 @@ impl BackendDevice for CudaDevice {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
let data = self.alloc_zeros::<u8>(elem_count).w()?;
|
||||
let data = self.alloc_zeros::<u8>(elem_count)?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
||||
let data = self.alloc_zeros::<u32>(elem_count)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
||||
let data = self.alloc_zeros::<i64>(elem_count)?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
||||
let data = self.alloc_zeros::<bf16>(elem_count)?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = self.alloc_zeros::<f16>(elem_count).w()?;
|
||||
let data = self.alloc_zeros::<f16>(elem_count)?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = self.alloc_zeros::<f32>(elem_count).w()?;
|
||||
let data = self.alloc_zeros::<f32>(elem_count)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let data = self.alloc_zeros::<f64>(elem_count).w()?;
|
||||
let data = self.alloc_zeros::<f64>(elem_count)?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
@ -373,12 +347,12 @@ impl BackendDevice for CudaDevice {
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count)? };
|
||||
curand.0.fill_with_uniform(&mut data).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count)? };
|
||||
curand.0.fill_with_uniform(&mut data).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
@ -417,7 +391,7 @@ impl BackendDevice for CudaDevice {
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count_round)? };
|
||||
curand
|
||||
.0
|
||||
.fill_with_normal(&mut data, mean as f32, std as f32)
|
||||
@ -425,7 +399,7 @@ impl BackendDevice for CudaDevice {
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count_round)? };
|
||||
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
@ -436,39 +410,35 @@ impl BackendDevice for CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
self.const_impl(1., shape, dtype)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
let data = self.alloc::<u8>(elem_count).w()?;
|
||||
let data = self.alloc::<u8>(elem_count)?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
let data = self.alloc::<u32>(elem_count).w()?;
|
||||
let data = self.alloc::<u32>(elem_count)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
let data = self.alloc::<i64>(elem_count).w()?;
|
||||
let data = self.alloc::<i64>(elem_count)?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc::<bf16>(elem_count).w()?;
|
||||
let data = self.alloc::<bf16>(elem_count)?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = self.alloc::<f16>(elem_count).w()?;
|
||||
let data = self.alloc::<f16>(elem_count)?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = self.alloc::<f32>(elem_count).w()?;
|
||||
let data = self.alloc::<f32>(elem_count)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let data = self.alloc::<f64>(elem_count).w()?;
|
||||
let data = self.alloc::<f64>(elem_count)?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
@ -481,31 +451,31 @@ impl BackendDevice for CudaDevice {
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
let slice = match T::cpu_storage_ref(s) {
|
||||
CpuStorageRef::U8(storage) => {
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorageRef::U32(storage) => {
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorageRef::I64(storage) => {
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorageRef::BF16(storage) => {
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorageRef::F16(storage) => {
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorageRef::F32(storage) => {
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorageRef::F64(storage) => {
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
@ -518,31 +488,31 @@ impl BackendDevice for CudaDevice {
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
let data = self.memcpy_stod(storage)?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
@ -555,31 +525,31 @@ impl BackendDevice for CudaDevice {
|
||||
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
let data = self.memcpy_stod(&storage)?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
let data = self.memcpy_stod(&storage)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
let data = self.memcpy_stod(&storage)?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
let data = self.memcpy_stod(&storage)?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
let data = self.memcpy_stod(&storage)?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
let data = self.memcpy_stod(&storage)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
let data = self.memcpy_stod(&storage)?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
|
@ -2,7 +2,7 @@
|
||||
//!
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, WithDType};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
@ -34,12 +34,27 @@ impl<T: DeviceRepr> SlicePtrOrNull<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::scalar::Scalar {
|
||||
pub fn builder_arg<'a, 'b: 'a>(&'b self, builder: &mut cudarc::driver::LaunchArgs<'a>) {
|
||||
use crate::scalar::Scalar;
|
||||
match self {
|
||||
Scalar::U8(v) => builder.arg(v),
|
||||
Scalar::U32(v) => builder.arg(v),
|
||||
Scalar::I64(v) => builder.arg(v),
|
||||
Scalar::F32(v) => builder.arg(v),
|
||||
Scalar::F64(v) => builder.arg(v),
|
||||
Scalar::F16(v) => builder.arg(v),
|
||||
Scalar::BF16(v) => builder.arg(v),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
impl SlicePtrOrNull<usize> {
|
||||
pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
||||
let ds = if l.is_contiguous() {
|
||||
SlicePtrOrNull::Null
|
||||
} else {
|
||||
SlicePtrOrNull::Ptr(dev.memcpy_stod(&[l.dims(), l.stride()].concat()).w()?)
|
||||
SlicePtrOrNull::Ptr(dev.memcpy_stod(&[l.dims(), l.stride()].concat())?)
|
||||
};
|
||||
Ok(ds)
|
||||
}
|
||||
@ -89,7 +104,7 @@ impl Map1 for Affine {
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), &kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(el)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el);
|
||||
barg!(builder, dims.len());
|
||||
@ -120,7 +135,7 @@ impl Map1 for Elu {
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), &kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(el)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el);
|
||||
barg!(builder, dims.len());
|
||||
@ -134,6 +149,7 @@ impl Map1 for Elu {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
struct Im2Col1D {
|
||||
l_k: usize,
|
||||
stride: usize,
|
||||
@ -142,6 +158,7 @@ struct Im2Col1D {
|
||||
}
|
||||
|
||||
impl Im2Col1D {
|
||||
#[allow(unused)]
|
||||
fn l_out(&self, l: usize) -> usize {
|
||||
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
|
||||
}
|
||||
@ -157,15 +174,15 @@ impl Map1 for Im2Col1D {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let l_out = self.l_out(dims[2]);
|
||||
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat()).w()?;
|
||||
let threads = dims[0] * l_out * dims[1];
|
||||
let cfg = LaunchConfig::for_num_elems(threads as u32);
|
||||
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), &kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let dst = unsafe { dev.alloc::<T>(threads * self.l_k)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, dst_el);
|
||||
barg!(builder, threads);
|
||||
barg!(builder, l_out);
|
||||
barg!(builder, self.l_k);
|
||||
barg!(builder, self.stride);
|
||||
@ -210,11 +227,11 @@ impl Map1 for Im2Col {
|
||||
let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
|
||||
let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat()).w()?;
|
||||
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), &kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let dst = unsafe { dev.alloc::<T>(dst_el)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, dst_el);
|
||||
barg!(builder, h_out);
|
||||
@ -249,7 +266,7 @@ impl Map1 for Powf {
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), &kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(el)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el);
|
||||
barg!(builder, dims.len());
|
||||
@ -302,9 +319,7 @@ impl Map1Any for FastReduce<'_> {
|
||||
block_dim: (block_dim as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let ds = dev
|
||||
.memcpy_stod(&[dims.as_slice(), stride.as_slice()].concat())
|
||||
.w()?;
|
||||
let ds = dev.memcpy_stod(&[dims.as_slice(), stride.as_slice()].concat())?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let (name, check_empty, return_index) = match self.1 {
|
||||
ReduceOp::Sum => ("fast_sum", false, false),
|
||||
@ -319,7 +334,7 @@ impl Map1Any for FastReduce<'_> {
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::REDUCE)?;
|
||||
if return_index {
|
||||
// SAFETY: filled in by the follow up kernel.
|
||||
let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<u32>(dst_el)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, src_el);
|
||||
barg!(builder, el_to_sum_per_block);
|
||||
@ -332,7 +347,7 @@ impl Map1Any for FastReduce<'_> {
|
||||
Ok(S::U32(out))
|
||||
} else {
|
||||
// SAFETY: filled in by the follow up kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(dst_el)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, src_el);
|
||||
barg!(builder, el_to_sum_per_block);
|
||||
@ -362,7 +377,7 @@ impl<U: UnaryOpT> Map1 for U {
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), &kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let mut out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
||||
let mut out = unsafe { dev.alloc::<T>(el_count)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el_count);
|
||||
barg!(builder, dims.len());
|
||||
@ -395,7 +410,7 @@ impl Map1 for IndexSelect<'_> {
|
||||
CudaStorageSlice::U8(slice) => ("is_u8", slice_ptr(slice, ids_l.start_offset())),
|
||||
CudaStorageSlice::I64(slice) => ("is_i64", slice_ptr(slice, ids_l.start_offset())),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "index_select ids should be u8 or u32",
|
||||
msg: "index_select ids should be u8, u32, or i64",
|
||||
expected: DType::U32,
|
||||
got: self.0.dtype(),
|
||||
})
|
||||
@ -403,7 +418,7 @@ impl Map1 for IndexSelect<'_> {
|
||||
};
|
||||
let ids_shape = ids_l.shape();
|
||||
let ids_dims = ids_shape.dims();
|
||||
let ds = dev.memcpy_stod(&[ids_dims, ids_l.stride()].concat()).w()?;
|
||||
let ds = dev.memcpy_stod(&[ids_dims, ids_l.stride()].concat())?;
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
|
||||
@ -416,7 +431,7 @@ impl Map1 for IndexSelect<'_> {
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(dst_el)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, dst_el);
|
||||
barg!(builder, ids_dims.len());
|
||||
@ -471,7 +486,7 @@ impl Map1 for Gather<'_> {
|
||||
let ids_dim_sz = ids_l.dims()[dim];
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(el)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el);
|
||||
barg!(builder, ids);
|
||||
@ -492,7 +507,7 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -514,6 +529,10 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
@ -521,7 +540,7 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_shape.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let ids_dim_sz = ids_l.dims()[0];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
@ -529,7 +548,59 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
barg!(builder, ids);
|
||||
barg!(builder, ids_dim_sz);
|
||||
builder.arg(&src);
|
||||
builder.arg(dst);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct Scatter<'a>(&'a CudaStorage, &'a Layout, usize);
|
||||
impl Map2InPlace for Scatter<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
let ids = &self.0;
|
||||
let ids_l = &self.1;
|
||||
let dim = self.2;
|
||||
let (ids_o1, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
};
|
||||
let (name, (ids, _guard)) = match &ids.slice {
|
||||
CudaStorageSlice::U32(slice) => ("s_u32", slice_ptr(slice, ids_o1)),
|
||||
CudaStorageSlice::I64(slice) => ("s_i64", slice_ptr(slice, ids_o1)),
|
||||
CudaStorageSlice::U8(slice) => ("s_u8", slice_ptr(slice, ids_o1)),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "scatter ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
};
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, ids);
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
@ -542,7 +613,7 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -564,6 +635,10 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
@ -571,13 +646,13 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_shape.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, ids);
|
||||
builder.arg(&src);
|
||||
builder.arg(dst);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
@ -608,7 +683,7 @@ impl Map2 for Conv1D<'_> {
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), &kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(dst_el)? };
|
||||
let ds = if dims.len() == 3 {
|
||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else if dims.len() == 2 {
|
||||
@ -616,7 +691,7 @@ impl Map2 for Conv1D<'_> {
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let ds = dev.memcpy_stod(&ds).w()?;
|
||||
let ds = dev.memcpy_stod(&ds)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el, l_out, p.stride, p.padding, p.dilation);
|
||||
builder.arg(&ds);
|
||||
@ -651,7 +726,7 @@ impl Map2 for Conv2D<'_> {
|
||||
let el = shape.elem_count();
|
||||
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(dst_el)? };
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv2d"), &kernels::CONV)?;
|
||||
let ds = if dims.len() == 4 {
|
||||
@ -659,7 +734,7 @@ impl Map2 for Conv2D<'_> {
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv2d {dims:?}")
|
||||
};
|
||||
let ds = dev.memcpy_stod(&ds).w()?;
|
||||
let ds = dev.memcpy_stod(&ds)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el, out_w, out_h, p.stride, p.padding, p.dilation);
|
||||
builder.arg(&ds);
|
||||
@ -687,7 +762,7 @@ impl Map1 for Col2Im1D {
|
||||
let stride = self.stride;
|
||||
let l_out = (l_in - 1) * stride + k_size;
|
||||
let dst_el = b_size * c_out * l_out;
|
||||
let mut im = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let mut im = unsafe { dev.alloc::<T>(dst_el)? };
|
||||
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), &kernels::CONV)?;
|
||||
@ -722,7 +797,7 @@ impl Map2 for ConvTranspose1D<'_> {
|
||||
let el = shape.elem_count();
|
||||
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(dst_el)? };
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), &kernels::CONV)?;
|
||||
let ds = if dims.len() == 3 {
|
||||
@ -730,7 +805,7 @@ impl Map2 for ConvTranspose1D<'_> {
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
|
||||
};
|
||||
let ds = dev.memcpy_stod(&ds).w()?;
|
||||
let ds = dev.memcpy_stod(&ds)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el);
|
||||
barg!(builder, l_out);
|
||||
@ -770,7 +845,7 @@ impl Map2 for ConvTranspose2D<'_> {
|
||||
let el = shape.elem_count();
|
||||
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(dst_el)? };
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), &kernels::CONV)?;
|
||||
let ds = if dims.len() == 4 {
|
||||
@ -778,7 +853,7 @@ impl Map2 for ConvTranspose2D<'_> {
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}")
|
||||
};
|
||||
let ds = dev.memcpy_stod(&ds).w()?;
|
||||
let ds = dev.memcpy_stod(&ds)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el);
|
||||
barg!(builder, out_w);
|
||||
@ -837,8 +912,8 @@ impl Map1 for Pool2D {
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(kname), &kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let ds = dev.memcpy_stod(&ds).w()?;
|
||||
let out = unsafe { dev.alloc::<T>(dst_el)? };
|
||||
let ds = dev.memcpy_stod(&ds)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el);
|
||||
barg!(builder, self.w_k);
|
||||
@ -876,8 +951,8 @@ impl Map1 for UpsampleNearest2D {
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("upsample_nearest2d"), &kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let ds = dev.memcpy_stod(&ds).w()?;
|
||||
let out = unsafe { dev.alloc::<T>(dst_el)? };
|
||||
let ds = dev.memcpy_stod(&ds)?;
|
||||
let scale_w = dims[2] as f64 / out_w as f64;
|
||||
let scale_h = dims[3] as f64 / out_h as f64;
|
||||
let mut builder = func.builder();
|
||||
@ -930,13 +1005,12 @@ impl Map2 for WhereCond<'_> {
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev
|
||||
.memcpy_stod(&[dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())
|
||||
.w()?;
|
||||
.memcpy_stod(&[dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?;
|
||||
let t = &t.slice(layout_t.start_offset()..);
|
||||
let f = &f.slice(layout_f.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::TERNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(el)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el);
|
||||
barg!(builder, dims.len());
|
||||
@ -967,16 +1041,13 @@ impl<U: crate::op::BinaryOpT> Map2 for U {
|
||||
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
||||
SlicePtrOrNull::Null
|
||||
} else {
|
||||
SlicePtrOrNull::Ptr(
|
||||
dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||
.w()?,
|
||||
)
|
||||
SlicePtrOrNull::Ptr(dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())?)
|
||||
};
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), &kernels::BINARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(elem_count)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, elem_count);
|
||||
barg!(builder, dims.len());
|
||||
@ -1007,10 +1078,7 @@ impl Map2Any for Cmp {
|
||||
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
||||
SlicePtrOrNull::Null
|
||||
} else {
|
||||
SlicePtrOrNull::Ptr(
|
||||
dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||
.w()?,
|
||||
)
|
||||
SlicePtrOrNull::Ptr(dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())?)
|
||||
};
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
@ -1024,7 +1092,7 @@ impl Map2Any for Cmp {
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::BINARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
|
||||
let out = unsafe { dev.alloc::<u8>(elem_count)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, elem_count);
|
||||
barg!(builder, dims.len());
|
||||
@ -1208,7 +1276,6 @@ fn gemm_config<T>(
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
};
|
||||
|
||||
Ok(StridedBatchedConfig {
|
||||
batch_size: b as i32,
|
||||
gemm,
|
||||
@ -1243,6 +1310,36 @@ impl BackendStorage for CudaStorage {
|
||||
&self.device
|
||||
}
|
||||
|
||||
fn const_set(&mut self, s: crate::scalar::Scalar, layout: &Layout) -> Result<()> {
|
||||
let dev = &self.device;
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src_o = layout.start_offset();
|
||||
let ((src, _guard_src), kernel_name) = match &mut self.slice {
|
||||
S::U8(s) => (slice_ptr(s, src_o), "const_set_u8"),
|
||||
S::U32(s) => (slice_ptr(s, src_o), "const_set_u32"),
|
||||
S::I64(s) => (slice_ptr(s, src_o), "const_set_i64"),
|
||||
S::BF16(s) => (slice_ptr(s, src_o), "const_set_bf16"),
|
||||
S::F16(s) => (slice_ptr(s, src_o), "const_set_f16"),
|
||||
S::F32(s) => (slice_ptr(s, src_o), "const_set_f32"),
|
||||
S::F64(s) => (slice_ptr(s, src_o), "const_set_f64"),
|
||||
};
|
||||
|
||||
let func = dev.get_or_load_func(kernel_name, &kernels::FILL)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el_count);
|
||||
barg!(builder, dims.len());
|
||||
ds.builder_arg(&mut builder);
|
||||
s.builder_arg(&mut builder);
|
||||
barg!(builder, src);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
@ -1269,7 +1366,7 @@ impl BackendStorage for CudaStorage {
|
||||
let func = dev.get_or_load_func(&kernel_name, &kernels::CAST)?;
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
let out = unsafe { dev.alloc::<u8>(el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<u8>(el)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el);
|
||||
barg!(builder, dims.len());
|
||||
@ -1280,7 +1377,7 @@ impl BackendStorage for CudaStorage {
|
||||
CudaStorageSlice::U8(out)
|
||||
}
|
||||
DType::U32 => {
|
||||
let out = unsafe { dev.alloc::<u32>(el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<u32>(el)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el);
|
||||
barg!(builder, dims.len());
|
||||
@ -1291,7 +1388,7 @@ impl BackendStorage for CudaStorage {
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
DType::I64 => {
|
||||
let out = unsafe { dev.alloc::<i64>(el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<i64>(el)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el);
|
||||
barg!(builder, dims.len());
|
||||
@ -1302,7 +1399,7 @@ impl BackendStorage for CudaStorage {
|
||||
CudaStorageSlice::I64(out)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let out = unsafe { dev.alloc::<bf16>(el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<bf16>(el)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el);
|
||||
barg!(builder, dims.len());
|
||||
@ -1313,7 +1410,7 @@ impl BackendStorage for CudaStorage {
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
DType::F16 => {
|
||||
let out = unsafe { dev.alloc::<f16>(el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<f16>(el)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el);
|
||||
barg!(builder, dims.len());
|
||||
@ -1324,7 +1421,7 @@ impl BackendStorage for CudaStorage {
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
DType::F32 => {
|
||||
let out = unsafe { dev.alloc::<f32>(el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<f32>(el)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el);
|
||||
barg!(builder, dims.len());
|
||||
@ -1335,7 +1432,7 @@ impl BackendStorage for CudaStorage {
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
DType::F64 => {
|
||||
let out = unsafe { dev.alloc::<f64>(el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<f64>(el)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el);
|
||||
barg!(builder, dims.len());
|
||||
@ -1445,6 +1542,7 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cudnn"))]
|
||||
fn conv1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
@ -1473,12 +1571,11 @@ impl BackendStorage for CudaStorage {
|
||||
let n = params.c_out;
|
||||
let k = params.k_size * params.c_in;
|
||||
let m = l_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let col_l = Layout::contiguous((b * m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
let kernel_l =
|
||||
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
|
||||
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = unsafe {
|
||||
@ -1486,10 +1583,9 @@ impl BackendStorage for CudaStorage {
|
||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||
};
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
let kernel_l =
|
||||
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
|
||||
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
||||
@ -1497,6 +1593,72 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
#[cfg(feature = "cudnn")]
|
||||
fn conv1d(
|
||||
&self,
|
||||
inp_l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
if !kernel_l.is_contiguous() {
|
||||
let slice = Conv1D(params).map(&self.slice, inp_l, &kernel.slice, kernel_l, &device)?;
|
||||
return Ok(Self { slice, device });
|
||||
}
|
||||
let l_out = params.l_out();
|
||||
let dst_el = params.c_out * l_out * params.b_size;
|
||||
let slice = match (&self.slice, &kernel.slice) {
|
||||
(S::U8(inp), S::U8(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<u8>(dst_el)? };
|
||||
crate::cudnn::launch_conv1d::<u8, u8>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::U8(out)
|
||||
}
|
||||
(S::BF16(inp), S::BF16(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<bf16>(dst_el)? };
|
||||
// Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16"
|
||||
// version.
|
||||
// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88
|
||||
crate::cudnn::launch_conv1d::<bf16, f32>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::BF16(out)
|
||||
}
|
||||
(S::F16(inp), S::F16(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<f16>(dst_el)? };
|
||||
crate::cudnn::launch_conv1d::<f16, f16>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F16(out)
|
||||
}
|
||||
(S::F32(inp), S::F32(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<f32>(dst_el)? };
|
||||
crate::cudnn::launch_conv1d::<f32, f32>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F32(out)
|
||||
}
|
||||
(S::F64(inp), S::F64(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<f64>(dst_el)? };
|
||||
crate::cudnn::launch_conv1d::<f64, f64>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F64(out)
|
||||
}
|
||||
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv1d does not support u32"))?,
|
||||
(S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv1d does not support i64"))?,
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in conv1d"))?,
|
||||
};
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
@ -1587,12 +1749,11 @@ impl BackendStorage for CudaStorage {
|
||||
let n = params.c_out;
|
||||
let k = params.k_h * params.k_w * params.c_in;
|
||||
let m = h_out * w_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let col_l = Layout::contiguous((b * m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
let kernel_l =
|
||||
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
|
||||
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = unsafe {
|
||||
@ -1600,10 +1761,9 @@ impl BackendStorage for CudaStorage {
|
||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||
};
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
let kernel_l =
|
||||
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
|
||||
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
||||
.transpose(1, 2)?
|
||||
@ -1632,7 +1792,7 @@ impl BackendStorage for CudaStorage {
|
||||
(S::U8(inp), S::U8(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<u8>(dst_el) }.w()?;
|
||||
let mut out = unsafe { device.alloc::<u8>(dst_el)? };
|
||||
crate::cudnn::launch_conv2d::<u8, u8>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::U8(out)
|
||||
@ -1640,7 +1800,7 @@ impl BackendStorage for CudaStorage {
|
||||
(S::BF16(inp), S::BF16(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<bf16>(dst_el) }.w()?;
|
||||
let mut out = unsafe { device.alloc::<bf16>(dst_el)? };
|
||||
// Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16"
|
||||
// version.
|
||||
// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88
|
||||
@ -1651,7 +1811,7 @@ impl BackendStorage for CudaStorage {
|
||||
(S::F16(inp), S::F16(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<f16>(dst_el) }.w()?;
|
||||
let mut out = unsafe { device.alloc::<f16>(dst_el)? };
|
||||
crate::cudnn::launch_conv2d::<f16, f16>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F16(out)
|
||||
@ -1659,7 +1819,7 @@ impl BackendStorage for CudaStorage {
|
||||
(S::F32(inp), S::F32(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<f32>(dst_el) }.w()?;
|
||||
let mut out = unsafe { device.alloc::<f32>(dst_el)? };
|
||||
crate::cudnn::launch_conv2d::<f32, f32>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F32(out)
|
||||
@ -1667,7 +1827,7 @@ impl BackendStorage for CudaStorage {
|
||||
(S::F64(inp), S::F64(k)) => {
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(kernel_l.start_offset()..);
|
||||
let mut out = unsafe { device.alloc::<f64>(dst_el) }.w()?;
|
||||
let mut out = unsafe { device.alloc::<f64>(dst_el)? };
|
||||
crate::cudnn::launch_conv2d::<f64, f64>(inp, inp_l, k, &mut out, params, &device)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
S::F64(out)
|
||||
@ -1738,20 +1898,29 @@ impl BackendStorage for CudaStorage {
|
||||
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
let device = self.device().clone();
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
Scatter(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
|
||||
}
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<()> {
|
||||
let device = self.device().clone();
|
||||
ScatterAdd(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
|
||||
}
|
||||
fn index_add(
|
||||
&self,
|
||||
@ -1765,7 +1934,7 @@ impl BackendStorage for CudaStorage {
|
||||
let device = self.device().clone();
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l, &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
}
|
||||
|
||||
@ -1783,7 +1952,7 @@ impl BackendStorage for CudaStorage {
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
|
||||
let mut out = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
|
||||
let mut out = unsafe { dev.alloc::<bf16>(elem_count)? };
|
||||
unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
||||
.w()?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
@ -1792,7 +1961,7 @@ impl BackendStorage for CudaStorage {
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
|
||||
let mut out = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
let mut out = unsafe { dev.alloc::<f16>(elem_count)? };
|
||||
unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
||||
.w()?;
|
||||
CudaStorageSlice::F16(out)
|
||||
@ -1801,7 +1970,7 @@ impl BackendStorage for CudaStorage {
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
||||
let mut out = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||
let mut out = unsafe { dev.alloc::<f32>(elem_count)? };
|
||||
unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
||||
.w()?;
|
||||
CudaStorageSlice::F32(out)
|
||||
@ -1810,7 +1979,7 @@ impl BackendStorage for CudaStorage {
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
||||
let mut out = unsafe { dev.alloc::<f64>(elem_count) }.w()?;
|
||||
let mut out = unsafe { dev.alloc::<f64>(elem_count)? };
|
||||
unsafe {
|
||||
self.device
|
||||
.blas
|
||||
@ -1883,7 +2052,7 @@ impl BackendStorage for CudaStorage {
|
||||
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.memcpy_dtod(&src, &mut dst).w()?
|
||||
dev.memcpy_dtod(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_bf16", &kernels::UNARY)?;
|
||||
let mut builder = func.builder();
|
||||
@ -1899,7 +2068,7 @@ impl BackendStorage for CudaStorage {
|
||||
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.memcpy_dtod(&src, &mut dst).w()?
|
||||
dev.memcpy_dtod(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_f16", &kernels::UNARY)?;
|
||||
let mut builder = func.builder();
|
||||
@ -1915,7 +2084,7 @@ impl BackendStorage for CudaStorage {
|
||||
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.memcpy_dtod(&src, &mut dst).w()?
|
||||
dev.memcpy_dtod(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_f32", &kernels::UNARY)?;
|
||||
let mut builder = func.builder();
|
||||
@ -1931,7 +2100,7 @@ impl BackendStorage for CudaStorage {
|
||||
(CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.memcpy_dtod(&src, &mut dst).w()?
|
||||
dev.memcpy_dtod(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_u8", &kernels::UNARY)?;
|
||||
let mut builder = func.builder();
|
||||
@ -1947,7 +2116,7 @@ impl BackendStorage for CudaStorage {
|
||||
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.memcpy_dtod(&src, &mut dst).w()?
|
||||
dev.memcpy_dtod(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_u32", &kernels::UNARY)?;
|
||||
let mut builder = func.builder();
|
||||
@ -1963,7 +2132,7 @@ impl BackendStorage for CudaStorage {
|
||||
(CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.memcpy_dtod(&src, &mut dst).w()?
|
||||
dev.memcpy_dtod(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_i64", &kernels::UNARY)?;
|
||||
let mut builder = func.builder();
|
||||
@ -1979,7 +2148,7 @@ impl BackendStorage for CudaStorage {
|
||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.memcpy_dtod(&src, &mut dst).w()?
|
||||
dev.memcpy_dtod(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_f64", &kernels::UNARY)?;
|
||||
let mut builder = func.builder();
|
||||
|
@ -1,5 +1,5 @@
|
||||
/// Helper functions to plug cuda kernels in candle.
|
||||
use crate::{Layout, Result, Shape, WithDType};
|
||||
use crate::{Layout, Result, WithDType};
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
|
||||
|
||||
@ -96,7 +96,7 @@ pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -105,19 +105,19 @@ pub trait Map2InPlace {
|
||||
fn map(
|
||||
&self,
|
||||
dst: &mut S,
|
||||
dst_s: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &S,
|
||||
src_l: &Layout,
|
||||
d: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
match (dst, src) {
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
}
|
||||
}
|
||||
|
@ -103,7 +103,63 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: NdArray> NdArray for Vec<S> {
|
||||
impl<S: WithDType> NdArray for Vec<S> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from(self.len()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
S::to_cpu_storage(self.as_slice())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType> NdArray for Vec<&[S]> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
crate::bail!("empty array")
|
||||
}
|
||||
let n = self.len();
|
||||
let m = self[0].len();
|
||||
for v in self.iter() {
|
||||
if v.len() != m {
|
||||
crate::bail!("two elements have different len {m} {}", v.len())
|
||||
}
|
||||
}
|
||||
Ok(Shape::from((n, m)))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let data = self.iter().copied().flatten().copied().collect::<Vec<_>>();
|
||||
S::to_cpu_storage_owned(data)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType> NdArray for Vec<Vec<S>> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
crate::bail!("empty array")
|
||||
}
|
||||
let n = self.len();
|
||||
let m = self[0].len();
|
||||
for v in self.iter() {
|
||||
if v.len() != m {
|
||||
crate::bail!("two elements have different len {m} {}", v.len())
|
||||
}
|
||||
}
|
||||
Ok(Shape::from((n, m)))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let len: usize = self.iter().map(|v| v.len()).sum();
|
||||
let mut dst = Vec::with_capacity(len);
|
||||
for v in self.iter() {
|
||||
dst.extend(v.iter().copied());
|
||||
}
|
||||
S::to_cpu_storage_owned(dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType> NdArray for Vec<Vec<Vec<S>>> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
crate::bail!("empty array")
|
||||
@ -120,9 +176,57 @@ impl<S: NdArray> NdArray for Vec<S> {
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
// This allocates intermediary memory and shouldn't be necessary.
|
||||
let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
|
||||
CpuStorage::concat(storages.as_slice()).unwrap()
|
||||
if self.is_empty() {
|
||||
return S::to_cpu_storage_owned(vec![]);
|
||||
}
|
||||
let len: usize = self
|
||||
.iter()
|
||||
.map(|v| v.iter().map(|v| v.len()).sum::<usize>())
|
||||
.sum();
|
||||
let mut dst = Vec::with_capacity(len);
|
||||
for v1 in self.iter() {
|
||||
for v2 in v1.iter() {
|
||||
dst.extend(v2.iter().copied());
|
||||
}
|
||||
}
|
||||
S::to_cpu_storage_owned(dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType> NdArray for Vec<Vec<Vec<Vec<S>>>> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
crate::bail!("empty array")
|
||||
}
|
||||
let shape0 = self[0].shape()?;
|
||||
let n = self.len();
|
||||
for v in self.iter() {
|
||||
let shape = v.shape()?;
|
||||
if shape != shape0 {
|
||||
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
|
||||
}
|
||||
}
|
||||
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let len: usize = self
|
||||
.iter()
|
||||
.map(|v| {
|
||||
v.iter()
|
||||
.map(|v| v.iter().map(|v| v.len()).sum::<usize>())
|
||||
.sum::<usize>()
|
||||
})
|
||||
.sum();
|
||||
let mut dst = Vec::with_capacity(len);
|
||||
for v1 in self.iter() {
|
||||
for v2 in v1.iter() {
|
||||
for v3 in v2.iter() {
|
||||
dst.extend(v3.iter().copied());
|
||||
}
|
||||
}
|
||||
}
|
||||
S::to_cpu_storage_owned(dst)
|
||||
}
|
||||
}
|
||||
|
||||
@ -292,23 +396,6 @@ impl Device {
|
||||
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
|
||||
}
|
||||
|
||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuDevice.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
|
@ -107,6 +107,7 @@ pub trait WithDType:
|
||||
|
||||
fn from_f64(v: f64) -> Self;
|
||||
fn to_f64(self) -> f64;
|
||||
fn to_scalar(self) -> crate::scalar::Scalar;
|
||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||
|
||||
@ -131,6 +132,10 @@ macro_rules! with_dtype {
|
||||
$to_f64(self)
|
||||
}
|
||||
|
||||
fn to_scalar(self) -> crate::scalar::Scalar {
|
||||
crate::scalar::Scalar::$dtype(self)
|
||||
}
|
||||
|
||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
|
||||
CpuStorageRef::$dtype(data)
|
||||
}
|
||||
@ -175,7 +180,7 @@ with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
|
||||
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
|
||||
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
|
||||
|
||||
pub trait IntDType: WithDType {
|
||||
pub trait IntDType: WithDType + num_traits::Bounded {
|
||||
fn is_true(&self) -> bool;
|
||||
fn as_usize(&self) -> usize;
|
||||
}
|
||||
|
@ -37,6 +37,10 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
@ -124,15 +128,27 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
@ -214,10 +230,6 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -41,6 +41,10 @@ impl crate::backend::BackendStorage for MetalStorage {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
@ -128,15 +132,27 @@ impl crate::backend::BackendStorage for MetalStorage {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
@ -218,10 +234,6 @@ impl crate::backend::BackendDevice for MetalDevice {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
@ -226,8 +226,8 @@ where
|
||||
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
|
||||
///
|
||||
/// let d = a.i((2.., ..))?;
|
||||
/// assert_eq!(c.shape().dims(), &[2]);
|
||||
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
|
||||
/// assert_eq!(d.shape().dims(), &[1, 3]);
|
||||
/// assert_eq!(d.to_vec2::<f32>()?, &[[6., 7., 8.]]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
fn i(&self, (a, b): (A, B)) -> Result<Tensor, Error> {
|
||||
|
@ -413,6 +413,100 @@ impl BackendStorage for MetalStorage {
|
||||
self.binary(name, rhs, lhs_l, rhs_l)
|
||||
}
|
||||
|
||||
fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
|
||||
use crate::scalar::Scalar;
|
||||
fn set<S: crate::WithDType + candle_metal_kernels::utils::EncoderParam>(
|
||||
self_: &mut MetalStorage,
|
||||
s: S,
|
||||
l: &Layout,
|
||||
) -> Result<()> {
|
||||
let device = self_.device();
|
||||
let dtype = self_.dtype;
|
||||
let shape = l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
command_buffer.set_label("const-set");
|
||||
let dst = buffer_o(&self_.buffer, l, self_.dtype);
|
||||
|
||||
match (el_count % 2, dtype, l.is_contiguous()) {
|
||||
(0, DType::BF16 | DType::F16, true) => {
|
||||
use candle_metal_kernels::unary::contiguous_tiled;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => contiguous_tiled::const_set::HALF,
|
||||
DType::BF16 => contiguous_tiled::const_set::BFLOAT,
|
||||
_ => crate::bail!("internal bug in const_set"),
|
||||
};
|
||||
candle_metal_kernels::call_const_set_contiguous_tiled(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
s,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, true) => {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => contiguous::const_set::HALF,
|
||||
DType::BF16 => contiguous::const_set::BFLOAT,
|
||||
DType::F32 => contiguous::const_set::FLOAT,
|
||||
DType::I64 => contiguous::const_set::I64,
|
||||
DType::U32 => contiguous::const_set::U32,
|
||||
DType::U8 => contiguous::const_set::U8,
|
||||
DType::F64 => crate::bail!("unsupported const-set f64"),
|
||||
};
|
||||
candle_metal_kernels::call_const_set_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
s,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, false) => {
|
||||
use candle_metal_kernels::unary::strided;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => strided::const_set::HALF,
|
||||
DType::BF16 => strided::const_set::BFLOAT,
|
||||
DType::F32 => strided::const_set::FLOAT,
|
||||
DType::I64 => strided::const_set::I64,
|
||||
DType::U32 => strided::const_set::U32,
|
||||
DType::U8 => strided::const_set::U8,
|
||||
DType::F64 => crate::bail!("unsupported const-set f64"),
|
||||
};
|
||||
candle_metal_kernels::call_const_set_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
l.dims(),
|
||||
s,
|
||||
l.stride(),
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
match (self.dtype, s) {
|
||||
(DType::U8, Scalar::U8(s)) => set(self, s, l),
|
||||
(DType::U32, Scalar::U32(s)) => set(self, s, l),
|
||||
(DType::I64, Scalar::I64(s)) => set(self, s, l),
|
||||
(DType::F16, Scalar::F16(s)) => set(self, s, l),
|
||||
(DType::BF16, Scalar::BF16(s)) => set(self, s, l),
|
||||
(DType::F32, Scalar::F32(s)) => set(self, s, l),
|
||||
(DType::F64, Scalar::F64(s)) => set(self, s, l),
|
||||
_ => crate::bail!("dtype mismatch, expected {:?}, got {:?}", self.dtype, s),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let shape = layout.shape();
|
||||
@ -1332,18 +1426,65 @@ impl BackendStorage for MetalStorage {
|
||||
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
) -> Result<()> {
|
||||
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "scatter" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U8, DType::F32) => "s_u8_f32",
|
||||
(DType::U8, DType::F16) => "s_u8_f16",
|
||||
(DType::U8, DType::BF16) => "s_u8_bf16",
|
||||
(DType::U32, DType::U32) => "s_u32_u32",
|
||||
(DType::U32, DType::F32) => "s_u32_f32",
|
||||
(DType::U32, DType::F16) => "s_u32_f16",
|
||||
(DType::U32, DType::BF16) => "s_u32_bf16",
|
||||
(DType::I64, DType::F32) => "s_i64_f32",
|
||||
(DType::I64, DType::F16) => "s_i64_f16",
|
||||
(DType::I64, DType::BF16) => "s_i64_bf16",
|
||||
_ => Err(MetalError::UnexpectedDType {
|
||||
msg: "scatter ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let dst = buffer_o(&self.buffer, l, self.dtype);
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_scatter(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
src_l.dims(),
|
||||
l.dims(),
|
||||
dim,
|
||||
src,
|
||||
ids,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<()> {
|
||||
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
@ -1364,9 +1505,10 @@ impl BackendStorage for MetalStorage {
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let dst = buffer_o(&self.buffer, l, self.dtype);
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_scatter_add(
|
||||
candle_metal_kernels::call_scatter(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
@ -1376,10 +1518,10 @@ impl BackendStorage for MetalStorage {
|
||||
dim,
|
||||
src,
|
||||
ids,
|
||||
&acc.buffer,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(acc)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
@ -1513,50 +1655,32 @@ impl BackendStorage for MetalStorage {
|
||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("matmul");
|
||||
if self.dtype == DType::BF16 {
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
candle_metal_kernels::GemmDType::BF16,
|
||||
(b, m, n, k),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let dtype = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::GemmDType::F32,
|
||||
DType::F16 => candle_metal_kernels::GemmDType::F16,
|
||||
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
|
||||
dtype => {
|
||||
return Err(MetalError::Message(format!(
|
||||
"mlx matmul doesn't support {dtype:?}"
|
||||
))
|
||||
.into())
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
let dtype = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::GemmDType::F32,
|
||||
DType::F16 => candle_metal_kernels::GemmDType::F16,
|
||||
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
|
||||
dtype => {
|
||||
return Err(
|
||||
MetalError::Message(format!("mlx matmul doesn't support {dtype:?}")).into(),
|
||||
)
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::new(
|
||||
buffer,
|
||||
self.device.clone(),
|
||||
@ -1965,40 +2089,6 @@ impl BackendDevice for MetalDevice {
|
||||
))
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
let name = match dtype {
|
||||
DType::U8 => "fill_u8",
|
||||
DType::U32 => "fill_u32",
|
||||
DType::I64 => "fill_i64",
|
||||
DType::F16 => "fill_f16",
|
||||
DType::BF16 => "fill_bf16",
|
||||
DType::F32 => "fill_f32",
|
||||
DType::F64 => {
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
|
||||
return self.storage_from_cpu_storage(&cpu_storage);
|
||||
}
|
||||
};
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
candle_metal_kernels::call_const_fill(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
shape.elem_count(),
|
||||
&buffer,
|
||||
1.,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(MetalStorage::new(
|
||||
buffer,
|
||||
self.clone(),
|
||||
shape.elem_count(),
|
||||
dtype,
|
||||
))
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
let (count, buffer) = match T::cpu_storage_ref(s) {
|
||||
CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
|
@ -80,6 +80,7 @@ pub enum Op {
|
||||
Reduce(Tensor, ReduceOp, Vec<usize>),
|
||||
Matmul(Tensor, Tensor),
|
||||
Gather(Tensor, Tensor, usize),
|
||||
Scatter(Tensor, Tensor, Tensor, usize),
|
||||
ScatterAdd(Tensor, Tensor, Tensor, usize),
|
||||
IndexSelect(Tensor, Tensor, usize),
|
||||
IndexAdd(Tensor, Tensor, Tensor, usize),
|
||||
|
@ -816,7 +816,7 @@ impl PthTensors {
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the pth file.
|
||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
||||
/// contains multiple objects and the state_dict is the one we are interested in.
|
||||
/// contains multiple objects and the state_dict is the one we are interested in.
|
||||
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
||||
path: P,
|
||||
key: Option<&str>,
|
||||
|
@ -73,7 +73,7 @@ fn dequantize_f32(
|
||||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let nb = elem_count.div_ceil(256);
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
|
||||
@ -99,7 +99,7 @@ fn dequantize_f32(
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
|
||||
let dst = unsafe { dev.alloc::<f32>(elem_count)? };
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
@ -133,7 +133,7 @@ fn dequantize_f16(
|
||||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let nb = elem_count.div_ceil(256);
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
|
||||
@ -159,7 +159,7 @@ fn dequantize_f16(
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
|
||||
let dst = unsafe { dev.alloc::<f16>(elem_count)? };
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
@ -216,7 +216,7 @@ fn dequantize_mul_mat_vec(
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows)? };
|
||||
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (block_num_y as u32, 1, 1),
|
||||
@ -256,7 +256,7 @@ fn mul_mat_vec_via_q8_1(
|
||||
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
||||
let y_size_in_bytes =
|
||||
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
|
||||
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
|
||||
|
||||
let kernel_name = match dtype {
|
||||
@ -274,12 +274,12 @@ fn mul_mat_vec_via_q8_1(
|
||||
};
|
||||
let kernel_name = format!("{kernel_name}{b_size}");
|
||||
let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size)? };
|
||||
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
||||
let (nblocks, nwarps) = match b_size {
|
||||
1 => (nrows as u32, 4),
|
||||
2..=4 => ((nrows as u32 + 1) / 2, 4),
|
||||
5..=8 => ((nrows as u32 + 1) / 2, 2),
|
||||
2..=4 => ((nrows as u32).div_ceil(2), 4),
|
||||
5..=8 => ((nrows as u32).div_ceil(2), 2),
|
||||
_ => crate::bail!("unexpected bsize {b_size}"),
|
||||
};
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
@ -329,7 +329,7 @@ fn mul_mat_via_q8_1(
|
||||
let k_padded = pad(k, MATRIX_ROW_PADDING);
|
||||
let y_size_in_bytes =
|
||||
k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
|
||||
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
|
||||
|
||||
let (kernel_name, mmq_x, mmq_y) = match dtype {
|
||||
@ -346,7 +346,7 @@ fn mul_mat_via_q8_1(
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
|
||||
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols)? };
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (
|
||||
ceil_div(x_rows, mmq_y) as u32,
|
||||
@ -378,7 +378,7 @@ impl QCudaStorage {
|
||||
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
|
||||
let padded_size_in_bytes =
|
||||
ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size();
|
||||
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes).w()?;
|
||||
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes)?;
|
||||
Ok(QCudaStorage {
|
||||
data: PaddedCudaSlice {
|
||||
inner,
|
||||
@ -425,8 +425,7 @@ impl QCudaStorage {
|
||||
|
||||
let buffer = self
|
||||
.device
|
||||
.memcpy_dtov(&self.data.inner.slice(..self.data.len))
|
||||
.w()?;
|
||||
.memcpy_dtov(&self.data.inner.slice(..self.data.len))?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
let block_len = elem_count / self.dtype.block_size();
|
||||
match self.dtype {
|
||||
@ -457,9 +456,7 @@ impl QCudaStorage {
|
||||
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
||||
// Run the quantization on cpu.
|
||||
let src = match &src.slice {
|
||||
crate::cuda_backend::CudaStorageSlice::F32(data) => {
|
||||
self.device.memcpy_dtov(data).w()?
|
||||
}
|
||||
crate::cuda_backend::CudaStorageSlice::F32(data) => self.device.memcpy_dtov(data)?,
|
||||
_ => crate::bail!("only f32 can be quantized"),
|
||||
};
|
||||
let src_len = src.len();
|
||||
@ -469,10 +466,9 @@ impl QCudaStorage {
|
||||
let data = qcpu_storage.data()?;
|
||||
let padded_len =
|
||||
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
|
||||
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
|
||||
let mut inner = unsafe { self.device.alloc::<u8>(padded_len)? };
|
||||
self.device
|
||||
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))
|
||||
.w()?;
|
||||
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))?;
|
||||
self.data = PaddedCudaSlice {
|
||||
inner,
|
||||
len: data.len(),
|
||||
@ -606,10 +602,8 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
};
|
||||
let dtype = T::DTYPE;
|
||||
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
|
||||
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
|
||||
device
|
||||
.memcpy_htod(data, &mut inner.slice_mut(..data.len()))
|
||||
.w()?;
|
||||
let mut inner = unsafe { device.alloc::<u8>(padded_len)? };
|
||||
device.memcpy_htod(data, &mut inner.slice_mut(..data.len()))?;
|
||||
Ok(QStorage::Cuda(QCudaStorage {
|
||||
data: PaddedCudaSlice {
|
||||
inner,
|
||||
@ -631,9 +625,9 @@ mod test {
|
||||
let el_padded = pad(el, MATRIX_ROW_PADDING);
|
||||
let y_size_in_bytes =
|
||||
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
|
||||
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
||||
let y = dev.memcpy_stod(&vs).w()?;
|
||||
let y = dev.memcpy_stod(&vs)?;
|
||||
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
||||
Ok(())
|
||||
}
|
||||
@ -643,7 +637,7 @@ mod test {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let ncols = 256;
|
||||
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
||||
let y = dev.memcpy_stod(&vs).w()?;
|
||||
let y = dev.memcpy_stod(&vs)?;
|
||||
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||
let cuda_storage = mul_mat_vec_via_q8_1(
|
||||
@ -656,7 +650,7 @@ mod test {
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||
let vs = dev.memcpy_dtov(&vs.slice(..))?;
|
||||
assert_eq!(vs.len(), 1);
|
||||
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
||||
// Q8 means 1/256 precision.
|
||||
@ -671,7 +665,7 @@ mod test {
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||
let vs = dev.memcpy_dtov(&vs.slice(..))?;
|
||||
assert_eq!(vs.len(), 1);
|
||||
assert_eq!(vs[0], 5561851.0);
|
||||
Ok(())
|
||||
@ -682,7 +676,7 @@ mod test {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let ncols = 256;
|
||||
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
||||
let y = dev.memcpy_stod(&vs).w()?;
|
||||
let y = dev.memcpy_stod(&vs)?;
|
||||
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||
let cuda_storage = mul_mat_via_q8_1(
|
||||
@ -696,7 +690,7 @@ mod test {
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||
let vs = dev.memcpy_dtov(&vs.slice(..))?;
|
||||
|
||||
/*
|
||||
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
||||
@ -723,7 +717,7 @@ mod test {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let (x_rows, ncols, y_cols) = (4, 16, 2048);
|
||||
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
|
||||
let y = dev.memcpy_stod(&vs).w()?;
|
||||
let y = dev.memcpy_stod(&vs)?;
|
||||
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
|
||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||
let cuda_storage = mul_mat_via_q8_1(
|
||||
@ -737,7 +731,7 @@ mod test {
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let _vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||
let _vs = dev.memcpy_dtov(&vs.slice(..))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,74 @@
|
||||
//! TensorScalar Enum and Trait
|
||||
//!
|
||||
use crate::{Result, Tensor, WithDType};
|
||||
use crate::{DType, Result, Tensor, WithDType};
|
||||
use half::{bf16, f16};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum Scalar {
|
||||
U8(u8),
|
||||
U32(u32),
|
||||
I64(i64),
|
||||
BF16(bf16),
|
||||
F16(f16),
|
||||
F32(f32),
|
||||
F64(f64),
|
||||
}
|
||||
|
||||
impl<T: WithDType> From<T> for Scalar {
|
||||
fn from(value: T) -> Self {
|
||||
value.to_scalar()
|
||||
}
|
||||
}
|
||||
|
||||
impl Scalar {
|
||||
pub fn zero(dtype: DType) -> Self {
|
||||
match dtype {
|
||||
DType::U8 => Scalar::U8(0),
|
||||
DType::U32 => Scalar::U32(0),
|
||||
DType::I64 => Scalar::I64(0),
|
||||
DType::BF16 => Scalar::BF16(bf16::ZERO),
|
||||
DType::F16 => Scalar::F16(f16::ZERO),
|
||||
DType::F32 => Scalar::F32(0.0),
|
||||
DType::F64 => Scalar::F64(0.0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn one(dtype: DType) -> Self {
|
||||
match dtype {
|
||||
DType::U8 => Scalar::U8(1),
|
||||
DType::U32 => Scalar::U32(1),
|
||||
DType::I64 => Scalar::I64(1),
|
||||
DType::BF16 => Scalar::BF16(bf16::ONE),
|
||||
DType::F16 => Scalar::F16(f16::ONE),
|
||||
DType::F32 => Scalar::F32(1.0),
|
||||
DType::F64 => Scalar::F64(1.0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self {
|
||||
Scalar::U8(_) => DType::U8,
|
||||
Scalar::U32(_) => DType::U32,
|
||||
Scalar::I64(_) => DType::I64,
|
||||
Scalar::BF16(_) => DType::BF16,
|
||||
Scalar::F16(_) => DType::F16,
|
||||
Scalar::F32(_) => DType::F32,
|
||||
Scalar::F64(_) => DType::F64,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_f64(&self) -> f64 {
|
||||
match self {
|
||||
Scalar::U8(v) => *v as f64,
|
||||
Scalar::U32(v) => *v as f64,
|
||||
Scalar::I64(v) => *v as f64,
|
||||
Scalar::BF16(v) => v.to_f64(),
|
||||
Scalar::F16(v) => v.to_f64(),
|
||||
Scalar::F32(v) => *v as f64,
|
||||
Scalar::F64(v) => *v,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum TensorScalar {
|
||||
Tensor(Tensor),
|
||||
|
@ -76,7 +76,7 @@ mod cuda {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
};
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
||||
let dst = unsafe { dev.alloc::<u32>(elem_count)? };
|
||||
let func = if self.asc {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
|
||||
} else {
|
||||
|
@ -1,5 +1,6 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::op::{self, CmpOp, ReduceOp};
|
||||
use crate::scalar::Scalar;
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
||||
|
||||
@ -73,6 +74,14 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn const_set(&mut self, v: Scalar, l: &Layout) -> Result<()> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => storage.const_set(v, l),
|
||||
Storage::Cuda(storage) => storage.const_set(v, l),
|
||||
Storage::Metal(storage) => storage.const_set(v, l),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
@ -619,32 +628,56 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn scatter_add(
|
||||
&self,
|
||||
pub(crate) fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
source: &Self,
|
||||
source_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
self.same_device(indexes, "scatter-set")?;
|
||||
self.same_device(source, "scatter-set")?;
|
||||
match (self, indexes, source) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn scatter_add(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
source: &Self,
|
||||
source_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<()> {
|
||||
self.same_device(indexes, "scatter-add")?;
|
||||
self.same_device(source, "scatter-add")?;
|
||||
match (self, indexes, source) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn index_add(
|
||||
|
@ -3,7 +3,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
|
||||
use crate::scalar::TensorOrScalar;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::shape::{Dim, Dims, ShapeWithOneHole};
|
||||
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
@ -185,7 +185,9 @@ impl Tensor {
|
||||
) -> Result<Self> {
|
||||
let none = BackpropOp::none();
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
||||
let layout = Layout::contiguous(shape.clone());
|
||||
storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
}
|
||||
|
||||
@ -202,6 +204,18 @@ impl Tensor {
|
||||
Self::ones_impl(shape, dtype, device, false)
|
||||
}
|
||||
|
||||
pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> {
|
||||
self.storage_mut().const_set(value, self.layout())
|
||||
}
|
||||
|
||||
pub fn zero_set(&self) -> Result<()> {
|
||||
self.const_set(crate::scalar::Scalar::zero(self.dtype()))
|
||||
}
|
||||
|
||||
pub fn one_set(&self) -> Result<()> {
|
||||
self.const_set(crate::scalar::Scalar::one(self.dtype()))
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
|
||||
///
|
||||
/// ```rust
|
||||
@ -368,8 +382,7 @@ impl Tensor {
|
||||
Self::new_impl(array, shape, device, false)
|
||||
}
|
||||
|
||||
/// Returns a new tensor with all the elements having the same specified value. Note that
|
||||
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
|
||||
/// Returns a new tensor with all the elements having the same specified value.
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
|
||||
@ -384,7 +397,12 @@ impl Tensor {
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
|
||||
let none = BackpropOp::none();
|
||||
let shape = shape.into();
|
||||
let mut storage = unsafe { device.alloc_uninit(&shape, D::DTYPE)? };
|
||||
let layout = Layout::contiguous(shape.clone());
|
||||
storage.const_set(value.to_scalar(), &layout)?;
|
||||
Ok(from_storage(storage, shape, none, false))
|
||||
}
|
||||
|
||||
/// Creates a new 1D tensor from an iterator.
|
||||
@ -452,17 +470,13 @@ impl Tensor {
|
||||
Self::from_vec_impl(data, len, device, false)
|
||||
}
|
||||
|
||||
pub(crate) fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
|
||||
pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let buffer_size = data.len();
|
||||
if buffer_size != shape.elem_count() {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let shape = shape.into_shape(data.len())?;
|
||||
let storage = device.storage_owned(data)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
@ -481,7 +495,7 @@ impl Tensor {
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
|
||||
pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
@ -502,17 +516,12 @@ impl Tensor {
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||
pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||
array: &[D],
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let n: usize = shape.elem_count();
|
||||
let buffer_size: usize = array.len();
|
||||
if buffer_size != n {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let shape = shape.into_shape(array.len())?;
|
||||
let storage = device.storage_from_slice(array)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, shape, none, false))
|
||||
@ -1226,6 +1235,83 @@ impl Tensor {
|
||||
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
||||
}
|
||||
|
||||
/// Computes the dot product of two 1D tensors.
|
||||
///
|
||||
/// - If inputs are 1D vectors (`[n]`), returns their scalar dot product.
|
||||
/// - Panics if shapes are not compatible
|
||||
/// - Not supported for integer dtypes
|
||||
///
|
||||
/// # Example (vectors)
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let t1 = Tensor::new(&[1.0, 2.0, 3.0], &Device::Cpu)?;
|
||||
/// let t2 = Tensor::new(&[4.0, 5.0, 6.0], &Device::Cpu)?;
|
||||
/// let res = t1.dot(&t2)?;
|
||||
/// assert_eq!(res.to_scalar::<f64>()?, 32.);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn dot(&self, rhs: &Self) -> Result<Self> {
|
||||
if self.dims().len() != 1 || rhs.dims().len() != 1 {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op: "dot",
|
||||
});
|
||||
}
|
||||
|
||||
(self * rhs).and_then(|ret| ret.sum_all())
|
||||
}
|
||||
|
||||
/// Computes the **Frobenius norm** (L2 norm of all elements) of the tensor.
|
||||
/// - Output is `sqrt(sum(x^2))`.
|
||||
/// - Always returns a scalar (`[]` shape).
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?;
|
||||
/// let norm = t.norm()?;
|
||||
/// assert_eq!(norm.to_scalar::<f64>()?, 5.);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn norm(&self) -> Result<Self> {
|
||||
if self.dtype().is_int() {
|
||||
bail!("norm not supported for integer dtypes");
|
||||
}
|
||||
|
||||
self.sqr().and_then(|x| x.sum_all()).and_then(|x| x.sqrt())
|
||||
}
|
||||
|
||||
/// Performs strict matrix-vector multiplication (`[m, n] * [n] = [m]`).
|
||||
///
|
||||
/// - If `self` is a matrix (`[m, n]`) and `rhs` is a vector (`[n]`), returns a vector (`[m]`).
|
||||
/// - **No broadcasting**: Panics if `self` is not 2D or if `rhs` is not 1D with matching size.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
/// let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?;
|
||||
/// let res = mat.mv(&vec)?;
|
||||
/// assert_eq!(res.to_vec1::<f64>()?, [6., 15.]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn mv(&self, rhs: &Self) -> Result<Self> {
|
||||
// Strict shape checks
|
||||
let lhs_dims = self.dims();
|
||||
let rhs_dims = rhs.dims();
|
||||
if lhs_dims.len() != 2 || rhs_dims.len() != 1 || lhs_dims[1] != rhs_dims[0] {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op: "mv",
|
||||
});
|
||||
}
|
||||
|
||||
// Direct matmul after ensuring rhs is column vector
|
||||
self.matmul(&rhs.unsqueeze(1)?)?.squeeze(1)
|
||||
}
|
||||
|
||||
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
@ -1349,8 +1435,7 @@ impl Tensor {
|
||||
self.index_select(ids, 0)
|
||||
}
|
||||
|
||||
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter-add")?;
|
||||
fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
|
||||
let source_dims = source.dims();
|
||||
let self_dims = self.dims();
|
||||
let mismatch = if source_dims.len() != self_dims.len() {
|
||||
@ -1367,7 +1452,7 @@ impl Tensor {
|
||||
};
|
||||
if mismatch {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (self, src)",
|
||||
op: "scatter (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
@ -1375,13 +1460,44 @@ impl Tensor {
|
||||
}
|
||||
if indexes.dims() != source.dims() {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (indexes, src)",
|
||||
op: "scatter (indexes, src)",
|
||||
lhs: indexes.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let storage = self.storage().scatter_add(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
let shape = self.shape();
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let layout = Layout::contiguous(shape);
|
||||
storage.scatter_set(
|
||||
&layout,
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
|
||||
Op::Scatter(t1, t2, t3, dim)
|
||||
});
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
|
||||
if self.same_storage(source) {
|
||||
crate::bail!("cannot use slice_set when self and src share their storage")
|
||||
}
|
||||
let dim = dim.to_index(self.shape(), "scatter-set")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
self.storage_mut().scatter_set(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
@ -1389,12 +1505,48 @@ impl Tensor {
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter-add")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
let shape = self.shape();
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let layout = Layout::contiguous(shape);
|
||||
storage.scatter_add(
|
||||
&layout,
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
|
||||
Op::ScatterAdd(t1, t2, t3, dim)
|
||||
});
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
|
||||
if self.same_storage(source) {
|
||||
crate::bail!("cannot use slice_set when self and src share their storage")
|
||||
}
|
||||
let dim = dim.to_index(self.shape(), "scatter-add-set")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
self.storage_mut().scatter_add(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
|
||||
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "slice-scatter")?;
|
||||
@ -2197,7 +2349,7 @@ impl Tensor {
|
||||
///
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
|
||||
pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
|
||||
let shape = s.into_shape(self.elem_count())?;
|
||||
if shape.elem_count() != self.elem_count() {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
|
@ -241,7 +241,7 @@ impl Tensor {
|
||||
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
|
||||
/// has to be greater than or equal to `offset` plus the `src` size.
|
||||
///
|
||||
/// Note that this modifies `self` in place and as such is not compatibel with
|
||||
/// Note that this modifies `self` in place and as such is not compatible with
|
||||
/// back-propagation.
|
||||
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
|
||||
let dim = dim.to_index(self.shape(), "slice-set")?;
|
||||
|
@ -53,6 +53,20 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
let res = {
|
||||
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
|
||||
t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?
|
||||
};
|
||||
assert_eq!(res.dims(), [3, 2, 5]);
|
||||
// Same as pytorch default padding: use zeros.
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
|
||||
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
|
||||
let w = w.transpose(0, 1)?;
|
||||
// The CPU kernels applied in the contiguous and non contiguous cases are different.
|
||||
@ -163,6 +177,22 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
let res = {
|
||||
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
|
||||
t.conv2d(&w, 0, 1, 1, 1)?
|
||||
};
|
||||
assert_eq!(res.dims(), [3, 2, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
|
||||
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
|
||||
[
|
||||
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
|
||||
|
@ -82,6 +82,26 @@ fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_dot() -> Result<()> {
|
||||
let lhs = Tensor::new(&[1., 2., 3.], &Device::Cpu)?;
|
||||
let rhs = Tensor::new(&[4., 5., 6.], &Device::Cpu)?;
|
||||
let expected = Tensor::new(32., &Device::Cpu)?;
|
||||
let dot_ret = lhs.dot(&rhs)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&dot_ret, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_mv() -> Result<()> {
|
||||
let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?;
|
||||
let expected = Tensor::new(&[6., 15.], &Device::Cpu)?;
|
||||
let mv_ret = mat.mv(&vec)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&mv_ret, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/candle/issues/1948
|
||||
fn squeeze_mm(device: &Device) -> Result<()> {
|
||||
let seq_len = 8_usize;
|
||||
|
@ -25,10 +25,12 @@ fn ones(device: &Device) -> Result<()> {
|
||||
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
if !device.is_metal() {
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
}
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
|
||||
[
|
||||
@ -63,6 +65,26 @@ fn ones(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn full(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((3, 4), DType::U32, device)?;
|
||||
tensor.const_set(42u32.into())?;
|
||||
assert_eq!(
|
||||
tensor.to_vec2::<u32>()?,
|
||||
[[42, 42, 42, 42], [42, 42, 42, 42], [42, 42, 42, 42]]
|
||||
);
|
||||
tensor.i((.., 2))?.const_set(1337u32.into())?;
|
||||
assert_eq!(
|
||||
tensor.to_vec2::<u32>()?,
|
||||
[[42, 42, 1337, 42], [42, 42, 1337, 42], [42, 42, 1337, 42]]
|
||||
);
|
||||
tensor.i((2, ..))?.const_set(1u32.into())?;
|
||||
assert_eq!(
|
||||
tensor.to_vec2::<u32>()?,
|
||||
[[42, 42, 1337, 42], [42, 42, 1337, 42], [1, 1, 1, 1]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn const_set(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
|
||||
[[42, 42, 42], [42, 42, 42]],
|
||||
@ -823,9 +845,37 @@ fn embeddings(device: &Device) -> Result<()> {
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
let ids = Tensor::new(&[u32::MAX, 2u32, u32::MAX], device)?;
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_fail() -> Result<()> {
|
||||
// Check that an error is properly reported on out of bounds.
|
||||
let ids = Tensor::new(&[4u32, 2u32, 1u32], &Device::Cpu)?;
|
||||
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &Device::Cpu)?;
|
||||
let hs = t.index_select(&ids, 0);
|
||||
assert!(hs.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// The test below triggers an unwinding panic as there is a panic within the
|
||||
// #[cfg(feature = "cuda")]
|
||||
// #[test]
|
||||
// #[should_panic]
|
||||
// fn index_select_fail_gpu() {
|
||||
// // Check that a panic happens for out of bounds in cuda
|
||||
// if let Ok(device) = Device::new_cuda(0) {
|
||||
// if let Ok(ids) = Tensor::new(&[4u32, 2u32, 1u32], &device) {
|
||||
// if let Ok(t) = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &device) {
|
||||
// let _ = t.index_select(&ids, 0);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
fn cmp(device: &Device) -> Result<()> {
|
||||
let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
|
||||
let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?;
|
||||
@ -980,7 +1030,7 @@ fn slice_scatter(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scatter_add(device: &Device) -> Result<()> {
|
||||
fn scatter(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
@ -1004,6 +1054,17 @@ fn scatter_add(device: &Device) -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
let hs = init.scatter(&ids, &t, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0, 1.0, 1.0],
|
||||
[5.0, 1.0, 1.0, 3.0, 4.0],
|
||||
[1.0, 8.0, 1.0, 7.0, 1.0],
|
||||
[10.0, 1.0, 9.0, 1.0, 11.0]
|
||||
]
|
||||
);
|
||||
|
||||
let init = Tensor::ones((6, 3), DType::F32, device)?;
|
||||
let hs = init.scatter_add(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
@ -1017,6 +1078,56 @@ fn scatter_add(device: &Device) -> Result<()> {
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
let hs = init.scatter(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 10.0, 5.0],
|
||||
[1.0, 1.0, 8.0],
|
||||
[9.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 1.0],
|
||||
[1.0, 4.0, 11.0],
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
|
||||
let hs = {
|
||||
let ids = Tensor::new(
|
||||
&[
|
||||
[0u32, u32::MAX, 2],
|
||||
[3, 4, u32::MAX],
|
||||
[3, 3, 1],
|
||||
[u32::MAX, u32::MAX, 4],
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
init.scatter(&ids, &t, 0)?
|
||||
};
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 8.0],
|
||||
[1.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 1.0],
|
||||
[1.0, 4.0, 11.0],
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
|
||||
init.scatter_set(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
init.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 10.0, 5.0],
|
||||
[1.0, 1.0, 8.0],
|
||||
[9.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 1.0],
|
||||
[1.0, 4.0, 11.0],
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1050,6 +1161,23 @@ fn gather(device: &Device) -> Result<()> {
|
||||
let hs = t.gather(&ids, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
|
||||
|
||||
let hs = {
|
||||
let ids = Tensor::new(
|
||||
&[
|
||||
[0u32, 0u32],
|
||||
[2u32, u32::MAX],
|
||||
[u32::MAX, 1u32],
|
||||
[0u32, 2u32],
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
t.gather(&ids, 1)?
|
||||
};
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 0.0], [5.0, 0.0], [0.0, 7.0], [9.0, 11.0]]
|
||||
);
|
||||
|
||||
// Random data
|
||||
|
||||
// Dim: 0
|
||||
@ -1484,6 +1612,7 @@ fn zero_dim(device: &Device) -> Result<()> {
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
||||
test_device!(const_set, cs_cpu, cs_gpu, cs_metal);
|
||||
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||
@ -1515,12 +1644,7 @@ test_device!(
|
||||
);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
||||
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
||||
test_device!(
|
||||
scatter_add,
|
||||
scatter_add_cpu,
|
||||
scatter_add_gpu,
|
||||
scatter_add_metal
|
||||
);
|
||||
test_device!(scatter, scatter_cpu, scatter_gpu, scatter_metal);
|
||||
test_device!(
|
||||
slice_scatter,
|
||||
slice_scatter_cpu,
|
||||
@ -1733,3 +1857,34 @@ fn test_flip_3d_channels() -> Result<()> {
|
||||
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_new() -> Result<()> {
|
||||
let t1 = Tensor::new(vec![1f32, 2.0, 3.0], &Device::Cpu)?;
|
||||
assert_eq!(t1.to_vec1::<f32>()?, [1.0, 2.0, 3.0]);
|
||||
let t2 = Tensor::new(vec![vec![1f32, 2., 3.], vec![4., 5., 6.]], &Device::Cpu)?;
|
||||
assert_eq!(t2.to_vec2::<f32>()?, [[1., 2., 3.], [4., 5., 6.]]);
|
||||
let t3 = Tensor::new(
|
||||
vec![
|
||||
vec![vec![1f32, 2., 3.], vec![4., 5., 6.]],
|
||||
vec![vec![3f32, 1., 4.], vec![1., 5., 9.]],
|
||||
],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
assert_eq!(
|
||||
t3.to_vec3::<f32>()?,
|
||||
[
|
||||
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
|
||||
[[3.0, 1.0, 4.0], [1.0, 5.0, 9.0]]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_norm() -> Result<()> {
|
||||
let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?;
|
||||
let norm = t.norm()?;
|
||||
assert_eq!(norm.to_scalar::<f64>()?, 5.);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -16,10 +16,9 @@ fn read_u32<T: Read>(reader: &mut T) -> std::io::Result<u32> {
|
||||
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {
|
||||
let magic_number = read_u32(reader)?;
|
||||
if magic_number != expected {
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("incorrect magic number {magic_number} != {expected}"),
|
||||
))?;
|
||||
Err(io::Error::other(format!(
|
||||
"incorrect magic number {magic_number} != {expected}"
|
||||
)))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ bindgen_cuda = { version = "0.1.1", optional = true }
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
cudnn = ["candle/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
@ -69,6 +69,7 @@ metal = ["candle/metal", "candle-nn/metal"]
|
||||
microphone = ["cpal", "rubato"]
|
||||
encodec = ["cpal", "symphonia", "rubato"]
|
||||
mimi = ["cpal", "symphonia", "rubato"]
|
||||
snac = ["cpal", "symphonia", "rubato"]
|
||||
depth_anything_v2 = ["palette", "enterpolation"]
|
||||
|
||||
[[example]]
|
||||
@ -83,6 +84,10 @@ required-features = ["pyo3"]
|
||||
name = "onnx"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "onnx-llm"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "onnx_basics"
|
||||
required-features = ["onnx"]
|
||||
@ -107,6 +112,10 @@ required-features = ["candle-datasets"]
|
||||
name = "mimi"
|
||||
required-features = ["mimi"]
|
||||
|
||||
[[example]]
|
||||
name = "snac"
|
||||
required-features = ["snac"]
|
||||
|
||||
[[example]]
|
||||
name = "encodec"
|
||||
required-features = ["encodec"]
|
||||
|
14
candle-examples/examples/csm/README.md
Normal file
14
candle-examples/examples/csm/README.md
Normal file
@ -0,0 +1,14 @@
|
||||
# Conversational Speech Model (CSM)
|
||||
|
||||
CSM is a speech generation model from Sesame,
|
||||
[SesameAILabs/csm](https://github.com/SesameAILabs/csm).
|
||||
|
||||
It can generate a conversational speech between two different speakers.
|
||||
The speakers turn are delimited by the `|` character in the prompt.
|
||||
|
||||
```bash
|
||||
cargo run --example csm --features cuda -r -- \
|
||||
--voices candle-examples/examples/csm/voices.safetensors \
|
||||
--prompt "Hey how are you doing?|Pretty good, pretty good. How about you?"
|
||||
```
|
||||
|
@ -34,9 +34,18 @@ struct Args {
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long, default_value = "[0]Hey how are you doing?")]
|
||||
/// The prompt to be used for the generation, use a | to separate the speakers.
|
||||
#[arg(long, default_value = "Hey how are you doing today?")]
|
||||
prompt: String,
|
||||
|
||||
/// The voices to be used, in safetensors format.
|
||||
#[arg(long)]
|
||||
voices: String,
|
||||
|
||||
/// The output file using the wav format.
|
||||
#[arg(long, default_value = "out.wav")]
|
||||
out_file: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 0.7)]
|
||||
temperature: f64,
|
||||
@ -162,7 +171,7 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (mut model, device) = {
|
||||
let dtype = DType::F32;
|
||||
let dtype = device.bf16_default_to_f32();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
(model, device)
|
||||
@ -177,45 +186,58 @@ fn main() -> Result<()> {
|
||||
let cb = config.audio_num_codebooks;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
if args.prompt.ends_with(".safetensors") {
|
||||
let prompt = candle::safetensors::load(args.prompt, &device)?;
|
||||
let mut tokens = prompt
|
||||
.get("tokens")
|
||||
.expect("no tokens in prompt")
|
||||
.to_dtype(DType::U32)?;
|
||||
let mut mask = prompt.get("mask").expect("no mask in prompt").clone();
|
||||
println!("tokens:\n{tokens:?}");
|
||||
println!("mask:\n{mask:?}");
|
||||
let mut lp = candle_transformers::generation::LogitsProcessor::new(42, None, None);
|
||||
let mut const_mask = vec![1u8; cb];
|
||||
const_mask.push(0);
|
||||
let const_mask = Tensor::from_vec(const_mask, (1, 1, cb + 1), &device)?;
|
||||
let mut pos = 0;
|
||||
let mut all_tokens = vec![];
|
||||
for i in 0.. {
|
||||
let mut frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||
|
||||
let voices = candle::safetensors::load(args.voices, &device)?;
|
||||
let mut lp = candle_transformers::generation::LogitsProcessor::new(
|
||||
args.seed,
|
||||
Some(args.temperature),
|
||||
None,
|
||||
);
|
||||
let tokens = voices
|
||||
.get("tokens")
|
||||
.expect("no tokens in prompt")
|
||||
.to_dtype(DType::U32)?;
|
||||
let mask = voices.get("mask").expect("no mask in prompt").clone();
|
||||
|
||||
let mut pos = 0;
|
||||
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||
pos += tokens.dim(1)?;
|
||||
|
||||
let mut all_pcms = vec![];
|
||||
for (turn_idx, prompt) in args.prompt.split('|').enumerate() {
|
||||
println!("{prompt:?}");
|
||||
let speaker_idx = turn_idx % 2;
|
||||
let prompt = format!("[{speaker_idx}]{}<|end_of_text|>", prompt);
|
||||
let prompt = tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||
|
||||
let (mut tokens, mut mask) = model.text_tokens_and_mask(prompt.get_ids())?;
|
||||
|
||||
let mut generated_tokens = vec![];
|
||||
loop {
|
||||
let frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||
pos += tokens.dim(1)?;
|
||||
frame.push(0);
|
||||
if frame.iter().all(|&x| x == 0) {
|
||||
let is_done = frame.iter().all(|&x| x == 0);
|
||||
(tokens, mask) = model.audio_tokens_and_mask(frame)?;
|
||||
print!("\rframe {pos}");
|
||||
if is_done {
|
||||
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||
pos += tokens.dim(1)?;
|
||||
break;
|
||||
}
|
||||
println!("frame {i} {pos}:\n{frame:?}");
|
||||
tokens = Tensor::from_vec(frame, (1, 1, cb + 1), &device)?;
|
||||
all_tokens.push(tokens.clone());
|
||||
mask = const_mask.clone();
|
||||
generated_tokens.push(tokens.clone());
|
||||
}
|
||||
let all_tokens = Tensor::cat(&all_tokens, 1)?.narrow(2, 0, cb)?.t()?;
|
||||
println!("all_tokens:\n{all_tokens:?}");
|
||||
let pcm = mimi_model.decode(&all_tokens)?;
|
||||
println!();
|
||||
let generated_tokens = Tensor::cat(&generated_tokens, 1)?.narrow(2, 0, cb)?.t()?;
|
||||
let pcm = mimi_model.decode(&generated_tokens)?;
|
||||
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
let mut output = std::fs::File::create("out.wav")?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||
} else {
|
||||
let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?;
|
||||
println!("{prompt:?}");
|
||||
all_pcms.push(pcm);
|
||||
}
|
||||
let pcm = Tensor::cat(&all_pcms, 0)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
println!("writing output file {}", args.out_file);
|
||||
let mut output = std::fs::File::create(args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
BIN
candle-examples/examples/csm/voices.safetensors
Normal file
BIN
candle-examples/examples/csm/voices.safetensors
Normal file
Binary file not shown.
@ -68,7 +68,7 @@ impl CustomOp1 for LayerNorm {
|
||||
Some((o1, o2)) => slice.slice(o1..o2),
|
||||
};
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||
let dst = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||
let func =
|
||||
dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?;
|
||||
let cfg = LaunchConfig {
|
||||
|
@ -20,8 +20,8 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::{Encoding, PaddingParams, Tokenizer};
|
||||
|
||||
enum TaskType {
|
||||
Ner(DebertaV2NERModel),
|
||||
TextClassification(DebertaV2SeqClassificationModel),
|
||||
Ner(Box<DebertaV2NERModel>),
|
||||
TextClassification(Box<DebertaV2SeqClassificationModel>),
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone, ValueEnum)]
|
||||
@ -169,21 +169,16 @@ impl Args {
|
||||
|
||||
match self.task {
|
||||
ArgsTask::Ner => Ok((
|
||||
TaskType::Ner(DebertaV2NERModel::load(
|
||||
vb,
|
||||
&config,
|
||||
Some(id2label.clone()),
|
||||
)?),
|
||||
TaskType::Ner(DebertaV2NERModel::load(vb, &config, Some(id2label.clone()))?.into()),
|
||||
config,
|
||||
tokenizer,
|
||||
id2label,
|
||||
)),
|
||||
ArgsTask::TextClassification => Ok((
|
||||
TaskType::TextClassification(DebertaV2SeqClassificationModel::load(
|
||||
vb,
|
||||
&config,
|
||||
Some(id2label.clone()),
|
||||
)?),
|
||||
TaskType::TextClassification(
|
||||
DebertaV2SeqClassificationModel::load(vb, &config, Some(id2label.clone()))?
|
||||
.into(),
|
||||
),
|
||||
config,
|
||||
tokenizer,
|
||||
id2label,
|
||||
|
@ -8,7 +8,7 @@ DistilBert is used to compute the sentence embeddings for a prompt. The model we
|
||||
are downloaded from the hub on the first run.
|
||||
|
||||
```bash
|
||||
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
||||
$ cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
||||
|
||||
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
||||
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
||||
@ -20,3 +20,25 @@ cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
||||
> Tensor[[1, 7, 768], f32]
|
||||
|
||||
```
|
||||
|
||||
## Masked Token
|
||||
|
||||
DistilBert is used to compute the top K choices for a masked token.
|
||||
|
||||
```bash
|
||||
$ cargo run --example distilbert -- --prompt "The capital of France is [MASK]." --top-k 10
|
||||
|
||||
> Input: The capital of France is [MASK].
|
||||
> Predictions for [MASK] at position 6:
|
||||
> 1: marseille (probability: 12.14%)
|
||||
> 2: paris (probability: 10.84%)
|
||||
> 3: toulouse (probability: 8.57%)
|
||||
> 4: lyon (probability: 7.61%)
|
||||
> 5: montpellier (probability: 5.18%)
|
||||
> 6: bordeaux (probability: 4.88%)
|
||||
> 7: nantes (probability: 4.82%)
|
||||
> 8: lille (probability: 4.07%)
|
||||
> 9: strasbourg (probability: 3.12%)
|
||||
> 10: cannes (probability: 3.04%)
|
||||
|
||||
```
|
@ -3,15 +3,48 @@ extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
|
||||
use candle_transformers::models::distilbert::{
|
||||
Config, DistilBertForMaskedLM, DistilBertModel, DTYPE,
|
||||
};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use anyhow::{Context, Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::path::PathBuf;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum ModelType {
|
||||
Masked(Box<DistilBertForMaskedLM>),
|
||||
UnMasked(Box<DistilBertModel>),
|
||||
}
|
||||
|
||||
impl ModelType {
|
||||
fn device(&self) -> &Device {
|
||||
match self {
|
||||
ModelType::Masked(model) => &model.bert.device,
|
||||
ModelType::UnMasked(model) => &model.device,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
ModelType::Masked(model) => Ok(model.forward(input_ids, attention_mask)?),
|
||||
ModelType::UnMasked(model) => Ok(model.forward(input_ids, attention_mask)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "distilbert")]
|
||||
DistilBert,
|
||||
|
||||
#[value(name = "distilbertformaskedlm")]
|
||||
DistilbertForMaskedLM,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -23,10 +56,14 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long, default_value = "distilbert")]
|
||||
model: Which,
|
||||
|
||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
/// Revision or branch
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
@ -42,94 +79,248 @@ struct Args {
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
/// Number of top predictions to show for each mask
|
||||
#[arg(long, default_value = "5")]
|
||||
top_k: usize,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
|
||||
fn build_model_and_tokenizer(&self) -> Result<(ModelType, Tokenizer)> {
|
||||
let device = candle_examples::device(self.cpu)?;
|
||||
|
||||
let (model_id, revision) = self.resolve_model_and_revision();
|
||||
let (config_path, tokenizer_path, weights_path) =
|
||||
self.download_model_files(&model_id, &revision)?;
|
||||
|
||||
let config = std::fs::read_to_string(config_path)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
|
||||
|
||||
let vb = self.load_variables(&weights_path, &device)?;
|
||||
let model = self.create_model(&config, vb)?;
|
||||
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
|
||||
fn resolve_model_and_revision(&self) -> (String, String) {
|
||||
let default_model = "distilbert-base-uncased".to_string();
|
||||
let default_revision = "main".to_string();
|
||||
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
||||
|
||||
match (self.model_id.clone(), self.revision.clone()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(Some(model_id), None) => (model_id, default_revision),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config = api.get("config.json")?;
|
||||
let tokenizer = api.get("tokenizer.json")?;
|
||||
let weights = if self.use_pth {
|
||||
api.get("pytorch_model.bin")?
|
||||
} else {
|
||||
api.get("model.safetensors")?
|
||||
};
|
||||
(config, tokenizer, weights)
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
fn download_model_files(
|
||||
&self,
|
||||
model_id: &str,
|
||||
revision: &str,
|
||||
) -> Result<(PathBuf, PathBuf, PathBuf)> {
|
||||
let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string());
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
|
||||
let vb = if self.use_pth {
|
||||
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
||||
let config = api.get("config.json")?;
|
||||
let tokenizer = api.get("tokenizer.json")?;
|
||||
let weights = if self.use_pth {
|
||||
api.get("pytorch_model.bin")?
|
||||
} else {
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||
api.get("model.safetensors")?
|
||||
};
|
||||
let model = DistilBertModel::load(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
|
||||
Ok((config, tokenizer, weights))
|
||||
}
|
||||
|
||||
fn load_variables(&self, weights_path: &PathBuf, device: &Device) -> Result<VarBuilder> {
|
||||
if self.use_pth {
|
||||
Ok(VarBuilder::from_pth(weights_path, DTYPE, device)?)
|
||||
} else {
|
||||
Ok(unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, device)? })
|
||||
}
|
||||
}
|
||||
|
||||
fn create_model(&self, config: &Config, vb: VarBuilder) -> Result<ModelType> {
|
||||
match self.model {
|
||||
Which::DistilbertForMaskedLM => Ok(ModelType::Masked(
|
||||
DistilBertForMaskedLM::load(vb, config)?.into(),
|
||||
)),
|
||||
Which::DistilBert => Ok(ModelType::UnMasked(
|
||||
DistilBertModel::load(vb, config)?.into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Tensor {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size, size), device).unwrap()
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let _guard = setup_tracing(&args);
|
||||
|
||||
let (model, tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = model.device();
|
||||
|
||||
let (token_ids, mask) = prepare_inputs(&args, &tokenizer, device)?;
|
||||
let output = model.forward(&token_ids, &mask)?;
|
||||
|
||||
process_output(&model, &output, &token_ids, &tokenizer, &args)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
fn setup_tracing(args: &Args) -> Option<impl Drop> {
|
||||
if args.tracing {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = &model.device;
|
||||
}
|
||||
}
|
||||
|
||||
let tokenizer = tokenizer
|
||||
fn prepare_inputs(args: &Args, tokenizer: &Tokenizer, device: &Device) -> Result<(Tensor, Tensor)> {
|
||||
let mut binding = tokenizer.clone();
|
||||
let tokenizer_configured = binding
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
|
||||
let tokens = tokenizer_configured
|
||||
.encode(args.prompt.clone(), true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let mask = get_mask(tokens.len(), device);
|
||||
|
||||
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
|
||||
println!("mask: {:?}", mask.to_vec2::<u8>());
|
||||
let mask = match args.model {
|
||||
Which::DistilbertForMaskedLM => attention_mask_maskedlm(tokenizer, &args.prompt, device)?,
|
||||
Which::DistilBert => attention_mask(tokens.len(), device)?,
|
||||
};
|
||||
|
||||
let ys = model.forward(&token_ids, &mask)?;
|
||||
println!("{ys}");
|
||||
println!("token_ids: {:?}", token_ids.to_vec2::<u32>()?);
|
||||
|
||||
Ok((token_ids, mask))
|
||||
}
|
||||
|
||||
fn process_output(
|
||||
model: &ModelType,
|
||||
output: &Tensor,
|
||||
token_ids: &Tensor,
|
||||
tokenizer: &Tokenizer,
|
||||
args: &Args,
|
||||
) -> Result<()> {
|
||||
match model {
|
||||
ModelType::UnMasked(_) => {
|
||||
println!("embeddings");
|
||||
println!("{output}");
|
||||
}
|
||||
ModelType::Masked(_) => {
|
||||
process_masked_output(output, token_ids, tokenizer, args)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||
fn process_masked_output(
|
||||
output: &Tensor,
|
||||
token_ids: &Tensor,
|
||||
tokenizer: &Tokenizer,
|
||||
args: &Args,
|
||||
) -> Result<()> {
|
||||
let input_ids_vec = token_ids.to_vec2::<u32>()?;
|
||||
let mask_token_id = tokenizer
|
||||
.token_to_id("[MASK]")
|
||||
.context("Mask token, \"[MASK]\", not found in tokenizer.")?;
|
||||
|
||||
println!("\nInput: {}", args.prompt);
|
||||
|
||||
for (token_idx, &token_id) in input_ids_vec[0].iter().enumerate() {
|
||||
if token_id == mask_token_id {
|
||||
println!("Predictions for [MASK] at position {}:", token_idx);
|
||||
|
||||
let pos_logits = output.get(0)?.get(token_idx)?;
|
||||
let probs = candle_nn::ops::softmax(&pos_logits, 0)?;
|
||||
let (top_values, top_indices) = get_top_k(&probs, args.top_k)?;
|
||||
|
||||
let values = top_values.to_vec1::<f32>()?;
|
||||
let indices = top_indices.to_vec1::<u32>()?;
|
||||
|
||||
for (i, (&token_id, &prob)) in indices.iter().zip(values.iter()).enumerate() {
|
||||
let token = tokenizer.decode(&[token_id], false).map_err(E::msg)?;
|
||||
println!(
|
||||
" {}: {:15} (probability: {:.2}%)",
|
||||
i + 1,
|
||||
token,
|
||||
prob * 100.0
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_top_k(tensor: &Tensor, k: usize) -> Result<(Tensor, Tensor)> {
|
||||
let n = tensor.dims().iter().product::<usize>();
|
||||
let k = std::cmp::min(k, n);
|
||||
|
||||
let values = tensor.to_vec1::<f32>()?;
|
||||
let mut value_indices: Vec<(f32, usize)> = values
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, val)| (val, idx))
|
||||
.collect();
|
||||
|
||||
value_indices.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
let top_k_values: Vec<f32> = value_indices.iter().take(k).map(|(val, _)| *val).collect();
|
||||
let top_k_indices: Vec<u32> = value_indices
|
||||
.iter()
|
||||
.take(k)
|
||||
.map(|(_, idx)| *idx as u32)
|
||||
.collect();
|
||||
|
||||
let device = tensor.device();
|
||||
let top_values = Tensor::from_vec(top_k_values, (k,), device)?;
|
||||
let top_indices = Tensor::from_vec(top_k_indices, (k,), device)?;
|
||||
|
||||
Ok((top_values, top_indices))
|
||||
}
|
||||
|
||||
fn attention_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
Ok(Tensor::from_slice(&mask, (size, size), device)?)
|
||||
}
|
||||
|
||||
fn attention_mask_maskedlm(tokenizer: &Tokenizer, input: &str, device: &Device) -> Result<Tensor> {
|
||||
let tokens = tokenizer.encode(input, true).map_err(E::msg)?;
|
||||
let seq_len = tokens.get_attention_mask().to_vec().len();
|
||||
|
||||
let mask_token_id = tokenizer
|
||||
.token_to_id("[MASK]")
|
||||
.context("Mask token, \"[MASK]\", not found in tokenizer.")?;
|
||||
|
||||
let mut attention_mask_vec = Vec::with_capacity(seq_len * seq_len);
|
||||
|
||||
let ids = tokens.get_ids();
|
||||
for _ in 0..seq_len {
|
||||
for id in ids.iter() {
|
||||
let mask_value = if id == &mask_token_id { 1u8 } else { 0u8 };
|
||||
attention_mask_vec.push(mask_value);
|
||||
}
|
||||
}
|
||||
|
||||
let shape = (1, 1, seq_len, seq_len);
|
||||
let mask = Tensor::from_vec(attention_mask_vec, shape, device)?;
|
||||
|
||||
Ok(mask)
|
||||
}
|
||||
|
@ -124,6 +124,17 @@ impl TextGeneration {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
|
||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||
Some(token) => token,
|
||||
None => {
|
||||
println!(
|
||||
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
|
||||
);
|
||||
eos_token
|
||||
}
|
||||
};
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
@ -146,7 +157,7 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
@ -350,6 +361,31 @@ fn main() -> Result<()> {
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
|
||||
let prompt = match args.which {
|
||||
Which::Base2B
|
||||
| Which::Base7B
|
||||
| Which::Instruct2B
|
||||
| Which::Instruct7B
|
||||
| Which::InstructV1_1_2B
|
||||
| Which::InstructV1_1_7B
|
||||
| Which::CodeBase2B
|
||||
| Which::CodeBase7B
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B
|
||||
| Which::BaseV2_2B
|
||||
| Which::InstructV2_2B
|
||||
| Which::BaseV2_9B
|
||||
| Which::InstructV2_9B
|
||||
| Which::BaseV3_1B => args.prompt,
|
||||
Which::InstructV3_1B => {
|
||||
format!(
|
||||
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
|
||||
args.prompt
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
pipeline.run(&prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -7,7 +7,10 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::helium::{Config, Model};
|
||||
use candle_transformers::models::helium::{Config as ConfigPreview, Model as ModelPreview};
|
||||
use candle_transformers::models::llama::{
|
||||
Cache as CacheV1, Llama as ModelV1, LlamaConfig as ConfigV1, LlamaEosToks,
|
||||
};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -16,6 +19,44 @@ use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum Model {
|
||||
V1 { model: ModelV1, cache: CacheV1 },
|
||||
Preview(ModelPreview),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, input: &Tensor, start_pos: usize) -> Result<Tensor> {
|
||||
let model = match self {
|
||||
Model::V1 { model, cache } => model.forward(input, start_pos, cache)?,
|
||||
Model::Preview(m) => m.forward(input, start_pos)?,
|
||||
};
|
||||
Ok(model)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum Config {
|
||||
V1(ConfigV1),
|
||||
Preview(ConfigPreview),
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn bos_token_id(&self) -> Option<u32> {
|
||||
match self {
|
||||
Config::V1(c) => c.bos_token_id,
|
||||
Config::Preview(c) => Some(c.bos_token_id),
|
||||
}
|
||||
}
|
||||
|
||||
fn eos_token_id(&self) -> Option<LlamaEosToks> {
|
||||
match self {
|
||||
Config::V1(c) => c.eos_token_id.clone(),
|
||||
Config::Preview(c) => Some(LlamaEosToks::Single(c.eos_token_id)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
@ -46,7 +87,7 @@ impl TextGeneration {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (top_k, top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(None, None) => Sampling::GumbelSoftmax { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
@ -106,7 +147,15 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == self.config.bos_token_id || next_token == self.config.eos_token_id {
|
||||
let is_eos = self
|
||||
.config
|
||||
.eos_token_id()
|
||||
.as_ref()
|
||||
.is_some_and(|v| match v {
|
||||
LlamaEosToks::Single(eos) => *eos == next_token,
|
||||
LlamaEosToks::Multiple(eos) => eos.contains(&next_token),
|
||||
});
|
||||
if Some(next_token) == self.config.bos_token_id() || is_eos {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
@ -131,6 +180,8 @@ impl TextGeneration {
|
||||
enum Which {
|
||||
#[value(name = "v1-preview")]
|
||||
V1Preview,
|
||||
#[value(name = "v1")]
|
||||
V1,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -144,9 +195,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
@ -171,7 +219,7 @@ struct Args {
|
||||
sample_len: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "v1-preview")]
|
||||
#[arg(long, default_value = "v1")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
@ -230,6 +278,7 @@ fn main() -> Result<()> {
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::V1Preview => "kyutai/helium-1-preview-2b",
|
||||
Which::V1 => "kyutai/helium-1-2b",
|
||||
};
|
||||
name.to_string()
|
||||
}
|
||||
@ -254,18 +303,27 @@ fn main() -> Result<()> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = match args.config {
|
||||
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
|
||||
None => {
|
||||
let config_file = repo.get("config.json")?;
|
||||
serde_json::from_slice(&std::fs::read(config_file)?)?
|
||||
}
|
||||
let config_file = match args.config {
|
||||
Some(config_file) => std::path::PathBuf::from(config_file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let config = match args.which {
|
||||
Which::V1Preview => Config::Preview(serde_json::from_slice(&std::fs::read(config_file)?)?),
|
||||
Which::V1 => Config::V1(serde_json::from_slice(&std::fs::read(config_file)?)?),
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = {
|
||||
let dtype = device.bf16_default_to_f32();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
let model = match &config {
|
||||
Config::V1(c) => {
|
||||
let c = c.clone().into_config(false);
|
||||
let model = ModelV1::load(vb, &c)?;
|
||||
let cache = CacheV1::new(true, dtype, &c, &device)?;
|
||||
Model::V1 { model, cache }
|
||||
}
|
||||
Config::Preview(c) => Model::Preview(ModelPreview::new(c, vb)?),
|
||||
};
|
||||
(model, device)
|
||||
};
|
||||
|
||||
|
@ -256,6 +256,12 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let tokenizer = common_args.tokenizer()?;
|
||||
|
||||
let device = candle_examples::device(common_args.cpu)?;
|
||||
#[cfg(feature = "cuda")]
|
||||
if let candle::Device::Cuda(d) = &device {
|
||||
unsafe {
|
||||
d.disable_event_tracking();
|
||||
}
|
||||
};
|
||||
|
||||
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
|
||||
let is_safetensors = config_path
|
||||
|
@ -21,7 +21,7 @@ impl Config {
|
||||
}
|
||||
|
||||
fn dt_rank(&self) -> usize {
|
||||
(self.d_model + 15) / 16
|
||||
self.d_model.div_ceil(16)
|
||||
}
|
||||
|
||||
fn d_conv(&self) -> usize {
|
||||
|
@ -3,7 +3,7 @@
|
||||
OLMo is a series of Open Language Models designed to enable the science of language models.
|
||||
|
||||
- **Project Page:** https://allenai.org/olmo
|
||||
- **Paper:** [Link](https://arxiv.org/abs/2402.00838)
|
||||
- **Papers:** [OLMo](https://arxiv.org/abs/2402.00838) [OLMo 2](https://arxiv.org/abs/2501.00656)
|
||||
- **Technical blog post:** https://blog.allenai.org/olmo-open-language-model-87ccfc95f580
|
||||
- **W&B Logs:** https://wandb.ai/ai2-llm/OLMo-1B/reports/OLMo-1B--Vmlldzo2NzY1Njk1
|
||||
<!-- - **Press release:** TODO -->
|
||||
|
@ -8,6 +8,7 @@ use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::olmo::{Config, Model as OLMo};
|
||||
use candle_transformers::models::olmo2::{Config as Config2, Model as OLMo2};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -18,6 +19,7 @@ use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
OLMo(OLMo),
|
||||
OLMo2(OLMo2),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
@ -82,6 +84,7 @@ impl TextGeneration {
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::OLMo(m) => m.forward(&input, start_pos)?,
|
||||
Model::OLMo2(m) => m.forward(&input, start_pos)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
@ -129,6 +132,8 @@ enum Which {
|
||||
W7bTwin2T,
|
||||
#[value(name = "1.7-7b")]
|
||||
V1_7W7b,
|
||||
#[value(name = "2-1b")]
|
||||
V2W1b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -220,6 +225,7 @@ fn main() -> Result<()> {
|
||||
Which::W7b => "allenai/OLMo-7B-hf".to_string(),
|
||||
Which::W7bTwin2T => "allenai/OLMo-7B-Twin-2T-hf".to_string(),
|
||||
Which::V1_7W7b => "allenai/OLMo-1.7-7B-hf".to_string(),
|
||||
Which::V2W1b => "allenai/OLMo-2-0425-1B-Instruct".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
@ -238,33 +244,36 @@ fn main() -> Result<()> {
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.model {
|
||||
Which::W1b => {
|
||||
Which::W1b | Which::V2W1b => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
},
|
||||
};
|
||||
|
||||
let config_filename = repo.get("config.json")?;
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
config
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = {
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = OLMo::new(&config, vb)?;
|
||||
Model::OLMo(model)
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = match args.model {
|
||||
Which::W1b | Which::W7b | Which::W7bTwin2T | Which::V1_7W7b => {
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let model = OLMo::new(&config, vb)?;
|
||||
Model::OLMo(model)
|
||||
}
|
||||
Which::V2W1b => {
|
||||
let config: Config2 = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let model = OLMo2::new(&config, vb)?;
|
||||
Model::OLMo2(model)
|
||||
}
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
11
candle-examples/examples/onnx-llm/README.md
Normal file
11
candle-examples/examples/onnx-llm/README.md
Normal file
@ -0,0 +1,11 @@
|
||||
## Using ONNX models in Candle
|
||||
|
||||
This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based LLM models in Candle.
|
||||
|
||||
This script only implements SmolLM-135M right now.
|
||||
|
||||
You can run the examples with following commands:
|
||||
|
||||
```bash
|
||||
cargo run --example onnx-llm --features onnx
|
||||
```
|
209
candle-examples/examples/onnx-llm/main.rs
Normal file
209
candle-examples/examples/onnx-llm/main.rs
Normal file
@ -0,0 +1,209 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Tensor};
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::api::sync::Api;
|
||||
use serde::Deserialize;
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_attention_heads: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
SmolLM135M,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
/// The prompt to be used.
|
||||
#[arg(long, default_value = "My favorite theorem is ")]
|
||||
prompt: String,
|
||||
|
||||
/// The model to be used.
|
||||
#[arg(value_enum, long, default_value_t = Which::SmolLM135M)]
|
||||
which: Which,
|
||||
|
||||
/// Run on CPU rather than GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The number of tokens to generate.
|
||||
#[arg(long, default_value_t = 100)]
|
||||
max_tokens: usize,
|
||||
|
||||
/// The temperature used for sampling.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f32,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
}
|
||||
|
||||
pub fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let (model_id, tokenizer_id) = match args.which {
|
||||
Which::SmolLM135M => ("HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-135M"),
|
||||
};
|
||||
|
||||
let api = Api::new()?;
|
||||
let model_repo = api.model(model_id.to_string());
|
||||
let tokenizer_repo = api.model(tokenizer_id.to_string());
|
||||
|
||||
let model_path = model_repo.get("onnx/model.onnx")?;
|
||||
let config_file = model_repo.get("config.json")?;
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
|
||||
|
||||
let tokenizer_path = tokenizer_repo.get("tokenizer.json")?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
|
||||
|
||||
let tokens_u32 = tokenizer
|
||||
.encode(args.prompt.as_str(), true)
|
||||
.map_err(anyhow::Error::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let tokens: Vec<i64> = tokens_u32.iter().map(|&t| t as i64).collect();
|
||||
|
||||
println!("Loading ONNX model from {:?}", model_path);
|
||||
let model = candle_onnx::read_file(model_path)?;
|
||||
|
||||
let mut generated_tokens = tokens.clone();
|
||||
print!("{}", args.prompt);
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature as f64;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let mut past_key_values: Option<Vec<(Tensor, Tensor)>> = None;
|
||||
let num_layers = config.num_hidden_layers;
|
||||
|
||||
for _ in 0..args.max_tokens {
|
||||
let mut inputs = std::collections::HashMap::new();
|
||||
|
||||
if let Some(past_kv) = &past_key_values {
|
||||
let last_token = vec![generated_tokens[generated_tokens.len() - 1]];
|
||||
let input_tensor = Tensor::new(last_token, &device)?.unsqueeze(0)?;
|
||||
inputs.insert("input_ids".to_string(), input_tensor);
|
||||
|
||||
let seq_len = generated_tokens.len();
|
||||
let attention_mask = vec![vec![1i64; seq_len]];
|
||||
let attention_mask_tensor = Tensor::new(attention_mask, &device)?;
|
||||
inputs.insert("attention_mask".to_string(), attention_mask_tensor);
|
||||
|
||||
let position_ids = vec![vec![(seq_len - 1) as i64]];
|
||||
let position_ids_tensor = Tensor::new(position_ids, &device)?;
|
||||
inputs.insert("position_ids".to_string(), position_ids_tensor);
|
||||
|
||||
for (i, (key, value)) in past_kv.iter().enumerate() {
|
||||
inputs.insert(format!("past_key_values.{}.key", i), key.clone());
|
||||
inputs.insert(format!("past_key_values.{}.value", i), value.clone());
|
||||
}
|
||||
} else {
|
||||
let input_tensor = Tensor::new(generated_tokens.clone(), &device)?.unsqueeze(0)?;
|
||||
inputs.insert("input_ids".to_string(), input_tensor);
|
||||
|
||||
let seq_len = generated_tokens.len();
|
||||
let attention_mask = vec![vec![1i64; seq_len]];
|
||||
let attention_mask_tensor = Tensor::new(attention_mask, &device)?;
|
||||
inputs.insert("attention_mask".to_string(), attention_mask_tensor);
|
||||
|
||||
let position_ids: Vec<i64> = (0..seq_len as i64).collect();
|
||||
let position_ids_tensor = Tensor::new(position_ids, &device)?.unsqueeze(0)?;
|
||||
inputs.insert("position_ids".to_string(), position_ids_tensor);
|
||||
|
||||
// Create empty key and value tensors
|
||||
for i in 0..num_layers {
|
||||
let batch_size = 1;
|
||||
let num_heads = config.num_key_value_heads;
|
||||
let head_dim = config.hidden_size / config.num_attention_heads;
|
||||
let seq_len = 0;
|
||||
|
||||
let empty_key = Tensor::zeros(
|
||||
&[batch_size, num_heads, seq_len, head_dim],
|
||||
DType::F32,
|
||||
&device,
|
||||
)?;
|
||||
let empty_value = Tensor::zeros(
|
||||
&[batch_size, num_heads, seq_len, head_dim],
|
||||
DType::F32,
|
||||
&device,
|
||||
)?;
|
||||
|
||||
inputs.insert(format!("past_key_values.{}.key", i), empty_key);
|
||||
inputs.insert(format!("past_key_values.{}.value", i), empty_value);
|
||||
}
|
||||
}
|
||||
|
||||
let outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||
|
||||
let logits = outputs.get("logits").unwrap();
|
||||
|
||||
let mut new_past_kv = Vec::with_capacity(num_layers);
|
||||
for i in 0..num_layers {
|
||||
let key = outputs
|
||||
.get(&format!("present.{}.key", i))
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing present.{}.key", i))?;
|
||||
let value = outputs
|
||||
.get(&format!("present.{}.value", i))
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing present.{}.value", i))?;
|
||||
new_past_kv.push((key.clone(), value.clone()));
|
||||
}
|
||||
past_key_values = Some(new_past_kv);
|
||||
|
||||
let logits_dim = logits.dims();
|
||||
let seq_len = logits_dim[1];
|
||||
|
||||
let next_token_id = logits_processor.sample(&logits.get(0)?.get(seq_len - 1)?)?;
|
||||
generated_tokens.push(next_token_id as i64);
|
||||
|
||||
if let Some(token_str) = tokenizer.decode(&[next_token_id], true).ok() {
|
||||
print!("{}", token_str);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
if let Some(eos_id) = tokenizer.token_to_id("<|endoftext|>") {
|
||||
if next_token_id == eos_id {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("\nGeneration complete!");
|
||||
Ok(())
|
||||
}
|
@ -5,12 +5,14 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{IndexOp, D};
|
||||
use candle_examples::save_image;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
SqueezeNet,
|
||||
EfficientNet,
|
||||
EsrGan,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
@ -28,10 +30,21 @@ struct Args {
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
let image = match args.which {
|
||||
Which::SqueezeNet | Which::EfficientNet => {
|
||||
candle_examples::imagenet::load_image224(&args.image)?
|
||||
}
|
||||
Which::EsrGan => candle_examples::imagenet::load_image_with_std_mean(
|
||||
&args.image,
|
||||
128,
|
||||
&[0.0f32, 0.0, 0.0],
|
||||
&[1.0f32, 1.0, 1.0],
|
||||
)?,
|
||||
};
|
||||
let image = match args.which {
|
||||
Which::SqueezeNet => image,
|
||||
Which::EfficientNet => image.permute((1, 2, 0))?,
|
||||
Which::EsrGan => image,
|
||||
};
|
||||
|
||||
println!("loaded image {image:?}");
|
||||
@ -45,6 +58,9 @@ pub fn main() -> anyhow::Result<()> {
|
||||
Which::EfficientNet => hf_hub::api::sync::Api::new()?
|
||||
.model("onnx/EfficientNet-Lite4".into())
|
||||
.get("efficientnet-lite4-11.onnx")?,
|
||||
Which::EsrGan => hf_hub::api::sync::Api::new()?
|
||||
.model("qualcomm/Real-ESRGAN-x4plus".into())
|
||||
.get("Real-ESRGAN-x4plus.onnx")?,
|
||||
},
|
||||
};
|
||||
|
||||
@ -57,21 +73,40 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let prs = match args.which {
|
||||
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
|
||||
Which::EfficientNet => output,
|
||||
Which::EsrGan => output,
|
||||
};
|
||||
let prs = prs.i(0)?.to_vec1::<f32>()?;
|
||||
|
||||
// Sort the predictions and take the top 5
|
||||
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
||||
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||
let top = top.into_iter().take(5).collect::<Vec<_>>();
|
||||
match args.which {
|
||||
Which::EfficientNet | Which::SqueezeNet => {
|
||||
let prs = prs.i(0)?.to_vec1::<f32>()?;
|
||||
|
||||
// Print the top predictions
|
||||
for &(i, p) in &top {
|
||||
println!(
|
||||
"{:50}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[i],
|
||||
p * 100.0
|
||||
);
|
||||
// Sort the predictions and take the top 5
|
||||
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
||||
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||
let top = top.into_iter().take(5).collect::<Vec<_>>();
|
||||
|
||||
// Print the top predictions
|
||||
for &(i, p) in &top {
|
||||
println!(
|
||||
"{:50}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[i],
|
||||
p * 100.0
|
||||
);
|
||||
}
|
||||
}
|
||||
Which::EsrGan => {
|
||||
let max_pixel_val = candle::Tensor::try_from(255.0f32)?
|
||||
.to_device(prs.device())?
|
||||
.broadcast_as(prs.shape())?;
|
||||
let out = (prs * max_pixel_val)?.i(0)?.to_dtype(candle::DType::U8)?;
|
||||
|
||||
let pb = std::path::PathBuf::from(args.image);
|
||||
let input_file_name = pb.file_name().unwrap();
|
||||
let mut output_file_name = std::ffi::OsString::from("super_");
|
||||
output_file_name.push(input_file_name);
|
||||
|
||||
save_image(&out, output_file_name)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
14
candle-examples/examples/orpheus/README.md
Normal file
14
candle-examples/examples/orpheus/README.md
Normal file
@ -0,0 +1,14 @@
|
||||
# Orpheus
|
||||
|
||||
Orpheus is a 3B text-to-speech model based on Llama.
|
||||
|
||||
- Weights on HuggingFace
|
||||
[canopylabs/orpheus-3b-0.1-ft](https://huggingface.co/canopylabs/orpheus-3b-0.1-ft).
|
||||
- Code on GitHub [canopyai/Orpheus-TTS](https://github.com/canopyai/Orpheus-TTS).
|
||||
|
||||
|
||||
```bash
|
||||
cargo run --example orpheus --features cuda -r
|
||||
```
|
||||
|
||||
|
329
candle-examples/examples/orpheus/main.rs
Normal file
329
candle-examples/examples/orpheus/main.rs
Normal file
@ -0,0 +1,329 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::llama::{Cache, Llama, LlamaConfig};
|
||||
use candle_transformers::models::snac::{Config as SnacConfig, Model as SnacModel};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/realtime_streaming_example/main.py#L43
|
||||
const STOP_TOKEN_ID: u32 = 128258;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
#[arg(long, default_value = "Hey, how are you doing today?")]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 0.6)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
model_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
/// The output wav file.
|
||||
#[arg(long, default_value = "out.wav")]
|
||||
out_file: String,
|
||||
|
||||
#[arg(long, default_value = "3b-0.1-ft")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long, default_value = "tara")]
|
||||
voice: Voice,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Voice {
|
||||
#[value(name = "tara")]
|
||||
Tara,
|
||||
#[value(name = "leah")]
|
||||
Leah,
|
||||
#[value(name = "jess")]
|
||||
Jess,
|
||||
#[value(name = "leo")]
|
||||
Leo,
|
||||
#[value(name = "dan")]
|
||||
Dan,
|
||||
#[value(name = "mia")]
|
||||
Mia,
|
||||
#[value(name = "zac")]
|
||||
Zac,
|
||||
#[value(name = "zoe")]
|
||||
Zoe,
|
||||
}
|
||||
|
||||
impl Voice {
|
||||
fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Voice::Tara => "tara",
|
||||
Voice::Leah => "leah",
|
||||
Voice::Jess => "jess",
|
||||
Voice::Leo => "leo",
|
||||
Voice::Dan => "dan",
|
||||
Voice::Mia => "mia",
|
||||
Voice::Zac => "zac",
|
||||
Voice::Zoe => "zoe",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "3b-0.1-ft")]
|
||||
ThreeB0_1Ft,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
let prompt = args.prompt.clone();
|
||||
let mut model = Model::load(args)?;
|
||||
model.run(&prompt)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct Model {
|
||||
model: Llama,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: candle_transformers::generation::LogitsProcessor,
|
||||
cache: Cache,
|
||||
device: Device,
|
||||
verbose_prompt: bool,
|
||||
snac: SnacModel,
|
||||
out_file: String,
|
||||
voice: Voice,
|
||||
}
|
||||
|
||||
fn load_snac(device: &Device) -> Result<SnacModel> {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let m = api.model("hubertsiuzdak/snac_24khz".to_string());
|
||||
let config = m.get("config.json")?;
|
||||
let config: SnacConfig = serde_json::from_reader(std::fs::File::open(config)?)?;
|
||||
let m = api.model("lmz/candle-snac".to_string());
|
||||
let model = m.get("snac_24khz.safetensors")?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, device)? };
|
||||
let model = SnacModel::new(&config, vb)?;
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn load(args: Args) -> Result<Self> {
|
||||
let start = std::time::Instant::now();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => match args.which {
|
||||
Which::ThreeB0_1Ft => "canopylabs/orpheus-3b-0.1-ft".to_string(),
|
||||
},
|
||||
};
|
||||
let revision = match args.revision {
|
||||
Some(r) => r,
|
||||
None => "main".to_string(),
|
||||
};
|
||||
let repo = api.repo(hf_hub::Repo::with_revision(
|
||||
model_id,
|
||||
hf_hub::RepoType::Model,
|
||||
revision,
|
||||
));
|
||||
let model_files = match args.model_file {
|
||||
Some(m) => vec![m.into()],
|
||||
None => match args.which {
|
||||
Which::ThreeB0_1Ft => {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
},
|
||||
};
|
||||
let config = match args.config_file {
|
||||
Some(m) => m.into(),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let tokenizer = match args.tokenizer_file {
|
||||
Some(m) => m.into(),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = device.bf16_default_to_f32();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, dtype, &device)? };
|
||||
let config: LlamaConfig = serde_json::from_reader(std::fs::File::open(config)?)?;
|
||||
let config = config.into_config(args.use_flash_attn);
|
||||
let model = Llama::load(vb, &config)?;
|
||||
let logits_processor = {
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
let temperature = args.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k.as_ref(), args.top_p.as_ref()) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(&k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(&p)) => Sampling::TopP { p, temperature },
|
||||
(Some(&k), Some(&p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
let cache = Cache::new(true, dtype, &config, &device)?;
|
||||
let snac = load_snac(&device)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
cache,
|
||||
device,
|
||||
verbose_prompt: args.verbose_prompt,
|
||||
snac,
|
||||
voice: args.voice,
|
||||
out_file: args.out_file,
|
||||
})
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str) -> Result<()> {
|
||||
println!("running the model on '{}'", prompt);
|
||||
let device = &self.device;
|
||||
let prompt = format!("{voice}: {prompt}", voice = self.voice.as_str());
|
||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/engine_class.py#L82
|
||||
let mut tokens = [
|
||||
&[128259],
|
||||
tokens.get_ids(),
|
||||
&[128009, 128260, 128261, 128257],
|
||||
]
|
||||
.concat();
|
||||
if self.verbose_prompt {
|
||||
println!("{:?}", tokens);
|
||||
}
|
||||
let mut cache = self.cache.clone();
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut index_pos = 0;
|
||||
let mut audio_tokens = vec![];
|
||||
for index in 0..2000 {
|
||||
let (context_size, context_index) = if index > 0 {
|
||||
(1, index_pos)
|
||||
} else {
|
||||
(tokens.len(), 0)
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, context_index, &mut cache)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
if let Some(tok) = self.tokenizer.id_to_token(next_token) {
|
||||
match tok.strip_prefix("<custom_token_") {
|
||||
Some(tok) => match tok.strip_suffix('>') {
|
||||
Some(tok) => {
|
||||
let tok = tok.parse::<u32>()?;
|
||||
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/decoder.py#L86C35-L86C63
|
||||
let tok = tok - 10 - ((audio_tokens.len() as u32 % 7) * 4096);
|
||||
audio_tokens.push(tok);
|
||||
}
|
||||
None => {
|
||||
println!("{index}: unexpected custom token {next_token} {tok}");
|
||||
}
|
||||
},
|
||||
None => {
|
||||
println!("{index}: unexpected token {next_token} {tok}");
|
||||
}
|
||||
}
|
||||
}
|
||||
if next_token == STOP_TOKEN_ID {
|
||||
println!("reached stop token");
|
||||
break;
|
||||
}
|
||||
tokens.push(next_token);
|
||||
}
|
||||
println!("generated {} audio tokens", audio_tokens.len());
|
||||
let mut codes0 = vec![];
|
||||
let mut codes1 = vec![];
|
||||
let mut codes2 = vec![];
|
||||
for audio_tokens in audio_tokens.chunks_exact(7) {
|
||||
codes0.push(audio_tokens[0]);
|
||||
for i in [1, 4] {
|
||||
codes1.push(audio_tokens[i]);
|
||||
}
|
||||
for i in [2, 3, 5, 6] {
|
||||
codes2.push(audio_tokens[i]);
|
||||
}
|
||||
}
|
||||
let codes0 = Tensor::new(codes0, device)?.unsqueeze(0)?;
|
||||
let codes1 = Tensor::new(codes1, device)?.unsqueeze(0)?;
|
||||
let codes2 = Tensor::new(codes2, device)?.unsqueeze(0)?;
|
||||
let pcm = self.snac.decode(&[&codes0, &codes1, &codes2])?;
|
||||
println!("decoded to pcm {pcm:?}");
|
||||
let mut output = std::fs::File::create(&self.out_file)?;
|
||||
let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24000)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -147,9 +147,9 @@ enum WhichModel {
|
||||
V3,
|
||||
#[value(name = "3-medium")]
|
||||
V3Medium,
|
||||
#[value(name = "2-old")]
|
||||
V4Mini,
|
||||
#[value(name = "4-mini")]
|
||||
V4Mini,
|
||||
#[value(name = "2-old")]
|
||||
V2Old,
|
||||
PuffinPhiV2,
|
||||
PhiHermes,
|
||||
|
18
candle-examples/examples/quantized-gemma/README.md
Normal file
18
candle-examples/examples/quantized-gemma/README.md
Normal file
@ -0,0 +1,18 @@
|
||||
# candle-quantized-gemma
|
||||
|
||||
Candle implementation of quantized Gemma.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example quantized-gemma -- --prompt "Write a function to calculate fibonacci numbers. "
|
||||
|
||||
> ```python
|
||||
> def fibonacci(n):
|
||||
> """Calculates the nth Fibonacci number using recursion."""
|
||||
> if n <= 1:
|
||||
> return n
|
||||
> else:
|
||||
> return fibonacci(n-1) + fibonacci(n-2
|
||||
> ```
|
||||
```
|
344
candle-examples/examples/quantized-gemma/main.rs
Normal file
344
candle-examples/examples/quantized-gemma/main.rs
Normal file
@ -0,0 +1,344 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::Tensor;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_transformers::models::quantized_gemma3::ModelWeights;
|
||||
|
||||
const DEFAULT_PROMPT: &str = "Write a function to calculate fibonacci num";
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "gemma3-4b-it")]
|
||||
Gemma3_4bIt,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// GGUF file to load, typically a .gguf file generated by quantization
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
|
||||
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
|
||||
/// is preserved.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The tokenizer config in json format.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples, use 0 for greedy sampling.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Process prompt elements separately.
|
||||
#[arg(long)]
|
||||
split_prompt: bool,
|
||||
|
||||
/// Run on CPU rather than GPU even if a GPU is available.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "gemma3-4b-it")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
|
||||
let tokenizer_path = match &self.tokenizer {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = "google/gemma-3-4b-it";
|
||||
println!("DEBUG: Downloading tokenizer from {}", repo);
|
||||
let api = api.model(repo.to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
println!("DEBUG: Loading tokenizer from {:?}", tokenizer_path);
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
|
||||
|
||||
Ok(tokenizer)
|
||||
}
|
||||
|
||||
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
|
||||
let model_path = match &self.model {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let (repo, filename) = match self.which {
|
||||
Which::Gemma3_4bIt => (
|
||||
"google/gemma-3-4b-it-qat-q4_0-gguf",
|
||||
"gemma-3-4b-it-q4_0.gguf",
|
||||
),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
api.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"main".to_string(),
|
||||
))
|
||||
.get(filename)?
|
||||
}
|
||||
};
|
||||
Ok(model_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn format_size(size_in_bytes: usize) -> String {
|
||||
if size_in_bytes < 1_000 {
|
||||
format!("{}B", size_in_bytes)
|
||||
} else if size_in_bytes < 1_000_000 {
|
||||
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
|
||||
} else if size_in_bytes < 1_000_000_000 {
|
||||
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
|
||||
} else {
|
||||
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Prompt {
|
||||
Interactive,
|
||||
Chat,
|
||||
One(String),
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||
);
|
||||
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let mut model = {
|
||||
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(&model_path))?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensor_infos.iter() {
|
||||
let elem_count = tensor.shape.elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
model.tensor_infos.len(),
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
ModelWeights::from_gguf(model, &mut file, &device)?
|
||||
};
|
||||
println!("model built");
|
||||
|
||||
let tokenizer = args.tokenizer()?;
|
||||
|
||||
let mut tos = TokenOutputStream::new(tokenizer);
|
||||
println!(
|
||||
"DEBUG: Tokenizer vocabulary size: {}",
|
||||
tos.tokenizer().get_vocab(true).len()
|
||||
);
|
||||
|
||||
let prompt = match args.prompt.as_deref() {
|
||||
Some("chat") => Prompt::Chat,
|
||||
Some("interactive") => Prompt::Interactive,
|
||||
Some(s) => Prompt::One(s.to_string()),
|
||||
None => Prompt::One(DEFAULT_PROMPT.to_string()),
|
||||
};
|
||||
|
||||
let mut pre_prompt_tokens = vec![];
|
||||
for _ in 0.. {
|
||||
let prompt_str = match &prompt {
|
||||
Prompt::One(prompt) => prompt.clone(),
|
||||
Prompt::Interactive | Prompt::Chat => {
|
||||
print!("> ");
|
||||
std::io::stdout().flush()?;
|
||||
let mut prompt = String::new();
|
||||
std::io::stdin().read_line(&mut prompt)?;
|
||||
if prompt.ends_with('\n') {
|
||||
prompt.pop();
|
||||
if prompt.ends_with('\r') {
|
||||
prompt.pop();
|
||||
}
|
||||
}
|
||||
// Format for Gemma 3 chat/instruction format
|
||||
format!("<start_of_turn> user\n{prompt}<end_of_turn>\n<start_of_turn> model\n")
|
||||
}
|
||||
};
|
||||
print!("{}", &prompt_str);
|
||||
|
||||
let tokens = tos
|
||||
.tokenizer()
|
||||
.encode(prompt_str, true)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat();
|
||||
|
||||
let to_sample = args.sample_len.saturating_sub(1);
|
||||
let max_seq_len = 8192; // Gemma 3 context length
|
||||
let prompt_tokens = if prompt_tokens.len() + to_sample > max_seq_len - 10 {
|
||||
let to_remove = prompt_tokens.len() + to_sample + 10 - max_seq_len;
|
||||
prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec()
|
||||
} else {
|
||||
prompt_tokens
|
||||
};
|
||||
let mut all_tokens = vec![];
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = if !args.split_prompt {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
} else {
|
||||
let mut next_token = 0;
|
||||
for (pos, token) in prompt_tokens.iter().enumerate() {
|
||||
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
next_token = logits_processor.sample(&logits)?
|
||||
}
|
||||
next_token
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
// For Gemma 3, use the correct end of sequence token
|
||||
let eos_token = *tos
|
||||
.tokenizer()
|
||||
.get_vocab(true)
|
||||
.get("<end_of_turn>")
|
||||
.unwrap();
|
||||
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&all_tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
sampled += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
};
|
||||
}
|
||||
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||
prompt_tokens.len(),
|
||||
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||
);
|
||||
println!(
|
||||
"{sampled:4} tokens generated: {:.2} token/s",
|
||||
sampled as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
|
||||
match prompt {
|
||||
Prompt::One(_) => break,
|
||||
Prompt::Interactive => {}
|
||||
Prompt::Chat => {
|
||||
pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -8,4 +8,8 @@
|
||||
cargo run --example quantized-qwen2-instruct --release -- --prompt "Write a function to count prime numbers up to N."
|
||||
```
|
||||
|
||||
0.5b, 1.5b, 7b and 72b models are available via `--model` argument.
|
||||
0.5b, 1.5b, 7b and 72b models are available via `--which` argument.
|
||||
|
||||
```bash
|
||||
cargo run --release --example quantized-qwen2-instruct -- --which 0.5b --prompt "Write a function to count prime numbers up to N."
|
||||
```
|
||||
|
17
candle-examples/examples/quantized-qwen3/README.md
Normal file
17
candle-examples/examples/quantized-qwen3/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
# candle-quantized-qwen3
|
||||
|
||||
[Qwen3]((https://qwenlm.github.io/blog/qwen3/)) is an upgraded version of Qwen2.5, released by Alibaba Cloud.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
cargo run --example quantized-qwen3 --release -- --prompt "Write a function to count prime numbers up to N."
|
||||
```
|
||||
|
||||
|
||||
0.6b is used by default, 1.7b, 4b, 8b, 14b, and 32b models are available via `--which` argument.
|
||||
|
||||
```bash
|
||||
cargo run --example quantized-qwen3 --release -- --which 4b --prompt "A train is travelling at 120mph, how far does it travel in 3 minutes 30 seconds?"
|
||||
```
|
||||
|
314
candle-examples/examples/quantized-qwen3/main.rs
Normal file
314
candle-examples/examples/quantized-qwen3/main.rs
Normal file
@ -0,0 +1,314 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::Tensor;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_transformers::models::quantized_qwen3::ModelWeights as Qwen3;
|
||||
|
||||
const DEFAULT_PROMPT: &str = "Write a Rust function to calculate the factorial of a given number.";
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "0.6b")]
|
||||
W3_0_6b,
|
||||
#[value(name = "1.7b")]
|
||||
W3_1_7b,
|
||||
#[value(name = "4b")]
|
||||
W3_4b,
|
||||
#[value(name = "8b")]
|
||||
W3_8b,
|
||||
#[value(name = "14b")]
|
||||
W3_14b,
|
||||
#[value(name = "32b")]
|
||||
W3_32b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
|
||||
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
|
||||
/// is preserved.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The tokenizer config in json format.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples, use 0 for greedy sampling.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Process prompt elements separately.
|
||||
#[arg(long)]
|
||||
split_prompt: bool,
|
||||
|
||||
/// Run on CPU rather than GPU even if a GPU is available.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "0.6b")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
|
||||
let tokenizer_path = match &self.tokenizer {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = match self.which {
|
||||
Which::W3_0_6b => "Qwen/Qwen3-0.6B",
|
||||
Which::W3_1_7b => "Qwen/Qwen3-1.7B",
|
||||
Which::W3_4b => "Qwen/Qwen3-4B",
|
||||
Which::W3_8b => "Qwen/Qwen3-8B",
|
||||
Which::W3_14b => "Qwen/Qwen3-14B",
|
||||
Which::W3_32b => "Qwen/Qwen3-32B",
|
||||
};
|
||||
let api = api.model(repo.to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
|
||||
}
|
||||
|
||||
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
|
||||
let model_path = match &self.model {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let (repo, filename, revision) = match self.which {
|
||||
Which::W3_0_6b => ("unsloth/Qwen3-0.6B-GGUF", "Qwen3-0.6B-Q4_K_M.gguf", "main"),
|
||||
Which::W3_1_7b => ("unsloth/Qwen3-1.7B-GGUF", "Qwen3-1.7B-Q4_K_M.gguf", "main"),
|
||||
Which::W3_4b => ("unsloth/Qwen3-4B-GGUF", "Qwen3-4B-Q4_K_M.gguf", "main"),
|
||||
Which::W3_8b => ("unsloth/Qwen3-8B-GGUF", "Qwen3-8B-Q4_K_M.gguf", "main"),
|
||||
Which::W3_14b => ("unsloth/Qwen3-14B-GGUF", "Qwen3-14B-Q4_K_M.gguf", "main"),
|
||||
Which::W3_32b => ("unsloth/Qwen3-32B-GGUF", "Qwen3-32B-Q4_K_M.gguf", "main"),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
api.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
revision.to_string(),
|
||||
))
|
||||
.get(filename)?
|
||||
}
|
||||
};
|
||||
Ok(model_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn format_size(size_in_bytes: usize) -> String {
|
||||
if size_in_bytes < 1_000 {
|
||||
format!("{}B", size_in_bytes)
|
||||
} else if size_in_bytes < 1_000_000 {
|
||||
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
|
||||
} else if size_in_bytes < 1_000_000_000 {
|
||||
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
|
||||
} else {
|
||||
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||
);
|
||||
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let mut model = {
|
||||
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensor_infos.iter() {
|
||||
let elem_count = tensor.shape.elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
model.tensor_infos.len(),
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
Qwen3::from_gguf(model, &mut file, &device)?
|
||||
};
|
||||
println!("model built");
|
||||
|
||||
let tokenizer = args.tokenizer()?;
|
||||
let mut tos = TokenOutputStream::new(tokenizer);
|
||||
let prompt_str = args
|
||||
.prompt
|
||||
.clone()
|
||||
.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
||||
|
||||
let prompt_str = format!("<|im_start|>user\n{prompt_str}<|im_end|>\n<|im_start|>assistant\n");
|
||||
print!("formatted prompt: {}", &prompt_str);
|
||||
|
||||
let tokens = tos
|
||||
.tokenizer()
|
||||
.encode(prompt_str, true)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
|
||||
let tokens = tokens.get_ids();
|
||||
|
||||
let to_sample = args.sample_len.saturating_sub(1);
|
||||
|
||||
let mut all_tokens = vec![];
|
||||
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
|
||||
let mut next_token = if !args.split_prompt {
|
||||
let input = Tensor::new(tokens, &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
} else {
|
||||
let mut next_token = 0;
|
||||
for (pos, token) in tokens.iter().enumerate() {
|
||||
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
next_token = logits_processor.sample(&logits)?
|
||||
}
|
||||
next_token
|
||||
};
|
||||
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
|
||||
all_tokens.push(next_token);
|
||||
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap();
|
||||
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&all_tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
sampled += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
|
||||
std::io::stdout().flush()?;
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||
tokens.len(),
|
||||
tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||
);
|
||||
println!(
|
||||
"{sampled:4} tokens generated: {:.2} token/s",
|
||||
sampled as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
@ -25,3 +25,28 @@ def print_prime(n: int): # n is the number of primes to be printed
|
||||
print(i)
|
||||
```
|
||||
|
||||
The qwen3 MoE variant is also an option.
|
||||
|
||||
```bash
|
||||
$ cargo run --example qwen --features metal --release -- --prompt "Write a poem about butterflies. <think></think>." --model "3-moe-a3b"
|
||||
> In morning's hush, where daisies sleep,
|
||||
> A fleeting dance through sunlit deep—
|
||||
> They flutter soft on gossamer thread,
|
||||
> The messengers of spring’s own head.
|
||||
>
|
||||
> With painted sails and delicate grace,
|
||||
> They drift from bloom to blossom's face.
|
||||
> Each wing a tale in hues unseen,
|
||||
> Of ancient dreams and secrets between.
|
||||
>
|
||||
> No sound they make, yet still they speak—
|
||||
> Of time that flies, of life so brief.
|
||||
> A fleeting kiss on summer’s breath,
|
||||
> A whisper lost before death.
|
||||
>
|
||||
> Yet in their flight, the soul takes wing,
|
||||
> And for a moment, all is spring.
|
||||
> For though they fade, they never die—
|
||||
> Their beauty lives where hearts can fly.
|
||||
> 161 tokens generated (3.00 token/s)
|
||||
```
|
||||
|
@ -9,6 +9,8 @@ use clap::Parser;
|
||||
|
||||
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
|
||||
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
|
||||
use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3};
|
||||
use candle_transformers::models::qwen3_moe::{Config as ConfigMoe3, ModelForCausalLM as ModelMoe3};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -20,6 +22,8 @@ use tokenizers::Tokenizer;
|
||||
enum Model {
|
||||
Base(ModelBase),
|
||||
Moe(ModelMoe),
|
||||
Base3(Model3),
|
||||
Moe3(ModelMoe3),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -27,6 +31,8 @@ impl Model {
|
||||
match self {
|
||||
Self::Moe(ref mut m) => m.forward(xs, s),
|
||||
Self::Base(ref mut m) => m.forward(xs, s),
|
||||
Self::Base3(ref mut m) => m.forward(xs, s),
|
||||
Self::Moe3(ref mut m) => m.forward(xs, s),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -85,6 +91,10 @@ impl TextGeneration {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||
};
|
||||
let eos_token2 = match self.tokenizer.get_token("<|im_end|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|im_end|> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
@ -107,7 +117,7 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
if next_token == eos_token || next_token == eos_token2 {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
@ -152,6 +162,16 @@ enum WhichModel {
|
||||
W2_7b,
|
||||
#[value(name = "2-72b")]
|
||||
W2_72b,
|
||||
#[value(name = "3-0.6b")]
|
||||
W3_0_6b,
|
||||
#[value(name = "3-1.7b")]
|
||||
W3_1_7b,
|
||||
#[value(name = "3-4b")]
|
||||
W3_4b,
|
||||
#[value(name = "3-8b")]
|
||||
W3_8b,
|
||||
#[value(name = "3-moe-a3b")]
|
||||
W3MoeA3b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -254,6 +274,11 @@ fn main() -> Result<()> {
|
||||
WhichModel::W14b => ("1.5", "14B"),
|
||||
WhichModel::W72b => ("1.5", "72B"),
|
||||
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
|
||||
WhichModel::W3_0_6b => ("3", "0.6B"),
|
||||
WhichModel::W3_1_7b => ("3", "1.7B"),
|
||||
WhichModel::W3_4b => ("3", "4B"),
|
||||
WhichModel::W3_8b => ("3", "8B"),
|
||||
WhichModel::W3MoeA3b => ("3", "30B-A3B"),
|
||||
};
|
||||
format!("Qwen/Qwen{version}-{size}")
|
||||
}
|
||||
@ -273,7 +298,11 @@ fn main() -> Result<()> {
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.model {
|
||||
WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
|
||||
WhichModel::W0_5b
|
||||
| WhichModel::W2_0_5b
|
||||
| WhichModel::W2_1_5b
|
||||
| WhichModel::W1_8b
|
||||
| WhichModel::W3_0_6b => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
WhichModel::W4b
|
||||
@ -282,7 +311,11 @@ fn main() -> Result<()> {
|
||||
| WhichModel::W14b
|
||||
| WhichModel::W72b
|
||||
| WhichModel::W2_72b
|
||||
| WhichModel::MoeA27b => {
|
||||
| WhichModel::MoeA27b
|
||||
| WhichModel::W3_1_7b
|
||||
| WhichModel::W3_4b
|
||||
| WhichModel::W3_8b
|
||||
| WhichModel::W3MoeA3b => {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
},
|
||||
@ -304,6 +337,14 @@ fn main() -> Result<()> {
|
||||
let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Moe(ModelMoe::new(&config, vb)?)
|
||||
}
|
||||
WhichModel::W3_0_6b | WhichModel::W3_1_7b | WhichModel::W3_4b | WhichModel::W3_8b => {
|
||||
let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Base3(Model3::new(&config, vb)?)
|
||||
}
|
||||
WhichModel::W3MoeA3b => {
|
||||
let config: ConfigMoe3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Moe3(ModelMoe3::new(&config, vb)?)
|
||||
}
|
||||
_ => {
|
||||
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Base(ModelBase::new(&config, vb)?)
|
||||
|
275
candle-examples/examples/snac/audio_io.rs
Normal file
275
candle-examples/examples/snac/audio_io.rs
Normal file
@ -0,0 +1,275 @@
|
||||
use anyhow::{Context, Result};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub const SAMPLE_RATE: usize = 24_000;
|
||||
|
||||
pub(crate) struct AudioOutputData_ {
|
||||
resampled_data: std::collections::VecDeque<f32>,
|
||||
resampler: rubato::FastFixedIn<f32>,
|
||||
output_buffer: Vec<f32>,
|
||||
input_buffer: Vec<f32>,
|
||||
input_len: usize,
|
||||
}
|
||||
|
||||
impl AudioOutputData_ {
|
||||
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
|
||||
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
|
||||
let resampler = rubato::FastFixedIn::new(
|
||||
resample_ratio,
|
||||
f64::max(resample_ratio, 1.0),
|
||||
rubato::PolynomialDegree::Septic,
|
||||
1024,
|
||||
1,
|
||||
)?;
|
||||
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
|
||||
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
|
||||
Ok(Self {
|
||||
resampled_data,
|
||||
resampler,
|
||||
input_buffer,
|
||||
output_buffer,
|
||||
input_len: 0,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
use rubato::Resampler;
|
||||
self.output_buffer.fill(0.);
|
||||
self.input_buffer.fill(0.);
|
||||
self.resampler.reset();
|
||||
self.resampled_data.clear();
|
||||
}
|
||||
|
||||
pub(crate) fn take_all(&mut self) -> Vec<f32> {
|
||||
let mut data = Vec::with_capacity(self.resampled_data.len());
|
||||
while let Some(elem) = self.resampled_data.pop_back() {
|
||||
data.push(elem);
|
||||
}
|
||||
data
|
||||
}
|
||||
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
self.resampled_data.is_empty()
|
||||
}
|
||||
|
||||
// Assumes that the input buffer is large enough.
|
||||
fn push_input_buffer(&mut self, samples: &[f32]) {
|
||||
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
|
||||
self.input_len += samples.len()
|
||||
}
|
||||
|
||||
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let mut pos_in = 0;
|
||||
loop {
|
||||
let rem = self.input_buffer.len() - self.input_len;
|
||||
let pos_end = usize::min(pos_in + rem, samples.len());
|
||||
self.push_input_buffer(&samples[pos_in..pos_end]);
|
||||
pos_in = pos_end;
|
||||
if self.input_len < self.input_buffer.len() {
|
||||
break;
|
||||
}
|
||||
let (_, out_len) = self.resampler.process_into_buffer(
|
||||
&[&self.input_buffer],
|
||||
&mut [&mut self.output_buffer],
|
||||
None,
|
||||
)?;
|
||||
for &elem in self.output_buffer[..out_len].iter() {
|
||||
self.resampled_data.push_front(elem)
|
||||
}
|
||||
self.input_len = 0;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
|
||||
|
||||
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
|
||||
println!("Setup audio output stream!");
|
||||
let host = cpal::default_host();
|
||||
let device = host
|
||||
.default_output_device()
|
||||
.context("no output device available")?;
|
||||
let mut supported_configs_range = device.supported_output_configs()?;
|
||||
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
|
||||
// On macOS, it's commonly the case that there are only stereo outputs.
|
||||
None => device
|
||||
.supported_output_configs()?
|
||||
.next()
|
||||
.context("no audio output available")?,
|
||||
Some(config_range) => config_range,
|
||||
};
|
||||
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
||||
config_range.min_sample_rate(),
|
||||
config_range.max_sample_rate(),
|
||||
);
|
||||
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
||||
let channels = config.channels as usize;
|
||||
println!(
|
||||
"cpal device: {} {} {config:?}",
|
||||
device.name().unwrap_or_else(|_| "unk".to_string()),
|
||||
config.sample_rate.0
|
||||
);
|
||||
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
||||
SAMPLE_RATE,
|
||||
config.sample_rate.0 as usize,
|
||||
)?));
|
||||
let ad = audio_data.clone();
|
||||
let stream = device.build_output_stream(
|
||||
&config,
|
||||
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
|
||||
data.fill(0.);
|
||||
let mut ad = ad.lock().unwrap();
|
||||
let mut last_elem = 0f32;
|
||||
for (idx, elem) in data.iter_mut().enumerate() {
|
||||
if idx % channels == 0 {
|
||||
match ad.resampled_data.pop_back() {
|
||||
None => break,
|
||||
Some(v) => {
|
||||
last_elem = v;
|
||||
*elem = v
|
||||
}
|
||||
}
|
||||
} else {
|
||||
*elem = last_elem
|
||||
}
|
||||
}
|
||||
},
|
||||
move |err| eprintln!("cpal error: {err}"),
|
||||
None, // None=blocking, Some(Duration)=timeout
|
||||
)?;
|
||||
stream.play()?;
|
||||
Ok((stream, audio_data))
|
||||
}
|
||||
|
||||
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
|
||||
println!("Setup audio input stream!");
|
||||
let host = cpal::default_host();
|
||||
let device = host
|
||||
.default_input_device()
|
||||
.context("no input device available")?;
|
||||
let mut supported_configs_range = device.supported_input_configs()?;
|
||||
let config_range = supported_configs_range
|
||||
.find(|c| c.channels() == 1)
|
||||
.context("no audio input available")?;
|
||||
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
||||
config_range.min_sample_rate(),
|
||||
config_range.max_sample_rate(),
|
||||
);
|
||||
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
||||
println!(
|
||||
"cpal device: {} {} {config:?}",
|
||||
device.name().unwrap_or_else(|_| "unk".to_string()),
|
||||
config.sample_rate.0
|
||||
);
|
||||
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
||||
config.sample_rate.0 as usize,
|
||||
SAMPLE_RATE,
|
||||
)?));
|
||||
let ad = audio_data.clone();
|
||||
let stream = device.build_input_stream(
|
||||
&config,
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
let mut ad = ad.lock().unwrap();
|
||||
if let Err(err) = ad.push_samples(data) {
|
||||
eprintln!("error processing audio input {err:?}")
|
||||
}
|
||||
},
|
||||
move |err| eprintln!("cpal error: {err}"),
|
||||
None, // None=blocking, Some(Duration)=timeout
|
||||
)?;
|
||||
stream.play()?;
|
||||
Ok((stream, audio_data))
|
||||
}
|
||||
|
||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||
where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
use symphonia::core::audio::Signal;
|
||||
use symphonia::core::conv::FromSample;
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
|
||||
let src = std::fs::File::open(path)?;
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||
let mut format = probed.format;
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||
.expect("no supported audio tracks");
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &Default::default())
|
||||
.expect("unsupported codec");
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
||||
|
||||
pub(crate) fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result<Vec<f32>> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let mut pcm_out =
|
||||
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
|
||||
|
||||
let mut resampler =
|
||||
rubato::FftFixedInOut::<f32>::new(sr_in as usize, sr_out as usize, 1024, 1)?;
|
||||
let mut output_buffer = resampler.output_buffer_allocate(true);
|
||||
let mut pos_in = 0;
|
||||
while pos_in + resampler.input_frames_next() < pcm_in.len() {
|
||||
let (in_len, out_len) =
|
||||
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
|
||||
pos_in += in_len;
|
||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||
}
|
||||
|
||||
if pos_in < pcm_in.len() {
|
||||
let (_in_len, out_len) = resampler.process_partial_into_buffer(
|
||||
Some(&[&pcm_in[pos_in..]]),
|
||||
&mut output_buffer,
|
||||
None,
|
||||
)?;
|
||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||
}
|
||||
|
||||
Ok(pcm_out)
|
||||
}
|
197
candle-examples/examples/snac/main.rs
Normal file
197
candle-examples/examples/snac/main.rs
Normal file
@ -0,0 +1,197 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::snac::{Config, Model};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::api::sync::Api;
|
||||
|
||||
mod audio_io;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Action {
|
||||
AudioToAudio,
|
||||
AudioToCode,
|
||||
CodeToAudio,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "24khz")]
|
||||
S24khz,
|
||||
#[value(name = "32khz")]
|
||||
S32khz,
|
||||
#[value(name = "44khz")]
|
||||
S44khz,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn sample_rate(&self) -> u32 {
|
||||
match self {
|
||||
Which::S24khz => 24000,
|
||||
Which::S32khz => 32000,
|
||||
Which::S44khz => 44000,
|
||||
}
|
||||
}
|
||||
|
||||
fn config_repo(&self) -> &'static str {
|
||||
match self {
|
||||
Which::S24khz => "hubertsiuzdak/snac_24khz",
|
||||
Which::S32khz => "hubertsiuzdak/snac_32khz",
|
||||
Which::S44khz => "hubertsiuzdak/snac_44khz",
|
||||
}
|
||||
}
|
||||
|
||||
fn model_file(&self) -> &'static str {
|
||||
match self {
|
||||
Which::S24khz => "snac_24khz.safetensors",
|
||||
Which::S32khz => "snac_32khz.safetensors",
|
||||
Which::S44khz => "snac_44khz.safetensors",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The action to be performed, specifies the format for the input and output data.
|
||||
action: Action,
|
||||
|
||||
/// The input file, either an audio file or some snac tokens stored as safetensors.
|
||||
in_file: String,
|
||||
|
||||
/// The output file, either a wave audio file or some snac tokens stored as safetensors.
|
||||
out_file: String,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "24khz")]
|
||||
which: Which,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The model weight file, in safetensor format.
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// The config file, in safetensor format.
|
||||
#[arg(long)]
|
||||
config: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model_sample_rate = args.which.sample_rate();
|
||||
let config = match args.config {
|
||||
Some(c) => std::path::PathBuf::from(c),
|
||||
None => Api::new()?
|
||||
.model(args.which.config_repo().to_string())
|
||||
.get("config.json")?,
|
||||
};
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config)?)?;
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => Api::new()?
|
||||
.model("lmz/candle-snac".to_string())
|
||||
.get(args.which.model_file())?,
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
let codes = match args.action {
|
||||
Action::CodeToAudio => {
|
||||
let codes = candle::safetensors::load(args.in_file, &device)?;
|
||||
let num_codebooks = model.num_codebooks();
|
||||
(0..num_codebooks)
|
||||
.map(|i| {
|
||||
codes
|
||||
.get(&format!("codes-{i}"))
|
||||
.expect("no codes in input file")
|
||||
.clone()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
Action::AudioToCode | Action::AudioToAudio => {
|
||||
let pcm = if args.in_file == "-" {
|
||||
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
|
||||
let (stream, input_audio) = audio_io::setup_input_stream()?;
|
||||
let mut pcms = vec![];
|
||||
let stdin = std::thread::spawn(|| {
|
||||
let mut s = String::new();
|
||||
std::io::stdin().read_line(&mut s)
|
||||
});
|
||||
while !stdin.is_finished() {
|
||||
let input = input_audio.lock().unwrap().take_all();
|
||||
if input.is_empty() {
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
continue;
|
||||
}
|
||||
pcms.push(input)
|
||||
}
|
||||
drop(stream);
|
||||
pcms.concat()
|
||||
} else {
|
||||
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
|
||||
if sample_rate != model_sample_rate {
|
||||
println!("WARNING: snac uses a {model_sample_rate} sample rate, input uses {sample_rate}, resampling...");
|
||||
audio_io::resample(&pcm, sample_rate, model_sample_rate)?
|
||||
} else {
|
||||
pcm
|
||||
}
|
||||
};
|
||||
let pcm_len = pcm.len();
|
||||
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||
println!("input pcm shape: {:?}", pcm.shape());
|
||||
model.encode(&pcm)?
|
||||
}
|
||||
};
|
||||
for codes in codes.iter() {
|
||||
println!("codes shape: {:?}", codes.shape());
|
||||
}
|
||||
|
||||
match args.action {
|
||||
Action::AudioToCode => {
|
||||
let mut tensors = std::collections::HashMap::new();
|
||||
for (i, codes) in codes.iter().enumerate() {
|
||||
tensors.insert(format!("codes-{i}"), codes.clone());
|
||||
}
|
||||
candle::safetensors::save(&tensors, "codes.safetensors")?;
|
||||
}
|
||||
Action::AudioToAudio | Action::CodeToAudio => {
|
||||
let codes = codes.iter().collect::<Vec<_>>();
|
||||
let pcm = model.decode(&codes)?;
|
||||
println!("output pcm shape: {:?}", pcm.shape());
|
||||
let pcm = pcm.i(0)?.i(0)?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, model_sample_rate, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
if args.out_file == "-" {
|
||||
let (stream, ad) = audio_io::setup_output_stream()?;
|
||||
{
|
||||
let mut ad = ad.lock().unwrap();
|
||||
ad.push_samples(&pcm)?;
|
||||
}
|
||||
loop {
|
||||
let ad = ad.lock().unwrap();
|
||||
if ad.is_empty() {
|
||||
break;
|
||||
}
|
||||
// That's very weird, calling thread::sleep here triggers the stream to stop
|
||||
// playing (the callback doesn't seem to be called anymore).
|
||||
// std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
}
|
||||
drop(stream)
|
||||
} else {
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, model_sample_rate)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -28,3 +28,26 @@ Ranking Results:
|
||||
> Rank #1 | Score: 0.9990 | The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.
|
||||
--------------------------------------------------------------------------------
|
||||
```
|
||||
|
||||
Text-Classification:
|
||||
```bash
|
||||
cargo run --example xlm-roberta -- --task text-classification --model xlmr-formality-classifier
|
||||
```
|
||||
```markdown
|
||||
Formality Scores:
|
||||
Text 1: "I like you. I love you"
|
||||
formal: 0.9933
|
||||
informal: 0.0067
|
||||
|
||||
Text 2: "Hey, what's up?"
|
||||
formal: 0.8812
|
||||
informal: 0.1188
|
||||
|
||||
Text 3: "Siema, co porabiasz?"
|
||||
formal: 0.9358
|
||||
informal: 0.0642
|
||||
|
||||
Text 4: "I feel deep regret and sadness about the situation in international politics."
|
||||
formal: 0.9987
|
||||
informal: 0.0013
|
||||
```
|
@ -2,6 +2,7 @@ use std::path::PathBuf;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::ops::softmax;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::xlm_roberta::{
|
||||
Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification,
|
||||
@ -17,12 +18,14 @@ enum Model {
|
||||
BgeRerankerBaseV2,
|
||||
XLMRobertaBase,
|
||||
XLMRobertaLarge,
|
||||
XLMRFormalityClassifier,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, ValueEnum)]
|
||||
enum Task {
|
||||
FillMask,
|
||||
Reranker,
|
||||
TextClassification,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -83,6 +86,12 @@ fn main() -> Result<()> {
|
||||
Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(),
|
||||
_ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"),
|
||||
},
|
||||
Task::TextClassification => match args.model {
|
||||
Model::XLMRFormalityClassifier => "s-nlp/xlmr_formality_classifier".to_string(),
|
||||
_ => anyhow::bail!(
|
||||
"XLM-RoBERTa models are not supported for text classification task"
|
||||
),
|
||||
},
|
||||
},
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
@ -217,6 +226,36 @@ fn main() -> Result<()> {
|
||||
});
|
||||
println!("{:-<80}", "");
|
||||
}
|
||||
Task::TextClassification => {
|
||||
let sentences = vec![
|
||||
"I like you. I love you".to_string(),
|
||||
"Hey, what's up?".to_string(),
|
||||
"Siema, co porabiasz?".to_string(),
|
||||
"I feel deep regret and sadness about the situation in international politics."
|
||||
.to_string(),
|
||||
];
|
||||
let model = XLMRobertaForSequenceClassification::new(2, &config, vb)?;
|
||||
let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Single(&sentences), &device)?;
|
||||
|
||||
let attention_mask =
|
||||
get_attention_mask(&tokenizer, TokenizeInput::Single(&sentences), &device)?;
|
||||
let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
|
||||
|
||||
let logits = model
|
||||
.forward(&input_ids, &attention_mask, &token_type_ids)?
|
||||
.to_dtype(candle::DType::F32)?;
|
||||
|
||||
let probabilities = softmax(&logits, 1)?;
|
||||
let probs_vec = probabilities.to_vec2::<f32>()?;
|
||||
|
||||
println!("Formality Scores:");
|
||||
for (i, (text, probs)) in sentences.iter().zip(probs_vec.iter()).enumerate() {
|
||||
println!("Text {}: \"{}\"", i + 1, text);
|
||||
println!(" formal: {:.4}", probs[0]);
|
||||
println!(" informal: {:.4}", probs[1]);
|
||||
println!();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -133,6 +133,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
||||
padding,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv = if bias {
|
||||
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
|
||||
|
@ -92,6 +92,7 @@ impl ConvBlock {
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.9.0-alpha.1"
|
||||
version = "0.9.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,14 +11,17 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.1" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.1" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
bindgen_cuda = "0.1.1"
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
candle-nn = { path = "../candle-nn", features = ["cuda"] }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cudnn = ["candle/cudnn"]
|
||||
|
@ -2,7 +2,6 @@ mod ffi;
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||
use candle::cuda_backend::WrapErr;
|
||||
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
|
||||
use half::{bf16, f16};
|
||||
|
||||
@ -142,10 +141,8 @@ impl FlashAttn {
|
||||
let seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
|
||||
let elem_count = out_shape.elem_count();
|
||||
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
||||
let softmax_lse = dev
|
||||
.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
|
||||
.w()?;
|
||||
let dst = unsafe { dev.alloc::<T>(elem_count)? };
|
||||
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)?;
|
||||
|
||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||
|
||||
@ -607,8 +604,8 @@ impl FlashAttnVarLen {
|
||||
let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128);
|
||||
|
||||
let elem_count = out_shape.elem_count();
|
||||
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q).w()?;
|
||||
let dst = unsafe { dev.alloc::<f16>(elem_count)? };
|
||||
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q)?;
|
||||
|
||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.9.0-alpha.1"
|
||||
version = "0.9.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -53,7 +53,7 @@ __device__ void conv1d(
|
||||
|
||||
template <typename T>
|
||||
__device__ void im2col1d(
|
||||
const size_t dst_numel,
|
||||
const size_t numel,
|
||||
const size_t l_out,
|
||||
const size_t l_k,
|
||||
const size_t stride,
|
||||
@ -63,10 +63,10 @@ __device__ void im2col1d(
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t thread_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// dst: (b_size, l_out, c_in, l_k)
|
||||
// src: (b_size, c_in, l_in)
|
||||
if (dst_i >= dst_numel) {
|
||||
if (thread_i >= numel) {
|
||||
return;
|
||||
}
|
||||
const size_t *src_dims = info;
|
||||
@ -74,26 +74,26 @@ __device__ void im2col1d(
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t l_in = src_dims[2];
|
||||
|
||||
const size_t dst_s2 = l_k;
|
||||
const size_t dst_s1 = c_in * dst_s2;
|
||||
const size_t dst_s1 = c_in;
|
||||
const size_t dst_s0 = l_out * dst_s1;
|
||||
|
||||
size_t tmp_dst_i = dst_i;
|
||||
size_t tmp_dst_i = thread_i;
|
||||
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||
tmp_dst_i -= b_idx * dst_s0;
|
||||
const size_t l_idx = tmp_dst_i / dst_s1;
|
||||
tmp_dst_i -= l_idx * dst_s1;
|
||||
const size_t c_idx = tmp_dst_i / dst_s2;
|
||||
tmp_dst_i -= c_idx * dst_s2;
|
||||
const size_t l_k_idx = tmp_dst_i;
|
||||
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
|
||||
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
|
||||
dst[dst_i] = static_cast<T>(0);
|
||||
}
|
||||
else {
|
||||
src_l_idx -= padding;
|
||||
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
|
||||
dst[dst_i] = src[src_i];
|
||||
const size_t c_idx = tmp_dst_i;
|
||||
for (size_t l_k_idx = 0; l_k_idx < l_k; ++l_k_idx) {
|
||||
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
|
||||
size_t dst_i = thread_i * l_k + l_k_idx;
|
||||
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
|
||||
dst[dst_i] = static_cast<T>(0);
|
||||
}
|
||||
else {
|
||||
src_l_idx -= padding;
|
||||
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
|
||||
dst[dst_i] = src[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include<stdint.h>
|
||||
#include "cuda_fp16.h"
|
||||
#include "cuda_utils.cuh"
|
||||
|
||||
template<typename T>
|
||||
__device__ void fill_with(T *buf, T value, const size_t numel) {
|
||||
@ -36,13 +37,45 @@ COPY2D_OP(uint8_t, copy2d_u8)
|
||||
COPY2D_OP(uint32_t, copy2d_u32)
|
||||
COPY2D_OP(int64_t, copy2d_i64)
|
||||
|
||||
#define CONST_SET_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const TYPENAME inp, \
|
||||
TYPENAME *out \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
out[i] = inp; \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
out[strided_i] = inp; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
CONST_SET_OP(float, const_set_f32)
|
||||
CONST_SET_OP(double, const_set_f64)
|
||||
CONST_SET_OP(uint8_t, const_set_u8)
|
||||
CONST_SET_OP(uint32_t, const_set_u32)
|
||||
CONST_SET_OP(int64_t, const_set_i64)
|
||||
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
COPY2D_OP(__half, copy2d_f16)
|
||||
CONST_SET_OP(__half, const_set_f16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
COPY2D_OP(__nv_bfloat16, copy2d_bf16)
|
||||
CONST_SET_OP(__nv_bfloat16, const_set_bf16)
|
||||
#endif
|
||||
|
@ -3,6 +3,28 @@
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__
|
||||
constexpr T max_value();
|
||||
|
||||
template <>
|
||||
__host__ __device__
|
||||
constexpr int64_t max_value<int64_t>() {
|
||||
return 0x7FFFFFFFFFFFFFFFLL;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__
|
||||
constexpr uint32_t max_value<uint32_t>() {
|
||||
return 0xFFFFFFFFu;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__
|
||||
constexpr uint8_t max_value<uint8_t>() {
|
||||
return 0xFFu;
|
||||
}
|
||||
|
||||
template<typename T, typename I>
|
||||
__device__ void index_select(
|
||||
const size_t numel,
|
||||
@ -23,9 +45,14 @@ __device__ void index_select(
|
||||
unsigned int left_i = dst_i / (ids_dim_size * right_size);
|
||||
unsigned int id_i = dst_i / right_size % ids_dim_size;
|
||||
unsigned int right_i = dst_i % right_size;
|
||||
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
|
||||
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
|
||||
out[dst_i] = inp[strided_i];
|
||||
if (ids[id_i] == max_value<I>()) {
|
||||
out[dst_i] = static_cast<T>(0);
|
||||
} else {
|
||||
assert(ids[id_i] < src_dim_size);
|
||||
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
|
||||
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
|
||||
out[dst_i] = inp[strided_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -56,10 +83,15 @@ __device__ void gather(
|
||||
) {
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
size_t post = i % right_size;
|
||||
size_t idx = ids[i];
|
||||
size_t pre = i / (right_size * ids_dim_size);
|
||||
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
|
||||
out[i] = inp[src_i];
|
||||
const I idx = ids[i];
|
||||
if (ids[i] == max_value<I>()) {
|
||||
out[i] = static_cast<T>(0);
|
||||
} else {
|
||||
assert(idx < src_dim_size);
|
||||
size_t pre = i / (right_size * ids_dim_size);
|
||||
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
|
||||
out[i] = inp[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,10 +123,13 @@ __device__ void index_add(
|
||||
const size_t pre = i / right_size;
|
||||
const size_t post = i % right_size;
|
||||
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
||||
const size_t idx = ids[j];
|
||||
const I idx = ids[j];
|
||||
const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
if (idx < max_value<I>()) {
|
||||
assert(idx < dst_dim_size);
|
||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -111,6 +146,32 @@ extern "C" __global__ void FN_NAME( \
|
||||
const size_t right_size \
|
||||
) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
|
||||
|
||||
template<typename T, typename I>
|
||||
__device__ void scatter(
|
||||
const I *ids,
|
||||
const T *inp,
|
||||
T *out,
|
||||
const size_t left_size,
|
||||
const size_t src_dim_size,
|
||||
const size_t dst_dim_size,
|
||||
const size_t right_size
|
||||
) {
|
||||
const size_t numel = left_size * right_size;
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
const size_t pre = i / right_size;
|
||||
const size_t post = i % right_size;
|
||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
||||
const I idx = ids[src_i];
|
||||
if (idx < max_value<I>()) {
|
||||
assert(idx < dst_dim_size);
|
||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] = inp[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename I>
|
||||
__device__ void scatter_add(
|
||||
const I *ids,
|
||||
@ -127,13 +188,27 @@ __device__ void scatter_add(
|
||||
const size_t post = i % right_size;
|
||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
||||
const size_t idx = ids[src_i];
|
||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
const I idx = ids[src_i];
|
||||
if (idx < max_value<I>()) {
|
||||
assert(idx < dst_dim_size);
|
||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define S_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const INDEX_TYPENAME *ids, \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out, \
|
||||
const size_t left_size, \
|
||||
const size_t src_dim_size, \
|
||||
const size_t dst_dim_size, \
|
||||
const size_t right_size \
|
||||
) { scatter(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
|
||||
|
||||
#define SA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const INDEX_TYPENAME *ids, \
|
||||
@ -159,6 +234,9 @@ IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16)
|
||||
SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16)
|
||||
SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16)
|
||||
SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
|
||||
S_OP(__nv_bfloat16, int64_t, s_i64_bf16)
|
||||
S_OP(__nv_bfloat16, uint32_t, s_u32_bf16)
|
||||
S_OP(__nv_bfloat16, uint8_t, s_u8_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -174,6 +252,9 @@ IA_OP(__half, uint8_t, ia_u8_f16)
|
||||
SA_OP(__half, int64_t, sa_i64_f16)
|
||||
SA_OP(__half, uint32_t, sa_u32_f16)
|
||||
SA_OP(__half, uint8_t, sa_u8_f16)
|
||||
S_OP(__half, int64_t, s_i64_f16)
|
||||
S_OP(__half, uint32_t, s_u32_f16)
|
||||
S_OP(__half, uint8_t, s_u8_f16)
|
||||
#endif
|
||||
|
||||
IS_OP(float, int64_t, is_i64_f32)
|
||||
@ -247,3 +328,21 @@ SA_OP(double, uint8_t, sa_u8_f64)
|
||||
SA_OP(uint8_t, uint8_t, sa_u8_u8)
|
||||
SA_OP(uint32_t, uint8_t, sa_u8_u32)
|
||||
SA_OP(int64_t, uint8_t, sa_u8_i64)
|
||||
|
||||
S_OP(float, int64_t, s_i64_f32)
|
||||
S_OP(double, int64_t, s_i64_f64)
|
||||
S_OP(uint8_t, int64_t, s_i64_u8)
|
||||
S_OP(int64_t, int64_t, s_i64_i64)
|
||||
S_OP(uint32_t, int64_t, s_i64_u32)
|
||||
|
||||
S_OP(float, uint32_t, s_u32_f32)
|
||||
S_OP(double, uint32_t, s_u32_f64)
|
||||
S_OP(uint8_t, uint32_t, s_u32_u8)
|
||||
S_OP(int64_t, uint32_t, s_u32_i64)
|
||||
S_OP(uint32_t, uint32_t, s_u32_u32)
|
||||
|
||||
S_OP(float, uint8_t, s_u8_f32)
|
||||
S_OP(double, uint8_t, s_u8_f64)
|
||||
S_OP(uint8_t, uint8_t, s_u8_u8)
|
||||
S_OP(uint32_t, uint8_t, s_u8_u32)
|
||||
S_OP(int64_t, uint8_t, s_u8_i64)
|
||||
|
@ -219,11 +219,15 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
|
||||
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t stride_b) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= bh * td) return;
|
||||
|
||||
uint32_t rope_idx = idx % (td / 2);
|
||||
if (stride_b > 0) {
|
||||
uint32_t b_idx = (2 * idx) / stride_b;
|
||||
rope_idx += b_idx * (td / 2);
|
||||
}
|
||||
T c = cos[rope_idx];
|
||||
T s = sin[rope_idx];
|
||||
|
||||
@ -232,7 +236,7 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
|
||||
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d, const uint32_t stride_b) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= bh * td) return;
|
||||
|
||||
@ -243,6 +247,10 @@ __device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const
|
||||
uint32_t i1 = i_bh * td + i_t * d + i_d;
|
||||
uint32_t i2 = i1 + d / 2;
|
||||
uint32_t i_cs = i_t * (d / 2) + i_d;
|
||||
if (stride_b > 0) {
|
||||
uint32_t b_idx = (2 * idx) / stride_b;
|
||||
i_cs += b_idx * (td / 2);
|
||||
}
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
|
||||
@ -259,7 +267,8 @@ __device__ void rope_thd(
|
||||
const uint32_t b,
|
||||
const uint32_t t,
|
||||
const uint32_t h,
|
||||
const uint32_t d
|
||||
const uint32_t d,
|
||||
const uint32_t stride_b
|
||||
) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= b * t * h * d) return;
|
||||
@ -270,6 +279,10 @@ __device__ void rope_thd(
|
||||
uint32_t i1 = i_bth * d + i_d;
|
||||
uint32_t i2 = i1 + d / 2;
|
||||
uint32_t i_cs = i_t * (d / 2) + i_d;
|
||||
if (stride_b > 0) {
|
||||
uint32_t b_idx = (2 * idx) / stride_b;
|
||||
i_cs += b_idx * ((t * d) / 2);
|
||||
}
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
|
||||
@ -546,8 +559,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
const TYPENAME *sin, \
|
||||
TYPENAME *dst, \
|
||||
const uint32_t bh, \
|
||||
const uint32_t td) { \
|
||||
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
|
||||
const uint32_t td, \
|
||||
const uint32_t stride_b) { \
|
||||
ropei<TYPENAME>(src, cos, sin, dst, bh, td, stride_b); \
|
||||
} \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *src, \
|
||||
@ -556,8 +570,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
TYPENAME *dst, \
|
||||
const uint32_t bh, \
|
||||
const uint32_t td, \
|
||||
const uint32_t d) { \
|
||||
rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \
|
||||
const uint32_t d, \
|
||||
const uint32_t stride_b) { \
|
||||
rope<TYPENAME>(src, cos, sin, dst, bh, td, d, stride_b); \
|
||||
} \
|
||||
extern "C" __global__ void FN_NAME_THD( \
|
||||
const TYPENAME *src, \
|
||||
@ -567,8 +582,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
const uint32_t b, \
|
||||
const uint32_t t, \
|
||||
const uint32_t h, \
|
||||
const uint32_t d) { \
|
||||
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
|
||||
const uint32_t d, \
|
||||
const uint32_t stride_b) { \
|
||||
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d, stride_b); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.9.0-alpha.1"
|
||||
version = "0.9.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
@ -12,6 +12,7 @@ license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
metal = { version = "0.27.0", features = ["mps"] }
|
||||
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
@ -4,20 +4,20 @@ using namespace metal;
|
||||
|
||||
template<typename T> METAL_FUNC void fill_with(
|
||||
device T *out,
|
||||
constant float &value,
|
||||
constant T &value,
|
||||
constant size_t &numel,
|
||||
uint tid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (tid >= numel) {
|
||||
return;
|
||||
}
|
||||
out[tid] = static_cast<T>(value);
|
||||
out[tid] = value;
|
||||
}
|
||||
|
||||
#define FILL_OP(NAME, T) \
|
||||
kernel void fill_##NAME( \
|
||||
device T *out, \
|
||||
constant float &value, \
|
||||
constant T &value, \
|
||||
constant size_t &numel, \
|
||||
uint tid [[thread_position_in_grid]] \
|
||||
) { \
|
||||
|
@ -1,6 +1,24 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
template <typename T>
|
||||
inline T max_value();
|
||||
|
||||
template <>
|
||||
inline int64_t max_value<int64_t>() {
|
||||
return 0x7FFFFFFFFFFFFFFF;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint32_t max_value<uint32_t>() {
|
||||
return 0xFFFFFFFFu;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint8_t max_value<uint8_t>() {
|
||||
return 0xFF;
|
||||
}
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
@ -35,17 +53,21 @@ METAL_FUNC void index(
|
||||
return;
|
||||
}
|
||||
const size_t id_i = (tid / right_size) % ids_size;
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size / ids_size;
|
||||
/*
|
||||
// Force prevent out of bounds indexing
|
||||
// since there doesn't seem to be a good way to force crash
|
||||
// No need to check for zero we're only allowing unsized.
|
||||
*/
|
||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
|
||||
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
|
||||
output[tid] = input[strided_src_i];
|
||||
if (input_ids[id_i] == max_value<INDEX_TYPENAME>()) {
|
||||
output[tid] = static_cast<TYPENAME>(0);
|
||||
} else {
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size / ids_size;
|
||||
/*
|
||||
// Force prevent out of bounds indexing
|
||||
// since there doesn't seem to be a good way to force crash
|
||||
// No need to check for zero we're only allowing unsized.
|
||||
*/
|
||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
|
||||
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
|
||||
output[tid] = input[strided_src_i];
|
||||
}
|
||||
}
|
||||
|
||||
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
@ -83,10 +105,14 @@ METAL_FUNC void gather(
|
||||
return;
|
||||
}
|
||||
const INDEX_TYPENAME input_i = input_ids[tid];
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size / ids_size;
|
||||
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
|
||||
output[tid] = input[src_i];
|
||||
if (input_i == max_value<INDEX_TYPENAME>()) {
|
||||
output[tid] = static_cast<TYPENAME>(0);
|
||||
} else {
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size / ids_size;
|
||||
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
|
||||
output[tid] = input[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
@ -104,6 +130,33 @@ kernel void NAME( \
|
||||
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
|
||||
}
|
||||
|
||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||
METAL_FUNC void scatter(
|
||||
constant size_t &dst_size,
|
||||
constant size_t &left_size,
|
||||
constant size_t &src_dim_size,
|
||||
constant size_t &right_size,
|
||||
constant size_t &dst_dim_size,
|
||||
const device TYPENAME *input,
|
||||
const device INDEX_TYPENAME *input_ids,
|
||||
device TYPENAME *output,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
if (tid >= dst_size) {
|
||||
return;
|
||||
}
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size;
|
||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||
const INDEX_TYPENAME idx = input_ids[src_i];
|
||||
if (idx < max_value<INDEX_TYPENAME>()) {
|
||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||
output[dst_i] = input[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||
METAL_FUNC void scatter_add(
|
||||
constant size_t &dst_size,
|
||||
@ -124,11 +177,28 @@ METAL_FUNC void scatter_add(
|
||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||
const INDEX_TYPENAME idx = input_ids[src_i];
|
||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||
output[dst_i] += input[src_i];
|
||||
if (idx < max_value<INDEX_TYPENAME>()) {
|
||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||
output[dst_i] += input[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# define SCATTER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
constant size_t &left_size, \
|
||||
constant size_t &src_dim_size, \
|
||||
constant size_t &right_size, \
|
||||
constant size_t &dst_dim_size, \
|
||||
const device TYPENAME *input, \
|
||||
const device INDEX_TYPENAME *input_ids, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
scatter<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
|
||||
}
|
||||
|
||||
# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
@ -164,9 +234,11 @@ METAL_FUNC void index_add(
|
||||
const size_t left_rank_i = tid / right_size;
|
||||
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
||||
const INDEX_TYPENAME idx = input_ids[j];
|
||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||
output[dst_i] += input[src_i];
|
||||
if (idx < max_value<INDEX_TYPENAME>()) {
|
||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||
output[dst_i] += input[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -235,6 +307,19 @@ SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
|
||||
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
SCATTER_OP(s_u32_f32, uint32_t, float)
|
||||
SCATTER_OP(s_u8_f32, uint8_t, float)
|
||||
SCATTER_OP(s_i64_f32, int64_t, float)
|
||||
SCATTER_OP(s_u32_u32, uint32_t, uint32_t)
|
||||
SCATTER_OP(s_u32_f16, uint32_t, half)
|
||||
SCATTER_OP(s_u8_f16, uint8_t, half)
|
||||
SCATTER_OP(s_i64_f16, int64_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
SCATTER_OP(s_u32_bf16, uint32_t, bfloat)
|
||||
SCATTER_OP(s_u8_bf16, uint8_t, bfloat)
|
||||
SCATTER_OP(s_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
// i64
|
||||
INDEX_ADD_OP(ia_i64_f16, int64_t, half)
|
||||
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
|
||||
|
@ -161,7 +161,7 @@ macro_rules! ops{
|
||||
pub mod unary {
|
||||
ops!(
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||
tanh, recip, silu, sign, sigmoid
|
||||
tanh, recip, silu, sign, sigmoid, const_set
|
||||
);
|
||||
}
|
||||
pub mod binary {
|
||||
@ -419,6 +419,82 @@ pub fn call_copy2d(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_const_set_contiguous_tiled(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous_tiled::Kernel,
|
||||
length: usize,
|
||||
input: impl EncoderParam,
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let tile_size = 2;
|
||||
let tiles = length.div_ceil(tile_size);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, &output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_const_set_contiguous(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous::Kernel,
|
||||
length: usize,
|
||||
input: impl EncoderParam,
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, &output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_const_set_strided(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
name: unary::strided::Kernel,
|
||||
shape: &[usize],
|
||||
input: impl EncoderParam,
|
||||
strides: &[usize],
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(encoder, (length, num_dims, shape, strides, input, &output));
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_unary_contiguous_tiled(
|
||||
device: &Device,
|
||||
@ -915,6 +991,7 @@ pub fn call_rope_i(
|
||||
kernel_name: &'static str,
|
||||
bh: usize,
|
||||
td: usize,
|
||||
stride_b: usize,
|
||||
src: &Buffer,
|
||||
src_offset: usize,
|
||||
cos: &Buffer,
|
||||
@ -933,6 +1010,7 @@ pub fn call_rope_i(
|
||||
(
|
||||
bh,
|
||||
td,
|
||||
stride_b,
|
||||
(src, src_offset),
|
||||
(cos, cos_offset),
|
||||
(sin, sin_offset),
|
||||
@ -958,6 +1036,7 @@ pub fn call_rope_thd(
|
||||
t: usize,
|
||||
h: usize,
|
||||
d: usize,
|
||||
stride_b: usize,
|
||||
src: &Buffer,
|
||||
src_offset: usize,
|
||||
cos: &Buffer,
|
||||
@ -978,6 +1057,7 @@ pub fn call_rope_thd(
|
||||
t,
|
||||
h,
|
||||
d,
|
||||
stride_b,
|
||||
(src, src_offset),
|
||||
(cos, cos_offset),
|
||||
(sin, sin_offset),
|
||||
@ -1002,6 +1082,7 @@ pub fn call_rope(
|
||||
bh: usize,
|
||||
td: usize,
|
||||
d: usize,
|
||||
stride_b: usize,
|
||||
src: &Buffer,
|
||||
src_offset: usize,
|
||||
cos: &Buffer,
|
||||
@ -1021,6 +1102,7 @@ pub fn call_rope(
|
||||
bh,
|
||||
td,
|
||||
d,
|
||||
stride_b,
|
||||
(src, src_offset),
|
||||
(cos, cos_offset),
|
||||
(sin, sin_offset),
|
||||
@ -1371,7 +1453,7 @@ pub fn call_gather(
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_scatter_add(
|
||||
pub fn call_scatter(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
@ -1381,7 +1463,7 @@ pub fn call_scatter_add(
|
||||
dim: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = src_shape[..dim].iter().product();
|
||||
let right_size: usize = src_shape[dim + 1..].iter().product();
|
||||
@ -1406,7 +1488,7 @@ pub fn call_scatter_add(
|
||||
dst_dim_size,
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
&output
|
||||
)
|
||||
);
|
||||
|
||||
@ -1414,7 +1496,7 @@ pub fn call_scatter_add(
|
||||
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
@ -2570,7 +2652,7 @@ pub fn call_const_fill(
|
||||
name: &'static str,
|
||||
length: usize,
|
||||
output: &Buffer,
|
||||
v: f32,
|
||||
v: impl EncoderParam,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
|
||||
let encoder = ep.encoder();
|
||||
|
@ -1097,6 +1097,7 @@ template<typename T>
|
||||
METAL_FUNC void ropei(
|
||||
constant size_t &bh,
|
||||
constant size_t &td,
|
||||
constant size_t &stride_b,
|
||||
device const T *src,
|
||||
device const T *cos,
|
||||
device const T *sin,
|
||||
@ -1107,6 +1108,10 @@ METAL_FUNC void ropei(
|
||||
return;
|
||||
}
|
||||
size_t rope_idx = tid % (td / 2);
|
||||
if (stride_b > 0) {
|
||||
size_t b_idx = (2 * tid) / stride_b;
|
||||
rope_idx += b_idx * (td / 2);
|
||||
}
|
||||
T c = cos[rope_idx];
|
||||
T s = sin[rope_idx];
|
||||
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s;
|
||||
@ -1118,6 +1123,7 @@ METAL_FUNC void rope(
|
||||
constant size_t &bh,
|
||||
constant size_t &td,
|
||||
constant size_t &d,
|
||||
constant size_t &stride_b,
|
||||
device const T *src,
|
||||
device const T *cos,
|
||||
device const T *sin,
|
||||
@ -1134,6 +1140,10 @@ METAL_FUNC void rope(
|
||||
size_t i1 = i_bh * td + i_t * d + i_d;
|
||||
size_t i2 = i1 + d / 2;
|
||||
size_t i_cs = i_t * (d / 2) + i_d;
|
||||
if (stride_b > 0) {
|
||||
size_t b_idx = (2 * idx) / stride_b;
|
||||
i_cs += b_idx * (td / 2);
|
||||
}
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
dst[i1] = src[i1] * c - src[i2] * s;
|
||||
@ -1146,6 +1156,7 @@ METAL_FUNC void rope_thd(
|
||||
constant size_t &t,
|
||||
constant size_t &h,
|
||||
constant size_t &d,
|
||||
constant size_t &stride_b,
|
||||
device const T *src,
|
||||
device const T *cos,
|
||||
device const T *sin,
|
||||
@ -1160,8 +1171,12 @@ METAL_FUNC void rope_thd(
|
||||
const size_t i_t = (i_bth / h) % t;
|
||||
const size_t i1 = i_bth * d + i_d;
|
||||
const size_t i2 = i1 + d / 2;
|
||||
const size_t i_cs = i_t * (d / 2) + i_d;
|
||||
T c = cos[i_cs];
|
||||
size_t i_cs = i_t * (d / 2) + i_d;
|
||||
if (stride_b > 0) {
|
||||
const size_t b_idx = (2 * idx) / stride_b;
|
||||
i_cs += b_idx * ((t * d) / 2);
|
||||
}
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
dst[i1] = src[i1] * c - src[i2] * s;
|
||||
dst[i2] = src[i1] * s + src[i2] * c;
|
||||
@ -1171,38 +1186,41 @@ METAL_FUNC void rope_thd(
|
||||
kernel void FN_NAME_I( \
|
||||
constant size_t &bh, \
|
||||
constant size_t &td, \
|
||||
constant size_t &stride_b, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *cos, \
|
||||
device const TYPENAME *sin, \
|
||||
device TYPENAME *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
ropei<TYPENAME>(bh, td, src, cos, sin, dst, tid); \
|
||||
ropei<TYPENAME>(bh, td, stride_b, src, cos, sin, dst, tid); \
|
||||
}\
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &bh, \
|
||||
constant size_t &td, \
|
||||
constant size_t &d, \
|
||||
constant size_t &stride_b, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *cos, \
|
||||
device const TYPENAME *sin, \
|
||||
device TYPENAME *dst, \
|
||||
uint idx [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
rope<TYPENAME>(bh, td, d, src, cos, sin, dst, idx); \
|
||||
rope<TYPENAME>(bh, td, d, stride_b, src, cos, sin, dst, idx); \
|
||||
}\
|
||||
kernel void FN_NAME_THD( \
|
||||
constant size_t &b, \
|
||||
constant size_t &t, \
|
||||
constant size_t &h, \
|
||||
constant size_t &d, \
|
||||
constant size_t &stride_b, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *cos, \
|
||||
device const TYPENAME *sin, \
|
||||
device TYPENAME *dst, \
|
||||
uint idx [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
rope_thd<TYPENAME>(b, t, h, d, src, cos, sin, dst, idx); \
|
||||
rope_thd<TYPENAME>(b, t, h, d, stride_b, src, cos, sin, dst, idx); \
|
||||
}\
|
||||
|
||||
RMSNORM(rmsnorm_f32, float)
|
||||
|
@ -1574,7 +1574,7 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
let input_buffer = new_buffer(&device, input);
|
||||
let ids_buffer = new_buffer(&device, ids);
|
||||
let output = device.new_buffer(std::mem::size_of_val(input) as u64, options);
|
||||
call_scatter_add(
|
||||
call_scatter(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
@ -1584,7 +1584,7 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
dim,
|
||||
BufferOffset::zero_offset(&input_buffer),
|
||||
BufferOffset::zero_offset(&ids_buffer),
|
||||
&output,
|
||||
BufferOffset::zero_offset(&output),
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -2343,7 +2343,7 @@ fn conv_transpose1d_u32() {
|
||||
|
||||
#[test]
|
||||
fn const_fill() {
|
||||
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
|
||||
fn constant_fill<T: Clone + EncoderParam>(name: &'static str, len: usize, value: T) -> Vec<T> {
|
||||
let dev = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = dev.new_command_queue();
|
||||
@ -2357,11 +2357,15 @@ fn const_fill() {
|
||||
command_buffer.wait_until_completed();
|
||||
read_to_vec::<T>(&buffer, len)
|
||||
}
|
||||
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
|
||||
fn test<T: Clone + Copy + EncoderParam + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(
|
||||
name: &'static str,
|
||||
f: F,
|
||||
) {
|
||||
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
|
||||
let value = rand::thread_rng().gen_range(1. ..19.);
|
||||
let value = f(value);
|
||||
let v = constant_fill::<T>(name, len, value);
|
||||
assert_eq!(v, vec![f(value); len])
|
||||
assert_eq!(v, vec![value; len])
|
||||
}
|
||||
test::<u8, _>("fill_u8", |v| v as u8);
|
||||
test::<u32, _>("fill_u32", |v| v as u32);
|
||||
|
@ -73,6 +73,44 @@ template <typename T> METAL_FUNC T sigmoid(T in) {
|
||||
|
||||
#define TILE_SIZE 2
|
||||
|
||||
#define CONST_SET(TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
constant TYPENAME &input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = input; \
|
||||
} \
|
||||
kernel void FN_NAME##_##strided( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant TYPENAME &input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[get_strided_index(tid, num_dims, dims, strides)] = input; \
|
||||
} \
|
||||
kernel void FN_NAME##_##tiled( \
|
||||
constant size_t &dim, \
|
||||
constant TYPENAME &input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
for (uint i = 0; i < TILE_SIZE; i++) { \
|
||||
const uint idx = tid * TILE_SIZE + i; \
|
||||
output[idx] = input; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
@ -139,6 +177,11 @@ COPY2D(copy2d_f16, half)
|
||||
COPY2D(copy2d_u8, uint8_t)
|
||||
COPY2D(copy2d_u32, uint32_t)
|
||||
|
||||
CONST_SET(float, const_set_f32)
|
||||
CONST_SET(half, const_set_f16)
|
||||
CONST_SET(uint8_t, const_set_u8)
|
||||
CONST_SET(uint32_t, const_set_u32)
|
||||
|
||||
UNARY_OP(cos)
|
||||
UNARY_OP(sin)
|
||||
UNARY_OP(sqr)
|
||||
@ -171,6 +214,7 @@ UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided);
|
||||
#if __METAL_VERSION__ >= 220
|
||||
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
||||
COPY2D(copy2d_i64, int64_t)
|
||||
CONST_SET(int64_t, const_set_i64)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
@ -199,4 +243,5 @@ UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided);
|
||||
|
||||
COPY2D(copy2d_bf16, bfloat)
|
||||
CONST_SET(bfloat, const_set_bf16)
|
||||
#endif
|
||||
|
@ -88,9 +88,13 @@ primitive!(bool);
|
||||
primitive!(usize);
|
||||
primitive!(i32);
|
||||
primitive!(i64);
|
||||
primitive!(u8);
|
||||
primitive!(u32);
|
||||
primitive!(u64);
|
||||
primitive!(f32);
|
||||
primitive!(f64);
|
||||
primitive!(half::bf16);
|
||||
primitive!(half::f16);
|
||||
|
||||
pub struct BufferOffset<'a> {
|
||||
pub buffer: &'a Buffer,
|
||||
|
@ -33,6 +33,7 @@ criterion = { workspace = true }
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||
cuda = ["candle/cuda"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||
metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]
|
||||
|
||||
|
@ -71,6 +71,8 @@ impl candle::Module for PReLU {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let weight = if self.is_scalar {
|
||||
self.weight.reshape(())?
|
||||
} else if xs.shape() == self.weight.shape() {
|
||||
self.weight.clone()
|
||||
} else if xs.rank() >= 2 {
|
||||
let num_channels = xs.dim(1)?;
|
||||
let num_weights = self.weight.elem_count();
|
||||
@ -78,7 +80,7 @@ impl candle::Module for PReLU {
|
||||
candle::bail!("error in prelu: unexpected number of channels for the input, got {num_channels}, weight dim is {num_weights}")
|
||||
}
|
||||
let mut s = vec![1; xs.rank()];
|
||||
s[1] = self.weight.elem_count();
|
||||
s[1] = num_weights;
|
||||
self.weight.reshape(s)?
|
||||
} else {
|
||||
self.weight.clone()
|
||||
|
@ -1,6 +1,6 @@
|
||||
//! Convolution Layers.
|
||||
use crate::BatchNorm;
|
||||
use candle::{Result, Tensor};
|
||||
use candle::{conv::CudnnFwdAlgo, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct Conv1dConfig {
|
||||
@ -8,6 +8,7 @@ pub struct Conv1dConfig {
|
||||
pub stride: usize,
|
||||
pub dilation: usize,
|
||||
pub groups: usize,
|
||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
}
|
||||
|
||||
impl Default for Conv1dConfig {
|
||||
@ -17,6 +18,7 @@ impl Default for Conv1dConfig {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -52,12 +54,13 @@ impl Conv1d {
|
||||
|
||||
impl crate::Module for Conv1d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv1d(
|
||||
let x = x.conv1d_with_algo(
|
||||
&self.weight,
|
||||
self.config.padding,
|
||||
self.config.stride,
|
||||
self.config.dilation,
|
||||
self.config.groups,
|
||||
self.config.cudnn_fwd_algo,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
@ -147,6 +150,7 @@ pub struct Conv2dConfig {
|
||||
pub stride: usize,
|
||||
pub dilation: usize,
|
||||
pub groups: usize,
|
||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
}
|
||||
|
||||
impl Default for Conv2dConfig {
|
||||
@ -156,6 +160,7 @@ impl Default for Conv2dConfig {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -211,12 +216,13 @@ impl Conv2d {
|
||||
|
||||
impl crate::Module for Conv2d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv2d(
|
||||
let x = x.conv2d_with_algo(
|
||||
&self.weight,
|
||||
self.config.padding,
|
||||
self.config.stride,
|
||||
self.config.dilation,
|
||||
self.config.groups,
|
||||
self.config.cudnn_fwd_algo,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
|
@ -1,6 +1,6 @@
|
||||
//! Cache Implementations
|
||||
//!
|
||||
use candle::{Device, Result, Tensor};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Cache {
|
||||
@ -294,6 +294,27 @@ impl RotatingCache {
|
||||
Tensor::from_slice(&mask, (size1, size2), device)
|
||||
}
|
||||
|
||||
/// Returns the positions corresponding to all the elements that will be retured
|
||||
/// *after* adding `seq_len` to the cache.
|
||||
pub fn positions(&self, seq_len: usize) -> Vec<usize> {
|
||||
if seq_len <= self.max_seq_len {
|
||||
let upd_offset = (self.offset + seq_len) % self.max_seq_len;
|
||||
let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len);
|
||||
(0..cache_out_len)
|
||||
.map(|i| {
|
||||
let pos_cache = self.current_seq_len + seq_len + i - upd_offset;
|
||||
if i < upd_offset {
|
||||
pos_cache
|
||||
} else {
|
||||
pos_cache - self.max_seq_len
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
(self.current_seq_len..(self.current_seq_len + seq_len)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
|
||||
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
|
||||
let mask = if seq_len == 1 {
|
||||
@ -362,12 +383,338 @@ impl RotatingKvCache {
|
||||
self.k.current_seq_len()
|
||||
}
|
||||
|
||||
/// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
|
||||
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
|
||||
self.k.attn_mask(seq_len, device)
|
||||
}
|
||||
|
||||
/// Returns the positions corresponding to all the elements that will be retured
|
||||
/// *after* adding `seq_len` to the cache.
|
||||
pub fn positions(&self, seq_len: usize) -> Vec<usize> {
|
||||
self.k.positions(seq_len)
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.k.reset();
|
||||
self.v.reset();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IndicesAndMask {
|
||||
indices: Tensor,
|
||||
mask: Tensor,
|
||||
}
|
||||
|
||||
impl IndicesAndMask {
|
||||
pub fn mask(&self) -> &Tensor {
|
||||
&self.mask
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScatteredKvCache {
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
context: usize,
|
||||
}
|
||||
|
||||
impl ScatteredKvCache {
|
||||
pub fn append(
|
||||
&mut self,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
iam: &IndicesAndMask,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
if self.context <= k.dim(2)? {
|
||||
return Ok((k.clone(), v.clone()));
|
||||
}
|
||||
let indices = iam.indices.unsqueeze(2)?.unsqueeze(1)?;
|
||||
let indices = indices.broadcast_as(k.shape())?.contiguous()?;
|
||||
self.k.scatter_set(&indices, k, 2)?;
|
||||
self.v.scatter_set(&indices, v, 2)?;
|
||||
Ok((self.k.clone(), self.v.clone()))
|
||||
}
|
||||
|
||||
pub fn k(&self) -> &Tensor {
|
||||
&self.k
|
||||
}
|
||||
|
||||
pub fn v(&self) -> &Tensor {
|
||||
&self.v
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScatteredCacheBuilder {
|
||||
context: usize,
|
||||
// The current position in the stream, this can be larger than context.
|
||||
positions: Vec<usize>,
|
||||
// The index where the next element will be stored.
|
||||
indices: Vec<usize>,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl ScatteredCacheBuilder {
|
||||
pub fn new(batch_size: usize, context: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let positions = vec![0; batch_size];
|
||||
let indices = vec![0; batch_size];
|
||||
Ok(Self {
|
||||
positions,
|
||||
indices,
|
||||
context,
|
||||
dtype,
|
||||
device: device.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn make_cache(&self, num_heads: usize, head_dim: usize) -> Result<ScatteredKvCache> {
|
||||
let batch_size = self.batch_size();
|
||||
let shape = (batch_size, num_heads, self.context, head_dim);
|
||||
let k = Tensor::zeros(shape, self.dtype, self.device())?;
|
||||
let v = Tensor::zeros(shape, self.dtype, self.device())?;
|
||||
Ok(ScatteredKvCache {
|
||||
k,
|
||||
v,
|
||||
context: self.context,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn positions(&self) -> &[usize] {
|
||||
&self.positions
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.positions.fill(0);
|
||||
self.indices.fill(0);
|
||||
}
|
||||
|
||||
pub fn batch_size(&self) -> usize {
|
||||
self.positions.len()
|
||||
}
|
||||
|
||||
pub fn reset_batch_index(&mut self, batch_index: usize) {
|
||||
self.positions[batch_index] = 0;
|
||||
self.indices[batch_index] = 0;
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
pub fn indices_and_mask(
|
||||
&mut self,
|
||||
seq_len: usize,
|
||||
batch_mask: &[bool],
|
||||
) -> Result<IndicesAndMask> {
|
||||
// mask shape is (b, h, t, k)
|
||||
let context = self.context;
|
||||
if self.context <= seq_len {
|
||||
return self.indices_and_mask_abs(seq_len, batch_mask);
|
||||
}
|
||||
let mut attention_masks = Vec::with_capacity(self.batch_size());
|
||||
let mut cache_indices = Vec::with_capacity(self.batch_size());
|
||||
for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
|
||||
if !batch_mask {
|
||||
let masks: Vec<Vec<f32>> = vec![vec![0.0; context]; seq_len];
|
||||
let indices = vec![self.indices[batch_i] as u32; seq_len];
|
||||
attention_masks.push(masks);
|
||||
cache_indices.push(indices);
|
||||
} else {
|
||||
let start_index = self.indices[batch_i];
|
||||
let start_pos = self.positions[batch_i];
|
||||
let mut masks: Vec<Vec<f32>> = Vec::with_capacity(seq_len);
|
||||
let mut indices = Vec::with_capacity(seq_len);
|
||||
let mut all_pos = vec![usize::MAX; context];
|
||||
if start_pos < context {
|
||||
for i in 0..start_pos {
|
||||
all_pos[i] = i;
|
||||
}
|
||||
} else {
|
||||
let offset = start_pos - start_index;
|
||||
for i in 0..context {
|
||||
all_pos[i] = if i < start_index {
|
||||
i + offset
|
||||
} else {
|
||||
i + offset - context
|
||||
};
|
||||
}
|
||||
}
|
||||
for seq_i in 0..seq_len {
|
||||
let index = self.indices[batch_i];
|
||||
all_pos[index] = seq_i + start_pos;
|
||||
indices.push(index as u32);
|
||||
self.indices[batch_i] += 1;
|
||||
self.positions[batch_i] += 1;
|
||||
if self.indices[batch_i] >= self.context {
|
||||
self.indices[batch_i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
for seq_i in 0..seq_len {
|
||||
let my_pos = seq_i + start_pos;
|
||||
let mask = all_pos
|
||||
.iter()
|
||||
.map(|&pos| {
|
||||
if pos <= my_pos {
|
||||
0.0
|
||||
} else {
|
||||
f32::NEG_INFINITY
|
||||
}
|
||||
})
|
||||
.collect::<Vec<f32>>();
|
||||
masks.push(mask);
|
||||
}
|
||||
|
||||
attention_masks.push(masks);
|
||||
cache_indices.push(indices);
|
||||
}
|
||||
}
|
||||
// Flattening the attention mask then using Tensor::from_vec rather using Tensor::new ends
|
||||
// up being almost 10x faster with candle 0.9.0. This has been fixed in candle 0.9.1.
|
||||
let attention_masks = attention_masks
|
||||
.into_iter()
|
||||
.flat_map(|m| m.into_iter().flatten())
|
||||
.collect::<Vec<f32>>();
|
||||
let mask = Tensor::from_vec(attention_masks, ((), 1, seq_len, context), self.device())?
|
||||
.to_dtype(self.dtype)?;
|
||||
let indices = Tensor::new(cache_indices, self.device())?;
|
||||
Ok(IndicesAndMask { indices, mask })
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
fn indices_and_mask_abs(
|
||||
&mut self,
|
||||
seq_len: usize,
|
||||
batch_mask: &[bool],
|
||||
) -> Result<IndicesAndMask> {
|
||||
let mask = self.get_mask_abs(seq_len, seq_len)?;
|
||||
let mut cache_indices = Vec::with_capacity(self.batch_size());
|
||||
for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
|
||||
if !batch_mask {
|
||||
let indices = vec![self.indices[batch_i] as u32; seq_len];
|
||||
cache_indices.push(indices);
|
||||
} else {
|
||||
let mut indices = Vec::with_capacity(seq_len);
|
||||
for _ in 0..seq_len {
|
||||
let index = self.indices[batch_i];
|
||||
indices.push(index as u32);
|
||||
self.indices[batch_i] += 1;
|
||||
self.positions[batch_i] += 1;
|
||||
if self.indices[batch_i] >= self.context {
|
||||
self.indices[batch_i] = 0;
|
||||
}
|
||||
}
|
||||
cache_indices.push(indices);
|
||||
}
|
||||
}
|
||||
let indices = Tensor::new(cache_indices, self.device())?;
|
||||
Ok(IndicesAndMask { indices, mask })
|
||||
}
|
||||
|
||||
fn get_mask_abs(&self, size1: usize, size2: usize) -> Result<Tensor> {
|
||||
let context = self.context;
|
||||
let mask: Vec<_> = (0..size1)
|
||||
.flat_map(|i| {
|
||||
(0..size2).map(move |j| {
|
||||
if size1 + j > size2 + i || size1 + j + context < size2 + i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size1, size2), self.device())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use candle::IndexOp;
|
||||
|
||||
#[test]
|
||||
fn test_scattered_kv_cache() -> Result<()> {
|
||||
let device = Device::Cpu;
|
||||
let mut cache = ScatteredCacheBuilder::new(2, 5, DType::F32, &device)?;
|
||||
let inf = f32::INFINITY;
|
||||
|
||||
let iam = cache.indices_and_mask(1, &[true, false])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [0]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[[[0.0, -inf, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
|
||||
);
|
||||
|
||||
let iam = cache.indices_and_mask(1, &[true, false])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[1], [0]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[[[0.0, 0.0, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
|
||||
);
|
||||
|
||||
let iam = cache.indices_and_mask(3, &[false, true])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 2, 2], [0, 1, 2]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[
|
||||
[
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[0.0, -inf, -inf, -inf, -inf],
|
||||
[0.0, 0.0, -inf, -inf, -inf],
|
||||
[0.0, 0.0, 0.0, -inf, -inf]
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
let iam = cache.indices_and_mask(3, &[true, true])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 3, 4], [3, 4, 0]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[
|
||||
[
|
||||
[0.0, 0.0, 0.0, -inf, -inf],
|
||||
[0.0, 0.0, 0.0, 0.0, -inf],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[-inf, 0.0, 0.0, 0.0, -inf],
|
||||
[-inf, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
let iam = cache.indices_and_mask(1, &[true, false])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [1]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[[[0.0, 0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
|
||||
);
|
||||
|
||||
let iam = cache.indices_and_mask(2, &[true, false])?;
|
||||
let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
|
||||
assert_eq!(iam.indices.to_vec2::<u32>()?, [[1, 2], [1, 1]]);
|
||||
assert_eq!(
|
||||
mask,
|
||||
[
|
||||
[[0.0, 0.0, -inf, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]],
|
||||
[[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -31,6 +31,7 @@ pub mod ops;
|
||||
pub mod optim;
|
||||
pub mod rnn;
|
||||
pub mod rotary_emb;
|
||||
pub mod sampling;
|
||||
pub mod sequential;
|
||||
pub mod var_builder;
|
||||
pub mod var_map;
|
||||
|
@ -41,12 +41,36 @@ impl Linear {
|
||||
|
||||
impl super::Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let w = match *x.dims() {
|
||||
[b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
_ => self.weight.t()?,
|
||||
// When possible, we avoid using a broadcasted matmul as it is much slower
|
||||
// than the standard matmul for the cuda and cpu backends.
|
||||
let x = match *x.dims() {
|
||||
[b1, b2, m, k] => {
|
||||
if x.is_contiguous() {
|
||||
let w = self.weight.t()?;
|
||||
x.reshape((b1 * b2 * m, k))?
|
||||
.matmul(&w)?
|
||||
.reshape((b1, b2, m, ()))?
|
||||
} else {
|
||||
let w = self.weight.broadcast_left((b1, b2))?.t()?;
|
||||
x.matmul(&w)?
|
||||
}
|
||||
}
|
||||
[bsize, m, k] => {
|
||||
if x.is_contiguous() {
|
||||
let w = self.weight.t()?;
|
||||
x.reshape((bsize * m, k))?
|
||||
.matmul(&w)?
|
||||
.reshape((bsize, m, ()))?
|
||||
} else {
|
||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
||||
x.matmul(&w)?
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
let w = self.weight.t()?;
|
||||
x.matmul(&w)?
|
||||
}
|
||||
};
|
||||
let x = x.matmul(&w)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
|
@ -112,7 +112,7 @@ impl candle::CustomOp1 for Sigmoid {
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(el_count)? };
|
||||
|
||||
let mut builder = func.builder();
|
||||
candle::builder_arg!(builder, el_count, dims.len());
|
||||
@ -373,7 +373,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let dst = unsafe { dev.alloc::<T>(el)? };
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
@ -561,7 +561,7 @@ impl candle::CustomOp2 for RmsNorm {
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let dst = unsafe { dev.alloc::<T>(el)? };
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
@ -800,7 +800,7 @@ impl candle::CustomOp3 for LayerNorm {
|
||||
let func =
|
||||
dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let dst = unsafe { dev.alloc::<T>(el)? };
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user