mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
20 Commits
0.9.0-alph
...
0.9.0
Author | SHA1 | Date | |
---|---|---|---|
fbaf0b0e32 | |||
a2e925462c | |||
3827685524 | |||
3aeb9575c7 | |||
6ff0a6999c | |||
82def7ae38 | |||
99bd69f383 | |||
a4c56a958e | |||
b2904a830b | |||
21055b5697 | |||
9dbaf958dc | |||
ce5f8dd129 | |||
9954981327 | |||
7f0f83a7c1 | |||
76e565c4ab | |||
e4e7b0b2da | |||
b01ebbad8a | |||
1d1d6d4fe6 | |||
2653002f29 | |||
a52b76ae82 |
26
Cargo.toml
26
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.9.0-alpha.2"
|
||||
version = "0.9.0"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.2" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.2" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.2" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.2" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.2" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.2" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.2" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.2" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.9.0" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.9.0" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.9.0" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.9.0" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.9.0" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.15.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
cudarc = { version = "0.16.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.4.1"
|
||||
@ -70,9 +70,9 @@ tokenizers = { version = "0.21.0", default-features = false }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
ug = "0.3.1"
|
||||
ug-cuda = "0.3.1"
|
||||
ug-metal = "0.3.1"
|
||||
ug = "0.4.0"
|
||||
ug-cuda = "0.4.0"
|
||||
ug-metal = "0.4.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "1.1.1", default-features = false }
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
|
@ -290,6 +290,8 @@ Cheatsheet:
|
||||
|
||||
### Why should I use Candle?
|
||||
|
||||
<!--- ANCHOR: goals --->
|
||||
|
||||
Candle's core goal is to *make serverless inference possible*. Full machine learning frameworks like PyTorch
|
||||
are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight
|
||||
binaries.
|
||||
@ -299,6 +301,7 @@ and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-fut
|
||||
|
||||
Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers).
|
||||
|
||||
<!--- ANCHOR_END: goals --->
|
||||
|
||||
### Other ML frameworks
|
||||
|
||||
|
13
candle-book/CONTRIBUTING.md
Normal file
13
candle-book/CONTRIBUTING.md
Normal file
@ -0,0 +1,13 @@
|
||||
# Candle Book
|
||||
|
||||
The book uses [mdBook](https://github.com/rust-lang/mdBook) for building.
|
||||
|
||||
## Installation
|
||||
|
||||
To install mdBook, run `cargo install mdbook`. More instructions can be found [here](https://rust-lang.github.io/mdBook/guide/installation.html).
|
||||
|
||||
## Viewing the book
|
||||
|
||||
To view the book, run `mdbook serve --open candle-book`. More instructions can be found [here](https://rust-lang.github.io/mdBook/guide/creating.html).
|
||||
|
||||
The book is built automatically in github CI.
|
@ -1,6 +1,7 @@
|
||||
# Introduction
|
||||
|
||||
{{#include ../../README.md:goals}}
|
||||
|
||||
{{#include ../../README.md:features}}
|
||||
|
||||
|
||||
This book will introduce step by step how to use `candle`.
|
||||
This book will introduce step by step how to use `candle`.
|
@ -5,7 +5,10 @@
|
||||
# User Guide
|
||||
|
||||
- [Installation](guide/installation.md)
|
||||
- [Hello World - MNIST](guide/hello_world.md)
|
||||
- [Tutorial - MNIST](guide/mnist/intro.md)
|
||||
- [Modeling](guide/mnist/modeling.md)
|
||||
- [Training](guide/mnist/training.md)
|
||||
- [Saving And Loading](guide/mnist/saving_loading.md)
|
||||
- [PyTorch cheatsheet](guide/cheatsheet.md)
|
||||
|
||||
# Reference Guide
|
||||
|
@ -1,8 +1,23 @@
|
||||
# Installation
|
||||
|
||||
**With Cuda support**:
|
||||
## 1. Create a new rust app or library
|
||||
|
||||
1. First, make sure that Cuda is correctly installed.
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
```
|
||||
|
||||
## 2. Add the correct candle version
|
||||
|
||||
### Standard
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core
|
||||
```
|
||||
|
||||
### CUDA
|
||||
|
||||
First, make sure that Cuda is correctly installed.
|
||||
- `nvcc --version` should print information about your Cuda compiler driver.
|
||||
- `nvidia-smi --query-gpu=compute_cap --format=csv` should print your GPUs compute capability, e.g. something
|
||||
like:
|
||||
@ -17,43 +32,36 @@ You can also compile the Cuda kernels for a specific compute cap using the
|
||||
|
||||
If any of the above commands errors out, please make sure to update your Cuda version.
|
||||
|
||||
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
|
||||
|
||||
Start by creating a new cargo:
|
||||
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
```
|
||||
|
||||
Make sure to add the `candle-core` crate with the cuda feature:
|
||||
Add the `candle-core` crate with the cuda feature:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"
|
||||
```
|
||||
|
||||
### MKL
|
||||
|
||||
You can also see the `mkl` feature which can get faster inference on CPU.
|
||||
|
||||
Add the `candle-core` crate with the mkl feature:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features "mkl"
|
||||
```
|
||||
|
||||
### Metal
|
||||
|
||||
Metal is exclusive to MacOS.
|
||||
|
||||
Add the `candle-core` crate with the metal feature:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core --features "metal"
|
||||
```
|
||||
|
||||
## 3. Building
|
||||
|
||||
Run `cargo build` to make sure everything can be correctly built.
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
|
||||
**Without Cuda support**:
|
||||
|
||||
Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) as follows:
|
||||
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
cargo add --git https://github.com/huggingface/candle.git candle-core
|
||||
```
|
||||
|
||||
Finally, run `cargo build` to make sure everything can be correctly built.
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
|
||||
**With mkl support**
|
||||
|
||||
You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md)
|
||||
|
17
candle-book/src/guide/mnist/intro.md
Normal file
17
candle-book/src/guide/mnist/intro.md
Normal file
@ -0,0 +1,17 @@
|
||||
# Candle MNIST Tutorial
|
||||
|
||||
## Introduction
|
||||
|
||||
This tutorial provides an introduction to Candle by implementing and training a neural network for MNIST digit classification from scratch.
|
||||
|
||||
Throughout this tutorial, you will learn the basics of:
|
||||
|
||||
- Tensor operations and model construction
|
||||
- Creating and implementing neural network layers
|
||||
- Parameter initialization
|
||||
- Training loop implementation
|
||||
- Saving and loading trained models
|
||||
|
||||
## Getting Started
|
||||
|
||||
Before proceeding, please ensure that you have properly installed Candle by following the instructions in the [Installation](../installation.md) guide.
|
172
candle-book/src/guide/mnist/modeling.md
Normal file
172
candle-book/src/guide/mnist/modeling.md
Normal file
@ -0,0 +1,172 @@
|
||||
# Candle MNIST Tutorial
|
||||
|
||||
## Modeling
|
||||
|
||||
Open `src/main.rs` in your project folder and insert the following code:
|
||||
|
||||
```rust
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
|
||||
struct Model {
|
||||
first: Tensor,
|
||||
second: Tensor,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = image.matmul(&self.first)?;
|
||||
let x = x.relu()?;
|
||||
x.matmul(&self.second)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; to utilize GPU acceleration.
|
||||
let device = Device::Cpu;
|
||||
|
||||
let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
|
||||
let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Execute the program with:
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> Digit Tensor[dims 1, 10; f32] digit
|
||||
```
|
||||
|
||||
Since random inputs are provided, expect an incoherent output.
|
||||
|
||||
## Implementing a `Linear` Layer
|
||||
|
||||
To create a more sophisticated layer type, add a `bias` to the weight to construct the standard `Linear` layer.
|
||||
|
||||
Replace the entire content of `src/main.rs` with:
|
||||
|
||||
```rust
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
|
||||
struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.matmul(&self.weight)?;
|
||||
x.broadcast_add(&self.bias)
|
||||
}
|
||||
}
|
||||
|
||||
struct Model {
|
||||
first: Linear,
|
||||
second: Linear,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = self.first.forward(image)?;
|
||||
let x = x.relu()?;
|
||||
self.second.forward(&x)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; for GPU acceleration.
|
||||
// Use Device::Cpu; for CPU computation.
|
||||
let device = Device::cuda_if_available(0)?;
|
||||
|
||||
// Initialize model parameters
|
||||
let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
|
||||
let first = Linear { weight, bias };
|
||||
let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
|
||||
let second = Linear { weight, bias };
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
|
||||
// Perform inference
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Execute again with:
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> Digit Tensor[dims 1, 10; f32] digit
|
||||
```
|
||||
|
||||
## Utilizing `candle_nn`
|
||||
|
||||
Many classical layers (such as [Linear](https://github.com/huggingface/candle/blob/main/candle-nn/src/linear.rs)) are already implemented in [candle-nn](https://github.com/huggingface/candle/tree/main/candle-nn).
|
||||
|
||||
This `Linear` implementation follows PyTorch conventions for improved compatibility with existing models, utilizing the transpose of weights rather than direct weights.
|
||||
|
||||
Let's simplify our implementation. First, add `candle-nn` as a dependency:
|
||||
|
||||
```bash
|
||||
$ cargo add --git https://github.com/huggingface/candle.git candle-nn
|
||||
```
|
||||
|
||||
Now, replace the entire content of `src/main.rs` with:
|
||||
|
||||
```rust
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use candle_nn::{Linear, Module};
|
||||
|
||||
struct Model {
|
||||
first: Linear,
|
||||
second: Linear,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = self.first.forward(image)?;
|
||||
let x = x.relu()?;
|
||||
self.second.forward(&x)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; for GPU acceleration.
|
||||
let device = Device::Cpu;
|
||||
|
||||
// Note the dimension change: (784, 100) -> (100, 784)
|
||||
let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
|
||||
let first = Linear::new(weight, Some(bias));
|
||||
let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?;
|
||||
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
|
||||
let second = Linear::new(weight, Some(bias));
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
|
||||
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Execute the final version:
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> Digit Tensor[dims 1, 10; f32] digit
|
||||
```
|
158
candle-book/src/guide/mnist/saving_loading.md
Normal file
158
candle-book/src/guide/mnist/saving_loading.md
Normal file
@ -0,0 +1,158 @@
|
||||
# Candle MNIST Tutorial
|
||||
|
||||
## Saving and Loading Models
|
||||
|
||||
After training a model, it is useful to save and subsequently load the model parameters. In Candle, this functionality is managed through the `VarMap` data structure, with parameters stored on disk using the [safetensors](https://huggingface.co/docs/safetensors/index) format.
|
||||
|
||||
### Saving Model Parameters
|
||||
|
||||
Let's modify our `training_loop` function to include functionality for saving weights:
|
||||
|
||||
```rust
|
||||
fn training_loop(
|
||||
m: candle_datasets::vision::Dataset,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images.to_device(&dev)?;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
// Initialize a VarMap for trainable parameters
|
||||
let varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||
let model = Model::new(vs.clone())?;
|
||||
|
||||
let learning_rate = 0.05;
|
||||
let epochs = 10;
|
||||
|
||||
// Initialize stochastic gradient descent optimizer
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
for epoch in 1..epochs {
|
||||
// Standard MNIST forward pass
|
||||
let logits = model.forward(&train_images)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
|
||||
// Compute Negative Log Likelihood loss
|
||||
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||
|
||||
// Perform backward pass and update weights
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
// Evaluate model on test set
|
||||
let test_logits = model.forward(&test_images)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
test_accuracy
|
||||
);
|
||||
}
|
||||
|
||||
// Save model weights to disk
|
||||
varmap.save("model_weights.safetensors")?;
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> 1 train loss: 2.40485 test acc: 0.11%
|
||||
> 2 train loss: 2.34161 test acc: 0.14%
|
||||
> 3 train loss: 2.28841 test acc: 0.17%
|
||||
> 4 train loss: 2.24158 test acc: 0.19%
|
||||
> 5 train loss: 2.19898 test acc: 0.23%
|
||||
> 6 train loss: 2.15927 test acc: 0.26%
|
||||
> 7 train loss: 2.12161 test acc: 0.29%
|
||||
> 8 train loss: 2.08549 test acc: 0.32%
|
||||
> 9 train loss: 2.05053 test acc: 0.35%
|
||||
```
|
||||
|
||||
### Loading Model Parameters
|
||||
|
||||
Now that we have saved our model parameters, we can modify the code to load them. The primary change required is to make the `varmap` variable mutable:
|
||||
|
||||
```rust
|
||||
fn training_loop(
|
||||
m: candle_datasets::vision::Dataset,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images.to_device(&dev)?;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
// Create a mutable VarMap for trainable parameters
|
||||
let mut varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||
let model = Model::new(vs.clone())?;
|
||||
|
||||
// Load pre-trained weights from file
|
||||
varmap.load("model_weights.safetensors")?;
|
||||
|
||||
let learning_rate = 0.05;
|
||||
let epochs = 10;
|
||||
|
||||
// Initialize stochastic gradient descent optimizer
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
for epoch in 1..epochs {
|
||||
// Standard MNIST forward pass
|
||||
let logits = model.forward(&train_images)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
|
||||
// Compute Negative Log Likelihood loss
|
||||
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||
|
||||
// Perform backward pass and update weights
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
// Evaluate model on test set
|
||||
let test_logits = model.forward(&test_images)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
test_accuracy
|
||||
);
|
||||
}
|
||||
|
||||
// Save updated weights back to disk
|
||||
varmap.save("model_weights.safetensors")?;
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> 1 train loss: 2.01645 test acc: 0.38%
|
||||
> 2 train loss: 1.98300 test acc: 0.41%
|
||||
> 3 train loss: 1.95008 test acc: 0.44%
|
||||
> 4 train loss: 1.91754 test acc: 0.47%
|
||||
> 5 train loss: 1.88534 test acc: 0.50%
|
||||
> 6 train loss: 1.85349 test acc: 0.53%
|
||||
> 7 train loss: 1.82198 test acc: 0.56%
|
||||
> 8 train loss: 1.79077 test acc: 0.59%
|
||||
> 9 train loss: 1.75989 test acc: 0.61%
|
||||
```
|
||||
|
||||
Note that loading the weights will fail if the specified file does not exist or is incompatible with the current model architecture. Implementing file existence checks and appropriate error handling is left to the user.
|
134
candle-book/src/guide/mnist/training.md
Normal file
134
candle-book/src/guide/mnist/training.md
Normal file
@ -0,0 +1,134 @@
|
||||
# Candle MNIST Tutorial
|
||||
|
||||
## Training Implementation
|
||||
|
||||
First, let's create a utility function `make_linear` that accepts a `VarBuilder` and returns an initialized linear layer. The `VarBuilder` constructs a `VarMap`, which is the data structure that stores our trainable parameters.
|
||||
|
||||
```rust
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use candle_nn::{Linear, Module, VarBuilder, VarMap};
|
||||
|
||||
fn make_linear(vs: VarBuilder, in_dim: usize, out_dim: usize) -> Result<Linear> {
|
||||
let ws = vs.get_with_hints(
|
||||
(out_dim, in_dim),
|
||||
"weight",
|
||||
candle_nn::init::DEFAULT_KAIMING_NORMAL,
|
||||
)?;
|
||||
let bound = 1. / (in_dim as f64).sqrt();
|
||||
let bs = vs.get_with_hints(
|
||||
out_dim,
|
||||
"bias",
|
||||
candle_nn::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
},
|
||||
)?;
|
||||
Ok(Linear::new(ws, Some(bs)))
|
||||
}
|
||||
```
|
||||
|
||||
Next, let's implement a `new` method for our model class to accept a `VarBuilder` and initialize the model. We use `VarBuilder::pp` to "push prefix" so that the parameter names are organized hierarchically: the first layer weights as `first.weight` and `first.bias`, and the second layer weights as `second.weight` and `second.bias`.
|
||||
|
||||
```rust
|
||||
impl Model {
|
||||
fn new(vs: VarBuilder) -> Result<Self> {
|
||||
const IMAGE_DIM: usize = 784;
|
||||
const HIDDEN_DIM: usize = 100;
|
||||
const LABELS: usize = 10;
|
||||
|
||||
let first = make_linear(vs.pp("first"), IMAGE_DIM, HIDDEN_DIM)?;
|
||||
let second = make_linear(vs.pp("second"), HIDDEN_DIM, LABELS)?;
|
||||
|
||||
Ok(Self { first, second })
|
||||
}
|
||||
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = self.first.forward(image)?;
|
||||
let x = x.relu()?;
|
||||
self.second.forward(&x)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Now, let's add the `candle-datasets` package to our project to access the MNIST dataset:
|
||||
|
||||
```bash
|
||||
$ cargo add --git https://github.com/huggingface/candle.git candle-datasets
|
||||
```
|
||||
|
||||
With the dataset available, we can implement our training loop:
|
||||
|
||||
```rust
|
||||
use candle_core::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{loss, ops, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||
|
||||
fn training_loop(
|
||||
m: candle_datasets::vision::Dataset,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images.to_device(&dev)?;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
// Initialize a VarMap to store trainable parameters
|
||||
let varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||
let model = Model::new(vs.clone())?;
|
||||
|
||||
let learning_rate = 0.05;
|
||||
let epochs = 10;
|
||||
|
||||
// Initialize a stochastic gradient descent optimizer to update parameters
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
for epoch in 1..epochs {
|
||||
// Perform forward pass on MNIST data
|
||||
let logits = model.forward(&train_images)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
|
||||
// Compute Negative Log Likelihood loss
|
||||
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||
|
||||
// Perform backward pass and update weights
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
// Evaluate model on test set
|
||||
let test_logits = model.forward(&test_images)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
test_accuracy
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Finally, let's implement our main function:
|
||||
|
||||
```rust
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let m = candle_datasets::vision::mnist::load()?;
|
||||
return training_loop(m);
|
||||
}
|
||||
```
|
||||
|
||||
Let's execute the training process:
|
||||
|
||||
```bash
|
||||
$ cargo run --release
|
||||
|
||||
> 1 train loss: 2.35449 test acc: 0.12%
|
||||
> 2 train loss: 2.30760 test acc: 0.15%
|
||||
> ...
|
||||
```
|
@ -6,28 +6,18 @@ extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
// xs: [1024, 64, 1924], c Tensor[dims 128, 64, 8; f32, cuda:0] Conv1dConfig { padding: 0, stride: 4, dilation: 1, groups: 1 }
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.matmul(&x)?;
|
||||
device.synchronize()?;
|
||||
println!("fp32: {:?}", start_time.elapsed());
|
||||
drop(_x1);
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
||||
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.matmul(&x)?;
|
||||
device.synchronize()?;
|
||||
println!("tf32: {:?}", start_time.elapsed());
|
||||
let x = Tensor::randn(0f32, 1.0, (1024, 64, 1924), &device)?;
|
||||
let c = Tensor::randn(0f32, 1.0, (128, 64, 8), &device)?;
|
||||
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
|
||||
drop(_x1);
|
||||
for _ in 0..20 {
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
|
||||
device.synchronize()?;
|
||||
println!("conv1d: {:?}", start_time.elapsed());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -71,15 +71,27 @@ pub trait BackendStorage: Sized {
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
|
||||
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn scatter_add(
|
||||
&self,
|
||||
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self>;
|
||||
) -> Result<()>;
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<()>;
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn index_add(
|
||||
&self,
|
||||
@ -113,6 +125,8 @@ pub trait BackendStorage: Sized {
|
||||
_src_offset: usize,
|
||||
_dst_offset: usize,
|
||||
) -> Result<()>;
|
||||
|
||||
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()>;
|
||||
}
|
||||
|
||||
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
@ -127,8 +141,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
|
||||
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
/// # Safety
|
||||
/// This function is unsafe as it doesn't initialize the underlying data store.
|
||||
/// The caller should ensure that the data is properly initialized as early as possible
|
||||
|
@ -53,6 +53,7 @@ impl Tensor {
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::IndexAdd(t1, t2, t3, _)
|
||||
| Op::Scatter(t1, t2, t3, _)
|
||||
| Op::ScatterAdd(t1, t2, t3, _)
|
||||
| Op::CustomOp3(t1, t2, t3, _)
|
||||
| Op::WhereCond(t1, t2, t3) => {
|
||||
@ -419,7 +420,7 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||
}
|
||||
Op::ScatterAdd(init, indexes, src, dim) => {
|
||||
Op::Scatter(init, indexes, src, dim) => {
|
||||
let init_sum_grad = grads.or_insert(init)?;
|
||||
*init_sum_grad = init_sum_grad.add(&grad)?;
|
||||
|
||||
@ -427,6 +428,16 @@ impl Tensor {
|
||||
let src_sum_grad = grads.or_insert(src)?;
|
||||
*src_sum_grad = src_sum_grad.add(&src_grad)?;
|
||||
}
|
||||
Op::ScatterAdd(init, indexes, src, dim) => {
|
||||
let init_sum_grad = grads.or_insert(init)?;
|
||||
let mask = init.ones_like()?;
|
||||
let mask = mask.scatter(indexes, &mask.zeros_like()?, *dim)?;
|
||||
*init_sum_grad = init_sum_grad.add(&grad.mul(&mask)?)?;
|
||||
|
||||
let src_grad = grad.gather(indexes, *dim)?;
|
||||
let src_sum_grad = grads.or_insert(src)?;
|
||||
*src_sum_grad = src_sum_grad.add(&src_grad)?;
|
||||
}
|
||||
Op::IndexAdd(init, indexes, src, dim) => {
|
||||
let init_sum_grad = grads.or_insert(init)?;
|
||||
*init_sum_grad = init_sum_grad.add(&grad)?;
|
||||
|
@ -55,7 +55,7 @@ impl ParamsConvTranspose1D {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum CudnnFwdAlgo {
|
||||
ImplicitGemm,
|
||||
ImplicitPrecompGemm,
|
||||
@ -152,6 +152,19 @@ impl Tensor {
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
self.conv1d_with_algo(kernel, padding, stride, dilation, groups, None)
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d_with_algo(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
@ -175,7 +188,7 @@ impl Tensor {
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
cudnn_fwd_algo: None,
|
||||
cudnn_fwd_algo,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv1d_single_group(kernel, ¶ms)
|
||||
@ -280,6 +293,18 @@ impl Tensor {
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
self.conv2d_with_algo(kernel, padding, stride, dilation, groups, None)
|
||||
}
|
||||
|
||||
pub fn conv2d_with_algo(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
||||
@ -299,7 +324,7 @@ impl Tensor {
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
cudnn_fwd_algo: None,
|
||||
cudnn_fwd_algo,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv2d_single_group(kernel, ¶ms)
|
||||
|
@ -7,7 +7,7 @@ use rayon::prelude::*;
|
||||
|
||||
mod utils;
|
||||
pub use utils::{
|
||||
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
|
||||
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
|
||||
};
|
||||
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
@ -554,26 +554,65 @@ impl<I: IntDType> Map1 for IndexSelect<'_, I> {
|
||||
}
|
||||
}
|
||||
|
||||
struct ScatterAdd<'a, I: IntDType> {
|
||||
trait ElemUpdate {
|
||||
fn f<T: WithDType>(dst: &mut T, src: T);
|
||||
}
|
||||
|
||||
struct Set;
|
||||
struct Add;
|
||||
|
||||
impl ElemUpdate for Set {
|
||||
fn f<T: WithDType>(dst: &mut T, src: T) {
|
||||
*dst = src
|
||||
}
|
||||
}
|
||||
|
||||
impl ElemUpdate for Add {
|
||||
fn f<T: WithDType>(dst: &mut T, src: T) {
|
||||
*dst += src
|
||||
}
|
||||
}
|
||||
|
||||
struct Scatter<'a, I: IntDType, M: ElemUpdate> {
|
||||
ids: &'a [I],
|
||||
ids_l: &'a Layout,
|
||||
dim: usize,
|
||||
_phantom: std::marker::PhantomData<M>,
|
||||
}
|
||||
|
||||
impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
|
||||
const OP: &'static str = "scatter-add";
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let dst_len = l1.shape().elem_count();
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> {
|
||||
fn new(ids: &'a [I], ids_l: &'a Layout, dim: usize) -> Self {
|
||||
Self {
|
||||
ids,
|
||||
ids_l,
|
||||
dim,
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
|
||||
const OP: &'static str = "scatter";
|
||||
fn f<T: WithDType>(
|
||||
&self,
|
||||
dst: &mut [T],
|
||||
dst_l: &Layout,
|
||||
src: &[T],
|
||||
src_l: &Layout,
|
||||
) -> Result<()> {
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
Some((o1, o2)) => &mut dst[o1..o2],
|
||||
};
|
||||
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
|
||||
let dim = self.dim;
|
||||
let ids_dims = self.ids_l.dims();
|
||||
let dst_dims = l1.dims();
|
||||
let dst_dims = dst_l.dims();
|
||||
let dst_dim_len = dst_dims[dim];
|
||||
let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
|
||||
|
||||
@ -602,12 +641,12 @@ impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
|
||||
.bt())?
|
||||
}
|
||||
let dst_idx = start_dst_idx + index * dst_right_len + right_i;
|
||||
dst[dst_idx] += src[ids_idx]
|
||||
M::f(&mut dst[dst_idx], src[ids_idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(dst)
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@ -2381,19 +2420,36 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
match ids {
|
||||
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::I64(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter").bt()),
|
||||
}
|
||||
}
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<()> {
|
||||
match ids {
|
||||
Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
|
||||
}
|
||||
}
|
||||
@ -2454,6 +2510,48 @@ impl BackendStorage for CpuStorage {
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Ok(self.clone())
|
||||
}
|
||||
|
||||
fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
|
||||
use crate::scalar::Scalar;
|
||||
fn set<T: crate::WithDType>(src: &mut [T], l: &Layout, s: T) {
|
||||
match l.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
src[start_offset..start_offset + len].fill(s)
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len: 1,
|
||||
} => {
|
||||
for src_index in block_start_index {
|
||||
src[src_index] = s
|
||||
}
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
for src_index in block_start_index {
|
||||
src[src_index..src_index + block_len].fill(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
match (self, s) {
|
||||
(Self::BF16(storage), Scalar::BF16(v)) => set(storage, l, v),
|
||||
(Self::F16(storage), Scalar::F16(v)) => set(storage, l, v),
|
||||
(Self::F32(storage), Scalar::F32(v)) => set(storage, l, v),
|
||||
(Self::F64(storage), Scalar::F64(v)) => set(storage, l, v),
|
||||
(Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
|
||||
(Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
|
||||
(Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
|
||||
(st, s) => crate::bail!(
|
||||
"const_set dtype mismatch, expected {:?} but got {:?}",
|
||||
st.dtype(),
|
||||
s
|
||||
),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for CpuDevice {
|
||||
@ -2628,20 +2726,6 @@ impl BackendDevice for CpuDevice {
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let storage = match dtype {
|
||||
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
|
||||
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
|
||||
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
|
||||
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
|
||||
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
|
||||
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
|
||||
DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
|
||||
};
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let storage = match dtype {
|
||||
|
@ -58,6 +58,30 @@ pub trait Map2 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2InPlace {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &mut [T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<()>;
|
||||
|
||||
fn map(&self, v1: &mut C, l1: &Layout, v2: &C, l2: &Layout) -> Result<()> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(v1, v2) => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}
|
||||
.bt())?,
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2U8 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||
|
@ -2,7 +2,7 @@ use crate::backend::BackendDevice;
|
||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
|
||||
use cudarc::driver::CudaFunction;
|
||||
use half::{bf16, f16};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
@ -144,6 +144,24 @@ impl CudaDevice {
|
||||
self.stream.clone()
|
||||
}
|
||||
|
||||
/// When turned on, all cuda tensors **created after calling this function** will
|
||||
/// not track uses via cuda events.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// It is up to the user to ensure proper synchronization between multiple streams:
|
||||
/// - Ensure that no tensor is freed before a use on another stream is finished.
|
||||
/// - Ensure that a tensor is not used on another stream before allocation on the
|
||||
/// allocating stream finishes.
|
||||
/// - Ensure that a tensor is not written two concurrently by multiple streams.
|
||||
pub unsafe fn disable_event_tracking(&self) {
|
||||
self.context.disable_event_tracking()
|
||||
}
|
||||
|
||||
pub fn is_event_tracking(&self) -> bool {
|
||||
self.context.is_event_tracking()
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn compile(
|
||||
&self,
|
||||
@ -170,100 +188,6 @@ impl CudaDevice {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u8>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u8;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as i64;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = bf16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f16>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = f16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as f32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f64>(elem_count) }?;
|
||||
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_custom_func(
|
||||
&self,
|
||||
fn_name: &str,
|
||||
@ -486,10 +410,6 @@ impl BackendDevice for CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
self.const_impl(1., shape, dtype)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
|
@ -2,7 +2,7 @@
|
||||
//!
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, WithDType};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
@ -34,6 +34,21 @@ impl<T: DeviceRepr> SlicePtrOrNull<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::scalar::Scalar {
|
||||
pub fn builder_arg<'a, 'b: 'a>(&'b self, builder: &mut cudarc::driver::LaunchArgs<'a>) {
|
||||
use crate::scalar::Scalar;
|
||||
match self {
|
||||
Scalar::U8(v) => builder.arg(v),
|
||||
Scalar::U32(v) => builder.arg(v),
|
||||
Scalar::I64(v) => builder.arg(v),
|
||||
Scalar::F32(v) => builder.arg(v),
|
||||
Scalar::F64(v) => builder.arg(v),
|
||||
Scalar::F16(v) => builder.arg(v),
|
||||
Scalar::BF16(v) => builder.arg(v),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
impl SlicePtrOrNull<usize> {
|
||||
pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
||||
let ds = if l.is_contiguous() {
|
||||
@ -395,7 +410,7 @@ impl Map1 for IndexSelect<'_> {
|
||||
CudaStorageSlice::U8(slice) => ("is_u8", slice_ptr(slice, ids_l.start_offset())),
|
||||
CudaStorageSlice::I64(slice) => ("is_i64", slice_ptr(slice, ids_l.start_offset())),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "index_select ids should be u8 or u32",
|
||||
msg: "index_select ids should be u8, u32, or i64",
|
||||
expected: DType::U32,
|
||||
got: self.0.dtype(),
|
||||
})
|
||||
@ -492,7 +507,7 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -514,6 +529,10 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
@ -521,7 +540,7 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_shape.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let ids_dim_sz = ids_l.dims()[0];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
@ -529,7 +548,59 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
barg!(builder, ids);
|
||||
barg!(builder, ids_dim_sz);
|
||||
builder.arg(&src);
|
||||
builder.arg(dst);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct Scatter<'a>(&'a CudaStorage, &'a Layout, usize);
|
||||
impl Map2InPlace for Scatter<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
let ids = &self.0;
|
||||
let ids_l = &self.1;
|
||||
let dim = self.2;
|
||||
let (ids_o1, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
};
|
||||
let (name, (ids, _guard)) = match &ids.slice {
|
||||
CudaStorageSlice::U32(slice) => ("s_u32", slice_ptr(slice, ids_o1)),
|
||||
CudaStorageSlice::I64(slice) => ("s_i64", slice_ptr(slice, ids_o1)),
|
||||
CudaStorageSlice::U8(slice) => ("s_u8", slice_ptr(slice, ids_o1)),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "scatter ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
};
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, ids);
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
@ -542,7 +613,7 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -564,6 +635,10 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
@ -571,13 +646,13 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_shape.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, ids);
|
||||
builder.arg(&src);
|
||||
builder.arg(dst);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
@ -1235,6 +1310,36 @@ impl BackendStorage for CudaStorage {
|
||||
&self.device
|
||||
}
|
||||
|
||||
fn const_set(&mut self, s: crate::scalar::Scalar, layout: &Layout) -> Result<()> {
|
||||
let dev = &self.device;
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src_o = layout.start_offset();
|
||||
let ((src, _guard_src), kernel_name) = match &mut self.slice {
|
||||
S::U8(s) => (slice_ptr(s, src_o), "const_set_u8"),
|
||||
S::U32(s) => (slice_ptr(s, src_o), "const_set_u32"),
|
||||
S::I64(s) => (slice_ptr(s, src_o), "const_set_i64"),
|
||||
S::BF16(s) => (slice_ptr(s, src_o), "const_set_bf16"),
|
||||
S::F16(s) => (slice_ptr(s, src_o), "const_set_f16"),
|
||||
S::F32(s) => (slice_ptr(s, src_o), "const_set_f32"),
|
||||
S::F64(s) => (slice_ptr(s, src_o), "const_set_f64"),
|
||||
};
|
||||
|
||||
let func = dev.get_or_load_func(kernel_name, &kernels::FILL)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el_count);
|
||||
barg!(builder, dims.len());
|
||||
ds.builder_arg(&mut builder);
|
||||
s.builder_arg(&mut builder);
|
||||
barg!(builder, src);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
@ -1793,20 +1898,29 @@ impl BackendStorage for CudaStorage {
|
||||
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
let device = self.device().clone();
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
Scatter(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
|
||||
}
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<()> {
|
||||
let device = self.device().clone();
|
||||
ScatterAdd(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
|
||||
}
|
||||
fn index_add(
|
||||
&self,
|
||||
@ -1820,7 +1934,7 @@ impl BackendStorage for CudaStorage {
|
||||
let device = self.device().clone();
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l, &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
/// Helper functions to plug cuda kernels in candle.
|
||||
use crate::{Layout, Result, Shape, WithDType};
|
||||
use crate::{Layout, Result, WithDType};
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
|
||||
|
||||
@ -96,7 +96,7 @@ pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -105,19 +105,19 @@ pub trait Map2InPlace {
|
||||
fn map(
|
||||
&self,
|
||||
dst: &mut S,
|
||||
dst_s: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &S,
|
||||
src_l: &Layout,
|
||||
d: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
match (dst, src) {
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
}
|
||||
}
|
||||
|
@ -292,23 +292,6 @@ impl Device {
|
||||
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
|
||||
}
|
||||
|
||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuDevice.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
|
@ -107,6 +107,7 @@ pub trait WithDType:
|
||||
|
||||
fn from_f64(v: f64) -> Self;
|
||||
fn to_f64(self) -> f64;
|
||||
fn to_scalar(self) -> crate::scalar::Scalar;
|
||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||
|
||||
@ -131,6 +132,10 @@ macro_rules! with_dtype {
|
||||
$to_f64(self)
|
||||
}
|
||||
|
||||
fn to_scalar(self) -> crate::scalar::Scalar {
|
||||
crate::scalar::Scalar::$dtype(self)
|
||||
}
|
||||
|
||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
|
||||
CpuStorageRef::$dtype(data)
|
||||
}
|
||||
|
@ -37,6 +37,10 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
@ -124,15 +128,27 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
@ -214,10 +230,6 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -41,6 +41,10 @@ impl crate::backend::BackendStorage for MetalStorage {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
@ -128,15 +132,27 @@ impl crate::backend::BackendStorage for MetalStorage {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
@ -218,10 +234,6 @@ impl crate::backend::BackendDevice for MetalDevice {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
@ -413,6 +413,100 @@ impl BackendStorage for MetalStorage {
|
||||
self.binary(name, rhs, lhs_l, rhs_l)
|
||||
}
|
||||
|
||||
fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
|
||||
use crate::scalar::Scalar;
|
||||
fn set<S: crate::WithDType + candle_metal_kernels::utils::EncoderParam>(
|
||||
self_: &mut MetalStorage,
|
||||
s: S,
|
||||
l: &Layout,
|
||||
) -> Result<()> {
|
||||
let device = self_.device();
|
||||
let dtype = self_.dtype;
|
||||
let shape = l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
command_buffer.set_label("const-set");
|
||||
let dst = buffer_o(&self_.buffer, l, self_.dtype);
|
||||
|
||||
match (el_count % 2, dtype, l.is_contiguous()) {
|
||||
(0, DType::BF16 | DType::F16, true) => {
|
||||
use candle_metal_kernels::unary::contiguous_tiled;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => contiguous_tiled::const_set::HALF,
|
||||
DType::BF16 => contiguous_tiled::const_set::BFLOAT,
|
||||
_ => crate::bail!("internal bug in const_set"),
|
||||
};
|
||||
candle_metal_kernels::call_const_set_contiguous_tiled(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
s,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, true) => {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => contiguous::const_set::HALF,
|
||||
DType::BF16 => contiguous::const_set::BFLOAT,
|
||||
DType::F32 => contiguous::const_set::FLOAT,
|
||||
DType::I64 => contiguous::const_set::I64,
|
||||
DType::U32 => contiguous::const_set::U32,
|
||||
DType::U8 => contiguous::const_set::U8,
|
||||
DType::F64 => crate::bail!("unsupported const-set f64"),
|
||||
};
|
||||
candle_metal_kernels::call_const_set_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
s,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, false) => {
|
||||
use candle_metal_kernels::unary::strided;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => strided::const_set::HALF,
|
||||
DType::BF16 => strided::const_set::BFLOAT,
|
||||
DType::F32 => strided::const_set::FLOAT,
|
||||
DType::I64 => strided::const_set::I64,
|
||||
DType::U32 => strided::const_set::U32,
|
||||
DType::U8 => strided::const_set::U8,
|
||||
DType::F64 => crate::bail!("unsupported const-set f64"),
|
||||
};
|
||||
candle_metal_kernels::call_const_set_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
l.dims(),
|
||||
s,
|
||||
l.stride(),
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
match (self.dtype, s) {
|
||||
(DType::U8, Scalar::U8(s)) => set(self, s, l),
|
||||
(DType::U32, Scalar::U32(s)) => set(self, s, l),
|
||||
(DType::I64, Scalar::I64(s)) => set(self, s, l),
|
||||
(DType::F16, Scalar::F16(s)) => set(self, s, l),
|
||||
(DType::BF16, Scalar::BF16(s)) => set(self, s, l),
|
||||
(DType::F32, Scalar::F32(s)) => set(self, s, l),
|
||||
(DType::F64, Scalar::F64(s)) => set(self, s, l),
|
||||
_ => crate::bail!("dtype mismatch, expected {:?}, got {:?}", self.dtype, s),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let shape = layout.shape();
|
||||
@ -1332,18 +1426,65 @@ impl BackendStorage for MetalStorage {
|
||||
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
) -> Result<()> {
|
||||
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "scatter" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U8, DType::F32) => "s_u8_f32",
|
||||
(DType::U8, DType::F16) => "s_u8_f16",
|
||||
(DType::U8, DType::BF16) => "s_u8_bf16",
|
||||
(DType::U32, DType::U32) => "s_u32_u32",
|
||||
(DType::U32, DType::F32) => "s_u32_f32",
|
||||
(DType::U32, DType::F16) => "s_u32_f16",
|
||||
(DType::U32, DType::BF16) => "s_u32_bf16",
|
||||
(DType::I64, DType::F32) => "s_i64_f32",
|
||||
(DType::I64, DType::F16) => "s_i64_f16",
|
||||
(DType::I64, DType::BF16) => "s_i64_bf16",
|
||||
_ => Err(MetalError::UnexpectedDType {
|
||||
msg: "scatter ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let dst = buffer_o(&self.buffer, l, self.dtype);
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_scatter(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
src_l.dims(),
|
||||
l.dims(),
|
||||
dim,
|
||||
src,
|
||||
ids,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<()> {
|
||||
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
@ -1364,9 +1505,10 @@ impl BackendStorage for MetalStorage {
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let dst = buffer_o(&self.buffer, l, self.dtype);
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_scatter_add(
|
||||
candle_metal_kernels::call_scatter(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
@ -1376,10 +1518,10 @@ impl BackendStorage for MetalStorage {
|
||||
dim,
|
||||
src,
|
||||
ids,
|
||||
&acc.buffer,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(acc)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
@ -1965,40 +2107,6 @@ impl BackendDevice for MetalDevice {
|
||||
))
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
let name = match dtype {
|
||||
DType::U8 => "fill_u8",
|
||||
DType::U32 => "fill_u32",
|
||||
DType::I64 => "fill_i64",
|
||||
DType::F16 => "fill_f16",
|
||||
DType::BF16 => "fill_bf16",
|
||||
DType::F32 => "fill_f32",
|
||||
DType::F64 => {
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
|
||||
return self.storage_from_cpu_storage(&cpu_storage);
|
||||
}
|
||||
};
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
candle_metal_kernels::call_const_fill(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
shape.elem_count(),
|
||||
&buffer,
|
||||
1.,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(MetalStorage::new(
|
||||
buffer,
|
||||
self.clone(),
|
||||
shape.elem_count(),
|
||||
dtype,
|
||||
))
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
let (count, buffer) = match T::cpu_storage_ref(s) {
|
||||
CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
|
@ -80,6 +80,7 @@ pub enum Op {
|
||||
Reduce(Tensor, ReduceOp, Vec<usize>),
|
||||
Matmul(Tensor, Tensor),
|
||||
Gather(Tensor, Tensor, usize),
|
||||
Scatter(Tensor, Tensor, Tensor, usize),
|
||||
ScatterAdd(Tensor, Tensor, Tensor, usize),
|
||||
IndexSelect(Tensor, Tensor, usize),
|
||||
IndexAdd(Tensor, Tensor, Tensor, usize),
|
||||
|
@ -1,6 +1,74 @@
|
||||
//! TensorScalar Enum and Trait
|
||||
//!
|
||||
use crate::{Result, Tensor, WithDType};
|
||||
use crate::{DType, Result, Tensor, WithDType};
|
||||
use half::{bf16, f16};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum Scalar {
|
||||
U8(u8),
|
||||
U32(u32),
|
||||
I64(i64),
|
||||
BF16(bf16),
|
||||
F16(f16),
|
||||
F32(f32),
|
||||
F64(f64),
|
||||
}
|
||||
|
||||
impl<T: WithDType> From<T> for Scalar {
|
||||
fn from(value: T) -> Self {
|
||||
value.to_scalar()
|
||||
}
|
||||
}
|
||||
|
||||
impl Scalar {
|
||||
pub fn zero(dtype: DType) -> Self {
|
||||
match dtype {
|
||||
DType::U8 => Scalar::U8(0),
|
||||
DType::U32 => Scalar::U32(0),
|
||||
DType::I64 => Scalar::I64(0),
|
||||
DType::BF16 => Scalar::BF16(bf16::ZERO),
|
||||
DType::F16 => Scalar::F16(f16::ZERO),
|
||||
DType::F32 => Scalar::F32(0.0),
|
||||
DType::F64 => Scalar::F64(0.0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn one(dtype: DType) -> Self {
|
||||
match dtype {
|
||||
DType::U8 => Scalar::U8(1),
|
||||
DType::U32 => Scalar::U32(1),
|
||||
DType::I64 => Scalar::I64(1),
|
||||
DType::BF16 => Scalar::BF16(bf16::ONE),
|
||||
DType::F16 => Scalar::F16(f16::ONE),
|
||||
DType::F32 => Scalar::F32(1.0),
|
||||
DType::F64 => Scalar::F64(1.0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self {
|
||||
Scalar::U8(_) => DType::U8,
|
||||
Scalar::U32(_) => DType::U32,
|
||||
Scalar::I64(_) => DType::I64,
|
||||
Scalar::BF16(_) => DType::BF16,
|
||||
Scalar::F16(_) => DType::F16,
|
||||
Scalar::F32(_) => DType::F32,
|
||||
Scalar::F64(_) => DType::F64,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_f64(&self) -> f64 {
|
||||
match self {
|
||||
Scalar::U8(v) => *v as f64,
|
||||
Scalar::U32(v) => *v as f64,
|
||||
Scalar::I64(v) => *v as f64,
|
||||
Scalar::BF16(v) => v.to_f64(),
|
||||
Scalar::F16(v) => v.to_f64(),
|
||||
Scalar::F32(v) => *v as f64,
|
||||
Scalar::F64(v) => *v,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum TensorScalar {
|
||||
Tensor(Tensor),
|
||||
|
@ -1,5 +1,6 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::op::{self, CmpOp, ReduceOp};
|
||||
use crate::scalar::Scalar;
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
||||
|
||||
@ -73,6 +74,14 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn const_set(&mut self, v: Scalar, l: &Layout) -> Result<()> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => storage.const_set(v, l),
|
||||
Storage::Cuda(storage) => storage.const_set(v, l),
|
||||
Storage::Metal(storage) => storage.const_set(v, l),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
@ -619,32 +628,56 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn scatter_add(
|
||||
&self,
|
||||
pub(crate) fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
source: &Self,
|
||||
source_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
self.same_device(indexes, "scatter-set")?;
|
||||
self.same_device(source, "scatter-set")?;
|
||||
match (self, indexes, source) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn scatter_add(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
source: &Self,
|
||||
source_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<()> {
|
||||
self.same_device(indexes, "scatter-add")?;
|
||||
self.same_device(source, "scatter-add")?;
|
||||
match (self, indexes, source) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn index_add(
|
||||
|
@ -3,7 +3,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
|
||||
use crate::scalar::TensorOrScalar;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::shape::{Dim, Dims, ShapeWithOneHole};
|
||||
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
@ -185,7 +185,9 @@ impl Tensor {
|
||||
) -> Result<Self> {
|
||||
let none = BackpropOp::none();
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
||||
let layout = Layout::contiguous(shape.clone());
|
||||
storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
}
|
||||
|
||||
@ -202,6 +204,18 @@ impl Tensor {
|
||||
Self::ones_impl(shape, dtype, device, false)
|
||||
}
|
||||
|
||||
pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> {
|
||||
self.storage_mut().const_set(value, self.layout())
|
||||
}
|
||||
|
||||
pub fn zero_set(&self) -> Result<()> {
|
||||
self.const_set(crate::scalar::Scalar::zero(self.dtype()))
|
||||
}
|
||||
|
||||
pub fn one_set(&self) -> Result<()> {
|
||||
self.const_set(crate::scalar::Scalar::one(self.dtype()))
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
|
||||
///
|
||||
/// ```rust
|
||||
@ -452,17 +466,13 @@ impl Tensor {
|
||||
Self::from_vec_impl(data, len, device, false)
|
||||
}
|
||||
|
||||
pub(crate) fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
|
||||
pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let buffer_size = data.len();
|
||||
if buffer_size != shape.elem_count() {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let shape = shape.into_shape(data.len())?;
|
||||
let storage = device.storage_owned(data)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
@ -481,7 +491,7 @@ impl Tensor {
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
|
||||
pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
@ -502,17 +512,12 @@ impl Tensor {
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||
pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||
array: &[D],
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let n: usize = shape.elem_count();
|
||||
let buffer_size: usize = array.len();
|
||||
if buffer_size != n {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let shape = shape.into_shape(array.len())?;
|
||||
let storage = device.storage_from_slice(array)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, shape, none, false))
|
||||
@ -1349,8 +1354,7 @@ impl Tensor {
|
||||
self.index_select(ids, 0)
|
||||
}
|
||||
|
||||
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter-add")?;
|
||||
fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
|
||||
let source_dims = source.dims();
|
||||
let self_dims = self.dims();
|
||||
let mismatch = if source_dims.len() != self_dims.len() {
|
||||
@ -1367,7 +1371,7 @@ impl Tensor {
|
||||
};
|
||||
if mismatch {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (self, src)",
|
||||
op: "scatter (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
@ -1375,13 +1379,44 @@ impl Tensor {
|
||||
}
|
||||
if indexes.dims() != source.dims() {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (indexes, src)",
|
||||
op: "scatter (indexes, src)",
|
||||
lhs: indexes.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let storage = self.storage().scatter_add(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
let shape = self.shape();
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let layout = Layout::contiguous(shape);
|
||||
storage.scatter_set(
|
||||
&layout,
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
|
||||
Op::Scatter(t1, t2, t3, dim)
|
||||
});
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
|
||||
if self.same_storage(source) {
|
||||
crate::bail!("cannot use slice_set when self and src share their storage")
|
||||
}
|
||||
let dim = dim.to_index(self.shape(), "scatter-set")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
self.storage_mut().scatter_set(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
@ -1389,12 +1424,48 @@ impl Tensor {
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter-add")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
let shape = self.shape();
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let layout = Layout::contiguous(shape);
|
||||
storage.scatter_add(
|
||||
&layout,
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
|
||||
Op::ScatterAdd(t1, t2, t3, dim)
|
||||
});
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
|
||||
if self.same_storage(source) {
|
||||
crate::bail!("cannot use slice_set when self and src share their storage")
|
||||
}
|
||||
let dim = dim.to_index(self.shape(), "scatter-add-set")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
self.storage_mut().scatter_add(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
|
||||
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "slice-scatter")?;
|
||||
@ -2197,7 +2268,7 @@ impl Tensor {
|
||||
///
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
|
||||
pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
|
||||
let shape = s.into_shape(self.elem_count())?;
|
||||
if shape.elem_count() != self.elem_count() {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
|
@ -241,7 +241,7 @@ impl Tensor {
|
||||
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
|
||||
/// has to be greater than or equal to `offset` plus the `src` size.
|
||||
///
|
||||
/// Note that this modifies `self` in place and as such is not compatibel with
|
||||
/// Note that this modifies `self` in place and as such is not compatible with
|
||||
/// back-propagation.
|
||||
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
|
||||
let dim = dim.to_index(self.shape(), "slice-set")?;
|
||||
|
@ -25,10 +25,12 @@ fn ones(device: &Device) -> Result<()> {
|
||||
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
if !device.is_metal() {
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
}
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
|
||||
[
|
||||
@ -63,6 +65,26 @@ fn ones(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn full(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((3, 4), DType::U32, device)?;
|
||||
tensor.const_set(42u32.into())?;
|
||||
assert_eq!(
|
||||
tensor.to_vec2::<u32>()?,
|
||||
[[42, 42, 42, 42], [42, 42, 42, 42], [42, 42, 42, 42]]
|
||||
);
|
||||
tensor.i((.., 2))?.const_set(1337u32.into())?;
|
||||
assert_eq!(
|
||||
tensor.to_vec2::<u32>()?,
|
||||
[[42, 42, 1337, 42], [42, 42, 1337, 42], [42, 42, 1337, 42]]
|
||||
);
|
||||
tensor.i((2, ..))?.const_set(1u32.into())?;
|
||||
assert_eq!(
|
||||
tensor.to_vec2::<u32>()?,
|
||||
[[42, 42, 1337, 42], [42, 42, 1337, 42], [1, 1, 1, 1]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn const_set(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
|
||||
[[42, 42, 42], [42, 42, 42]],
|
||||
@ -826,6 +848,31 @@ fn embeddings(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_fail() -> Result<()> {
|
||||
// Check that an error is properly reported on out of bounds.
|
||||
let ids = Tensor::new(&[4u32, 2u32, 1u32], &Device::Cpu)?;
|
||||
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &Device::Cpu)?;
|
||||
let hs = t.index_select(&ids, 0);
|
||||
assert!(hs.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// The test below triggers an unwinding panic as there is a panic within the
|
||||
// #[cfg(feature = "cuda")]
|
||||
// #[test]
|
||||
// #[should_panic]
|
||||
// fn index_select_fail_gpu() {
|
||||
// // Check that a panic happens for out of bounds in cuda
|
||||
// if let Ok(device) = Device::new_cuda(0) {
|
||||
// if let Ok(ids) = Tensor::new(&[4u32, 2u32, 1u32], &device) {
|
||||
// if let Ok(t) = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &device) {
|
||||
// let _ = t.index_select(&ids, 0);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
fn cmp(device: &Device) -> Result<()> {
|
||||
let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
|
||||
let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?;
|
||||
@ -980,7 +1027,7 @@ fn slice_scatter(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scatter_add(device: &Device) -> Result<()> {
|
||||
fn scatter(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
@ -1004,6 +1051,17 @@ fn scatter_add(device: &Device) -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
let hs = init.scatter(&ids, &t, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0, 1.0, 1.0],
|
||||
[5.0, 1.0, 1.0, 3.0, 4.0],
|
||||
[1.0, 8.0, 1.0, 7.0, 1.0],
|
||||
[10.0, 1.0, 9.0, 1.0, 11.0]
|
||||
]
|
||||
);
|
||||
|
||||
let init = Tensor::ones((6, 3), DType::F32, device)?;
|
||||
let hs = init.scatter_add(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
@ -1017,6 +1075,30 @@ fn scatter_add(device: &Device) -> Result<()> {
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
let hs = init.scatter(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 10.0, 5.0],
|
||||
[1.0, 1.0, 8.0],
|
||||
[9.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 1.0],
|
||||
[1.0, 4.0, 11.0],
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
init.scatter_set(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
init.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 10.0, 5.0],
|
||||
[1.0, 1.0, 8.0],
|
||||
[9.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 1.0],
|
||||
[1.0, 4.0, 11.0],
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1484,6 +1566,7 @@ fn zero_dim(device: &Device) -> Result<()> {
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
||||
test_device!(const_set, cs_cpu, cs_gpu, cs_metal);
|
||||
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||
@ -1515,12 +1598,7 @@ test_device!(
|
||||
);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
||||
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
||||
test_device!(
|
||||
scatter_add,
|
||||
scatter_add_cpu,
|
||||
scatter_add_gpu,
|
||||
scatter_add_metal
|
||||
);
|
||||
test_device!(scatter, scatter_cpu, scatter_gpu, scatter_metal);
|
||||
test_device!(
|
||||
slice_scatter,
|
||||
slice_scatter_cpu,
|
||||
|
@ -124,6 +124,17 @@ impl TextGeneration {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
|
||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||
Some(token) => token,
|
||||
None => {
|
||||
println!(
|
||||
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
|
||||
);
|
||||
eos_token
|
||||
}
|
||||
};
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
@ -146,7 +157,7 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
@ -350,6 +361,31 @@ fn main() -> Result<()> {
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
|
||||
let prompt = match args.which {
|
||||
Which::Base2B
|
||||
| Which::Base7B
|
||||
| Which::Instruct2B
|
||||
| Which::Instruct7B
|
||||
| Which::InstructV1_1_2B
|
||||
| Which::InstructV1_1_7B
|
||||
| Which::CodeBase2B
|
||||
| Which::CodeBase7B
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B
|
||||
| Which::BaseV2_2B
|
||||
| Which::InstructV2_2B
|
||||
| Which::BaseV2_9B
|
||||
| Which::InstructV2_9B
|
||||
| Which::BaseV3_1B => args.prompt,
|
||||
Which::InstructV3_1B => {
|
||||
format!(
|
||||
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
|
||||
args.prompt
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
pipeline.run(&prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -46,7 +46,7 @@ impl TextGeneration {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (top_k, top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(None, None) => Sampling::GumbelSoftmax { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
|
@ -256,6 +256,12 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let tokenizer = common_args.tokenizer()?;
|
||||
|
||||
let device = candle_examples::device(common_args.cpu)?;
|
||||
#[cfg(feature = "cuda")]
|
||||
if let candle::Device::Cuda(d) = &device {
|
||||
unsafe {
|
||||
d.disable_event_tracking();
|
||||
}
|
||||
};
|
||||
|
||||
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
|
||||
let is_safetensors = config_path
|
||||
|
18
candle-examples/examples/quantized-gemma/README.md
Normal file
18
candle-examples/examples/quantized-gemma/README.md
Normal file
@ -0,0 +1,18 @@
|
||||
# candle-quantized-gemma
|
||||
|
||||
Candle implementation of quantized Gemma.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example quantized-gemma -- --prompt "Write a function to calculate fibonacci numbers. "
|
||||
|
||||
> ```python
|
||||
> def fibonacci(n):
|
||||
> """Calculates the nth Fibonacci number using recursion."""
|
||||
> if n <= 1:
|
||||
> return n
|
||||
> else:
|
||||
> return fibonacci(n-1) + fibonacci(n-2
|
||||
> ```
|
||||
```
|
344
candle-examples/examples/quantized-gemma/main.rs
Normal file
344
candle-examples/examples/quantized-gemma/main.rs
Normal file
@ -0,0 +1,344 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::Tensor;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_transformers::models::quantized_gemma3::ModelWeights;
|
||||
|
||||
const DEFAULT_PROMPT: &str = "Write a function to calculate fibonacci num";
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "gemma3-4b-it")]
|
||||
Gemma3_4bIt,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// GGUF file to load, typically a .gguf file generated by quantization
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
|
||||
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
|
||||
/// is preserved.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The tokenizer config in json format.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples, use 0 for greedy sampling.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Process prompt elements separately.
|
||||
#[arg(long)]
|
||||
split_prompt: bool,
|
||||
|
||||
/// Run on CPU rather than GPU even if a GPU is available.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "gemma3-4b-it")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
|
||||
let tokenizer_path = match &self.tokenizer {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = "google/gemma-3-4b-it";
|
||||
println!("DEBUG: Downloading tokenizer from {}", repo);
|
||||
let api = api.model(repo.to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
println!("DEBUG: Loading tokenizer from {:?}", tokenizer_path);
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
|
||||
|
||||
Ok(tokenizer)
|
||||
}
|
||||
|
||||
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
|
||||
let model_path = match &self.model {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let (repo, filename) = match self.which {
|
||||
Which::Gemma3_4bIt => (
|
||||
"google/gemma-3-4b-it-qat-q4_0-gguf",
|
||||
"gemma-3-4b-it-q4_0.gguf",
|
||||
),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
api.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"main".to_string(),
|
||||
))
|
||||
.get(filename)?
|
||||
}
|
||||
};
|
||||
Ok(model_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn format_size(size_in_bytes: usize) -> String {
|
||||
if size_in_bytes < 1_000 {
|
||||
format!("{}B", size_in_bytes)
|
||||
} else if size_in_bytes < 1_000_000 {
|
||||
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
|
||||
} else if size_in_bytes < 1_000_000_000 {
|
||||
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
|
||||
} else {
|
||||
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Prompt {
|
||||
Interactive,
|
||||
Chat,
|
||||
One(String),
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||
);
|
||||
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let mut model = {
|
||||
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(&model_path))?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensor_infos.iter() {
|
||||
let elem_count = tensor.shape.elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
model.tensor_infos.len(),
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
ModelWeights::from_gguf(model, &mut file, &device)?
|
||||
};
|
||||
println!("model built");
|
||||
|
||||
let tokenizer = args.tokenizer()?;
|
||||
|
||||
let mut tos = TokenOutputStream::new(tokenizer);
|
||||
println!(
|
||||
"DEBUG: Tokenizer vocabulary size: {}",
|
||||
tos.tokenizer().get_vocab(true).len()
|
||||
);
|
||||
|
||||
let prompt = match args.prompt.as_deref() {
|
||||
Some("chat") => Prompt::Chat,
|
||||
Some("interactive") => Prompt::Interactive,
|
||||
Some(s) => Prompt::One(s.to_string()),
|
||||
None => Prompt::One(DEFAULT_PROMPT.to_string()),
|
||||
};
|
||||
|
||||
let mut pre_prompt_tokens = vec![];
|
||||
for _ in 0.. {
|
||||
let prompt_str = match &prompt {
|
||||
Prompt::One(prompt) => prompt.clone(),
|
||||
Prompt::Interactive | Prompt::Chat => {
|
||||
print!("> ");
|
||||
std::io::stdout().flush()?;
|
||||
let mut prompt = String::new();
|
||||
std::io::stdin().read_line(&mut prompt)?;
|
||||
if prompt.ends_with('\n') {
|
||||
prompt.pop();
|
||||
if prompt.ends_with('\r') {
|
||||
prompt.pop();
|
||||
}
|
||||
}
|
||||
// Format for Gemma 3 chat/instruction format
|
||||
format!("<start_of_turn> user\n{prompt}<end_of_turn>\n<start_of_turn> model\n")
|
||||
}
|
||||
};
|
||||
print!("{}", &prompt_str);
|
||||
|
||||
let tokens = tos
|
||||
.tokenizer()
|
||||
.encode(prompt_str, true)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat();
|
||||
|
||||
let to_sample = args.sample_len.saturating_sub(1);
|
||||
let max_seq_len = 8192; // Gemma 3 context length
|
||||
let prompt_tokens = if prompt_tokens.len() + to_sample > max_seq_len - 10 {
|
||||
let to_remove = prompt_tokens.len() + to_sample + 10 - max_seq_len;
|
||||
prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec()
|
||||
} else {
|
||||
prompt_tokens
|
||||
};
|
||||
let mut all_tokens = vec![];
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = if !args.split_prompt {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
} else {
|
||||
let mut next_token = 0;
|
||||
for (pos, token) in prompt_tokens.iter().enumerate() {
|
||||
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
next_token = logits_processor.sample(&logits)?
|
||||
}
|
||||
next_token
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
// For Gemma 3, use the correct end of sequence token
|
||||
let eos_token = *tos
|
||||
.tokenizer()
|
||||
.get_vocab(true)
|
||||
.get("<end_of_turn>")
|
||||
.unwrap();
|
||||
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&all_tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
sampled += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
};
|
||||
}
|
||||
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||
prompt_tokens.len(),
|
||||
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||
);
|
||||
println!(
|
||||
"{sampled:4} tokens generated: {:.2} token/s",
|
||||
sampled as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
|
||||
match prompt {
|
||||
Prompt::One(_) => break,
|
||||
Prompt::Interactive => {}
|
||||
Prompt::Chat => {
|
||||
pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -133,6 +133,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
||||
padding,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv = if bias {
|
||||
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
|
||||
|
@ -92,6 +92,7 @@ impl ConvBlock {
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.9.0-alpha.2"
|
||||
version = "0.9.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.2" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.9.0-alpha.2"
|
||||
version = "0.9.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include<stdint.h>
|
||||
#include "cuda_fp16.h"
|
||||
#include "cuda_utils.cuh"
|
||||
|
||||
template<typename T>
|
||||
__device__ void fill_with(T *buf, T value, const size_t numel) {
|
||||
@ -36,13 +37,45 @@ COPY2D_OP(uint8_t, copy2d_u8)
|
||||
COPY2D_OP(uint32_t, copy2d_u32)
|
||||
COPY2D_OP(int64_t, copy2d_i64)
|
||||
|
||||
#define CONST_SET_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const TYPENAME inp, \
|
||||
TYPENAME *out \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
out[i] = inp; \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
out[strided_i] = inp; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
CONST_SET_OP(float, const_set_f32)
|
||||
CONST_SET_OP(double, const_set_f64)
|
||||
CONST_SET_OP(uint8_t, const_set_u8)
|
||||
CONST_SET_OP(uint32_t, const_set_u32)
|
||||
CONST_SET_OP(int64_t, const_set_i64)
|
||||
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
COPY2D_OP(__half, copy2d_f16)
|
||||
CONST_SET_OP(__half, const_set_f16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
COPY2D_OP(__nv_bfloat16, copy2d_bf16)
|
||||
CONST_SET_OP(__nv_bfloat16, const_set_bf16)
|
||||
#endif
|
||||
|
@ -23,6 +23,7 @@ __device__ void index_select(
|
||||
unsigned int left_i = dst_i / (ids_dim_size * right_size);
|
||||
unsigned int id_i = dst_i / right_size % ids_dim_size;
|
||||
unsigned int right_i = dst_i % right_size;
|
||||
assert(ids[id_i] < src_dim_size);
|
||||
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
|
||||
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
|
||||
out[dst_i] = inp[strided_i];
|
||||
@ -57,6 +58,7 @@ __device__ void gather(
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
size_t post = i % right_size;
|
||||
size_t idx = ids[i];
|
||||
assert(idx < src_dim_size);
|
||||
size_t pre = i / (right_size * ids_dim_size);
|
||||
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
|
||||
out[i] = inp[src_i];
|
||||
@ -92,6 +94,7 @@ __device__ void index_add(
|
||||
const size_t post = i % right_size;
|
||||
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
||||
const size_t idx = ids[j];
|
||||
assert(idx < dst_dim_size);
|
||||
const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
@ -111,6 +114,30 @@ extern "C" __global__ void FN_NAME( \
|
||||
const size_t right_size \
|
||||
) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
|
||||
|
||||
template<typename T, typename I>
|
||||
__device__ void scatter(
|
||||
const I *ids,
|
||||
const T *inp,
|
||||
T *out,
|
||||
const size_t left_size,
|
||||
const size_t src_dim_size,
|
||||
const size_t dst_dim_size,
|
||||
const size_t right_size
|
||||
) {
|
||||
const size_t numel = left_size * right_size;
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
const size_t pre = i / right_size;
|
||||
const size_t post = i % right_size;
|
||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
||||
const size_t idx = ids[src_i];
|
||||
assert(idx < dst_dim_size);
|
||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] = inp[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename I>
|
||||
__device__ void scatter_add(
|
||||
const I *ids,
|
||||
@ -128,12 +155,24 @@ __device__ void scatter_add(
|
||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
||||
const size_t idx = ids[src_i];
|
||||
assert(idx < dst_dim_size);
|
||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define S_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const INDEX_TYPENAME *ids, \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out, \
|
||||
const size_t left_size, \
|
||||
const size_t src_dim_size, \
|
||||
const size_t dst_dim_size, \
|
||||
const size_t right_size \
|
||||
) { scatter(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
|
||||
|
||||
#define SA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const INDEX_TYPENAME *ids, \
|
||||
@ -159,6 +198,9 @@ IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16)
|
||||
SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16)
|
||||
SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16)
|
||||
SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
|
||||
S_OP(__nv_bfloat16, int64_t, s_i64_bf16)
|
||||
S_OP(__nv_bfloat16, uint32_t, s_u32_bf16)
|
||||
S_OP(__nv_bfloat16, uint8_t, s_u8_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -174,6 +216,9 @@ IA_OP(__half, uint8_t, ia_u8_f16)
|
||||
SA_OP(__half, int64_t, sa_i64_f16)
|
||||
SA_OP(__half, uint32_t, sa_u32_f16)
|
||||
SA_OP(__half, uint8_t, sa_u8_f16)
|
||||
S_OP(__half, int64_t, s_i64_f16)
|
||||
S_OP(__half, uint32_t, s_u32_f16)
|
||||
S_OP(__half, uint8_t, s_u8_f16)
|
||||
#endif
|
||||
|
||||
IS_OP(float, int64_t, is_i64_f32)
|
||||
@ -247,3 +292,21 @@ SA_OP(double, uint8_t, sa_u8_f64)
|
||||
SA_OP(uint8_t, uint8_t, sa_u8_u8)
|
||||
SA_OP(uint32_t, uint8_t, sa_u8_u32)
|
||||
SA_OP(int64_t, uint8_t, sa_u8_i64)
|
||||
|
||||
S_OP(float, int64_t, s_i64_f32)
|
||||
S_OP(double, int64_t, s_i64_f64)
|
||||
S_OP(uint8_t, int64_t, s_i64_u8)
|
||||
S_OP(int64_t, int64_t, s_i64_i64)
|
||||
S_OP(uint32_t, int64_t, s_i64_u32)
|
||||
|
||||
S_OP(float, uint32_t, s_u32_f32)
|
||||
S_OP(double, uint32_t, s_u32_f64)
|
||||
S_OP(uint8_t, uint32_t, s_u32_u8)
|
||||
S_OP(int64_t, uint32_t, s_u32_i64)
|
||||
S_OP(uint32_t, uint32_t, s_u32_u32)
|
||||
|
||||
S_OP(float, uint8_t, s_u8_f32)
|
||||
S_OP(double, uint8_t, s_u8_f64)
|
||||
S_OP(uint8_t, uint8_t, s_u8_u8)
|
||||
S_OP(uint32_t, uint8_t, s_u8_u32)
|
||||
S_OP(int64_t, uint8_t, s_u8_i64)
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.9.0-alpha.2"
|
||||
version = "0.9.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
@ -12,6 +12,7 @@ license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
metal = { version = "0.27.0", features = ["mps"] }
|
||||
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
@ -4,20 +4,20 @@ using namespace metal;
|
||||
|
||||
template<typename T> METAL_FUNC void fill_with(
|
||||
device T *out,
|
||||
constant float &value,
|
||||
constant T &value,
|
||||
constant size_t &numel,
|
||||
uint tid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (tid >= numel) {
|
||||
return;
|
||||
}
|
||||
out[tid] = static_cast<T>(value);
|
||||
out[tid] = value;
|
||||
}
|
||||
|
||||
#define FILL_OP(NAME, T) \
|
||||
kernel void fill_##NAME( \
|
||||
device T *out, \
|
||||
constant float &value, \
|
||||
constant T &value, \
|
||||
constant size_t &numel, \
|
||||
uint tid [[thread_position_in_grid]] \
|
||||
) { \
|
||||
|
@ -104,6 +104,31 @@ kernel void NAME( \
|
||||
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
|
||||
}
|
||||
|
||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||
METAL_FUNC void scatter(
|
||||
constant size_t &dst_size,
|
||||
constant size_t &left_size,
|
||||
constant size_t &src_dim_size,
|
||||
constant size_t &right_size,
|
||||
constant size_t &dst_dim_size,
|
||||
const device TYPENAME *input,
|
||||
const device INDEX_TYPENAME *input_ids,
|
||||
device TYPENAME *output,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
if (tid >= dst_size) {
|
||||
return;
|
||||
}
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size;
|
||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||
const INDEX_TYPENAME idx = input_ids[src_i];
|
||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||
output[dst_i] = input[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||
METAL_FUNC void scatter_add(
|
||||
constant size_t &dst_size,
|
||||
@ -129,6 +154,21 @@ METAL_FUNC void scatter_add(
|
||||
}
|
||||
}
|
||||
|
||||
# define SCATTER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
constant size_t &left_size, \
|
||||
constant size_t &src_dim_size, \
|
||||
constant size_t &right_size, \
|
||||
constant size_t &dst_dim_size, \
|
||||
const device TYPENAME *input, \
|
||||
const device INDEX_TYPENAME *input_ids, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
scatter<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
|
||||
}
|
||||
|
||||
# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
@ -235,6 +275,19 @@ SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
|
||||
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
SCATTER_OP(s_u32_f32, uint32_t, float)
|
||||
SCATTER_OP(s_u8_f32, uint8_t, float)
|
||||
SCATTER_OP(s_i64_f32, int64_t, float)
|
||||
SCATTER_OP(s_u32_u32, uint32_t, uint32_t)
|
||||
SCATTER_OP(s_u32_f16, uint32_t, half)
|
||||
SCATTER_OP(s_u8_f16, uint8_t, half)
|
||||
SCATTER_OP(s_i64_f16, int64_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
SCATTER_OP(s_u32_bf16, uint32_t, bfloat)
|
||||
SCATTER_OP(s_u8_bf16, uint8_t, bfloat)
|
||||
SCATTER_OP(s_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
// i64
|
||||
INDEX_ADD_OP(ia_i64_f16, int64_t, half)
|
||||
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
|
||||
|
@ -161,7 +161,7 @@ macro_rules! ops{
|
||||
pub mod unary {
|
||||
ops!(
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||
tanh, recip, silu, sign, sigmoid
|
||||
tanh, recip, silu, sign, sigmoid, const_set
|
||||
);
|
||||
}
|
||||
pub mod binary {
|
||||
@ -419,6 +419,82 @@ pub fn call_copy2d(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_const_set_contiguous_tiled(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous_tiled::Kernel,
|
||||
length: usize,
|
||||
input: impl EncoderParam,
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let tile_size = 2;
|
||||
let tiles = length.div_ceil(tile_size);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, &output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_const_set_contiguous(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous::Kernel,
|
||||
length: usize,
|
||||
input: impl EncoderParam,
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, &output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_const_set_strided(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
name: unary::strided::Kernel,
|
||||
shape: &[usize],
|
||||
input: impl EncoderParam,
|
||||
strides: &[usize],
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(encoder, (length, num_dims, shape, strides, input, &output));
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_unary_contiguous_tiled(
|
||||
device: &Device,
|
||||
@ -1371,7 +1447,7 @@ pub fn call_gather(
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_scatter_add(
|
||||
pub fn call_scatter(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
@ -1381,7 +1457,7 @@ pub fn call_scatter_add(
|
||||
dim: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = src_shape[..dim].iter().product();
|
||||
let right_size: usize = src_shape[dim + 1..].iter().product();
|
||||
@ -1406,7 +1482,7 @@ pub fn call_scatter_add(
|
||||
dst_dim_size,
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
&output
|
||||
)
|
||||
);
|
||||
|
||||
@ -1414,7 +1490,7 @@ pub fn call_scatter_add(
|
||||
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
@ -2570,7 +2646,7 @@ pub fn call_const_fill(
|
||||
name: &'static str,
|
||||
length: usize,
|
||||
output: &Buffer,
|
||||
v: f32,
|
||||
v: impl EncoderParam,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
|
||||
let encoder = ep.encoder();
|
||||
|
@ -1574,7 +1574,7 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
let input_buffer = new_buffer(&device, input);
|
||||
let ids_buffer = new_buffer(&device, ids);
|
||||
let output = device.new_buffer(std::mem::size_of_val(input) as u64, options);
|
||||
call_scatter_add(
|
||||
call_scatter(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
@ -2343,7 +2343,7 @@ fn conv_transpose1d_u32() {
|
||||
|
||||
#[test]
|
||||
fn const_fill() {
|
||||
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
|
||||
fn constant_fill<T: Clone + EncoderParam>(name: &'static str, len: usize, value: T) -> Vec<T> {
|
||||
let dev = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = dev.new_command_queue();
|
||||
@ -2357,11 +2357,15 @@ fn const_fill() {
|
||||
command_buffer.wait_until_completed();
|
||||
read_to_vec::<T>(&buffer, len)
|
||||
}
|
||||
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
|
||||
fn test<T: Clone + Copy + EncoderParam + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(
|
||||
name: &'static str,
|
||||
f: F,
|
||||
) {
|
||||
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
|
||||
let value = rand::thread_rng().gen_range(1. ..19.);
|
||||
let value = f(value);
|
||||
let v = constant_fill::<T>(name, len, value);
|
||||
assert_eq!(v, vec![f(value); len])
|
||||
assert_eq!(v, vec![value; len])
|
||||
}
|
||||
test::<u8, _>("fill_u8", |v| v as u8);
|
||||
test::<u32, _>("fill_u32", |v| v as u32);
|
||||
|
@ -73,6 +73,44 @@ template <typename T> METAL_FUNC T sigmoid(T in) {
|
||||
|
||||
#define TILE_SIZE 2
|
||||
|
||||
#define CONST_SET(TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
constant TYPENAME &input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = input; \
|
||||
} \
|
||||
kernel void FN_NAME##_##strided( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant TYPENAME &input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[get_strided_index(tid, num_dims, dims, strides)] = input; \
|
||||
} \
|
||||
kernel void FN_NAME##_##tiled( \
|
||||
constant size_t &dim, \
|
||||
constant TYPENAME &input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
for (uint i = 0; i < TILE_SIZE; i++) { \
|
||||
const uint idx = tid * TILE_SIZE + i; \
|
||||
output[idx] = input; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
@ -139,6 +177,11 @@ COPY2D(copy2d_f16, half)
|
||||
COPY2D(copy2d_u8, uint8_t)
|
||||
COPY2D(copy2d_u32, uint32_t)
|
||||
|
||||
CONST_SET(float, const_set_f32)
|
||||
CONST_SET(half, const_set_f16)
|
||||
CONST_SET(uint8_t, const_set_u8)
|
||||
CONST_SET(uint32_t, const_set_u32)
|
||||
|
||||
UNARY_OP(cos)
|
||||
UNARY_OP(sin)
|
||||
UNARY_OP(sqr)
|
||||
@ -171,6 +214,7 @@ UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided);
|
||||
#if __METAL_VERSION__ >= 220
|
||||
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
||||
COPY2D(copy2d_i64, int64_t)
|
||||
CONST_SET(int64_t, const_set_i64)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
@ -199,4 +243,5 @@ UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided);
|
||||
|
||||
COPY2D(copy2d_bf16, bfloat)
|
||||
CONST_SET(bfloat, const_set_bf16)
|
||||
#endif
|
||||
|
@ -88,9 +88,13 @@ primitive!(bool);
|
||||
primitive!(usize);
|
||||
primitive!(i32);
|
||||
primitive!(i64);
|
||||
primitive!(u8);
|
||||
primitive!(u32);
|
||||
primitive!(u64);
|
||||
primitive!(f32);
|
||||
primitive!(f64);
|
||||
primitive!(half::bf16);
|
||||
primitive!(half::f16);
|
||||
|
||||
pub struct BufferOffset<'a> {
|
||||
pub buffer: &'a Buffer,
|
||||
|
@ -71,6 +71,8 @@ impl candle::Module for PReLU {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let weight = if self.is_scalar {
|
||||
self.weight.reshape(())?
|
||||
} else if xs.shape() == self.weight.shape() {
|
||||
self.weight.clone()
|
||||
} else if xs.rank() >= 2 {
|
||||
let num_channels = xs.dim(1)?;
|
||||
let num_weights = self.weight.elem_count();
|
||||
@ -78,7 +80,7 @@ impl candle::Module for PReLU {
|
||||
candle::bail!("error in prelu: unexpected number of channels for the input, got {num_channels}, weight dim is {num_weights}")
|
||||
}
|
||||
let mut s = vec![1; xs.rank()];
|
||||
s[1] = self.weight.elem_count();
|
||||
s[1] = num_weights;
|
||||
self.weight.reshape(s)?
|
||||
} else {
|
||||
self.weight.clone()
|
||||
|
@ -1,6 +1,6 @@
|
||||
//! Convolution Layers.
|
||||
use crate::BatchNorm;
|
||||
use candle::{Result, Tensor};
|
||||
use candle::{conv::CudnnFwdAlgo, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct Conv1dConfig {
|
||||
@ -8,6 +8,7 @@ pub struct Conv1dConfig {
|
||||
pub stride: usize,
|
||||
pub dilation: usize,
|
||||
pub groups: usize,
|
||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
}
|
||||
|
||||
impl Default for Conv1dConfig {
|
||||
@ -17,6 +18,7 @@ impl Default for Conv1dConfig {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -52,12 +54,13 @@ impl Conv1d {
|
||||
|
||||
impl crate::Module for Conv1d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv1d(
|
||||
let x = x.conv1d_with_algo(
|
||||
&self.weight,
|
||||
self.config.padding,
|
||||
self.config.stride,
|
||||
self.config.dilation,
|
||||
self.config.groups,
|
||||
self.config.cudnn_fwd_algo,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
@ -147,6 +150,7 @@ pub struct Conv2dConfig {
|
||||
pub stride: usize,
|
||||
pub dilation: usize,
|
||||
pub groups: usize,
|
||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
}
|
||||
|
||||
impl Default for Conv2dConfig {
|
||||
@ -156,6 +160,7 @@ impl Default for Conv2dConfig {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -211,12 +216,13 @@ impl Conv2d {
|
||||
|
||||
impl crate::Module for Conv2d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv2d(
|
||||
let x = x.conv2d_with_algo(
|
||||
&self.weight,
|
||||
self.config.padding,
|
||||
self.config.stride,
|
||||
self.config.dilation,
|
||||
self.config.groups,
|
||||
self.config.cudnn_fwd_algo,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
|
@ -294,6 +294,27 @@ impl RotatingCache {
|
||||
Tensor::from_slice(&mask, (size1, size2), device)
|
||||
}
|
||||
|
||||
/// Returns the positions corresponding to all the elements that will be retured
|
||||
/// *after* adding `seq_len` to the cache.
|
||||
pub fn positions(&self, seq_len: usize) -> Vec<usize> {
|
||||
if seq_len <= self.max_seq_len {
|
||||
let upd_offset = (self.offset + seq_len) % self.max_seq_len;
|
||||
let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len);
|
||||
(0..cache_out_len)
|
||||
.map(|i| {
|
||||
let pos_cache = self.current_seq_len + seq_len + i - upd_offset;
|
||||
if i < upd_offset {
|
||||
pos_cache
|
||||
} else {
|
||||
pos_cache - self.max_seq_len
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
(self.current_seq_len..(self.current_seq_len + seq_len)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
|
||||
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
|
||||
let mask = if seq_len == 1 {
|
||||
@ -362,10 +383,17 @@ impl RotatingKvCache {
|
||||
self.k.current_seq_len()
|
||||
}
|
||||
|
||||
/// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
|
||||
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
|
||||
self.k.attn_mask(seq_len, device)
|
||||
}
|
||||
|
||||
/// Returns the positions corresponding to all the elements that will be retured
|
||||
/// *after* adding `seq_len` to the cache.
|
||||
pub fn positions(&self, seq_len: usize) -> Vec<usize> {
|
||||
self.k.positions(seq_len)
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.k.reset();
|
||||
self.v.reset();
|
||||
|
@ -31,6 +31,7 @@ pub mod ops;
|
||||
pub mod optim;
|
||||
pub mod rnn;
|
||||
pub mod rotary_emb;
|
||||
pub mod sampling;
|
||||
pub mod sequential;
|
||||
pub mod var_builder;
|
||||
pub mod var_map;
|
||||
|
20
candle-nn/src/sampling.rs
Normal file
20
candle-nn/src/sampling.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
/// Sample according to the Gumbel-Softmax distribution.
|
||||
pub fn gumbel_softmax<D: candle::shape::Dim>(
|
||||
logits: &Tensor,
|
||||
temperature: f64,
|
||||
dim: D,
|
||||
) -> Result<Tensor> {
|
||||
if temperature <= 0.0 {
|
||||
logits.argmax(dim)
|
||||
} else if temperature == 1.0 {
|
||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
||||
let sampled = (logits - minus_g)?.argmax(dim)?;
|
||||
Ok(sampled)
|
||||
} else {
|
||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
||||
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
|
||||
Ok(sampled)
|
||||
}
|
||||
}
|
@ -39,9 +39,16 @@ fn rotating_kv_cache() -> Result<()> {
|
||||
assert_eq!(cache.current_seq_len(), 0);
|
||||
let data = cache.current_data()?;
|
||||
assert!(data.is_none());
|
||||
assert_eq!(cache.positions(1), &[0]);
|
||||
assert_eq!(cache.positions(2), &[0, 1]);
|
||||
let t = Tensor::new(&[1., 2., 3.], &Device::Cpu)?;
|
||||
let data = cache.append(&t)?;
|
||||
assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3.]);
|
||||
assert_eq!(cache.positions(0), &[0, 1, 2]);
|
||||
assert_eq!(cache.positions(1), &[0, 1, 2, 3]);
|
||||
assert_eq!(cache.positions(2), &[0, 1, 2, 3, 4]);
|
||||
assert_eq!(cache.positions(3), &[0, 1, 2, 3, 4, 5]);
|
||||
assert_eq!(cache.positions(4), &[6, 1, 2, 3, 4, 5]);
|
||||
let t = Tensor::new(&[4.], &Device::Cpu)?;
|
||||
let data = cache.append(&t)?;
|
||||
assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3., 4.]);
|
||||
@ -79,11 +86,17 @@ fn rotating_kv_cache() -> Result<()> {
|
||||
mask.to_vec2::<u8>()?,
|
||||
&[[0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0]],
|
||||
);
|
||||
assert_eq!(cache.positions(0), &[12, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(cache.positions(2), &[12, 13, 14, 9, 10, 11]);
|
||||
assert_eq!(cache.positions(3), &[12, 13, 14, 15, 10, 11]);
|
||||
assert_eq!(cache.positions(8), &[13, 14, 15, 16, 17, 18, 19, 20]);
|
||||
let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?;
|
||||
let data = cache.append(&t)?;
|
||||
assert_eq!(data.to_vec1::<f64>()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]);
|
||||
assert_eq!(cache.current_seq_len(), 22);
|
||||
assert_eq!(cache.offset(), 0);
|
||||
assert_eq!(cache.positions(0), &[16, 17, 18, 19, 20, 21]);
|
||||
assert_eq!(cache.positions(1), &[22, 17, 18, 19, 20, 21]);
|
||||
|
||||
let mask = cache.attn_mask(1, &Device::Cpu)?;
|
||||
assert!(mask.is_none());
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.9.0-alpha.2"
|
||||
version = "0.9.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.2" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.2" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.9.0" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,7 +1,9 @@
|
||||
use crate::onnx::attribute_proto::AttributeType;
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
use crate::onnx::{self, GraphProto};
|
||||
use candle::Module;
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use candle_nn::activation::PReLU;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
pub type Value = Tensor;
|
||||
@ -991,6 +993,14 @@ fn simple_eval_(
|
||||
let output = input.relu()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"PRelu" => {
|
||||
// https://onnx.ai/onnx/operators/onnx__PRelu.html
|
||||
let input = get(&node.input[0])?;
|
||||
let slope = get(&node.input[1])?;
|
||||
|
||||
let output = PReLU::new(slope.clone(), false).forward(input)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Ceil" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.ceil()?;
|
||||
|
@ -1846,6 +1846,64 @@ fn test_relu_operation() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "PRelu"
|
||||
#[test]
|
||||
fn test_prelu_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "PRelu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x: Tensor = Tensor::from_vec(
|
||||
vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let y: Tensor = Tensor::from_vec(vec![1.0f32, 1.1f32, 1.2f32, 1.3f32], &[2, 2], &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), y);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
assert_eq!(results, vec![vec![-1.0, 1.0], vec![-2.4, 3.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// "Constant"
|
||||
// #[test]
|
||||
|
||||
|
@ -13,6 +13,8 @@ pub enum Sampling {
|
||||
TopK { k: usize, temperature: f64 },
|
||||
TopP { p: f64, temperature: f64 },
|
||||
TopKThenTopP { k: usize, p: f64, temperature: f64 },
|
||||
// Note that the rng is not used for the Gumbel-Softmax sampling.
|
||||
GumbelSoftmax { temperature: f64 },
|
||||
}
|
||||
|
||||
pub struct LogitsProcessor {
|
||||
@ -49,6 +51,11 @@ impl LogitsProcessor {
|
||||
Ok(next_token)
|
||||
}
|
||||
|
||||
fn sample_gumbel_softmax(&mut self, logits: &Tensor, temperature: f64) -> Result<u32> {
|
||||
let sampled = candle_nn::sampling::gumbel_softmax(logits, temperature, candle::D::Minus1)?;
|
||||
sampled.to_vec0::<u32>()
|
||||
}
|
||||
|
||||
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
||||
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||
let next_token = distr.sample(&mut self.rng) as u32;
|
||||
@ -127,6 +134,9 @@ impl LogitsProcessor {
|
||||
|
||||
let next_token = match &self.sampling {
|
||||
Sampling::ArgMax => self.sample_argmax(logits)?,
|
||||
Sampling::GumbelSoftmax { temperature } => {
|
||||
self.sample_gumbel_softmax(&logits, *temperature)?
|
||||
}
|
||||
Sampling::All { temperature } => {
|
||||
let prs = prs(*temperature)?;
|
||||
self.sample_multinomial(&prs)?
|
||||
|
@ -124,6 +124,7 @@ impl ResidualConvUnit {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv1 = conv2d(
|
||||
conf.num_features,
|
||||
@ -208,6 +209,7 @@ impl FeatureFusionBlock {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let output_conv = conv2d(
|
||||
conf.num_features,
|
||||
@ -258,6 +260,7 @@ impl Scratch {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
|
||||
let layer1_rn = conv2d_no_bias(
|
||||
@ -319,6 +322,7 @@ impl Scratch {
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let output_conv1 = conv2d(
|
||||
conf.num_features,
|
||||
@ -425,6 +429,7 @@ impl DPTHead {
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
},
|
||||
vb.pp("resize_layers").pp("3"),
|
||||
)?),
|
||||
|
@ -468,6 +468,7 @@ impl EncodecConv1d {
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
|
@ -21,6 +21,7 @@ pub struct Config {
|
||||
pub num_key_value_heads: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f64,
|
||||
pub rope_local_base_freq: f64,
|
||||
pub vocab_size: usize,
|
||||
pub final_logit_softcapping: Option<f64>,
|
||||
pub attn_logit_softcapping: Option<f64>,
|
||||
@ -67,12 +68,22 @@ struct RotaryEmbedding {
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
fn new(
|
||||
dtype: DType,
|
||||
cfg: &Config,
|
||||
dev: &Device,
|
||||
sliding_window: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let dim = cfg.head_dim;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let rope_freq = if sliding_window.is_some() {
|
||||
cfg.rope_local_base_freq
|
||||
} else {
|
||||
cfg.rope_theta
|
||||
};
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.map(|i| 1f32 / rope_freq.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
@ -162,8 +173,8 @@ impl Attention {
|
||||
fn new(
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
use_flash_attn: bool,
|
||||
is_sliding: bool,
|
||||
cfg: &Config,
|
||||
sliding_window: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
@ -178,13 +189,13 @@ impl Attention {
|
||||
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
|
||||
let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
|
||||
let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
|
||||
let kv_cache = if is_sliding {
|
||||
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(
|
||||
2,
|
||||
cfg.sliding_window,
|
||||
))
|
||||
let kv_cache = if let Some(sliding_window) = sliding_window {
|
||||
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(2, sliding_window))
|
||||
} else {
|
||||
KvCache::Normal(candle_nn::kv_cache::KvCache::new(2, cfg.sliding_window))
|
||||
KvCache::Normal(candle_nn::kv_cache::KvCache::new(
|
||||
2,
|
||||
cfg.max_position_embeddings,
|
||||
))
|
||||
};
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
@ -302,21 +313,27 @@ struct DecoderLayer {
|
||||
pre_feedforward_layernorm: RmsNorm,
|
||||
post_feedforward_layernorm: RmsNorm,
|
||||
post_attention_layernorm: RmsNorm,
|
||||
sliding_window: Option<usize>,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
use_flash_attn: bool,
|
||||
is_sliding: bool,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
sliding_window: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(
|
||||
vb.dtype(),
|
||||
cfg,
|
||||
vb.device(),
|
||||
sliding_window,
|
||||
)?);
|
||||
let self_attn = Attention::new(
|
||||
rotary_emb,
|
||||
use_flash_attn,
|
||||
is_sliding,
|
||||
cfg,
|
||||
sliding_window,
|
||||
vb.pp("self_attn"),
|
||||
)?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
@ -344,6 +361,7 @@ impl DecoderLayer {
|
||||
pre_feedforward_layernorm,
|
||||
post_feedforward_layernorm,
|
||||
post_attention_layernorm,
|
||||
sliding_window,
|
||||
})
|
||||
}
|
||||
|
||||
@ -370,6 +388,42 @@ impl DecoderLayer {
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
sliding_window: Option<usize>,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = if let Some(sliding_window) = sliding_window {
|
||||
(0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
(0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect()
|
||||
};
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(dtype)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
@ -388,17 +442,15 @@ impl Model {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let is_sliding = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
|
||||
let sliding_window = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
|
||||
let layer = DecoderLayer::new(
|
||||
rotary_emb.clone(),
|
||||
use_flash_attn,
|
||||
is_sliding,
|
||||
cfg,
|
||||
vb_l.pp(layer_idx),
|
||||
sliding_window.then_some(cfg.sliding_window),
|
||||
)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
@ -417,51 +469,52 @@ impl Model {
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
fn create_attention_masks(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
batch_size: usize,
|
||||
seq_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = match Some(self.sliding_window) {
|
||||
None => (0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect(),
|
||||
Some(sliding_window) => (0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
};
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
||||
if seq_len <= 1 {
|
||||
return Ok((None, None));
|
||||
}
|
||||
|
||||
let mask = prepare_decoder_attention_mask(
|
||||
batch_size,
|
||||
seq_len,
|
||||
seqlen_offset,
|
||||
None,
|
||||
self.dtype,
|
||||
&self.device,
|
||||
)?;
|
||||
|
||||
let sliding_mask = prepare_decoder_attention_mask(
|
||||
batch_size,
|
||||
seq_len,
|
||||
seqlen_offset,
|
||||
Some(self.sliding_window),
|
||||
self.dtype,
|
||||
&self.device,
|
||||
)?;
|
||||
|
||||
Ok((Some(mask), Some(sliding_mask)))
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let xs = self.embed_tokens.forward(input_ids)?;
|
||||
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
|
||||
|
||||
let (attention_mask, sliding_attention_mask) =
|
||||
self.create_attention_masks(b_size, seq_len, seqlen_offset)?;
|
||||
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
let mask = if layer.sliding_window.is_some() {
|
||||
&sliding_attention_mask
|
||||
} else {
|
||||
&attention_mask
|
||||
};
|
||||
xs = layer.forward(&xs, mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
let logits = xs
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
|
@ -267,6 +267,7 @@ impl StreamableConv1d {
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
|
||||
if k_size < stride {
|
||||
|
@ -79,6 +79,7 @@ pub mod phi3;
|
||||
pub mod pixtral;
|
||||
pub mod quantized_blip;
|
||||
pub mod quantized_blip_text;
|
||||
pub mod quantized_gemma3;
|
||||
pub mod quantized_llama;
|
||||
pub mod quantized_llama2_c;
|
||||
pub mod quantized_metavoice;
|
||||
|
466
candle-transformers/src/models/quantized_gemma3.rs
Normal file
466
candle-transformers/src/models/quantized_gemma3.rs
Normal file
@ -0,0 +1,466 @@
|
||||
//! Gemma 3 model implementation with quantization support.
|
||||
//!
|
||||
//! Gemma 3 is a family of multimodal language models developed by Google.
|
||||
//! This implementation provides quantization for reduced memory usage and faster inference.
|
||||
//!
|
||||
//! Key characteristics:
|
||||
//! - Group-Query Attention (GQA) with specialized key-value heads
|
||||
//! - RMSNorm for layer normalization
|
||||
//! - Specialized attention patterns with separate normalization for Q/K/V
|
||||
//! - Feed-forward network with SwiGLU activation
|
||||
//! - Support for 2/3/4/8-bit quantization
|
||||
//!
|
||||
//! References:
|
||||
//! - [Gemma 3 Models](https://blog.google/technology/developers/gemma-3/)
|
||||
//!
|
||||
|
||||
use crate::quantized_nn::RmsNorm;
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::quantized::QTensor;
|
||||
use candle::D;
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{Embedding, Module};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window
|
||||
pub const DEFAULT_SLIDING_WINDOW_TYPE: usize = 6;
|
||||
pub const DEFAULT_ROPE_FREQUENCY: f32 = 1_000_000.;
|
||||
pub const DEFAULT_ROPE_FREQUENCY_SLIDING: f32 = 10_000.;
|
||||
pub const DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR: f32 = 1.;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct QMatMul {
|
||||
inner: candle::quantized::QMatMul,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||
let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
feed_forward_gate: QMatMul, // ffn_gate in GGUF
|
||||
feed_forward_up: QMatMul, // ffn_up in GGUF
|
||||
feed_forward_down: QMatMul, // ffn_down in GGUF
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let gate = self.feed_forward_gate.forward(xs)?;
|
||||
let up = self.feed_forward_up.forward(xs)?;
|
||||
let silu = candle_nn::ops::silu(&gate)?;
|
||||
let gated = (silu * up)?;
|
||||
self.feed_forward_down.forward(&gated)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(head_dim: usize, rope_frequency: f32, device: &Device) -> Result<Self> {
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / rope_frequency.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
Ok(Self { sin, cos })
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
index_pos: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct LayerWeights {
|
||||
// Attention components
|
||||
attention_wq: QMatMul,
|
||||
attention_wk: QMatMul,
|
||||
attention_wv: QMatMul,
|
||||
attention_wo: QMatMul,
|
||||
|
||||
// Specialized normalization for Q and K
|
||||
attention_q_norm: RmsNorm,
|
||||
attention_k_norm: RmsNorm,
|
||||
|
||||
// Layer normalization
|
||||
attention_norm: RmsNorm, // Applied before attention
|
||||
post_attention_norm: RmsNorm, // Applied after attention
|
||||
ffn_norm: RmsNorm, // Applied before feedforward
|
||||
post_ffn_norm: RmsNorm, // Applied after feedforward
|
||||
|
||||
// Feed-forward network
|
||||
mlp: Mlp,
|
||||
|
||||
// Attention parameters
|
||||
n_head: usize, // Number of query heads
|
||||
n_kv_head: usize, // Number of key-value heads
|
||||
head_dim: usize, // Dimension of each head
|
||||
q_dim: usize, // Total dimension for queries
|
||||
|
||||
sliding_window_size: Option<usize>,
|
||||
|
||||
rotary_embedding: RotaryEmbedding,
|
||||
neg_inf: Tensor,
|
||||
|
||||
// Cache
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
|
||||
// Tracing
|
||||
span_attn: tracing::Span,
|
||||
span_mlp: tracing::Span,
|
||||
}
|
||||
|
||||
impl LayerWeights {
|
||||
fn mask(
|
||||
&self,
|
||||
b_sz: usize,
|
||||
seq_len: usize,
|
||||
index_pos: usize,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = if let Some(sliding_window_size) = self.sliding_window_size {
|
||||
(0..seq_len)
|
||||
.flat_map(|i| {
|
||||
(0..seq_len).map(move |j| {
|
||||
if i < j || j + sliding_window_size < i {
|
||||
0u32
|
||||
} else {
|
||||
1u32
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
(0..seq_len)
|
||||
.flat_map(|i| (0..seq_len).map(move |j| if i < j { 0u32 } else { 1u32 }))
|
||||
.collect()
|
||||
};
|
||||
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
|
||||
let mask = if index_pos > 0 {
|
||||
let mask0 = Tensor::zeros((seq_len, index_pos), DType::F32, device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_sz, 1, seq_len, seq_len + index_pos))?
|
||||
.to_dtype(dtype)
|
||||
}
|
||||
|
||||
fn forward_attn(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
mask: Option<&Tensor>,
|
||||
index_pos: usize,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span_attn.enter();
|
||||
let (b_sz, seq_len, _) = x.dims3()?;
|
||||
|
||||
let q = self.attention_wq.forward(x)?;
|
||||
let k = self.attention_wk.forward(x)?;
|
||||
let v = self.attention_wv.forward(x)?;
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q = self.attention_q_norm.forward(&q.contiguous()?)?;
|
||||
let k = self.attention_k_norm.forward(&k.contiguous()?)?;
|
||||
|
||||
let (q, k) = self
|
||||
.rotary_embedding
|
||||
.apply_rotary_emb_qkv(&q, &k, index_pos)?;
|
||||
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
Some((k_cache, v_cache)) => {
|
||||
if index_pos == 0 {
|
||||
(k, v)
|
||||
} else {
|
||||
let k = Tensor::cat(&[k_cache, &k], 2)?; // concat on seq dim
|
||||
let v = Tensor::cat(&[v_cache, &v], 2)?;
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone())); // update cache
|
||||
|
||||
// Repeat KV for GQA
|
||||
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
|
||||
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
|
||||
|
||||
// Scaled Dot-Product Attention
|
||||
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||
let mut attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
if let Some(mask) = mask {
|
||||
let mask = mask.broadcast_as(attn_weights.shape())?;
|
||||
let neg_inf = self.neg_inf.broadcast_as(attn_weights.dims())?;
|
||||
attn_weights = mask.eq(0u32)?.where_cond(&neg_inf, &attn_weights)?;
|
||||
}
|
||||
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
|
||||
let attn_output = attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, seq_len, self.q_dim))?;
|
||||
|
||||
self.attention_wo.forward(&attn_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelWeights {
|
||||
tok_embeddings: Embedding,
|
||||
embedding_length: usize,
|
||||
layers: Vec<LayerWeights>,
|
||||
norm: RmsNorm,
|
||||
output: QMatMul,
|
||||
span: tracing::Span,
|
||||
span_output: tracing::Span,
|
||||
}
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
};
|
||||
|
||||
let head_count = md_get("gemma3.attention.head_count")?.to_u32()? as usize;
|
||||
let head_count_kv = md_get("gemma3.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("gemma3.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("gemma3.embedding_length")?.to_u32()? as usize;
|
||||
let key_length = md_get("gemma3.attention.key_length")?.to_u32()? as usize;
|
||||
let _value_length = md_get("gemma3.attention.value_length")?.to_u32()? as usize;
|
||||
let rms_norm_eps = md_get("gemma3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||
let sliding_window_size = md_get("gemma3.attention.sliding_window")?.to_u32()? as usize;
|
||||
|
||||
let sliding_window_type = md_get("gemma3.attention.sliding_window_type")
|
||||
.and_then(|m| Ok(m.to_u32()? as usize))
|
||||
.unwrap_or(DEFAULT_SLIDING_WINDOW_TYPE);
|
||||
|
||||
let rope_freq_base = md_get("gemma3.rope.freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(DEFAULT_ROPE_FREQUENCY);
|
||||
|
||||
let rope_freq_base_sliding = md_get("gemma3.rope.local_freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SLIDING);
|
||||
|
||||
// Unused in Llama.cpp so we aren't using it here.
|
||||
let _rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR);
|
||||
|
||||
// Compute the dimensions for queries, keys, and values
|
||||
// These are the total dimensions when projected across all heads
|
||||
let q_dim = head_count * key_length;
|
||||
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||
|
||||
// Load token embeddings and output projection
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, "output_norm.weight", device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
let output = match ct.tensor(reader, "output.weight", device) {
|
||||
Ok(tensor) => tensor,
|
||||
Err(_) => ct.tensor(reader, "token_embd.weight", device)?, // Use tied weights if output.weight doesn't exist
|
||||
};
|
||||
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
|
||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
||||
let attention_wo =
|
||||
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
||||
|
||||
let attention_q_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.attn_q_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let attention_k_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.attn_k_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let attention_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let post_attention_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(
|
||||
reader,
|
||||
&format!("{prefix}.post_attention_norm.weight"),
|
||||
device,
|
||||
)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let ffn_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let post_ffn_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.post_ffw_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let feed_forward_gate =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||
let feed_forward_up = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||
let feed_forward_down =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||
|
||||
let mlp = Mlp {
|
||||
feed_forward_gate: QMatMul::from_qtensor(feed_forward_gate)?,
|
||||
feed_forward_up: QMatMul::from_qtensor(feed_forward_up)?,
|
||||
feed_forward_down: QMatMul::from_qtensor(feed_forward_down)?,
|
||||
};
|
||||
|
||||
// Sliding window pattern hardcoded to 6 because it's not explicitly defined
|
||||
let is_sliding = (layer_idx + 1) % sliding_window_type > 0;
|
||||
let sliding_window_size = is_sliding.then_some(sliding_window_size);
|
||||
let layer_rope_frequency = if is_sliding {
|
||||
rope_freq_base_sliding
|
||||
} else {
|
||||
rope_freq_base
|
||||
};
|
||||
|
||||
let rotary_embedding = RotaryEmbedding::new(key_length, layer_rope_frequency, device)?;
|
||||
|
||||
// Tracing spans
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
|
||||
layers.push(LayerWeights {
|
||||
attention_wq: QMatMul::from_qtensor(attention_wq)?,
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||
attention_q_norm,
|
||||
attention_k_norm,
|
||||
attention_norm,
|
||||
post_attention_norm,
|
||||
ffn_norm,
|
||||
post_ffn_norm,
|
||||
mlp,
|
||||
n_head: head_count,
|
||||
n_kv_head: head_count_kv,
|
||||
head_dim: key_length,
|
||||
q_dim,
|
||||
sliding_window_size,
|
||||
rotary_embedding,
|
||||
neg_inf: neg_inf.clone(),
|
||||
kv_cache: None,
|
||||
span_attn,
|
||||
span_mlp,
|
||||
})
|
||||
}
|
||||
|
||||
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||
|
||||
Ok(Self {
|
||||
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
||||
embedding_length,
|
||||
layers,
|
||||
norm,
|
||||
output: QMatMul::from_qtensor(output)?,
|
||||
span,
|
||||
span_output,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = x.dims2()?;
|
||||
let _enter = self.span.enter();
|
||||
|
||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||
layer_in = (layer_in * (self.embedding_length as f64).sqrt())?;
|
||||
|
||||
for layer in self.layers.iter_mut() {
|
||||
let attention_mask = if seq_len == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(layer.mask(b_sz, seq_len, index_pos, x.dtype(), x.device())?)
|
||||
};
|
||||
|
||||
// Attention block
|
||||
let residual = &layer_in;
|
||||
let x = layer.attention_norm.forward(&layer_in)?;
|
||||
let x = layer.forward_attn(&x, attention_mask.as_ref(), index_pos)?;
|
||||
let x = layer.post_attention_norm.forward(&x)?;
|
||||
let x = (x + residual)?;
|
||||
|
||||
// Feed-forward block
|
||||
let _enter = layer.span_mlp.enter();
|
||||
let residual = &x;
|
||||
let x = layer.ffn_norm.forward(&x)?;
|
||||
let x = layer.mlp.forward(&x)?;
|
||||
let x = layer.post_ffn_norm.forward(&x)?;
|
||||
let x = (x + residual)?;
|
||||
drop(_enter);
|
||||
|
||||
layer_in = x;
|
||||
}
|
||||
|
||||
let _enter = self.span_output.enter();
|
||||
|
||||
let x = layer_in.i((.., seq_len - 1, ..))?;
|
||||
let x = self.norm.forward(&x)?;
|
||||
let output = self.output.forward(&x)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
}
|
@ -68,6 +68,7 @@ impl ResnetBlock2D {
|
||||
padding: 1,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
|
||||
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
|
||||
@ -83,6 +84,7 @@ impl ResnetBlock2D {
|
||||
padding: 0,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
Some(conv2d(
|
||||
in_channels,
|
||||
|
@ -248,12 +248,14 @@ impl AudioEncoder {
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
|
@ -244,12 +244,14 @@ impl AudioEncoder {
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
|
@ -54,3 +54,25 @@ fn sample_with_top_k() -> Result<()> {
|
||||
assert_eq!(token, 2);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sample_gumbel() -> Result<()> {
|
||||
let mut logits_process = LogitsProcessor::from_sampling(
|
||||
42,
|
||||
candle_transformers::generation::Sampling::GumbelSoftmax { temperature: 1.0 },
|
||||
);
|
||||
let logits = Tensor::new(&[-1.0, 0.0, 0.2, 1.0], &Device::Cpu)?;
|
||||
let sm = candle_nn::ops::softmax(&logits, 0)?.to_vec1::<f64>()?;
|
||||
let mut counts = vec![0f64; 4];
|
||||
let samples = 100000;
|
||||
for _ in 0..samples {
|
||||
let token = logits_process.sample(&logits)?;
|
||||
counts[token as usize] += 1f64 / samples as f64;
|
||||
}
|
||||
for i in 0..4 {
|
||||
if (counts[i] - sm[i]).abs() > 0.05 {
|
||||
panic!("pr mismatch {counts:?} {sm:?}");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -98,6 +98,7 @@ impl ConvBlock {
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||
|
Reference in New Issue
Block a user