mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
153 Commits
remove_wra
...
linear-tra
Author | SHA1 | Date | |
---|---|---|---|
a35a935118 | |||
df6667ba88 | |||
a79286885c | |||
74845a4dcd | |||
aa76b783eb | |||
25564357f7 | |||
634700d84a | |||
e635f18eda | |||
52414ba5c8 | |||
186c308d51 | |||
4f17290ce0 | |||
0902846f25 | |||
e2acbe1e72 | |||
4fe8a02f88 | |||
03a421f714 | |||
d38943aadc | |||
51e51da896 | |||
6e33ff62d6 | |||
4b3bd79fbd | |||
cc76c63202 | |||
ff876c2103 | |||
a27239f3d9 | |||
babee9f011 | |||
afb5e24a63 | |||
89d1fd03e5 | |||
310094310b | |||
836ba3e090 | |||
091e781977 | |||
5cead227ef | |||
ebd0315623 | |||
ad9d8fe400 | |||
5bc5716b85 | |||
ba37de94d4 | |||
6242a1470e | |||
75e0448114 | |||
614f911e9e | |||
e1e8127f15 | |||
fa98ca0c35 | |||
1a07ff8d17 | |||
f28558d0b7 | |||
6b98b66eb3 | |||
9ae1f6afee | |||
1064b9b031 | |||
ffeafbfc43 | |||
b3ea96b62b | |||
94a43faaca | |||
62a9b03715 | |||
67834119fc | |||
0ace420e66 | |||
a8d8f9f206 | |||
38ff693af0 | |||
ba2254556c | |||
c950a5c6b1 | |||
16c33383eb | |||
bedcef64dc | |||
40c80bfbb2 | |||
07eb899729 | |||
c0a8ed19eb | |||
4bf2ebf836 | |||
97d8712ba5 | |||
97181a77c0 | |||
50d8273ae4 | |||
7513a5e005 | |||
cb8dd5cd53 | |||
a0e47aba98 | |||
fb84ead8f7 | |||
3eb2bc6d07 | |||
68eab38de6 | |||
54ccf94472 | |||
4002968cf5 | |||
be256a6ba6 | |||
d2dea11ef6 | |||
3e89df938c | |||
6a54ca115e | |||
4f260ef025 | |||
0b97987b21 | |||
8435a99edd | |||
ca479a873e | |||
952eca6b54 | |||
f291065f6c | |||
25a2086e8f | |||
7c7e6ba201 | |||
1553b58fe5 | |||
b7814f66b4 | |||
ed58de7551 | |||
1735e4831e | |||
209f06d7c3 | |||
6475bfadfe | |||
89ba005962 | |||
4f92420132 | |||
ded197497c | |||
84ad558e50 | |||
368f169c6a | |||
8da6568c20 | |||
07a22fe606 | |||
834e1b197b | |||
89fd988836 | |||
1235aa2536 | |||
f052ba76cb | |||
46f2d9f0ac | |||
81bfa46702 | |||
8b1d12bead | |||
035372248e | |||
2ce5f12513 | |||
97990f4afc | |||
1a5416ec35 | |||
fa2b64d678 | |||
e40b150bbe | |||
471855e2ee | |||
d9f9c859af | |||
c97d51243c | |||
be9c26180c | |||
944d70bd9a | |||
18cc73954a | |||
74a6a769dd | |||
581b104f97 | |||
b50f932e7c | |||
160ba09d30 | |||
5a26cba733 | |||
550a13a547 | |||
35b65fed88 | |||
b6f7dfb682 | |||
fe87778223 | |||
23827c49cd | |||
e449ce53a2 | |||
b8a10425ad | |||
5f20acf080 | |||
c8459d199d | |||
1f26042693 | |||
43c7223292 | |||
52c5d8c087 | |||
6eeea1b04e | |||
27174a82aa | |||
5cc843550d | |||
4a100875bf | |||
a6bcdfb269 | |||
b02229ce92 | |||
410654525f | |||
c60831aad4 | |||
4845d5cc64 | |||
fa08fb3126 | |||
2a8f28d687 | |||
e9c052bf94 | |||
dc416243a3 | |||
12d6dc018d | |||
c34f932319 | |||
536c5e702e | |||
001f9a59ce | |||
9515e8ea6c | |||
ad12e20f6b | |||
e6584476c4 | |||
cb687b4897 | |||
439321745a |
42
.github/workflows/book-cd.yml
vendored
Normal file
42
.github/workflows/book-cd.yml
vendored
Normal file
@ -0,0 +1,42 @@
|
||||
name: Deploy Rust book
|
||||
on:
|
||||
# TODO put this back only when merging after this PR lands.
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write # To push a branch
|
||||
pull-requests: write # To create a PR from that branch
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Install latest mdbook
|
||||
run: |
|
||||
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
|
||||
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
|
||||
mkdir mdbook
|
||||
curl -sSL $url | tar -xz --directory=./mdbook
|
||||
echo `pwd`/mdbook >> $GITHUB_PATH
|
||||
- name: Deploy GitHub Pages
|
||||
run: |
|
||||
# This assumes your book is in the root of your repository.
|
||||
# Just add a `cd` here if you need to change to another directory.
|
||||
cd candle-book
|
||||
mdbook build
|
||||
git worktree add gh-pages
|
||||
git config user.name "Deploy from CI"
|
||||
git config user.email ""
|
||||
cd gh-pages
|
||||
# Delete the ref to avoid keeping history.
|
||||
git update-ref -d refs/heads/gh-pages
|
||||
rm -rf *
|
||||
mv ../book/* .
|
||||
git add .
|
||||
git commit -m "Deploy $GITHUB_SHA to gh-pages"
|
||||
git push --force --set-upstream origin gh-pages
|
29
.github/workflows/book.yml
vendored
Normal file
29
.github/workflows/book.yml
vendored
Normal file
@ -0,0 +1,29 @@
|
||||
name: CI
|
||||
on:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Test candle-book
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write # To push a branch
|
||||
pull-requests: write # To create a PR from that branch
|
||||
steps:
|
||||
- uses: actions/checkout@master
|
||||
- name: Install Rust
|
||||
run: |
|
||||
rustup set profile minimal
|
||||
rustup toolchain install stable
|
||||
rustup default stable
|
||||
- name: Install latest mdbook
|
||||
run: |
|
||||
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
|
||||
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
|
||||
mkdir bin
|
||||
curl -sSL $url | tar -xz --directory=bin
|
||||
echo "$(pwd)/bin" >> $GITHUB_PATH
|
||||
- name: Run tests
|
||||
run: cd candle-book && cargo build && mdbook test -L ../target/debug/deps/
|
||||
|
||||
|
8
.gitignore
vendored
8
.gitignore
vendored
@ -1,6 +1,7 @@
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
debug/
|
||||
data/
|
||||
dist/
|
||||
target/
|
||||
|
||||
@ -23,6 +24,7 @@ flamegraph.svg
|
||||
*.swp
|
||||
trace-*.json
|
||||
|
||||
candle-wasm-example/*.wav
|
||||
candle-wasm-example/*.safetensors
|
||||
candle-wasm-example/package-lock.json
|
||||
candle-wasm-examples/*/*.bin
|
||||
candle-wasm-examples/*/*.wav
|
||||
candle-wasm-examples/*/*.safetensors
|
||||
candle-wasm-examples/*/package-lock.json
|
||||
|
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
[submodule "candle-examples/examples/flash-attn/cutlass"]
|
||||
path = candle-flash-attn/cutlass
|
||||
url = https://github.com/NVIDIA/cutlass.git
|
37
Cargo.toml
37
Cargo.toml
@ -2,46 +2,47 @@
|
||||
members = [
|
||||
"candle-core",
|
||||
"candle-examples",
|
||||
"candle-hub",
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
"candle-transformers",
|
||||
"candle-wasm-example",
|
||||
"candle-wasm-examples/llama2-c",
|
||||
"candle-wasm-examples/whisper",
|
||||
]
|
||||
exclude = [
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT/Apache-2.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
# Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes
|
||||
# cudarc = { version = "0.9.13", optional = true, features = ["f16"] }
|
||||
cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", features = ["f16"] }
|
||||
futures = "0.3.28"
|
||||
# TODO: Switch back to the official gemm implementation once the following are available.
|
||||
# https://github.com/sarah-ek/gemm/pull/8.
|
||||
# https://github.com/sarah-ek/gemm/pull/9.
|
||||
gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vec-plus-wasm-simd" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
indicatif = "0.17.5"
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-dynamic-lp64-iomp"] }
|
||||
cudarc = { version = "0.9.13", features = ["f16"] }
|
||||
# TODO: Switch back to the official gemm implementation once it has caught up.
|
||||
gemm = { version = "0.15.5", package = "candle-gemm" }
|
||||
hf-hub = "0.2.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
memmap2 = "0.7.1"
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
rand = "0.8.5"
|
||||
reqwest = "0.11.18"
|
||||
safetensors = "0.3.1"
|
||||
serde = { version = "1.0.166", features = ["derive"] }
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_json = "1.0.99"
|
||||
sha256 = "=1.1.4"
|
||||
thiserror = "1"
|
||||
tokenizers = { version = "0.13.3", default-features = false }
|
||||
tokio = "1.28.2"
|
||||
tokio-test = "0.4.2"
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
|
74
README.md
74
README.md
@ -1,59 +1,95 @@
|
||||
# candle
|
||||
ML framework for Rust
|
||||
[](https://crates.io/crates/candle-core)
|
||||
[](https://docs.rs/candle-core)
|
||||

|
||||
|
||||
Candle is a minimalist ML framework for Rust with a focus on easiness of use and
|
||||
on performance (including GPU support). Try our online demos:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
|
||||
|
||||
```rust
|
||||
let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?;
|
||||
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
|
||||
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
println!("{c}");
|
||||
```
|
||||
|
||||
## Check out our examples
|
||||
|
||||
Check out our [examples](./candle-examples/examples/):
|
||||
|
||||
- [Whisper](./candle-examples/examples/whisper/)
|
||||
- [Llama](./candle-examples/examples/llama/)
|
||||
- [Bert](./candle-examples/examples/bert/) (Useful for sentence embeddings)
|
||||
- [Falcon](./candle-examples/examples/falcon/)
|
||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||
- [Llama and Llama-v2](./candle-examples/examples/llama/): general LLM.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
|
||||
generation.
|
||||
|
||||
Run them using the following commands:
|
||||
```
|
||||
cargo run --example bert --release
|
||||
cargo run --example whisper --release
|
||||
cargo run --example llama --release
|
||||
cargo run --example falcon --release
|
||||
cargo run --example bert --release
|
||||
cargo run --example bigcode --release
|
||||
```
|
||||
|
||||
In order to use **CUDA** add `--features cuda` to the example command line.
|
||||
|
||||
There are also some wasm examples for whisper and
|
||||
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
|
||||
`trunk` or try them online:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
|
||||
|
||||
For llama2, run the following command to retrieve the weight files and start a
|
||||
test server:
|
||||
```bash
|
||||
cd candle-wasm-examples/llama2-c
|
||||
wget https://karpathy.ai/llama2c/model.bin
|
||||
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
|
||||
trunk serve --release --public-url /candle-llama2/ --port 8081
|
||||
```
|
||||
And then browse to
|
||||
[http://localhost:8081/candle-llama2](http://localhost:8081/candle-llama2).
|
||||
|
||||
<!--- ANCHOR: features --->
|
||||
|
||||
## Features
|
||||
|
||||
- Simple syntax (looks and like PyTorch)
|
||||
- CPU and Cuda backends, m1, f16, bf16 (and tentatively wasm)
|
||||
- Simple syntax, looks and like PyTorch.
|
||||
- CPU and Cuda backends, m1, f16, bf16.
|
||||
- Enable serverless (CPU), small and fast deployments
|
||||
- Model training
|
||||
- Distributed computing (NCCL).
|
||||
- Models out of the box (Llama, Whisper, Falcon, ...)
|
||||
- Emphasis on enabling users to use custom ops/kernels
|
||||
- WASM support, run your models in a browser.
|
||||
- Model training.
|
||||
- Distributed computing using NCCL.
|
||||
- Models out of the box: Llama, Whisper, Falcon, StarCoder...
|
||||
- Embed user-defined ops/kernels, such as [flash-attention
|
||||
v2](https://github.com/LaurentMazare/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
|
||||
|
||||
<!--- ANCHOR_END: features --->
|
||||
|
||||
## How to use ?
|
||||
|
||||
<!--- ANCHOR: cheatsheet --->
|
||||
Cheatsheet:
|
||||
|
||||
| | Using PyTorch | Using Candle |
|
||||
|------------|------------------------------------------|------------------------------------------------------------------|
|
||||
| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(` |
|
||||
| | | ` &[[1f32, 2.]], [3., 4.]],` |
|
||||
| | | ` &Device::Cpu)?` |
|
||||
| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(&[[1f32, 2.]], [3., 4.]], &Device::Cpu)?` |
|
||||
| Creation | `torch.zeros((2, 2))` | `Tensor::zeros((2, 2), DType::F32, &Device::Cpu)?` |
|
||||
| Indexing | `tensor[:, :4]` | `tensor.i((.., ..4))?` |
|
||||
| Operations | `tensor.view((2, 2))` | `tensor.reshape((2, 2))?` |
|
||||
| Operations | `a.matmul(b)` | `a.matmul(&b)?` |
|
||||
| Arithmetic | `a + b` | `&a + &b` |
|
||||
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::Cuda(0))?` |
|
||||
| Dtype | `tensor.to(dtype=torch.float16)` | `tensor.to_dtype(&DType::F16)?` |
|
||||
| Saving | `torch.save({"A": A}, "model.bin")` | `tensor.save_safetensors("A", "model.safetensors")?` |
|
||||
| Loading | `weights = torch.load("model.bin")` | TODO (see the examples for now) |
|
||||
| Saving | `torch.save({"A": A}, "model.bin")` | `candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")?` |
|
||||
| Loading | `weights = torch.load("model.bin")` | `candle::safetensors::load("model.safetensors", &device)` |
|
||||
|
||||
<!--- ANCHOR_END: cheatsheet --->
|
||||
|
||||
|
||||
## Structure
|
||||
|
1
candle-book/.gitignore
vendored
Normal file
1
candle-book/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
book
|
6
candle-book/book.toml
Normal file
6
candle-book/book.toml
Normal file
@ -0,0 +1,6 @@
|
||||
[book]
|
||||
authors = ["Nicolas Patry"]
|
||||
language = "en"
|
||||
multilingual = false
|
||||
src = "src"
|
||||
title = "Candle Documentation"
|
6
candle-book/src/README.md
Normal file
6
candle-book/src/README.md
Normal file
@ -0,0 +1,6 @@
|
||||
# Introduction
|
||||
|
||||
{{#include ../../README.md:features}}
|
||||
|
||||
|
||||
This book will introduce step by step how to use `candle`.
|
27
candle-book/src/SUMMARY.md
Normal file
27
candle-book/src/SUMMARY.md
Normal file
@ -0,0 +1,27 @@
|
||||
# Summary
|
||||
|
||||
[Introduction](README.md)
|
||||
|
||||
# User Guide
|
||||
|
||||
- [Installation](guide/installation.md)
|
||||
- [Hello World - MNIST](guide/hello_world.md)
|
||||
- [PyTorch cheatsheet](guide/cheatsheet.md)
|
||||
|
||||
# Reference Guide
|
||||
|
||||
- [Running a model](inference/README.md)
|
||||
- [Using the hub](inference/hub.md)
|
||||
- [Serialization](inference/serialization.md)
|
||||
- [Advanced Cuda usage](inference/cuda/README.md)
|
||||
- [Writing a custom kernel](inference/cuda/writing.md)
|
||||
- [Porting a custom kernel](inference/cuda/porting.md)
|
||||
- [Error management](error_manage.md)
|
||||
- [Creating apps](apps/README.md)
|
||||
- [Creating a WASM app](apps/wasm.md)
|
||||
- [Creating a REST api webserver](apps/rest.md)
|
||||
- [Creating a desktop Tauri app](apps/dekstop.md)
|
||||
- [Training](training/README.md)
|
||||
- [MNIST](training/mnist.md)
|
||||
- [Fine-tuning](training/finetuning.md)
|
||||
- [Using MKL](advanced/mkl.md)
|
1
candle-book/src/advanced/mkl.md
Normal file
1
candle-book/src/advanced/mkl.md
Normal file
@ -0,0 +1 @@
|
||||
# Using MKL
|
1
candle-book/src/apps/README.md
Normal file
1
candle-book/src/apps/README.md
Normal file
@ -0,0 +1 @@
|
||||
# Creating apps
|
1
candle-book/src/apps/dekstop.md
Normal file
1
candle-book/src/apps/dekstop.md
Normal file
@ -0,0 +1 @@
|
||||
# Creating a desktop Tauri app
|
1
candle-book/src/apps/rest.md
Normal file
1
candle-book/src/apps/rest.md
Normal file
@ -0,0 +1 @@
|
||||
# Creating a REST api webserver
|
1
candle-book/src/apps/wasm.md
Normal file
1
candle-book/src/apps/wasm.md
Normal file
@ -0,0 +1 @@
|
||||
# Creating a WASM app
|
1
candle-book/src/chapter_1.md
Normal file
1
candle-book/src/chapter_1.md
Normal file
@ -0,0 +1 @@
|
||||
# Chapter 1
|
1
candle-book/src/error_manage.md
Normal file
1
candle-book/src/error_manage.md
Normal file
@ -0,0 +1 @@
|
||||
# Error management
|
3
candle-book/src/guide/cheatsheet.md
Normal file
3
candle-book/src/guide/cheatsheet.md
Normal file
@ -0,0 +1,3 @@
|
||||
# Pytorch cheatsheet
|
||||
|
||||
{{#include ../../../README.md:cheatsheet}}
|
195
candle-book/src/guide/hello_world.md
Normal file
195
candle-book/src/guide/hello_world.md
Normal file
@ -0,0 +1,195 @@
|
||||
# Hello world!
|
||||
|
||||
We will now create the hello world of the ML world, building a model capable of solving MNIST dataset.
|
||||
|
||||
Open `src/main.rs` and fill in this content:
|
||||
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
use candle_core::{DType, Device, Result, Tensor};
|
||||
|
||||
struct Model {
|
||||
first: Tensor,
|
||||
second: Tensor,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = image.matmul(&self.first)?;
|
||||
let x = x.relu()?;
|
||||
x.matmul(&self.second)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; to use the GPU.
|
||||
let device = Device::Cpu;
|
||||
|
||||
let first = Tensor::zeros((784, 100), DType::F32, &device)?;
|
||||
let second = Tensor::zeros((100, 10), DType::F32, &device)?;
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Everything should now run with:
|
||||
|
||||
```bash
|
||||
cargo run --release
|
||||
```
|
||||
|
||||
## Using a `Linear` layer.
|
||||
|
||||
Now that we have this, we might want to complexify things a bit, for instance by adding `bias` and creating
|
||||
the classical `Linear` layer. We can do as such
|
||||
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
# use candle_core::{DType, Device, Result, Tensor};
|
||||
struct Linear{
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
}
|
||||
impl Linear{
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.matmul(&self.weight)?;
|
||||
x.broadcast_add(&self.bias)
|
||||
}
|
||||
}
|
||||
|
||||
struct Model {
|
||||
first: Linear,
|
||||
second: Linear,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = self.first.forward(image)?;
|
||||
let x = x.relu()?;
|
||||
self.second.forward(&x)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This will change the model running code into a new function
|
||||
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
# use candle_core::{DType, Device, Result, Tensor};
|
||||
# struct Linear{
|
||||
# weight: Tensor,
|
||||
# bias: Tensor,
|
||||
# }
|
||||
# impl Linear{
|
||||
# fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
# let x = x.matmul(&self.weight)?;
|
||||
# x.broadcast_add(&self.bias)
|
||||
# }
|
||||
# }
|
||||
#
|
||||
# struct Model {
|
||||
# first: Linear,
|
||||
# second: Linear,
|
||||
# }
|
||||
#
|
||||
# impl Model {
|
||||
# fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
# let x = self.first.forward(image)?;
|
||||
# let x = x.relu()?;
|
||||
# self.second.forward(&x)
|
||||
# }
|
||||
# }
|
||||
fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; to use the GPU.
|
||||
// Use Device::Cpu; to use the CPU.
|
||||
let device = Device::cuda_if_available(0)?;
|
||||
|
||||
// Creating a dummy model
|
||||
let weight = Tensor::zeros((784, 100), DType::F32, &device)?;
|
||||
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
|
||||
let first = Linear{weight, bias};
|
||||
let weight = Tensor::zeros((100, 10), DType::F32, &device)?;
|
||||
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
|
||||
let second = Linear{weight, bias};
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||
|
||||
// Inference on the model
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Now it works, it is a great way to create your own layers.
|
||||
But most of the classical layers are already implemented in [candle-nn](https://github.com/LaurentMazare/candle/tree/main/candle-nn).
|
||||
|
||||
## Using `candle_nn`.
|
||||
|
||||
For instance [Linear](https://github.com/LaurentMazare/candle/blob/main/candle-nn/src/linear.rs) is already there.
|
||||
This Linear is coded with PyTorch layout in mind, to reuse better existing models out there, so it uses the transpose of the weights and not the weights directly.
|
||||
|
||||
So instead we can simplify our example:
|
||||
|
||||
```bash
|
||||
cargo add --git https://github.com/LaurentMazare/candle.git candle-nn
|
||||
```
|
||||
|
||||
And rewrite our examples using it
|
||||
|
||||
```rust
|
||||
# extern crate candle_core;
|
||||
# extern crate candle_nn;
|
||||
use candle_core::{DType, Device, Result, Tensor};
|
||||
use candle_nn::Linear;
|
||||
|
||||
struct Model {
|
||||
first: Linear,
|
||||
second: Linear,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, image: &Tensor) -> Result<Tensor> {
|
||||
let x = self.first.forward(image)?;
|
||||
let x = x.relu()?;
|
||||
self.second.forward(&x)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Use Device::new_cuda(0)?; to use the GPU.
|
||||
let device = Device::Cpu;
|
||||
|
||||
// This has changed (784, 100) -> (100, 784) !
|
||||
let weight = Tensor::zeros((100, 784), DType::F32, &device)?;
|
||||
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
|
||||
let first = Linear::new(weight, Some(bias));
|
||||
let weight = Tensor::zeros((10, 100), DType::F32, &device)?;
|
||||
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
|
||||
let second = Linear::new(weight, Some(bias));
|
||||
let model = Model { first, second };
|
||||
|
||||
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||
|
||||
let digit = model.forward(&dummy_image)?;
|
||||
println!("Digit {digit:?} digit");
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
Feel free to modify this example to use `Conv2d` to create a classical convnet instead.
|
||||
|
||||
|
||||
Now that we have the running dummy code we can get to more advanced topics:
|
||||
|
||||
- [For PyTorch users](./guide/cheatsheet.md)
|
||||
- [Running existing models](./inference/README.md)
|
||||
- [Training models](./training/README.md)
|
||||
|
||||
|
24
candle-book/src/guide/installation.md
Normal file
24
candle-book/src/guide/installation.md
Normal file
@ -0,0 +1,24 @@
|
||||
# Installation
|
||||
|
||||
Start by creating a new app:
|
||||
|
||||
```bash
|
||||
cargo new myapp
|
||||
cd myapp
|
||||
cargo add --git https://github.com/LaurentMazare/candle.git candle
|
||||
```
|
||||
|
||||
At this point, candle will be built **without** CUDA support.
|
||||
To get CUDA support use the `cuda` feature
|
||||
```bash
|
||||
cargo add --git https://github.com/LaurentMazare/candle.git candle --features cuda
|
||||
```
|
||||
|
||||
You can check everything works properly:
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
```
|
||||
|
||||
|
||||
You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md)
|
1
candle-book/src/inference/README.md
Normal file
1
candle-book/src/inference/README.md
Normal file
@ -0,0 +1 @@
|
||||
# Running a model
|
1
candle-book/src/inference/cuda/README.md
Normal file
1
candle-book/src/inference/cuda/README.md
Normal file
@ -0,0 +1 @@
|
||||
# Advanced Cuda usage
|
1
candle-book/src/inference/cuda/porting.md
Normal file
1
candle-book/src/inference/cuda/porting.md
Normal file
@ -0,0 +1 @@
|
||||
# Porting a custom kernel
|
1
candle-book/src/inference/cuda/writing.md
Normal file
1
candle-book/src/inference/cuda/writing.md
Normal file
@ -0,0 +1 @@
|
||||
# Writing a custom kernel
|
1
candle-book/src/inference/hub.md
Normal file
1
candle-book/src/inference/hub.md
Normal file
@ -0,0 +1 @@
|
||||
# Using the hub
|
1
candle-book/src/inference/serialization.md
Normal file
1
candle-book/src/inference/serialization.md
Normal file
@ -0,0 +1 @@
|
||||
# Serialization
|
1
candle-book/src/training/README.md
Normal file
1
candle-book/src/training/README.md
Normal file
@ -0,0 +1 @@
|
||||
# Training
|
1
candle-book/src/training/finetuning.md
Normal file
1
candle-book/src/training/finetuning.md
Normal file
@ -0,0 +1 @@
|
||||
# Fine-tuning
|
1
candle-book/src/training/mnist.md
Normal file
1
candle-book/src/training/mnist.md
Normal file
@ -0,0 +1 @@
|
||||
# MNIST
|
@ -1,18 +1,17 @@
|
||||
[package]
|
||||
name = "candle"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/LaurentMazare/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT/Apache-2.0"
|
||||
name = "candle-core"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", optional = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.1.0", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
half = { workspace = true }
|
||||
|
@ -2,9 +2,14 @@
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
|
||||
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
|
||||
let c = a.matmul(&b)?;
|
||||
println!("{a} {b} {c}");
|
||||
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 7., 1., 8., 2.]];
|
||||
let t1 = Tensor::new(data, &Device::Cpu)?;
|
||||
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 7., 1., 8., 2.]];
|
||||
|
@ -1,81 +0,0 @@
|
||||
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
// use candle::quantized::GgmlType;
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
// use clap::{Parser, Subcommand};
|
||||
|
||||
// fn softmax<D: candle::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
|
||||
// let dim = dim.to_index(xs.shape(), "softmax")?;
|
||||
// let max = xs.max_keepdim(dim)?;
|
||||
// let diff = xs.broadcast_sub(&max)?;
|
||||
// let num = diff.exp()?;
|
||||
// let den = num.sum_keepdim(dim)?;
|
||||
// num.broadcast_div(&den)
|
||||
// }
|
||||
|
||||
trait Benchmark {
|
||||
type PreProcessData;
|
||||
type RunResult;
|
||||
|
||||
fn preprocess() -> Result<Self::PreProcessData>;
|
||||
fn run_one(_: &Self::PreProcessData) -> Result<Self::RunResult>;
|
||||
|
||||
const ITERS: usize;
|
||||
}
|
||||
|
||||
struct Matmul;
|
||||
impl Benchmark for Matmul {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let lhs = Tensor::randn((1024, 1024), DType::F32, &Device::Cpu, 1.0, 0.0)?;
|
||||
let rhs = Tensor::randn((1024, 1024), DType::F32, &Device::Cpu, 1.0, 0.0)?;
|
||||
Ok((lhs, rhs))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.matmul(&d.1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
// struct Softmax;
|
||||
// impl Benchmark for Softmax {
|
||||
// type PreProcessData = Tensor;
|
||||
// type RunResult = Tensor;
|
||||
// fn preprocess() -> Result<Self::PreProcessData> {
|
||||
// // Typical whisper tiny size.
|
||||
// let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
|
||||
// Ok(x)
|
||||
// }
|
||||
//
|
||||
// fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
// softmax(d, D::Minus1)
|
||||
// }
|
||||
//
|
||||
// const ITERS: usize = 100;
|
||||
// }
|
||||
|
||||
fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> {
|
||||
use std::hint::black_box;
|
||||
|
||||
let iters = iters.unwrap_or(B::ITERS);
|
||||
let d = B::preprocess()?;
|
||||
let start = std::time::Instant::now();
|
||||
for _iter in 0..iters {
|
||||
let _res = black_box(B::run_one(black_box(&d))?);
|
||||
}
|
||||
println!("{:?}", start.elapsed() / iters as u32);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
run::<Matmul>(None)?;
|
||||
Ok(())
|
||||
}
|
@ -2,7 +2,7 @@
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
|
@ -4,7 +4,7 @@ extern crate intel_mkl_src;
|
||||
use std::str::FromStr;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn cos_sin(n: usize, device: &Device) -> Result<Tensor> {
|
||||
let thetas: Vec<_> = (0..n).map(|i| (i as f32 / n as f32)).collect();
|
||||
|
@ -1,6 +1,7 @@
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
|
||||
pub(crate) trait BackendStorage: Sized {
|
||||
pub trait BackendStorage: Sized {
|
||||
type Device: BackendDevice;
|
||||
|
||||
fn try_clone(&self, _: &Layout) -> Result<Self>;
|
||||
@ -16,16 +17,15 @@ pub(crate) trait BackendStorage: Sized {
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
|
||||
|
||||
fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self>;
|
||||
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>;
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
|
||||
|
||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
|
||||
|
||||
fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self>;
|
||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>;
|
||||
|
||||
fn binary_impl<B: crate::op::BinaryOp>(&self, _: &Self, _: &Layout, _: &Layout)
|
||||
-> Result<Self>;
|
||||
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
|
||||
|
||||
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
|
||||
|
||||
@ -37,7 +37,26 @@ pub(crate) trait BackendStorage: Sized {
|
||||
_params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn scatter_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self>;
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
@ -50,7 +69,7 @@ pub(crate) trait BackendStorage: Sized {
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
|
||||
}
|
||||
|
||||
pub(crate) trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
type Storage: BackendStorage;
|
||||
|
||||
// TODO: Make the usize generic and part of a generic DeviceLocation.
|
||||
|
@ -1,6 +1,20 @@
|
||||
use crate::{op::Op, Error, Result, Tensor, TensorId};
|
||||
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
|
||||
use crate::{Error, Result, Tensor, TensorId};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// arg has been reduced to node via reduce_dims, expand it back to arg.
|
||||
// This has to handle keepdims.
|
||||
fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result<Tensor> {
|
||||
if arg.rank() == node.rank() {
|
||||
// keepdim = true
|
||||
node.broadcast_as(arg.shape())
|
||||
} else {
|
||||
// keepdim = false
|
||||
// first expand the reduced dims.
|
||||
node.reshape(reduced_dims)?.broadcast_as(arg.shape())
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||
@ -24,7 +38,10 @@ impl Tensor {
|
||||
nodes
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::WhereCond(t1, t2, t3) => {
|
||||
Op::IndexAdd(t1, t2, t3, _)
|
||||
| Op::ScatterAdd(t1, t2, t3, _)
|
||||
| Op::CustomOp3(t1, t2, t3, _)
|
||||
| Op::WhereCond(t1, t2, t3) => {
|
||||
let (tg, nodes) = walk(t1, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(t2, nodes, already_seen);
|
||||
@ -38,11 +55,10 @@ impl Tensor {
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::Add(lhs, rhs)
|
||||
| Op::Mul(lhs, rhs)
|
||||
| Op::Sub(lhs, rhs)
|
||||
| Op::Div(lhs, rhs)
|
||||
| Op::Embedding(lhs, rhs)
|
||||
| Op::CustomOp2(lhs, rhs, _)
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::Gather(lhs, rhs, _)
|
||||
| Op::IndexSelect(lhs, rhs, _)
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
@ -65,24 +81,17 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
Op::Reshape(node)
|
||||
| Op::Copy(node)
|
||||
| Op::Broadcast(node)
|
||||
| Op::Sum(node, _)
|
||||
| Op::Cmp(node, _)
|
||||
| Op::Reduce(node, _, _)
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Narrow(node, _, _, _)
|
||||
| Op::Softmax(node, _)
|
||||
| Op::Sqr(node)
|
||||
| Op::Sqrt(node)
|
||||
| Op::Gelu(node)
|
||||
| Op::Relu(node)
|
||||
| Op::Unary(node, _)
|
||||
| Op::Elu(node, _)
|
||||
| Op::Exp(node)
|
||||
| Op::Log(node)
|
||||
| Op::Sin(node)
|
||||
| Op::Cos(node)
|
||||
| Op::Abs(node)
|
||||
| Op::Neg(node) => {
|
||||
| Op::CustomOp1(node, _) => {
|
||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
@ -116,19 +125,19 @@ impl Tensor {
|
||||
// this is out of scope.
|
||||
if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::Add(lhs, rhs) => {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
|
||||
}
|
||||
Op::Sub(lhs, rhs) => {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Sub) => {
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
|
||||
}
|
||||
Op::Mul(lhs, rhs) => {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Mul) => {
|
||||
let lhs_grad = grad.mul(rhs)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
@ -136,13 +145,13 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Div(lhs, rhs) => {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Div) => {
|
||||
let lhs_grad = grad.div(rhs)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
|
||||
}
|
||||
Op::WhereCond(pred, t, f) => {
|
||||
let zeros = grad.zeros_like()?;
|
||||
@ -153,9 +162,30 @@ impl Tensor {
|
||||
let f_grad = pred.where_cond(&zeros, &grad)?;
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }),
|
||||
Op::Embedding(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported { op: "embedding" })
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::Gather(arg, indexes, dim) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||
}
|
||||
Op::ScatterAdd(init, indexes, src, dim) => {
|
||||
let init_sum_grad = grads.or_insert(init)?;
|
||||
*init_sum_grad = init_sum_grad.add(&grad)?;
|
||||
|
||||
let src_grad = grad.gather(indexes, *dim)?;
|
||||
let src_sum_grad = grads.or_insert(src)?;
|
||||
*src_sum_grad = src_sum_grad.add(&src_grad)?;
|
||||
}
|
||||
Op::IndexAdd(init, indexes, src, dim) => {
|
||||
let init_sum_grad = grads.or_insert(init)?;
|
||||
*init_sum_grad = init_sum_grad.add(&grad)?;
|
||||
|
||||
let src_grad = grad.index_select(indexes, *dim)?;
|
||||
let src_sum_grad = grads.or_insert(src)?;
|
||||
*src_sum_grad = src_sum_grad.add(&src_grad)?;
|
||||
}
|
||||
Op::IndexSelect(arg, indexes, dim) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
|
||||
}
|
||||
Op::Matmul(lhs, rhs) => {
|
||||
// Skipping checks, the op went ok, we can skip
|
||||
@ -195,41 +225,69 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
let arg_grad = grad.sum(sum_dims.as_slice())?;
|
||||
let mut arg_grad = grad.sum_keepdim(sum_dims.as_slice())?;
|
||||
for _i in 0..left_dims {
|
||||
arg_grad = arg_grad.squeeze(0)?
|
||||
}
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
|
||||
*sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?;
|
||||
}
|
||||
Op::Sum(arg, _sum_dims) => {
|
||||
Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
|
||||
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.broadcast_add(&grad)?
|
||||
*sum_grad = sum_grad.add(&grad)?;
|
||||
}
|
||||
Op::Cmp(_args, _) => {}
|
||||
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
|
||||
let node = broadcast_back(arg, node, reduced_dims)?;
|
||||
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||
let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
|
||||
}
|
||||
Op::Reduce(arg, ReduceOp::Min, reduced_dims) => {
|
||||
let node = broadcast_back(arg, node, reduced_dims)?;
|
||||
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||
let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||
}
|
||||
Op::Copy(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad)?
|
||||
}
|
||||
Op::Affine { arg, mul, .. } => {
|
||||
let arg_grad = grad.affine(*mul, 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Log(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Log) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * *node)?)?
|
||||
*sum_grad = sum_grad.add(&(grad / arg)?)?
|
||||
}
|
||||
Op::Sin(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Sin) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
|
||||
}
|
||||
Op::Cos(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Cos) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
||||
}
|
||||
Op::Abs(_args) => return Err(Error::BackwardNotSupported { op: "abs" }),
|
||||
Op::Exp(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Abs) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad / arg)?)?
|
||||
let ones = arg.ones_like()?;
|
||||
let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?);
|
||||
*sum_grad = sum_grad.add(&(&grad * abs_grad)?)?
|
||||
}
|
||||
Op::Neg(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Exp) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * *node)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Neg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&grad)?
|
||||
}
|
||||
@ -259,24 +317,60 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Softmax(_arg, _) => {
|
||||
return Err(Error::BackwardNotSupported { op: "softmax" })
|
||||
}
|
||||
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
|
||||
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
|
||||
Op::Reshape(arg) => {
|
||||
let arg_grad = grad.reshape(arg.dims())?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }),
|
||||
Op::Relu(_) => return Err(Error::BackwardNotSupported { op: "relu" }),
|
||||
Op::Elu(..) => return Err(Error::BackwardNotSupported { op: "elu" }),
|
||||
Op::Sqr(arg) => {
|
||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||
Op::Unary(arg, UnaryOp::Relu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||
}
|
||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||
Op::CustomOp1(arg, c) => {
|
||||
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
}
|
||||
Op::CustomOp2(arg1, arg2, c) => {
|
||||
let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?;
|
||||
if let Some(arg_grad1) = arg_grad1 {
|
||||
let sum_grad = grads.or_insert(arg1)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad1)?
|
||||
}
|
||||
if let Some(arg_grad2) = arg_grad2 {
|
||||
let sum_grad = grads.or_insert(arg2)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad2)?
|
||||
}
|
||||
}
|
||||
Op::CustomOp3(arg1, arg2, arg3, c) => {
|
||||
let (arg_grad1, arg_grad2, arg_grad3) =
|
||||
c.bwd(arg1, arg2, arg3, node, &grad)?;
|
||||
if let Some(arg_grad1) = arg_grad1 {
|
||||
let sum_grad = grads.or_insert(arg1)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad1)?
|
||||
}
|
||||
if let Some(arg_grad2) = arg_grad2 {
|
||||
let sum_grad = grads.or_insert(arg2)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad2)?
|
||||
}
|
||||
if let Some(arg_grad3) = arg_grad3 {
|
||||
let sum_grad = grads.or_insert(arg3)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad3)?
|
||||
}
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Sqr) => {
|
||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Sqrt(arg) => {
|
||||
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
|
||||
Op::Unary(arg, UnaryOp::Sqrt) => {
|
||||
let arg_grad = grad.div(node)?.affine(0.5, 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,9 +1,12 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||
use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{
|
||||
CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
CudaFunction, CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig,
|
||||
ValidAsZeroBits,
|
||||
};
|
||||
use half::{bf16, f16};
|
||||
use std::sync::{Arc, Mutex};
|
||||
@ -32,9 +35,6 @@ pub enum CudaError {
|
||||
#[error("internal error '{0}'")]
|
||||
InternalError(&'static str),
|
||||
|
||||
#[error("internal error '{0}'")]
|
||||
WrappedError(Box<dyn std::error::Error + 'static + std::marker::Send + std::marker::Sync>),
|
||||
|
||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||
MatMulNonContiguous {
|
||||
lhs_stride: Vec<usize>,
|
||||
@ -100,7 +100,7 @@ impl std::ops::Deref for CudaDevice {
|
||||
}
|
||||
}
|
||||
|
||||
trait WrapErr<O> {
|
||||
pub trait WrapErr<O> {
|
||||
fn w(self) -> std::result::Result<O, crate::Error>;
|
||||
}
|
||||
|
||||
@ -170,7 +170,7 @@ impl CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||
if !self.has_func(module_name, module_name) {
|
||||
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
||||
// done once per kernel name.
|
||||
@ -255,6 +255,8 @@ impl BackendDevice for CudaDevice {
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
let slice = match dtype {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_uniform",
|
||||
@ -282,6 +284,8 @@ impl BackendDevice for CudaDevice {
|
||||
}
|
||||
|
||||
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
let slice = match dtype {
|
||||
@ -398,6 +402,82 @@ trait Map2 {
|
||||
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()>;
|
||||
|
||||
fn map(
|
||||
&self,
|
||||
dst: &mut S,
|
||||
dst_s: &Shape,
|
||||
src: &S,
|
||||
src_l: &Layout,
|
||||
d: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
match (dst, src) {
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait Map1Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
wrap: W,
|
||||
) -> Result<S>;
|
||||
|
||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||
let out = match s {
|
||||
S::U8(s) => self.f(s, d, l, S::U8)?,
|
||||
S::U32(s) => self.f(s, d, l, S::U32)?,
|
||||
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
||||
S::F16(s) => self.f(s, d, l, S::F16)?,
|
||||
S::F32(s) => self.f(s, d, l, S::F32)?,
|
||||
S::F64(s) => self.f(s, d, l, S::F64)?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
trait Map2Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
layout1: &Layout,
|
||||
src2: &CudaSlice<T>,
|
||||
layout2: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<S>;
|
||||
|
||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
|
||||
};
|
||||
Ok(out)
|
||||
@ -515,14 +595,15 @@ impl<'a> Map1 for Sum<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct FastSum<'a>(&'a [usize]);
|
||||
impl<'a> Map1 for FastSum<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
struct FastReduce<'a>(&'a [usize], ReduceOp);
|
||||
impl<'a> Map1Any for FastReduce<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
wrap: W,
|
||||
) -> Result<S> {
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
let src_el: usize = src_dims.iter().product();
|
||||
@ -557,16 +638,36 @@ impl<'a> Map1 for FastSum<'a> {
|
||||
.htod_copy([dims.as_slice(), stride.as_slice()].concat())
|
||||
.w()?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("fast_sum"), kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<T>(dst_el).w()?;
|
||||
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
let (name, check_empty, return_index) = match self.1 {
|
||||
ReduceOp::Sum => ("fast_sum", false, false),
|
||||
ReduceOp::Min => ("fast_min", true, false),
|
||||
ReduceOp::Max => ("fast_max", true, false),
|
||||
ReduceOp::ArgMin => ("fast_argmin", true, true),
|
||||
ReduceOp::ArgMax => ("fast_argmax", true, true),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
|
||||
if return_index {
|
||||
// SAFETY: filled in by the follow up kernel.
|
||||
let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
|
||||
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(S::U32(out))
|
||||
} else {
|
||||
// SAFETY: filled in by the follow up kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(wrap(out))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<U: crate::op::UnaryOp> Map1 for U {
|
||||
impl<U: UnaryOpT> Map1 for U {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
@ -589,46 +690,200 @@ impl<U: crate::op::UnaryOp> Map1 for U {
|
||||
}
|
||||
}
|
||||
|
||||
struct Embedding<'a>(&'a CudaStorage, &'a Layout);
|
||||
impl<'a> Map1 for Embedding<'a> {
|
||||
struct IndexSelect<'a>(&'a CudaStorage, &'a Layout, usize);
|
||||
impl<'a> Map1 for IndexSelect<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
rhs: &CudaSlice<T>,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
rhs_l: &Layout,
|
||||
src_l: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let ids_l = &self.1;
|
||||
let ids = match &self.0.slice {
|
||||
CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
|
||||
let (name, ids) = match &self.0.slice {
|
||||
CudaStorageSlice::U32(slice) => {
|
||||
("is_u32", *slice.slice(ids_l.start_offset()..).device_ptr())
|
||||
}
|
||||
CudaStorageSlice::U8(slice) => {
|
||||
("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "embedding ids should be u32",
|
||||
msg: "index_select ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: self.0.dtype(),
|
||||
})
|
||||
.w()?,
|
||||
};
|
||||
let ids = &ids;
|
||||
let shape = ids_l.shape();
|
||||
let (v_size, h_size) = rhs_l
|
||||
.shape()
|
||||
.r2()
|
||||
.map_err(|e| CudaError::WrappedError(Box::new(e)))
|
||||
.w()?;
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?;
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("emb"), kernels::EMBEDDINGS)?;
|
||||
let ids_shape = ids_l.shape();
|
||||
let ids_dims = ids_shape.dims();
|
||||
let ids_el = ids_shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(ids_el as u32);
|
||||
let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
|
||||
};
|
||||
let left_size: usize = src_l.dims()[..self.2].iter().product();
|
||||
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
|
||||
let dim_size = src_l.dims()[self.2];
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el * h_size) }.w()?;
|
||||
let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
|
||||
let out = unsafe { dev.alloc::<T>(ids_el * left_size * right_size) }.w()?;
|
||||
let params = (
|
||||
ids_el,
|
||||
ids_dims.len(),
|
||||
&ds,
|
||||
ids,
|
||||
&src,
|
||||
&out,
|
||||
left_size,
|
||||
dim_size,
|
||||
right_size,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct Gather<'a>(&'a CudaStorage, &'a Layout, usize);
|
||||
impl<'a> Map1 for Gather<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
src_l: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let ids = &self.0;
|
||||
let ids_l = &self.1;
|
||||
let dim = self.2;
|
||||
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
};
|
||||
let (name, ids) = match &ids.slice {
|
||||
CudaStorageSlice::U32(slice) => {
|
||||
("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr())
|
||||
}
|
||||
CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "gather ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let el = ids_l.shape().elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
};
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let ids_dim_sz = ids_l.dims()[dim];
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
el, ids, &src, &out, left_sz, src_dim_sz, ids_dim_sz, right_sz,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct IndexAdd<'a>(&'a CudaStorage, &'a Layout, usize);
|
||||
impl<'a> Map2InPlace for IndexAdd<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
let ids = &self.0;
|
||||
let ids_l = &self.1;
|
||||
let dim = self.2;
|
||||
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
let (name, ids) = match &ids.slice {
|
||||
CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "index-add ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_shape.dims()[dim];
|
||||
let ids_dim_sz = ids_l.dims()[0];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (
|
||||
ids, ids_dim_sz, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct ScatterAdd<'a>(&'a CudaStorage, &'a Layout, usize);
|
||||
impl<'a> Map2InPlace for ScatterAdd<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
let ids = &self.0;
|
||||
let ids_l = &self.1;
|
||||
let dim = self.2;
|
||||
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
};
|
||||
let (name, ids) = match &ids.slice {
|
||||
CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "scatter-add ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
};
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_shape.dims()[dim];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (ids, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
|
||||
impl<'a> Map2 for Conv1D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
@ -680,16 +935,22 @@ impl<'a> Map2 for WhereCond<'a> {
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let ids_l = &self.1;
|
||||
let ids = match &self.0.slice {
|
||||
CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
|
||||
let (ids, name) = match &self.0.slice {
|
||||
CudaStorageSlice::U8(slice) => {
|
||||
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
|
||||
(ptr, "where_u8")
|
||||
}
|
||||
CudaStorageSlice::U32(slice) => {
|
||||
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
|
||||
(ptr, "where_u32")
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "where conditions should be u32",
|
||||
msg: "where conditions should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: self.0.dtype(),
|
||||
})
|
||||
.w()?,
|
||||
};
|
||||
let ids = &ids;
|
||||
let shape = ids_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
@ -699,7 +960,7 @@ impl<'a> Map2 for WhereCond<'a> {
|
||||
.w()?;
|
||||
let t = &t.slice(layout_t.start_offset()..);
|
||||
let f = &f.slice(layout_f.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("where"), kernels::TERNARY)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::TERNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||
@ -709,7 +970,7 @@ impl<'a> Map2 for WhereCond<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<U: crate::op::BinaryOp> Map2 for U {
|
||||
impl<U: crate::op::BinaryOpT> Map2 for U {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
lhs: &CudaSlice<T>,
|
||||
@ -737,6 +998,43 @@ impl<U: crate::op::BinaryOp> Map2 for U {
|
||||
}
|
||||
}
|
||||
|
||||
struct Cmp(CmpOp);
|
||||
impl Map2Any for Cmp {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
lhs: &CudaSlice<T>,
|
||||
lhs_l: &Layout,
|
||||
rhs: &CudaSlice<T>,
|
||||
rhs_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<S> {
|
||||
let shape = lhs_l.shape();
|
||||
let dims = shape.dims();
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let dims_and_strides = dev
|
||||
.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||
.w()?;
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let name = match self.0 {
|
||||
CmpOp::Eq => "eq",
|
||||
CmpOp::Ne => "ne",
|
||||
CmpOp::Lt => "lt",
|
||||
CmpOp::Le => "le",
|
||||
CmpOp::Gt => "gt",
|
||||
CmpOp::Ge => "ge",
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::BINARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
|
||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(S::U8(out))
|
||||
}
|
||||
}
|
||||
|
||||
fn slice_src_and_dst<'a, T>(
|
||||
src: &'a CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
@ -762,6 +1060,50 @@ pub struct CudaStorage {
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
pub trait CudaDType: Sized {
|
||||
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>>;
|
||||
fn wrap_cuda_slice(s: CudaSlice<Self>, dev: CudaDevice) -> CudaStorage;
|
||||
}
|
||||
|
||||
macro_rules! cuda_dtype {
|
||||
($ty:ty, $dtype:ident) => {
|
||||
impl CudaDType for $ty {
|
||||
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>> {
|
||||
match &s.slice {
|
||||
CudaStorageSlice::$dtype(data) => Ok(&data),
|
||||
_ => Err(crate::Error::UnexpectedDType {
|
||||
expected: DType::$dtype,
|
||||
got: s.dtype(),
|
||||
msg: "unexpected dtype",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
fn wrap_cuda_slice(slice: CudaSlice<Self>, device: CudaDevice) -> CudaStorage {
|
||||
let slice = CudaStorageSlice::$dtype(slice);
|
||||
CudaStorage { slice, device }
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
cuda_dtype!(u8, U8);
|
||||
cuda_dtype!(u32, U32);
|
||||
cuda_dtype!(f16, F16);
|
||||
cuda_dtype!(bf16, BF16);
|
||||
cuda_dtype!(f32, F32);
|
||||
cuda_dtype!(f64, F64);
|
||||
|
||||
impl CudaStorage {
|
||||
pub fn wrap_cuda_slice<T: CudaDType>(slice: CudaSlice<T>, device: CudaDevice) -> CudaStorage {
|
||||
T::wrap_cuda_slice(slice, device)
|
||||
}
|
||||
|
||||
pub fn as_cuda_slice<T: CudaDType>(&self) -> Result<&CudaSlice<T>> {
|
||||
T::as_cuda_slice(self)
|
||||
}
|
||||
}
|
||||
|
||||
fn gemm_config<T>(
|
||||
alpha: T,
|
||||
beta: T,
|
||||
@ -788,8 +1130,7 @@ fn gemm_config<T>(
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})
|
||||
.w()?
|
||||
})?
|
||||
};
|
||||
// The b tensor has dims batching, m, k (lhs)
|
||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
@ -801,8 +1142,7 @@ fn gemm_config<T>(
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})
|
||||
.w()?
|
||||
})?
|
||||
};
|
||||
// The setup below was copied from:
|
||||
// https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531
|
||||
@ -827,8 +1167,7 @@ fn gemm_config<T>(
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})
|
||||
.w()?,
|
||||
})?,
|
||||
};
|
||||
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
@ -838,8 +1177,7 @@ fn gemm_config<T>(
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})
|
||||
.w()?,
|
||||
})?,
|
||||
};
|
||||
|
||||
Ok(StridedBatchedConfig {
|
||||
@ -876,7 +1214,6 @@ impl BackendStorage for CudaStorage {
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
use cudarc::driver::DevicePtr;
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
@ -955,23 +1292,25 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?;
|
||||
let slice = FastReduce(sum_dims, op).map(&self.slice, &device, layout)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
|
||||
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Cmp(op).map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||
fn unary_impl<U: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = U::V.map(&self.slice, &device, layout)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn binary_impl<B: crate::op::BinaryOp>(
|
||||
fn binary_impl<B: BinaryOpT>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
lhs_l: &Layout,
|
||||
@ -1042,11 +1381,46 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
|
||||
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
fn scatter_add(
|
||||
&self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
}
|
||||
fn index_add(
|
||||
&self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
@ -1110,7 +1484,7 @@ impl BackendStorage for CudaStorage {
|
||||
.w()?;
|
||||
CudaStorageSlice::F64(out)
|
||||
}
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in matmul op")).w()?,
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in matmul op"))?,
|
||||
};
|
||||
let device = dev.clone();
|
||||
Ok(Self { slice, device })
|
||||
@ -1198,8 +1572,7 @@ impl BackendStorage for CudaStorage {
|
||||
}
|
||||
_ => Err(CudaError::InternalError(
|
||||
"dtype mismatch in copy_strided op",
|
||||
))
|
||||
.w()?,
|
||||
))?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -71,8 +71,7 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let mut vec = Vec::new();
|
||||
vec.reserve(N1 * N2 * N3);
|
||||
let mut vec = Vec::with_capacity(N1 * N2 * N3);
|
||||
for i1 in 0..N1 {
|
||||
for i2 in 0..N2 {
|
||||
vec.extend(self[i1][i2])
|
||||
@ -117,12 +116,12 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rand_uniform(
|
||||
pub(crate) fn rand_uniform_f64(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
lo: f64,
|
||||
up: f64,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
@ -136,12 +135,21 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rand_normal(
|
||||
pub(crate) fn rand_uniform<T: crate::FloatDType>(
|
||||
&self,
|
||||
lo: T,
|
||||
up: T,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
) -> Result<Storage> {
|
||||
self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
|
||||
}
|
||||
|
||||
pub(crate) fn rand_normal_f64(
|
||||
&self,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
@ -155,6 +163,15 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rand_normal<T: crate::FloatDType>(
|
||||
&self,
|
||||
mean: T,
|
||||
std: T,
|
||||
shape: &Shape,
|
||||
) -> Result<Storage> {
|
||||
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
|
||||
}
|
||||
|
||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
|
@ -119,3 +119,33 @@ with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
|
||||
with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
|
||||
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
|
||||
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
|
||||
|
||||
pub trait IntDType: WithDType {
|
||||
fn is_true(&self) -> bool;
|
||||
fn as_usize(&self) -> usize;
|
||||
}
|
||||
|
||||
impl IntDType for u32 {
|
||||
fn is_true(&self) -> bool {
|
||||
*self != 0
|
||||
}
|
||||
fn as_usize(&self) -> usize {
|
||||
*self as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl IntDType for u8 {
|
||||
fn is_true(&self) -> bool {
|
||||
*self != 0
|
||||
}
|
||||
fn as_usize(&self) -> usize {
|
||||
*self as usize
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FloatDType: WithDType {}
|
||||
|
||||
impl FloatDType for f16 {}
|
||||
impl FloatDType for bf16 {}
|
||||
impl FloatDType for f32 {}
|
||||
impl FloatDType for f64 {}
|
||||
|
@ -1,4 +1,5 @@
|
||||
#![allow(dead_code)]
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -40,11 +41,11 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
@ -52,16 +53,11 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self> {
|
||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn binary_impl<B: crate::op::BinaryOp>(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Layout,
|
||||
) -> Result<Self> {
|
||||
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
@ -79,7 +75,34 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
|
@ -79,6 +79,19 @@ pub enum Error {
|
||||
nth_shape: Shape,
|
||||
},
|
||||
|
||||
#[error("Cannot divide tensor of shape {shape:?} equally along dim {dim} into {n_parts}")]
|
||||
ShapeMismatchSplit {
|
||||
shape: Shape,
|
||||
dim: usize,
|
||||
n_parts: usize,
|
||||
},
|
||||
|
||||
#[error("{op} can only be performed on a single dimension")]
|
||||
OnlySingleDimension { op: &'static str, dims: Vec<usize> },
|
||||
|
||||
#[error("empty tensor for {op}")]
|
||||
EmptyTensor { op: &'static str },
|
||||
|
||||
// === Device Errors ===
|
||||
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||
DeviceMismatchBinaryOp {
|
||||
@ -106,11 +119,11 @@ pub enum Error {
|
||||
msg: &'static str,
|
||||
},
|
||||
|
||||
#[error("{op} invalid index {index} with vocab {vocab_size}")]
|
||||
#[error("{op} invalid index {index} with dim size {size}")]
|
||||
InvalidIndex {
|
||||
op: &'static str,
|
||||
index: usize,
|
||||
vocab_size: usize,
|
||||
size: usize,
|
||||
},
|
||||
|
||||
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
||||
@ -168,6 +181,7 @@ pub enum Error {
|
||||
#[error("unsupported safetensor dtype {0:?}")]
|
||||
UnsupportedSafeTensorDtype(safetensors::Dtype),
|
||||
|
||||
/// Arbitrary errors wrapping.
|
||||
#[error(transparent)]
|
||||
Wrapped(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
@ -176,6 +190,10 @@ pub enum Error {
|
||||
inner: Box<Self>,
|
||||
backtrace: Box<std::backtrace::Backtrace>,
|
||||
},
|
||||
|
||||
/// User generated error message, typically created via `bail!`.
|
||||
#[error("{0}")]
|
||||
Msg(String),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@ -197,3 +215,24 @@ impl Error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! bail {
|
||||
($msg:literal $(,)?) => {
|
||||
return Err($crate::Error::Msg(format!($msg).into()).bt())
|
||||
};
|
||||
($err:expr $(,)?) => {
|
||||
return Err($crate::Error::Msg(format!($err).into()).bt())
|
||||
};
|
||||
($fmt:expr, $($arg:tt)*) => {
|
||||
return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt())
|
||||
};
|
||||
}
|
||||
|
||||
pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
|
||||
match (r1, r2) {
|
||||
(Ok(r1), Ok(r2)) => Ok((r1, r2)),
|
||||
(Err(e), _) => Err(e),
|
||||
(_, Err(e)) => Err(e),
|
||||
}
|
||||
}
|
||||
|
516
candle-core/src/ggml.rs
Normal file
516
candle-core/src/ggml.rs
Normal file
@ -0,0 +1,516 @@
|
||||
//! Support for the GGML file format.
|
||||
|
||||
use crate::{DType, Device, Result, Tensor};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use half::f16;
|
||||
|
||||
// Default to QK_K 256 rather than 64.
|
||||
pub const QK_K: usize = 256;
|
||||
pub const K_SCALE_SIZE: usize = 12;
|
||||
|
||||
pub const QK4_0: usize = 32;
|
||||
pub const QK4_1: usize = 32;
|
||||
pub const QK5_0: usize = 32;
|
||||
pub const QK5_1: usize = 32;
|
||||
pub const QK8_0: usize = 32;
|
||||
pub const QK8_1: usize = 32;
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ4_0 {
|
||||
d: f16,
|
||||
qs: [u8; QK4_0 / 2],
|
||||
}
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ4_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
qs: [u8; QK4_1 / 2],
|
||||
}
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ5_0 {
|
||||
d: f16,
|
||||
qh: [u8; 4],
|
||||
qs: [u8; QK5_0 / 2],
|
||||
}
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ5_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
qh: [u8; 4],
|
||||
qs: [u8; QK5_1 / 2],
|
||||
}
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ8_0 {
|
||||
d: f16,
|
||||
qs: [u8; QK8_0],
|
||||
}
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ8_1 {
|
||||
d: f16,
|
||||
s: f16,
|
||||
qs: [u8; QK8_1],
|
||||
}
|
||||
const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ2K {
|
||||
scales: [u8; QK_K / 16],
|
||||
qs: [u8; QK_K / 4],
|
||||
d: f16,
|
||||
dmin: f16,
|
||||
}
|
||||
const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ3K {
|
||||
hmask: [u8; QK_K / 8],
|
||||
qs: [u8; QK_K / 4],
|
||||
scales: [u8; 12],
|
||||
d: f16,
|
||||
}
|
||||
const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>());
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
|
||||
#[repr(C)]
|
||||
struct BlockQ4K {
|
||||
d: f16,
|
||||
dmin: f16,
|
||||
scales: [u8; K_SCALE_SIZE],
|
||||
qs: [u8; QK_K / 2],
|
||||
}
|
||||
const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ5K {
|
||||
d: f16,
|
||||
dmin: f16,
|
||||
scales: [u8; K_SCALE_SIZE],
|
||||
qh: [u8; QK_K / 8],
|
||||
qs: [u8; QK_K / 2],
|
||||
}
|
||||
const _: () =
|
||||
assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
|
||||
|
||||
#[repr(C)]
|
||||
struct BlockQ6K {
|
||||
ql: [u8; QK_K / 2],
|
||||
qh: [u8; QK_K / 4],
|
||||
scales: [i8; QK_K / 16],
|
||||
d: f16,
|
||||
}
|
||||
const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354
|
||||
fn dequantize_row_q2k(xs: &[BlockQ2K], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK_K != 0 {
|
||||
crate::bail!("dequantize_row_q2k: {k} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut ys_index = 0;
|
||||
for x in xs {
|
||||
let d = x.d.to_f32();
|
||||
let min = x.dmin.to_f32();
|
||||
let q = &x.qs;
|
||||
|
||||
let mut is = 0;
|
||||
for n in (0..QK_K).step_by(128) {
|
||||
// Step by 32 over q.
|
||||
let q = &q[n / 4..];
|
||||
let mut shift = 0;
|
||||
for _j in 0..4 {
|
||||
let sc = x.scales[is];
|
||||
is += 1;
|
||||
let dl = d * (sc & 0xF) as f32;
|
||||
let ml = min * (sc >> 4) as f32;
|
||||
for q in &q[..16] {
|
||||
let y = dl * ((q >> shift) & 3) as i8 as f32 - ml;
|
||||
ys[ys_index] = y;
|
||||
ys_index += 1;
|
||||
}
|
||||
|
||||
let sc = x.scales[is];
|
||||
is += 1;
|
||||
let dl = d * (sc & 0xF) as f32;
|
||||
let ml = min * (sc >> 4) as f32;
|
||||
for q in &q[16..32] {
|
||||
let y = dl * ((q >> shift) & 3) as i8 as f32 - ml;
|
||||
ys[ys_index] = y;
|
||||
ys_index += 1;
|
||||
}
|
||||
|
||||
shift += 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
|
||||
if j < 4 {
|
||||
let d = q[j] & 63;
|
||||
let m = q[j + 4] & 63;
|
||||
(d, m)
|
||||
} else {
|
||||
let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
|
||||
let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
|
||||
(d, m)
|
||||
}
|
||||
}
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735
|
||||
fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK_K != 0 {
|
||||
crate::bail!("dequantize_row_q4k: {k} is not divisible by {QK_K}")
|
||||
}
|
||||
let mut ys_index = 0;
|
||||
for x in xs.iter() {
|
||||
let d = x.d.to_f32();
|
||||
let min = x.dmin.to_f32();
|
||||
let q = &x.qs;
|
||||
let mut is = 0;
|
||||
for j in (0..QK_K).step_by(64) {
|
||||
let q = &q[j / 2..j / 2 + 32];
|
||||
let (sc, m) = get_scale_min_k4(is, &x.scales);
|
||||
let d1 = d * sc as f32;
|
||||
let m1 = min * m as f32;
|
||||
let (sc, m) = get_scale_min_k4(is + 1, &x.scales);
|
||||
let d2 = d * sc as f32;
|
||||
let m2 = min * m as f32;
|
||||
for q in q {
|
||||
let y = d1 * (q & 0xF) as f32 - m1;
|
||||
ys[ys_index] = y;
|
||||
ys_index += 1;
|
||||
}
|
||||
for q in q {
|
||||
let y = d2 * (q >> 4) as f32 - m2;
|
||||
ys[ys_index] = y;
|
||||
ys_index += 1;
|
||||
}
|
||||
is += 2;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067
|
||||
fn dequantize_row_q6k(xs: &[BlockQ6K], ys: &mut [f32]) -> Result<()> {
|
||||
let k = ys.len();
|
||||
if k % QK_K != 0 {
|
||||
crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}")
|
||||
}
|
||||
for x in xs.iter() {
|
||||
let d = x.d.to_f32();
|
||||
let ql = &x.ql;
|
||||
let qh = &x.qh;
|
||||
let sc = &x.scales;
|
||||
for n in (0..QK_K).step_by(128) {
|
||||
let idx = n / 128;
|
||||
let ys = &mut ys[n..];
|
||||
let sc = &sc[8 * idx..];
|
||||
let ql = &ql[64 * idx..];
|
||||
let qh = &qh[32 * idx..];
|
||||
for l in 0..32 {
|
||||
let is = l / 16;
|
||||
let q1 = ((ql[l] & 0xF) | ((qh[l] & 3) << 4)) as i8 - 32;
|
||||
let q2 = ((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i8 - 32;
|
||||
let q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i8 - 32;
|
||||
let q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i8 - 32;
|
||||
ys[l] = d * sc[is] as f32 * q1 as f32;
|
||||
ys[l + 32] = d * sc[is + 2] as f32 * q2 as f32;
|
||||
ys[l + 64] = d * sc[is + 4] as f32 * q3 as f32;
|
||||
ys[l + 96] = d * sc[is + 6] as f32 * q4 as f32;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum Magic {
|
||||
Ggjt,
|
||||
Ggla,
|
||||
Ggmf,
|
||||
Ggml,
|
||||
Ggsn,
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for Magic {
|
||||
type Error = crate::Error;
|
||||
fn try_from(value: u32) -> Result<Self> {
|
||||
let magic = match value {
|
||||
0x67676a74 => Self::Ggjt,
|
||||
0x67676c61 => Self::Ggla,
|
||||
0x67676d66 => Self::Ggmf,
|
||||
0x67676d6c => Self::Ggml,
|
||||
0x6767736e => Self::Ggsn,
|
||||
_ => crate::bail!("unknown magic {value:08x}"),
|
||||
};
|
||||
Ok(magic)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum VersionedMagic {
|
||||
GgmlUnversioned,
|
||||
GgmfV1,
|
||||
GgjtV1,
|
||||
GgjtV2,
|
||||
GgjtV3,
|
||||
}
|
||||
|
||||
impl VersionedMagic {
|
||||
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
|
||||
let magic = reader.read_u32::<LittleEndian>()?;
|
||||
let magic = Magic::try_from(magic)?;
|
||||
if magic == Magic::Ggml {
|
||||
return Ok(Self::GgmlUnversioned);
|
||||
}
|
||||
let version = reader.read_u32::<LittleEndian>()?;
|
||||
let versioned_magic = match (magic, version) {
|
||||
(Magic::Ggmf, 1) => Self::GgmfV1,
|
||||
(Magic::Ggjt, 1) => Self::GgjtV1,
|
||||
(Magic::Ggjt, 2) => Self::GgjtV2,
|
||||
(Magic::Ggjt, 3) => Self::GgjtV3,
|
||||
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
||||
};
|
||||
Ok(versioned_magic)
|
||||
}
|
||||
|
||||
fn align32(&self) -> bool {
|
||||
match self {
|
||||
Self::GgmlUnversioned | Self::GgmfV1 => false,
|
||||
Self::GgjtV1 | Self::GgjtV2 | Self::GgjtV3 => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct HParams {
|
||||
pub n_vocab: u32,
|
||||
pub n_embd: u32,
|
||||
pub n_mult: u32,
|
||||
pub n_head: u32,
|
||||
pub n_layer: u32,
|
||||
pub n_rot: u32,
|
||||
pub ftype: u32,
|
||||
}
|
||||
|
||||
impl HParams {
|
||||
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
|
||||
let n_vocab = reader.read_u32::<LittleEndian>()?;
|
||||
let n_embd = reader.read_u32::<LittleEndian>()?;
|
||||
let n_mult = reader.read_u32::<LittleEndian>()?;
|
||||
let n_head = reader.read_u32::<LittleEndian>()?;
|
||||
let n_layer = reader.read_u32::<LittleEndian>()?;
|
||||
let n_rot = reader.read_u32::<LittleEndian>()?;
|
||||
let ftype = reader.read_u32::<LittleEndian>()?;
|
||||
Ok(Self {
|
||||
n_vocab,
|
||||
n_embd,
|
||||
n_mult,
|
||||
n_head,
|
||||
n_layer,
|
||||
n_rot,
|
||||
ftype,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Vocab {
|
||||
pub token_score_pairs: Vec<(Vec<u8>, f32)>,
|
||||
}
|
||||
|
||||
impl Vocab {
|
||||
fn read<R: std::io::Read>(reader: &mut R, n_vocab: usize) -> Result<Self> {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L556
|
||||
let mut token_score_pairs = Vec::with_capacity(n_vocab);
|
||||
for _index in 0..n_vocab {
|
||||
let len = reader.read_u32::<LittleEndian>()? as usize;
|
||||
let mut word = vec![0u8; len];
|
||||
reader.read_exact(&mut word)?;
|
||||
let score = reader.read_f32::<LittleEndian>()?;
|
||||
token_score_pairs.push((word, score))
|
||||
}
|
||||
Ok(Self { token_score_pairs })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum GgmlDType {
|
||||
F32,
|
||||
F16,
|
||||
Q4_0,
|
||||
Q4_1,
|
||||
Q5_0,
|
||||
Q5_1,
|
||||
Q8_0,
|
||||
Q8_1,
|
||||
Q2K,
|
||||
Q3K,
|
||||
Q4K,
|
||||
Q5K,
|
||||
Q6K,
|
||||
}
|
||||
|
||||
impl GgmlDType {
|
||||
fn from_u32(u: u32) -> Result<Self> {
|
||||
let dtype = match u {
|
||||
0 => Self::F32,
|
||||
1 => Self::F16,
|
||||
2 => Self::Q4_0,
|
||||
3 => Self::Q4_1,
|
||||
6 => Self::Q5_0,
|
||||
7 => Self::Q5_1,
|
||||
8 => Self::Q8_0,
|
||||
9 => Self::Q8_1,
|
||||
10 => Self::Q2K,
|
||||
11 => Self::Q3K,
|
||||
12 => Self::Q4K,
|
||||
13 => Self::Q5K,
|
||||
14 => Self::Q6K,
|
||||
_ => crate::bail!("unknown dtype for tensor {u}"),
|
||||
};
|
||||
Ok(dtype)
|
||||
}
|
||||
|
||||
fn type_size(&self) -> usize {
|
||||
match self {
|
||||
Self::F32 => 4,
|
||||
Self::F16 => 2,
|
||||
Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
|
||||
Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
|
||||
Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
|
||||
Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
|
||||
Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
|
||||
Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
|
||||
Self::Q2K => std::mem::size_of::<BlockQ2K>(),
|
||||
Self::Q3K => std::mem::size_of::<BlockQ3K>(),
|
||||
Self::Q4K => std::mem::size_of::<BlockQ4K>(),
|
||||
Self::Q5K => std::mem::size_of::<BlockQ5K>(),
|
||||
Self::Q6K => std::mem::size_of::<BlockQ6K>(),
|
||||
}
|
||||
}
|
||||
|
||||
fn blck_size(&self) -> usize {
|
||||
match self {
|
||||
Self::F32 => 1,
|
||||
Self::F16 => 1,
|
||||
Self::Q4_0 => QK4_0,
|
||||
Self::Q4_1 => QK4_1,
|
||||
Self::Q5_0 => QK5_0,
|
||||
Self::Q5_1 => QK5_1,
|
||||
Self::Q8_0 => QK8_0,
|
||||
Self::Q8_1 => QK8_1,
|
||||
Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K => QK_K,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Content {
|
||||
pub magic: VersionedMagic,
|
||||
pub hparams: HParams,
|
||||
pub vocab: Vocab,
|
||||
pub tensors: Vec<(String, Tensor)>,
|
||||
}
|
||||
|
||||
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
magic: VersionedMagic,
|
||||
device: &Device,
|
||||
) -> Result<(String, Tensor)> {
|
||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||
let dtype = reader.read_u32::<LittleEndian>()?;
|
||||
let dtype = GgmlDType::from_u32(dtype)?;
|
||||
let mut dims = vec![0u32; n_dims as usize];
|
||||
reader.read_u32_into::<LittleEndian>(&mut dims)?;
|
||||
let mut name = vec![0u8; name_len as usize];
|
||||
reader.read_exact(&mut name)?;
|
||||
let name = String::from_utf8_lossy(&name).into_owned();
|
||||
|
||||
if magic.align32() {
|
||||
let pos = reader.stream_position()?;
|
||||
reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?;
|
||||
}
|
||||
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let size_in_bytes = tensor_elems * dtype.type_size() / dtype.blck_size();
|
||||
println!("{name} {dtype:?} {dims:?}");
|
||||
// TODO: Mmap version to avoid copying the data around?
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
let tensor = match dtype {
|
||||
GgmlDType::F32 => Tensor::from_raw_buffer(&raw_data, DType::F32, &dims, device)?,
|
||||
GgmlDType::F16 => Tensor::from_raw_buffer(&raw_data, DType::F16, &dims, device)?,
|
||||
GgmlDType::Q2K => {
|
||||
let mut f32_data = vec![0f32; tensor_elems];
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ2K>();
|
||||
let raw_data =
|
||||
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ2K, n_blocks) };
|
||||
dequantize_row_q2k(raw_data, &mut f32_data)?;
|
||||
// Maybe we should use bf16 instead?
|
||||
Tensor::from_vec(f32_data, dims, device)?
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let mut f32_data = vec![0f32; tensor_elems];
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ4K>();
|
||||
let raw_data =
|
||||
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ4K, n_blocks) };
|
||||
dequantize_row_q4k(raw_data, &mut f32_data)?;
|
||||
Tensor::from_vec(f32_data, dims, device)?
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let mut f32_data = vec![0f32; tensor_elems];
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ6K>();
|
||||
let raw_data =
|
||||
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ6K, n_blocks) };
|
||||
dequantize_row_q6k(raw_data, &mut f32_data)?;
|
||||
Tensor::from_vec(f32_data, dims, device)?
|
||||
}
|
||||
_ => crate::bail!("quantized type {dtype:?} used in {name} is not supported yet"),
|
||||
};
|
||||
Ok((name, tensor))
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Content> {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||
let magic = VersionedMagic::read(reader)?;
|
||||
let hparams = HParams::read(reader)?;
|
||||
let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
|
||||
let mut tensors = vec![];
|
||||
|
||||
while reader.stream_position()? != last_position {
|
||||
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||
tensors.push((name, tensor))
|
||||
}
|
||||
Ok(Self {
|
||||
magic,
|
||||
hparams,
|
||||
vocab,
|
||||
tensors,
|
||||
})
|
||||
}
|
||||
}
|
@ -7,7 +7,7 @@ impl Tensor {
|
||||
/// Intended to be use by the trait `.i()`
|
||||
///
|
||||
/// ```
|
||||
/// # use candle::{Tensor, DType, Device, IndexOp};
|
||||
/// # use candle_core::{Tensor, DType, Device, IndexOp};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
///
|
||||
/// let c = a.i(0..1)?;
|
||||
@ -22,7 +22,7 @@ impl Tensor {
|
||||
/// let c = a.i((.., ..=2))?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 3]);
|
||||
///
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
fn index(&self, indexers: &[TensorIndexer]) -> Result<Self, Error> {
|
||||
let mut x = self.clone();
|
||||
@ -42,7 +42,7 @@ impl Tensor {
|
||||
Bound::Excluded(n) => *n,
|
||||
Bound::Unbounded => dims[i],
|
||||
};
|
||||
let out = x.narrow(current_dim, start, stop - start)?;
|
||||
let out = x.narrow(current_dim, start, stop.saturating_sub(start))?;
|
||||
current_dim += 1;
|
||||
out
|
||||
}
|
||||
|
@ -1,8 +1,8 @@
|
||||
//! ML framework for Rust
|
||||
//!
|
||||
//! ```rust
|
||||
//! use candle::{Tensor, DType, Device};
|
||||
//! # use candle::Error;
|
||||
//! use candle_core::{Tensor, DType, Device};
|
||||
//! # use candle_core::Error;
|
||||
//! # fn main() -> Result<(), Error>{
|
||||
//!
|
||||
//! let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
|
||||
@ -33,18 +33,19 @@
|
||||
//!
|
||||
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
|
||||
|
||||
mod backend;
|
||||
mod backprop;
|
||||
pub mod backend;
|
||||
pub mod backprop;
|
||||
mod conv;
|
||||
mod convert;
|
||||
mod cpu_backend;
|
||||
pub mod cpu_backend;
|
||||
#[cfg(feature = "cuda")]
|
||||
mod cuda_backend;
|
||||
pub mod cuda_backend;
|
||||
mod device;
|
||||
pub mod display;
|
||||
mod dtype;
|
||||
mod dummy_cuda_backend;
|
||||
mod error;
|
||||
pub mod error;
|
||||
pub mod ggml;
|
||||
mod indexer;
|
||||
pub mod layout;
|
||||
#[cfg(feature = "mkl")]
|
||||
@ -52,7 +53,7 @@ mod mkl;
|
||||
pub mod npy;
|
||||
mod op;
|
||||
pub mod safetensors;
|
||||
mod shape;
|
||||
pub mod shape;
|
||||
mod storage;
|
||||
mod strided_index;
|
||||
mod tensor;
|
||||
@ -61,10 +62,11 @@ mod variable;
|
||||
|
||||
pub use cpu_backend::CpuStorage;
|
||||
pub use device::{Device, DeviceLocation};
|
||||
pub use dtype::{DType, WithDType};
|
||||
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use indexer::IndexOp;
|
||||
pub use layout::Layout;
|
||||
pub use op::{CustomOp1, CustomOp2, CustomOp3};
|
||||
pub use shape::{Shape, D};
|
||||
pub use storage::Storage;
|
||||
pub use strided_index::{StridedBlocks, StridedIndex};
|
||||
|
@ -1,15 +1,74 @@
|
||||
use crate::Tensor;
|
||||
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
|
||||
use half::{bf16, f16};
|
||||
use num_traits::float::Float;
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CmpOp {
|
||||
Eq,
|
||||
Ne,
|
||||
Le,
|
||||
Ge,
|
||||
Lt,
|
||||
Gt,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ReduceOp {
|
||||
Sum,
|
||||
Min,
|
||||
Max,
|
||||
ArgMin,
|
||||
ArgMax,
|
||||
}
|
||||
|
||||
impl ReduceOp {
|
||||
pub(crate) fn name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::ArgMax => "argmax",
|
||||
Self::ArgMin => "argmin",
|
||||
Self::Min => "min",
|
||||
Self::Max => "max",
|
||||
Self::Sum => "sum",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// These ops return the same type as their input type.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum BinaryOp {
|
||||
Add,
|
||||
Mul,
|
||||
Sub,
|
||||
Div,
|
||||
}
|
||||
|
||||
// Unary ops with no argument
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum UnaryOp {
|
||||
Exp,
|
||||
Log,
|
||||
Sin,
|
||||
Cos,
|
||||
Abs,
|
||||
Neg,
|
||||
Sqr,
|
||||
Sqrt,
|
||||
Gelu,
|
||||
Relu,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) enum Op {
|
||||
Add(Tensor, Tensor),
|
||||
Mul(Tensor, Tensor),
|
||||
Sub(Tensor, Tensor),
|
||||
Div(Tensor, Tensor),
|
||||
pub enum Op {
|
||||
Binary(Tensor, Tensor, BinaryOp),
|
||||
Unary(Tensor, UnaryOp),
|
||||
Cmp(Tensor, CmpOp),
|
||||
// The third argument is the reduced shape with `keepdim=true`.
|
||||
Reduce(Tensor, ReduceOp, Vec<usize>),
|
||||
Matmul(Tensor, Tensor),
|
||||
Embedding(Tensor, Tensor),
|
||||
Gather(Tensor, Tensor, usize),
|
||||
ScatterAdd(Tensor, Tensor, Tensor, usize),
|
||||
IndexSelect(Tensor, Tensor, usize),
|
||||
IndexAdd(Tensor, Tensor, Tensor, usize),
|
||||
WhereCond(Tensor, Tensor, Tensor),
|
||||
|
||||
#[allow(dead_code)]
|
||||
@ -28,29 +87,126 @@ pub(crate) enum Op {
|
||||
mul: f64,
|
||||
add: f64,
|
||||
},
|
||||
Sum(Tensor, Vec<usize>),
|
||||
ToDType(Tensor),
|
||||
Copy(Tensor),
|
||||
Broadcast(Tensor),
|
||||
Exp(Tensor),
|
||||
Log(Tensor),
|
||||
Sin(Tensor),
|
||||
Cos(Tensor),
|
||||
Abs(Tensor),
|
||||
Narrow(Tensor, usize, usize, usize),
|
||||
Neg(Tensor),
|
||||
Reshape(Tensor),
|
||||
Softmax(Tensor, usize),
|
||||
Sqr(Tensor),
|
||||
Sqrt(Tensor),
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
Gelu(Tensor),
|
||||
Relu(Tensor),
|
||||
Elu(Tensor, f64),
|
||||
// TODO: Support for custom ops.
|
||||
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>),
|
||||
CustomOp2(Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp2>>),
|
||||
CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp3>>),
|
||||
}
|
||||
|
||||
pub(crate) trait UnaryOp {
|
||||
/// Unary ops that can be defined in user-land.
|
||||
pub trait CustomOp1: Send + Sync {
|
||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||
/// The function should return the gradient of the argument.
|
||||
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp2: Send + Sync {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
_arg2: &Tensor,
|
||||
_res: &Tensor,
|
||||
_grad_res: &Tensor,
|
||||
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp3: Send + Sync {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
s3: &CpuStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
_arg2: &Tensor,
|
||||
_arg3: &Tensor,
|
||||
_res: &Tensor,
|
||||
_grad_res: &Tensor,
|
||||
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
pub trait UnaryOpT {
|
||||
const NAME: &'static str;
|
||||
const KERNEL: &'static str;
|
||||
const V: Self;
|
||||
@ -73,7 +229,7 @@ pub(crate) trait UnaryOp {
|
||||
fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
|
||||
}
|
||||
|
||||
pub(crate) trait BinaryOp {
|
||||
pub trait BinaryOpT {
|
||||
const NAME: &'static str;
|
||||
const KERNEL: &'static str;
|
||||
const V: Self;
|
||||
@ -115,7 +271,7 @@ pub(crate) struct Relu;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||
impl BinaryOp for $op {
|
||||
impl BinaryOpT for $op {
|
||||
const NAME: &'static str = $name;
|
||||
const KERNEL: &'static str = concat!("b", $name);
|
||||
const V: Self = $op;
|
||||
@ -169,7 +325,7 @@ bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
|
||||
|
||||
macro_rules! unary_op {
|
||||
($op: ident, $name: literal, $a: ident, $e: expr) => {
|
||||
impl UnaryOp for $op {
|
||||
impl UnaryOpT for $op {
|
||||
const NAME: &'static str = $name;
|
||||
const KERNEL: &'static str = concat!("u", $name);
|
||||
const V: Self = $op;
|
||||
@ -201,7 +357,7 @@ macro_rules! unary_op {
|
||||
};
|
||||
|
||||
($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
|
||||
impl UnaryOp for $op {
|
||||
impl UnaryOpT for $op {
|
||||
const NAME: &'static str = $name;
|
||||
const KERNEL: &'static str = concat!("u", $name);
|
||||
const V: Self = $op;
|
||||
@ -259,7 +415,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||
|
||||
/// `gelu` operation
|
||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||
impl UnaryOp for Gelu {
|
||||
impl UnaryOpT for Gelu {
|
||||
const NAME: &'static str = "gelu";
|
||||
const V: Self = Gelu;
|
||||
#[inline(always)]
|
||||
@ -325,7 +481,7 @@ impl UnaryOp for Gelu {
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOp for Relu {
|
||||
impl UnaryOpT for Relu {
|
||||
const NAME: &'static str = "relu";
|
||||
const KERNEL: &'static str = "urelu";
|
||||
const V: Self = Relu;
|
||||
@ -354,3 +510,63 @@ impl UnaryOp for Relu {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
|
||||
/// properly checked when creating a new value
|
||||
#[derive(Clone)]
|
||||
pub struct BackpropOp(Option<Op>);
|
||||
|
||||
impl BackpropOp {
|
||||
pub(crate) fn none() -> Self {
|
||||
BackpropOp(None)
|
||||
}
|
||||
|
||||
pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
|
||||
let op = if arg.track_op() {
|
||||
Some(f(arg.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
|
||||
pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
|
||||
let op = if arg1.track_op() || arg2.track_op() {
|
||||
Some(f(arg1.clone(), arg2.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
|
||||
pub(crate) fn new3(
|
||||
arg1: &Tensor,
|
||||
arg2: &Tensor,
|
||||
arg3: &Tensor,
|
||||
f: impl Fn(Tensor, Tensor, Tensor) -> Op,
|
||||
) -> Self {
|
||||
let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
|
||||
Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
|
||||
pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
|
||||
let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
|
||||
let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
|
||||
Some(f(args))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for BackpropOp {
|
||||
type Target = Option<Op>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,9 @@
|
||||
use crate::{DType, Device, Error, Result, Tensor, WithDType};
|
||||
use safetensors::tensor as st;
|
||||
pub use safetensors::tensor::SafeTensors;
|
||||
use safetensors::tensor::SafeTensors;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
impl From<DType> for st::Dtype {
|
||||
fn from(value: DType) -> Self {
|
||||
@ -52,6 +54,27 @@ impl st::View for Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
impl st::View for &Tensor {
|
||||
fn dtype(&self) -> st::Dtype {
|
||||
(*self).dtype().into()
|
||||
}
|
||||
fn shape(&self) -> &[usize] {
|
||||
self.dims()
|
||||
}
|
||||
|
||||
fn data(&self) -> Cow<[u8]> {
|
||||
// This copies data from GPU to CPU.
|
||||
// TODO: Avoid the unwrap here.
|
||||
Cow::Owned(convert_back(self).unwrap())
|
||||
}
|
||||
|
||||
fn data_len(&self) -> usize {
|
||||
let n: usize = self.dims().iter().product();
|
||||
let bytes_per_element = (*self).dtype().size_in_bytes();
|
||||
n * bytes_per_element
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
pub fn save_safetensors<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
@ -63,15 +86,15 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
let v = view.data();
|
||||
fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
|
||||
let size_in_bytes = T::DTYPE.size_in_bytes();
|
||||
let elem_count = v.len() / size_in_bytes;
|
||||
if (v.as_ptr() as usize) % size_in_bytes == 0 {
|
||||
let elem_count = data.len() / size_in_bytes;
|
||||
if (data.as_ptr() as usize) % size_in_bytes == 0 {
|
||||
// SAFETY This is safe because we just checked that this
|
||||
// was correctly aligned.
|
||||
let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) };
|
||||
Tensor::from_slice(data, view.shape(), device)
|
||||
let data: &[T] =
|
||||
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
|
||||
Tensor::from_slice(data, shape, device)
|
||||
} else {
|
||||
// XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
|
||||
// Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
|
||||
@ -81,13 +104,57 @@ fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<
|
||||
// We're downgrading the `c` pointer from T to u8, which removes alignment
|
||||
// constraints.
|
||||
unsafe {
|
||||
std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len());
|
||||
std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
|
||||
c.set_len(elem_count)
|
||||
}
|
||||
Tensor::from_slice(&c, view.shape(), device)
|
||||
Tensor::from_slice(&c, shape, device)
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
|
||||
data: &[u8],
|
||||
shape: &[usize],
|
||||
device: &Device,
|
||||
conv: F,
|
||||
) -> Result<Tensor> {
|
||||
let size_in_bytes = std::mem::size_of::<T>();
|
||||
let elem_count = data.len() / size_in_bytes;
|
||||
if (data.as_ptr() as usize) % size_in_bytes == 0 {
|
||||
// SAFETY This is safe because we just checked that this
|
||||
// was correctly aligned.
|
||||
let data: &[T] =
|
||||
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
|
||||
let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
|
||||
Tensor::from_vec(data, shape, device)
|
||||
} else {
|
||||
// XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
|
||||
// Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
|
||||
let mut c: Vec<T> = Vec::with_capacity(elem_count);
|
||||
// SAFETY: We just created c, so the allocated memory is necessarily
|
||||
// contiguous and non overlapping with the view's data.
|
||||
// We're downgrading the `c` pointer from T to u8, which removes alignment
|
||||
// constraints.
|
||||
unsafe {
|
||||
std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
|
||||
c.set_len(elem_count)
|
||||
}
|
||||
let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
|
||||
Tensor::from_vec(c, shape, device)
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
|
||||
view: &st::TensorView<'_>,
|
||||
device: &Device,
|
||||
conv: F,
|
||||
) -> Result<Tensor> {
|
||||
convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
|
||||
}
|
||||
|
||||
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
convert_slice::<T>(view.data(), view.shape(), device)
|
||||
}
|
||||
|
||||
fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
|
||||
let size_in_bytes = T::DTYPE.size_in_bytes();
|
||||
let length = vs.len() * size_in_bytes;
|
||||
@ -112,19 +179,55 @@ impl<'a> Load for st::TensorView<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
impl Tensor {
|
||||
pub fn from_raw_buffer(
|
||||
data: &[u8],
|
||||
dtype: DType,
|
||||
shape: &[usize],
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
match dtype {
|
||||
DType::U8 => convert_slice::<u8>(data, shape, device),
|
||||
DType::U32 => convert_slice::<u32>(data, shape, device),
|
||||
DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
|
||||
DType::F16 => convert_slice::<half::f16>(data, shape, device),
|
||||
DType::F32 => convert_slice::<f32>(data, shape, device),
|
||||
DType::F64 => convert_slice::<f64>(data, shape, device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
match view.dtype() {
|
||||
st::Dtype::U8 => convert_::<u8>(view, device),
|
||||
st::Dtype::U32 => convert_::<u8>(view, device),
|
||||
st::Dtype::U16 => {
|
||||
let conv = |x| Ok(u32::from(x));
|
||||
convert_with_cast_::<u16, u32, _>(view, device, conv)
|
||||
}
|
||||
st::Dtype::U32 => convert_::<u32>(view, device),
|
||||
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
|
||||
st::Dtype::F16 => convert_::<half::f16>(view, device),
|
||||
st::Dtype::F32 => convert_::<f32>(view, device),
|
||||
st::Dtype::F64 => convert_::<f64>(view, device),
|
||||
st::Dtype::I32 => {
|
||||
let conv = |x| {
|
||||
u32::try_from(x)
|
||||
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
|
||||
};
|
||||
convert_with_cast_::<i32, u32, _>(view, device, conv)
|
||||
}
|
||||
st::Dtype::I64 => {
|
||||
let conv = |x| {
|
||||
u32::try_from(x)
|
||||
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
|
||||
};
|
||||
convert_with_cast_::<i64, u32, _>(view, device, conv)
|
||||
}
|
||||
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
// TODO: This makes an unnecessary copy when the tensor is on the cpu.
|
||||
let tensor = tensor.flatten_all()?;
|
||||
match tensor.dtype() {
|
||||
@ -137,6 +240,19 @@ pub fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
|
||||
let data = std::fs::read(filename.as_ref())?;
|
||||
let st = safetensors::SafeTensors::deserialize(&data)?;
|
||||
st.tensors()
|
||||
.into_iter()
|
||||
.map(|(name, view)| Ok((name, view.load(device)?)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn save<P: AsRef<Path>>(tensors: &HashMap<&str, Tensor>, filename: P) -> Result<()> {
|
||||
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
|
||||
}
|
||||
|
||||
pub struct MmapedFile(memmap2::Mmap);
|
||||
|
||||
impl MmapedFile {
|
||||
@ -173,11 +289,15 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_multiple_tensors() {
|
||||
fn save_load_multiple_tensors() {
|
||||
let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
|
||||
let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap();
|
||||
let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect();
|
||||
st::serialize_to_file(map, &None, std::path::Path::new("multi.safetensors")).unwrap();
|
||||
save(&map, "multi.safetensors").unwrap();
|
||||
|
||||
let weights = load("multi.safetensors", &Device::Cpu).unwrap();
|
||||
assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]);
|
||||
assert_eq!(weights.get("u").unwrap().dims(), &[1, 2]);
|
||||
let bytes = std::fs::read("multi.safetensors").unwrap();
|
||||
assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
|
||||
std::fs::remove_file("multi.safetensors").unwrap();
|
||||
|
@ -41,6 +41,12 @@ impl From<usize> for Shape {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize,)> for Shape {
|
||||
fn from(d1: (usize,)) -> Self {
|
||||
Self(vec![d1.0])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize)> for Shape {
|
||||
fn from(d12: (usize, usize)) -> Self {
|
||||
Self(vec![d12.0, d12.1])
|
||||
@ -87,6 +93,12 @@ macro_rules! extract_dims {
|
||||
}
|
||||
}
|
||||
}
|
||||
impl crate::Tensor {
|
||||
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||
self.shape().$fn_name()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::TryInto<$out_type> for Shape {
|
||||
type Error = crate::Error;
|
||||
fn try_into(self) -> std::result::Result<$out_type, Self::Error> {
|
||||
@ -328,23 +340,23 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
|
||||
}
|
||||
}
|
||||
|
||||
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
|
||||
extract_dims!(r1, 1, |d: &[usize]| d[0], usize);
|
||||
extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||
extract_dims!(dims0, 0, |_: &Vec<usize>| (), ());
|
||||
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
|
||||
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||
extract_dims!(
|
||||
r3,
|
||||
dims3,
|
||||
3,
|
||||
|d: &[usize]| (d[0], d[1], d[2]),
|
||||
(usize, usize, usize)
|
||||
);
|
||||
extract_dims!(
|
||||
r4,
|
||||
dims4,
|
||||
4,
|
||||
|d: &[usize]| (d[0], d[1], d[2], d[3]),
|
||||
(usize, usize, usize, usize)
|
||||
);
|
||||
extract_dims!(
|
||||
r5,
|
||||
dims5,
|
||||
5,
|
||||
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
|
||||
(usize, usize, usize, usize, usize)
|
||||
|
@ -1,5 +1,6 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
// out of memory. Instead try_clone should be used.
|
||||
@ -80,26 +81,48 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.sum(layout, s)?;
|
||||
pub(crate) fn cmp(
|
||||
&self,
|
||||
op: CmpOp,
|
||||
rhs: &Self,
|
||||
lhs_layout: &Layout,
|
||||
rhs_layout: &Layout,
|
||||
) -> Result<Self> {
|
||||
self.same_device(rhs, "cmp")?;
|
||||
self.same_dtype(rhs, "cmp")?;
|
||||
match (self, rhs) {
|
||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.sum(layout, s)?;
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "cmp",
|
||||
}
|
||||
.bt())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This assumes a contiguous layout and no offset.
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
pub(crate) fn reduce_op(&self, op: ReduceOp, layout: &Layout, s: &[usize]) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
|
||||
Self::Cuda(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
@ -115,8 +138,65 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||
// TODO: Different code path for the contiguous case?
|
||||
pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
|
||||
match self {
|
||||
Self::Cpu(storage) => {
|
||||
let (storage, shape) = c.cpu_fwd(storage, l)?;
|
||||
Ok((Self::Cpu(storage), shape))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
||||
Ok((Self::Cuda(storage), shape))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn custom_op2(
|
||||
&self,
|
||||
l1: &Layout,
|
||||
t2: &Self,
|
||||
l2: &Layout,
|
||||
c: &dyn CustomOp2,
|
||||
) -> Result<(Self, Shape)> {
|
||||
self.same_device(t2, c.name())?;
|
||||
match (self, t2) {
|
||||
(Self::Cpu(s1), Self::Cpu(s2)) => {
|
||||
let (s, shape) = c.cpu_fwd(s1, l1, s2, l2)?;
|
||||
Ok((Self::Cpu(s), shape))
|
||||
}
|
||||
(Self::Cuda(s1), Self::Cuda(s2)) => {
|
||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
||||
Ok((Self::Cuda(s), shape))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn custom_op3(
|
||||
&self,
|
||||
l1: &Layout,
|
||||
t2: &Self,
|
||||
l2: &Layout,
|
||||
t3: &Self,
|
||||
l3: &Layout,
|
||||
c: &dyn CustomOp3,
|
||||
) -> Result<(Self, Shape)> {
|
||||
self.same_device(t2, c.name())?;
|
||||
self.same_device(t3, c.name())?;
|
||||
match (self, t2, t3) {
|
||||
(Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => {
|
||||
let (s, shape) = c.cpu_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||
Ok((Self::Cpu(s), shape))
|
||||
}
|
||||
(Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => {
|
||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||
Ok((Self::Cuda(s), shape))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.unary_impl::<B>(layout)?;
|
||||
@ -129,7 +209,7 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn binary_impl<B: op::BinaryOp>(
|
||||
pub(crate) fn binary_impl<B: op::BinaryOpT>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
lhs_layout: &Layout,
|
||||
@ -215,21 +295,96 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
self.same_device(rhs, "embedding")?;
|
||||
pub(crate) fn gather(
|
||||
&self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<Self> {
|
||||
self.same_device(indexes, "index-add")?;
|
||||
match (self, indexes) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes)) => {
|
||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes)) => {
|
||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn scatter_add(
|
||||
&self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
source: &Self,
|
||||
source_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<Self> {
|
||||
self.same_device(indexes, "scatter-add")?;
|
||||
self.same_device(source, "scatter-add")?;
|
||||
match (self, indexes, source) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn index_add(
|
||||
&self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
source: &Self,
|
||||
source_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<Self> {
|
||||
self.same_device(indexes, "index-add")?;
|
||||
self.same_device(source, "index-add")?;
|
||||
match (self, indexes, source) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn index_select(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<Self> {
|
||||
self.same_device(rhs, "index-select")?;
|
||||
match (self, rhs) {
|
||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||
let storage = lhs.embedding(layout, rhs, rhs_l)?;
|
||||
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
|
||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = lhs.embedding(layout, rhs, rhs_l)?;
|
||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "embedding",
|
||||
op: "index-select",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -34,25 +34,50 @@ impl Var {
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
pub fn rand<S: Into<Shape>>(
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
lo: f64,
|
||||
up: f64,
|
||||
) -> Result<Self> {
|
||||
let inner = Tensor::rand_impl(s, dtype, device, lo, up, true)?;
|
||||
pub fn from_tensor(t: &Tensor) -> Result<Self> {
|
||||
let inner = t.make_var()?;
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
pub fn randn<S: Into<Shape>>(
|
||||
pub fn rand_f64<S: Into<Shape>>(
|
||||
lo: f64,
|
||||
up: f64,
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?;
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
pub fn randn_f64<S: Into<Shape>>(
|
||||
mean: f64,
|
||||
std: f64,
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let inner = Tensor::randn_impl(s, dtype, device, mean, std, true)?;
|
||||
let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?;
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
|
||||
lo: T,
|
||||
up: T,
|
||||
s: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let inner = Tensor::rand_impl(lo, up, s, device, true)?;
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
|
||||
mean: T,
|
||||
std: T,
|
||||
s: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let inner = Tensor::randn_impl(mean, std, s, device, true)?;
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
|
116
candle-core/tests/custom_op_tests.rs
Normal file
116
candle-core/tests/custom_op_tests.rs
Normal file
@ -0,0 +1,116 @@
|
||||
use candle_core::backend::BackendStorage;
|
||||
use candle_core::cpu_backend;
|
||||
use candle_core::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
|
||||
|
||||
mod test_utils;
|
||||
use test_utils::to_vec1_round;
|
||||
|
||||
fn fwd<T: num_traits::Float>(v: T, alpha: f64) -> T {
|
||||
if v.is_sign_positive() {
|
||||
v
|
||||
} else {
|
||||
let alpha = T::from(alpha).unwrap_or(T::nan());
|
||||
(v.exp() - T::one()) * alpha
|
||||
}
|
||||
}
|
||||
|
||||
struct Elu {
|
||||
alpha: f64,
|
||||
}
|
||||
|
||||
impl CustomOp1 for Elu {
|
||||
fn name(&self) -> &'static str {
|
||||
"elu"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
let storage = candle_core::map_dtype!(
|
||||
"elu",
|
||||
s,
|
||||
|s| cpu_backend::unary_map(s, l, |v| fwd(v, self.alpha)),
|
||||
(BF16, F16, F32, F64)
|
||||
);
|
||||
Ok((storage, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_op1_no_backward() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
|
||||
let t = (t - 5.)?;
|
||||
let elu_t = t.custom_op1(Elu { alpha: 1. })?;
|
||||
assert_eq!(
|
||||
to_vec1_round(&elu_t, 4)?,
|
||||
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Define a similar struct as Elu but with backward support.
|
||||
fn bwd<T: num_traits::Float>(v: T, alpha: f64) -> T {
|
||||
if v.is_sign_positive() {
|
||||
T::one()
|
||||
} else {
|
||||
let alpha = T::from(alpha).unwrap_or(T::nan());
|
||||
v.exp() * alpha
|
||||
}
|
||||
}
|
||||
|
||||
struct EluBackward {
|
||||
alpha: f64,
|
||||
}
|
||||
|
||||
impl CustomOp1 for EluBackward {
|
||||
fn name(&self) -> &'static str {
|
||||
"elu-bwd"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
let storage = candle_core::map_dtype!(
|
||||
"elu-bwd",
|
||||
s,
|
||||
|s| cpu_backend::unary_map(s, l, |v| bwd(v, self.alpha)),
|
||||
(BF16, F16, F32, F64)
|
||||
);
|
||||
Ok((storage, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
struct EluWithBackward(Elu);
|
||||
|
||||
impl EluWithBackward {
|
||||
fn new(alpha: f64) -> Self {
|
||||
Self(Elu { alpha })
|
||||
}
|
||||
}
|
||||
|
||||
impl CustomOp1 for EluWithBackward {
|
||||
fn name(&self) -> &'static str {
|
||||
"elu"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
self.0.cpu_fwd(s, l)
|
||||
}
|
||||
|
||||
fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||
let alpha = self.0.alpha;
|
||||
let bwd = arg.custom_op1(EluBackward { alpha })?;
|
||||
Ok(Some(grad_res.mul(&bwd)?))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_op1_with_backward() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let t = candle_core::Var::new(&[-2f32, 0f32, 2f32], cpu)?;
|
||||
let elu_t = t.custom_op1(EluWithBackward::new(2.))?;
|
||||
assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]);
|
||||
|
||||
let grads = elu_t.backward()?;
|
||||
let grad_x = grads.get(&t).unwrap();
|
||||
assert_eq!(to_vec1_round(grad_x, 4)?, [0.2707, 1.0, 1.0]);
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device::Cpu, Tensor};
|
||||
use candle_core::{DType, Device::Cpu, Tensor};
|
||||
|
||||
#[test]
|
||||
fn display_scalar() -> Result<()> {
|
||||
|
@ -1,5 +1,5 @@
|
||||
use anyhow::{Context, Result};
|
||||
use candle::{Device, Shape, Var};
|
||||
use candle_core::{Device, Shape, Tensor, Var};
|
||||
mod test_utils;
|
||||
|
||||
fn simple_grad(device: &Device) -> Result<()> {
|
||||
@ -79,7 +79,97 @@ fn grad_descent(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unary_grad(device: &Device) -> Result<()> {
|
||||
let x = Var::new(&[3f32, 1., 4., 0.15], device)?;
|
||||
let x = x.as_tensor();
|
||||
let y = (x.log()? + 1.)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [2.0986123, 1.0, 2.3862944, -0.89712]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [0.33333334, 1.0, 0.25, 6.6666665]);
|
||||
let y = x.exp()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
y.to_vec1::<f32>()?,
|
||||
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
||||
);
|
||||
assert_eq!(
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
||||
);
|
||||
let y = x.exp()?.sqr()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
y.to_vec1::<f32>()?,
|
||||
[403.4288, 7.3890557, 2980.9578, 1.3498588]
|
||||
);
|
||||
// exp(x)^2 = exp(2*x)
|
||||
assert_eq!(
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[806.8576, 14.778111, 5961.9155, 2.6997175]
|
||||
);
|
||||
let y = x.sin()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[0.1411, 0.8415, -0.7568, 0.1494],
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[-0.99, 0.5403, -0.6536, 0.9888],
|
||||
);
|
||||
let y = x.cos()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[-0.99, 0.5403, -0.6536, 0.9888],
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[-0.1411, -0.8415, 0.7568, -0.1494],
|
||||
);
|
||||
let y = x.sqr()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [9.0, 1.0, 16.0, 0.0225]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, 8.0, 0.3]);
|
||||
let y = x.sqr()?.sqrt()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [3.0, 1.0, 4.0, 0.15]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1.0, 1.0, 1.0, 1.0]);
|
||||
let y = x.neg()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [-3.0, -1.0, -4.0, -0.15]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [-1.0, -1.0, -1.0, -1.0]);
|
||||
let y = x.affine(0.2, 1.)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [1.6, 1.2, 1.8, 1.03]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [0.2, 0.2, 0.2, 0.2]);
|
||||
let y = Tensor::new(1f32, device)?.broadcast_div(x)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [0.33333334, 1.0, 0.25, 6.6666665]);
|
||||
assert_eq!(
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[-0.11111111, -1.0, -0.0625, -44.444443],
|
||||
);
|
||||
let y = x.broadcast_div(&Tensor::new(0.5f32, device)?)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [6., 2., 8., 0.3]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [2., 2., 2., 2.]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
||||
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
||||
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
|
||||
|
@ -1,5 +1,5 @@
|
||||
use anyhow::Result;
|
||||
use candle::{Device, IndexOp, Tensor};
|
||||
use candle_core::{Device, IndexOp, Tensor};
|
||||
|
||||
mod test_utils;
|
||||
|
||||
@ -58,6 +58,19 @@ fn range_index() -> Result<()> {
|
||||
let result = tensor.i(..=1)?;
|
||||
assert_eq!(result.dims(), &[2, 3]);
|
||||
assert_eq!(result.to_vec2::<u32>()?, &[[0, 1, 2], [3, 4, 5]]);
|
||||
|
||||
// Empty range
|
||||
let result = tensor.i(1..1)?;
|
||||
assert_eq!(result.dims(), &[0, 3]);
|
||||
let empty: [[u32; 3]; 0] = [];
|
||||
assert_eq!(result.to_vec2::<u32>()?, &empty);
|
||||
|
||||
// Similar to PyTorch, allow empty ranges when the computed length is negative.
|
||||
#[allow(clippy::reversed_empty_ranges)]
|
||||
let result = tensor.i(1..0)?;
|
||||
assert_eq!(result.dims(), &[0, 3]);
|
||||
let empty: [[u32; 3]; 0] = [];
|
||||
assert_eq!(result.to_vec2::<u32>()?, &empty);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
mod test_utils;
|
||||
use candle::{Device, IndexOp, Result, Tensor};
|
||||
use candle_core as candle;
|
||||
|
||||
fn contiguous(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::arange(0u32, 24u32, device)?.reshape((2, 3, 4))?;
|
||||
|
@ -1,10 +1,9 @@
|
||||
mod test_utils;
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use test_utils::to_vec3_round;
|
||||
use candle_core::{DType, Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn zeros(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||
let (dim1, dim2) = tensor.shape().r2()?;
|
||||
let (dim1, dim2) = tensor.dims2()?;
|
||||
assert_eq!(dim1, 5);
|
||||
assert_eq!(dim2, 2);
|
||||
Ok(())
|
||||
@ -12,7 +11,7 @@ fn zeros(device: &Device) -> Result<()> {
|
||||
|
||||
fn add_mul(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||
let dim1 = tensor.shape().r1()?;
|
||||
let dim1 = tensor.dims1()?;
|
||||
assert_eq!(dim1, 3);
|
||||
let content: Vec<f32> = tensor.to_vec1()?;
|
||||
assert_eq!(content, [3., 1., 4.]);
|
||||
@ -28,7 +27,7 @@ fn add_mul(device: &Device) -> Result<()> {
|
||||
fn tensor_2d(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let dims = tensor.shape().r2()?;
|
||||
let dims = tensor.dims2()?;
|
||||
assert_eq!(dims, (2, 5));
|
||||
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||
assert_eq!(content, data);
|
||||
@ -41,7 +40,7 @@ fn binary_op(device: &Device) -> Result<()> {
|
||||
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor2 = Tensor::new(data2, device)?;
|
||||
let tensor = (&tensor + (&tensor * &tensor)? / (&tensor + &tensor2))?;
|
||||
let dims = tensor.shape().r2()?;
|
||||
let dims = tensor.dims2()?;
|
||||
assert_eq!(dims, (2, 5));
|
||||
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||
assert_eq!(content[0], [4.125, 1.1666666, 5.7777777, 1.1666666, 7.5]);
|
||||
@ -56,7 +55,7 @@ fn binary_op(device: &Device) -> Result<()> {
|
||||
fn transpose(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor = Tensor::new(data, device)?.t()?;
|
||||
let dims = tensor.shape().r2()?;
|
||||
let dims = tensor.dims2()?;
|
||||
assert_eq!(dims, (5, 2));
|
||||
assert_eq!(
|
||||
tensor.to_vec2::<f32>()?,
|
||||
@ -68,42 +67,6 @@ fn transpose(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn softmax(device: &Device) -> Result<()> {
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let t0 = tensor.log()?.softmax(0)?;
|
||||
let t1 = tensor.log()?.softmax(1)?;
|
||||
let t2 = tensor.log()?.softmax(2)?;
|
||||
assert_eq!(
|
||||
to_vec3_round(t0, 4)?,
|
||||
&[
|
||||
// 3/5, 1/2, 4/11
|
||||
[[0.6, 0.5, 0.3636], [0.1111, 0.7143, 0.5294]],
|
||||
// 2/5, 1/2, 7/11
|
||||
[[0.4, 0.5, 0.6364], [0.8889, 0.2857, 0.4706]]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
to_vec3_round(t1, 4)?,
|
||||
&[
|
||||
// 3/4, 1/6, 4/13
|
||||
[[0.75, 0.1667, 0.3077], [0.25, 0.8333, 0.6923]],
|
||||
// 2/10, 1/3, 7/15
|
||||
[[0.2, 0.3333, 0.4667], [0.8, 0.6667, 0.5333]]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
to_vec3_round(t2, 4)?,
|
||||
&[
|
||||
// (3, 1, 4) / 8, (1, 5, 9) / 15
|
||||
[[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]],
|
||||
// (2, 1, 7) / 10, (8, 2, 8) / 18
|
||||
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sum(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
@ -201,6 +164,278 @@ fn sum(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn min(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(2)?.to_vec3::<u32>()?,
|
||||
&[[[1], [1]], [[1], [2]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[[2, 1, 4], [1, 2, 8]]],
|
||||
);
|
||||
let data: Vec<u32> = (200..4000u32).collect();
|
||||
let tensor = Tensor::new(data.as_slice(), device)?;
|
||||
assert_eq!(tensor.min_keepdim(0)?.to_vec1::<u32>()?, &[200]);
|
||||
let tensor = tensor.reshape((1900, 2))?;
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(0)?.min_keepdim(1)?.to_vec2::<u32>()?,
|
||||
&[[200]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(1)?.min_keepdim(0)?.to_vec2::<u32>()?,
|
||||
&[[200]]
|
||||
);
|
||||
assert_eq!(tensor.min_keepdim(0)?.to_vec2::<u32>()?, &[[200, 201]]);
|
||||
|
||||
// Make the tensor non contiguous.
|
||||
let tensor = tensor.t()?.contiguous()?.t()?;
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(0)?.min_keepdim(1)?.to_vec2::<u32>()?,
|
||||
&[[200]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(1)?.min_keepdim(0)?.to_vec2::<u32>()?,
|
||||
&[[200]]
|
||||
);
|
||||
assert_eq!(tensor.min_keepdim(0)?.to_vec2::<u32>()?, &[[200, 201]]);
|
||||
|
||||
let t1 = tensor.reshape((190, 5, 4))?;
|
||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||
for tensor in [t1, t2] {
|
||||
assert_eq!(
|
||||
tensor
|
||||
.min_keepdim(0)?
|
||||
.min_keepdim(2)?
|
||||
.min_keepdim(1)?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[200]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.min_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[
|
||||
[200, 201, 202, 203],
|
||||
[204, 205, 206, 207],
|
||||
[208, 209, 210, 211],
|
||||
[212, 213, 214, 215],
|
||||
[216, 217, 218, 219]
|
||||
]]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn max(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(2)?.to_vec3::<u32>()?,
|
||||
&[[[4], [9]], [[7], [8]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[[3, 1, 7], [8, 5, 9]]],
|
||||
);
|
||||
let data: Vec<u32> = (200..4000u32).collect();
|
||||
let tensor = Tensor::new(data.as_slice(), device)?;
|
||||
assert_eq!(tensor.max_keepdim(0)?.to_vec1::<u32>()?, &[3999]);
|
||||
let tensor = tensor.reshape((1900, 2))?;
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(0)?.max_keepdim(1)?.to_vec2::<u32>()?,
|
||||
&[[3999]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(1)?.max_keepdim(0)?.to_vec2::<u32>()?,
|
||||
&[[3999]]
|
||||
);
|
||||
assert_eq!(tensor.max_keepdim(0)?.to_vec2::<u32>()?, &[[3998, 3999]]);
|
||||
|
||||
// Make the tensor non contiguous.
|
||||
let tensor = tensor.t()?.contiguous()?.t()?;
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(0)?.max_keepdim(1)?.to_vec2::<u32>()?,
|
||||
&[[3999]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(1)?.max_keepdim(0)?.to_vec2::<u32>()?,
|
||||
&[[3999]]
|
||||
);
|
||||
assert_eq!(tensor.max_keepdim(0)?.to_vec2::<u32>()?, &[[3998, 3999]]);
|
||||
|
||||
let t1 = tensor.reshape((190, 5, 4))?;
|
||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||
for tensor in [t1, t2] {
|
||||
assert_eq!(
|
||||
tensor
|
||||
.max_keepdim(0)?
|
||||
.max_keepdim(2)?
|
||||
.max_keepdim(1)?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[3999]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.max_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[
|
||||
[3980, 3981, 3982, 3983],
|
||||
[3984, 3985, 3986, 3987],
|
||||
[3988, 3989, 3990, 3991],
|
||||
[3992, 3993, 3994, 3995],
|
||||
[3996, 3997, 3998, 3999]
|
||||
]]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn argmin(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.argmin_keepdim(2)?.to_vec3::<u32>()?,
|
||||
&[[[1], [0]], [[1], [1]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.argmin_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[[1, 0, 0], [0, 1, 1]]],
|
||||
);
|
||||
let data: Vec<u32> = (200..4000u32).collect();
|
||||
let tensor = Tensor::new(data.as_slice(), device)?;
|
||||
assert_eq!(tensor.argmin_keepdim(0)?.to_vec1::<u32>()?, &[0]);
|
||||
let tensor = tensor.reshape((1900, 2))?;
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmin_keepdim(0)?
|
||||
.argmin_keepdim(1)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmin_keepdim(1)?
|
||||
.argmin_keepdim(0)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(tensor.argmin_keepdim(0)?.to_vec2::<u32>()?, &[[0, 0]]);
|
||||
|
||||
// Make the tensor non contiguous.
|
||||
let tensor = tensor.t()?.contiguous()?.t()?;
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmin_keepdim(0)?
|
||||
.argmin_keepdim(1)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmin_keepdim(1)?
|
||||
.argmin_keepdim(0)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(tensor.argmin_keepdim(0)?.to_vec2::<u32>()?, &[[0, 0]]);
|
||||
|
||||
let t1 = tensor.reshape((190, 5, 4))?;
|
||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||
for tensor in [t1, t2] {
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmin_keepdim(0)?
|
||||
.argmin_keepdim(2)?
|
||||
.argmin_keepdim(1)?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[0]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.argmin_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
]]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn argmax(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.argmax_keepdim(2)?.to_vec3::<u32>()?,
|
||||
&[[[2], [2]], [[2], [0]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.argmax_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[[0, 0, 1], [1, 0, 0]]],
|
||||
);
|
||||
let data: Vec<u32> = (200..4000u32).collect();
|
||||
let tensor = Tensor::new(data.as_slice(), device)?;
|
||||
assert_eq!(tensor.argmax_keepdim(0)?.to_vec1::<u32>()?, &[3799]);
|
||||
let tensor = tensor.reshape((1900, 2))?;
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmax_keepdim(0)?
|
||||
.argmax_keepdim(1)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmax_keepdim(1)?
|
||||
.argmax_keepdim(0)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(tensor.argmax_keepdim(0)?.to_vec2::<u32>()?, &[[1899, 1899]]);
|
||||
|
||||
// Make the tensor non contiguous.
|
||||
let tensor = tensor.t()?.contiguous()?.t()?;
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmax_keepdim(0)?
|
||||
.argmax_keepdim(1)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmax_keepdim(1)?
|
||||
.argmax_keepdim(0)?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[0]]
|
||||
);
|
||||
assert_eq!(tensor.argmax_keepdim(0)?.to_vec2::<u32>()?, &[[1899, 1899]]);
|
||||
|
||||
let t1 = tensor.reshape((190, 5, 4))?;
|
||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||
for tensor in [t1, t2] {
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmax_keepdim(0)?
|
||||
.argmax_keepdim(2)?
|
||||
.argmax_keepdim(1)?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[0]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.argmax_keepdim(0)?.to_vec3::<u32>()?,
|
||||
&[[
|
||||
[189, 189, 189, 189],
|
||||
[189, 189, 189, 189],
|
||||
[189, 189, 189, 189],
|
||||
[189, 189, 189, 189],
|
||||
[189, 189, 189, 189],
|
||||
]]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn narrow(device: &Device) -> Result<()> {
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
@ -270,7 +505,11 @@ fn cat(device: &Device) -> Result<()> {
|
||||
[2.0, 7.0, 1.0, 8.0, 2.0]
|
||||
]
|
||||
);
|
||||
// TODO: This is not the expected answer, to be fixed!
|
||||
// PyTorch equivalent:
|
||||
// import torch
|
||||
// t1 = torch.tensor([[3, 1, 4, 1, 5], [2, 7, 1, 8, 2]])
|
||||
// t2 = torch.tensor([[5]*5, [2, 7, 1, 8, 2]])
|
||||
// torch.cat([t1.t(), t2.t()], dim=1).t()
|
||||
assert_eq!(
|
||||
Tensor::cat(&[&t1.t()?, &t2.t()?], 1)?
|
||||
.t()?
|
||||
@ -282,7 +521,6 @@ fn cat(device: &Device) -> Result<()> {
|
||||
[2.0, 7.0, 1.0, 8.0, 2.0]
|
||||
]
|
||||
);
|
||||
// TODO: This is not the expected answer, to be fixed!
|
||||
assert_eq!(
|
||||
Tensor::cat(&[&t1, &t2], 1)?.to_vec2::<f32>()?,
|
||||
[
|
||||
@ -296,8 +534,167 @@ fn cat(device: &Device) -> Result<()> {
|
||||
fn embeddings(device: &Device) -> Result<()> {
|
||||
let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?;
|
||||
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
|
||||
let hs = Tensor::embedding(&ids, &t)?;
|
||||
let hs = t.embedding(&ids)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cmp(device: &Device) -> Result<()> {
|
||||
let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
|
||||
let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?;
|
||||
assert_eq!(t1.eq(&t2)?.to_vec2::<u8>()?, &[[0, 0], [0, 1], [1, 0]]);
|
||||
assert_eq!(t1.ne(&t2)?.to_vec2::<u8>()?, &[[1, 1], [1, 0], [0, 1]]);
|
||||
assert_eq!(t1.le(&t2)?.to_vec2::<u8>()?, &[[1, 0], [1, 1], [1, 1]]);
|
||||
assert_eq!(t1.lt(&t2)?.to_vec2::<u8>()?, &[[1, 0], [1, 0], [0, 1]]);
|
||||
assert_eq!(t1.gt(&t2)?.to_vec2::<u8>()?, &[[0, 1], [0, 0], [0, 0]]);
|
||||
assert_eq!(t1.ge(&t2)?.to_vec2::<u8>()?, &[[0, 1], [0, 1], [1, 0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn index_select(device: &Device) -> Result<()> {
|
||||
let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?;
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 2.0, 1.0],
|
||||
[3.0, 5.0, 4.0],
|
||||
[6.0, 8.0, 7.0],
|
||||
[9.0, 11.0, 10.0]
|
||||
]
|
||||
);
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn index_add(device: &Device) -> Result<()> {
|
||||
let ids = Tensor::new(&[0u32, 1u32, 1u32], device)?;
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
let init = Tensor::ones((4, 2), DType::F32, device)?;
|
||||
let hs = init.index_add(&ids, &t, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[1.0, 4.0], [4.0, 10.0], [7.0, 16.0], [10.0, 22.0]],
|
||||
);
|
||||
let init = Tensor::zeros((4, 2), DType::F32, device)?;
|
||||
let ids = Tensor::new(&[1u32, 0u32, 0u32], device)?;
|
||||
let hs = init.index_add(&ids, &t, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[3.0, 0.0], [9.0, 3.0], [15.0, 6.0], [21.0, 9.0]],
|
||||
);
|
||||
|
||||
let init = Tensor::zeros((6, 3), DType::F32, device)?;
|
||||
let ids = Tensor::new(&[5u32, 0u32, 1u32, 0u32], device)?;
|
||||
let hs = init.index_add(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[12.0, 14.0, 16.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 2.0]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scatter_add(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
let ids = Tensor::new(&[[0u32, 1, 2], [3, 4, 0], [3, 3, 1], [2, 0, 4]], device)?;
|
||||
let init = Tensor::ones((4, 5), DType::F32, device)?;
|
||||
let hs = init.scatter_add(&ids, &t, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[1.0, 2.0, 3.0, 1.0, 1.0],
|
||||
[6.0, 1.0, 1.0, 4.0, 5.0],
|
||||
[1.0, 9.0, 1.0, 14.0, 1.0],
|
||||
[11.0, 1.0, 10.0, 1.0, 12.0]
|
||||
]
|
||||
);
|
||||
|
||||
let init = Tensor::ones((6, 3), DType::F32, device)?;
|
||||
let hs = init.scatter_add(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[1.0, 11.0, 6.0],
|
||||
[1.0, 2.0, 9.0],
|
||||
[10.0, 1.0, 3.0],
|
||||
[10.0, 8.0, 1.0],
|
||||
[1.0, 5.0, 12.0],
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn gather(device: &Device) -> Result<()> {
|
||||
let ids = Tensor::new(&[[0u32], [2u32], [1u32], [0u32]], device)?;
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
let hs = t.gather(&ids, 1)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0], [5.0], [7.0], [9.0]]);
|
||||
let ids = Tensor::new(
|
||||
&[[0u32, 0u32], [2u32, 0u32], [1u32, 1u32], [0u32, 2u32]],
|
||||
device,
|
||||
)?;
|
||||
let hs = t.gather(&ids, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 0.0], [5.0, 3.0], [7.0, 7.0], [9.0, 11.0]]
|
||||
);
|
||||
let ids = Tensor::new(&[[0u32, 2u32, 0u32]], device)?;
|
||||
let hs = t.gather(&ids, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0]]);
|
||||
let ids = Tensor::new(&[[0u32, 2u32, 0u32], [0u32, 1u32, 1u32]], device)?;
|
||||
let hs = t.gather(&ids, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -458,9 +855,17 @@ test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
|
||||
test_device!(cat, cat_cpu, cat_gpu);
|
||||
test_device!(sum, sum_cpu, sum_gpu);
|
||||
test_device!(min, min_cpu, min_gpu);
|
||||
test_device!(max, max_cpu, max_gpu);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(softmax, softmax_cpu, softmax_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu);
|
||||
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
||||
test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||
test_device!(gather, gather_cpu, gather_gpu);
|
||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||
|
@ -1,6 +1,6 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use candle::{Result, Tensor};
|
||||
use candle_core::{Result, Tensor};
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! test_device {
|
||||
@ -20,6 +20,23 @@ macro_rules! test_device {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec1::<f32>()?;
|
||||
let t = t.iter().map(|t| f32::round(t * b) / b).collect();
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec2::<f32>()?;
|
||||
let t = t
|
||||
.iter()
|
||||
.map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
|
||||
.collect();
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec3::<f32>()?;
|
||||
|
@ -1,28 +1,33 @@
|
||||
[package]
|
||||
name = "candle-examples"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Examples for the candle ML framework."
|
||||
repository = "https://github.com/LaurentMazare/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT/Apache-2.0"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core" }
|
||||
candle-nn = { path = "../candle-nn" }
|
||||
candle-transformers = { path = "../candle-transformers" }
|
||||
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.1.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.0", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
candle-hub = { path = "../candle-hub" }
|
||||
byteorder = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
hf-hub = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
tracing = { workspace = true }
|
||||
@ -30,7 +35,16 @@ tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
required-features = ["cuda", "nccl", "flash-attn"]
|
||||
|
1
candle-examples/README.md
Normal file
1
candle-examples/README.md
Normal file
@ -0,0 +1 @@
|
||||
# candle-examples
|
238
candle-examples/build.rs
Normal file
238
candle-examples/build.rs
Normal file
@ -0,0 +1,238 @@
|
||||
#![allow(unused)]
|
||||
use anyhow::{Context, Result};
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
struct KernelDirectories {
|
||||
kernel_dir: &'static str,
|
||||
rust_target: &'static str,
|
||||
include_dirs: &'static [&'static str],
|
||||
}
|
||||
|
||||
const DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
||||
kernel_dir: "examples/custom-ops/kernels/",
|
||||
rust_target: "examples/custom-ops/cuda_kernels.rs",
|
||||
include_dirs: &[],
|
||||
}];
|
||||
|
||||
impl KernelDirectories {
|
||||
fn maybe_build_ptx(
|
||||
&self,
|
||||
cu_file: &std::path::Path,
|
||||
ptx_file: &std::path::Path,
|
||||
compute_cap: usize,
|
||||
) -> Result<()> {
|
||||
let should_compile = if ptx_file.exists() {
|
||||
let ptx_modified = ptx_file.metadata()?.modified()?;
|
||||
let cu_modified = cu_file.metadata()?.modified()?;
|
||||
cu_modified.duration_since(ptx_modified).is_ok()
|
||||
} else {
|
||||
true
|
||||
};
|
||||
if should_compile {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
|
||||
let include_dirs: Vec<String> =
|
||||
self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
|
||||
command
|
||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("--ptx")
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.args(["--output-directory", out_dir.to_str().unwrap()])
|
||||
.arg(format!("-I/{}", self.kernel_dir))
|
||||
.args(include_dirs)
|
||||
.arg(cu_file);
|
||||
let output = command
|
||||
.spawn()
|
||||
.context("failed spawning nvcc")?
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.open(ptx_file)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> {
|
||||
println!("cargo:rerun-if-changed={}", self.kernel_dir);
|
||||
let kernel_dir = PathBuf::from(self.kernel_dir);
|
||||
let out_dir = out_dir.join(self.kernel_dir);
|
||||
if !out_dir.exists() {
|
||||
std::fs::create_dir_all(&out_dir)?;
|
||||
}
|
||||
let mut cu_files = vec![];
|
||||
let mut cuh_files = vec![];
|
||||
for file in std::fs::read_dir(kernel_dir)?.flatten() {
|
||||
let file = file.path();
|
||||
match file.extension().and_then(|v| v.to_str()) {
|
||||
Some("cu") => cu_files.push(file),
|
||||
Some("cuh") => cuh_files.push(file),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let mut ptx_paths = vec![];
|
||||
for cu_file in cu_files.iter() {
|
||||
let file_stem = cu_file
|
||||
.file_stem()
|
||||
.with_context(|| format!("no stem {cu_file:?}"))?;
|
||||
let file_stem = file_stem.to_string_lossy().into_owned();
|
||||
let ptx_file = out_dir.join(&format!("{file_stem}.ptx"));
|
||||
self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?;
|
||||
ptx_paths.push(ptx_file);
|
||||
}
|
||||
|
||||
let regenerate_rs_file = true;
|
||||
if regenerate_rs_file {
|
||||
let mut file = std::fs::File::create(self.rust_target)?;
|
||||
for ptx_path in ptx_paths {
|
||||
let name = ptx_path
|
||||
.file_stem()
|
||||
.context("empty stem")?
|
||||
.to_string_lossy();
|
||||
file.write_all(b"#[rustfmt::skip]\n")?;
|
||||
let const_definition = format!(
|
||||
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
|
||||
name.to_uppercase().replace('.', "_"),
|
||||
self.kernel_dir,
|
||||
);
|
||||
file.write_all(const_definition.as_bytes())?;
|
||||
file.write_all(b"\n")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
|
||||
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
|
||||
let out_dir = PathBuf::from(out_dir);
|
||||
#[cfg(feature = "cuda")]
|
||||
set_cuda_include_dir()?;
|
||||
#[cfg(feature = "cuda")]
|
||||
let compute_cap = compute_cap()?;
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
let compute_cap = 0;
|
||||
for d in DIRS {
|
||||
d.process(&out_dir, compute_cap)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_cuda_include_dir() -> Result<()> {
|
||||
// NOTE: copied from cudarc build.rs.
|
||||
let env_vars = [
|
||||
"CUDA_PATH",
|
||||
"CUDA_ROOT",
|
||||
"CUDA_TOOLKIT_ROOT_DIR",
|
||||
"CUDNN_LIB",
|
||||
];
|
||||
let env_vars = env_vars
|
||||
.into_iter()
|
||||
.map(std::env::var)
|
||||
.filter_map(Result::ok)
|
||||
.map(Into::<PathBuf>::into);
|
||||
|
||||
let roots = [
|
||||
"/usr",
|
||||
"/usr/local/cuda",
|
||||
"/opt/cuda",
|
||||
"/usr/lib/cuda",
|
||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||
"C:/CUDA",
|
||||
];
|
||||
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||
let root = env_vars
|
||||
.chain(roots)
|
||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||
.context("cannot find include/cuda.h")?;
|
||||
println!(
|
||||
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||
root.join("include").display()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn compute_cap() -> Result<usize> {
|
||||
// Grab compute code from nvidia-smi
|
||||
let mut compute_cap = {
|
||||
let out = std::process::Command::new("nvidia-smi")
|
||||
.arg("--query-gpu=compute_cap")
|
||||
.arg("--format=csv")
|
||||
.output()
|
||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||
let mut lines = out.lines();
|
||||
assert_eq!(
|
||||
lines.next().context("missing line in stdout")?,
|
||||
"compute_cap"
|
||||
);
|
||||
let cap = lines
|
||||
.next()
|
||||
.context("missing line in stdout")?
|
||||
.replace('.', "");
|
||||
cap.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as int {cap}"))?
|
||||
};
|
||||
|
||||
// Grab available GPU codes from nvcc and select the highest one
|
||||
let max_nvcc_code = {
|
||||
let out = std::process::Command::new("nvcc")
|
||||
.arg("--list-gpu-code")
|
||||
.output()
|
||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||
|
||||
let out = out.lines().collect::<Vec<&str>>();
|
||||
let mut codes = Vec::with_capacity(out.len());
|
||||
for code in out {
|
||||
let code = code.split('_').collect::<Vec<&str>>();
|
||||
if !code.is_empty() && code.contains(&"sm") {
|
||||
if let Ok(num) = code[1].parse::<usize>() {
|
||||
codes.push(num);
|
||||
}
|
||||
}
|
||||
}
|
||||
codes.sort();
|
||||
if !codes.contains(&compute_cap) {
|
||||
anyhow::bail!(
|
||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
|
||||
);
|
||||
}
|
||||
*codes.last().unwrap()
|
||||
};
|
||||
|
||||
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
|
||||
// then choose the highest gpu code in nvcc
|
||||
if compute_cap > max_nvcc_code {
|
||||
println!(
|
||||
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
|
||||
);
|
||||
compute_cap = max_nvcc_code;
|
||||
}
|
||||
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
|
||||
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
compute_cap = compute_cap_str
|
||||
.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
|
||||
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
|
||||
}
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
||||
Ok(compute_cap)
|
||||
}
|
@ -4,9 +4,9 @@ mod model;
|
||||
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use candle::Tensor;
|
||||
use candle_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use model::{BertModel, Config, DTYPE};
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
@ -69,10 +69,11 @@ impl Args {
|
||||
)
|
||||
} else {
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
(
|
||||
api.get(&repo, "config.json")?,
|
||||
api.get(&repo, "tokenizer.json")?,
|
||||
api.get(&repo, "model.safetensors")?,
|
||||
api.get("config.json")?,
|
||||
api.get("tokenizer.json")?,
|
||||
api.get("model.safetensors")?,
|
||||
)
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
@ -161,7 +162,7 @@ fn main() -> Result<()> {
|
||||
let embeddings = model.forward(&token_ids, &token_type_ids)?;
|
||||
println!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?;
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
println!("pooled embeddings {:?}", embeddings.shape());
|
||||
let mut similarities = vec![];
|
||||
|
@ -87,7 +87,7 @@ impl LayerNorm {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
@ -262,7 +262,7 @@ impl BertEmbeddings {
|
||||
|
||||
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_bsize, seq_len) = input_ids.shape().r2()?;
|
||||
let (_bsize, seq_len) = input_ids.dims2()?;
|
||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
||||
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
|
||||
@ -333,7 +333,7 @@ impl BertSelfAttention {
|
||||
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
||||
let attention_probs = {
|
||||
let _enter_sm = self.span_softmax.enter();
|
||||
attention_scores.softmax(candle::D::Minus1)?
|
||||
candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
|
||||
};
|
||||
let attention_probs = self.dropout.forward(&attention_probs)?;
|
||||
|
||||
|
156
candle-examples/examples/bigcode/main.rs
Normal file
156
candle-examples/examples/bigcode/main.rs
Normal file
@ -0,0 +1,156 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
mod model;
|
||||
use model::{Config, GPTBigCode};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: GPTBigCode,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
fn new(
|
||||
model: GPTBigCode,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut new_tokens = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let (context_size, past_len) = if self.model.config().use_cache && index > 0 {
|
||||
(1, tokens.len().saturating_sub(1))
|
||||
} else {
|
||||
(tokens.len(), 0)
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, past_len)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
let token = self
|
||||
.tokenizer
|
||||
.decode(vec![next_token], true)
|
||||
.map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"{sample_len} tokens generated ({:.3} token/s)",
|
||||
sample_len as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "bigcode/starcoderbase-1b")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let filenames = match args.weight_file {
|
||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||
None => ["model.safetensors"]
|
||||
.iter()
|
||||
.map(|f| repo.get(f))
|
||||
.collect::<std::result::Result<Vec<_>, _>>()?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let weights = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|f| Ok(f.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
|
||||
let config = Config::starcoder_1b();
|
||||
let model = GPTBigCode::load(vb, config)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
359
candle-examples/examples/bigcode/model.rs
Normal file
359
candle-examples/examples/bigcode/model.rs
Normal file
@ -0,0 +1,359 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
|
||||
|
||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(size2, "bias")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Linear::new(weight, bias))
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let weight = vb.get(size, "weight")?;
|
||||
let bias = vb.get(size, "bias")?;
|
||||
Ok(LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j <= i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), device)?;
|
||||
Ok(mask)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
// max_position_embeddings aka n_positions
|
||||
pub max_position_embeddings: usize,
|
||||
// num_hidden_layers aka n_layer
|
||||
pub num_hidden_layers: usize,
|
||||
// hidden_size aka n_embd
|
||||
pub hidden_size: usize,
|
||||
pub layer_norm_epsilon: f64,
|
||||
pub n_inner: Option<usize>,
|
||||
// num_attention_heads aka n_head
|
||||
pub num_attention_heads: usize,
|
||||
pub multi_query: bool,
|
||||
pub use_cache: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
#[allow(dead_code)]
|
||||
pub fn starcoder_1b() -> Self {
|
||||
Self {
|
||||
vocab_size: 49152,
|
||||
max_position_embeddings: 8192,
|
||||
num_hidden_layers: 24,
|
||||
hidden_size: 2048,
|
||||
layer_norm_epsilon: 1e-5,
|
||||
n_inner: Some(8192),
|
||||
num_attention_heads: 16,
|
||||
multi_query: true,
|
||||
use_cache: true,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn starcoder_3b() -> Self {
|
||||
Self {
|
||||
vocab_size: 49152,
|
||||
max_position_embeddings: 8192,
|
||||
num_hidden_layers: 36,
|
||||
hidden_size: 2816,
|
||||
layer_norm_epsilon: 1e-5,
|
||||
n_inner: Some(11264),
|
||||
num_attention_heads: 22,
|
||||
multi_query: true,
|
||||
use_cache: true,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn starcoder_7b() -> Self {
|
||||
Self {
|
||||
vocab_size: 49152,
|
||||
max_position_embeddings: 8192,
|
||||
num_hidden_layers: 42,
|
||||
hidden_size: 4096,
|
||||
layer_norm_epsilon: 1e-5,
|
||||
n_inner: Some(16384),
|
||||
num_attention_heads: 32,
|
||||
multi_query: true,
|
||||
use_cache: true,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn starcoder() -> Self {
|
||||
Self {
|
||||
vocab_size: 49152,
|
||||
max_position_embeddings: 8192,
|
||||
num_hidden_layers: 40,
|
||||
hidden_size: 6144,
|
||||
layer_norm_epsilon: 1e-5,
|
||||
n_inner: Some(24576),
|
||||
num_attention_heads: 48,
|
||||
multi_query: true,
|
||||
use_cache: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Attention {
|
||||
c_attn: Linear,
|
||||
c_proj: Linear,
|
||||
kv_cache: Option<Tensor>,
|
||||
use_cache: bool,
|
||||
embed_dim: usize,
|
||||
kv_dim: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
multi_query: bool,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let head_dim = hidden_size / cfg.num_attention_heads;
|
||||
let kv_heads = if cfg.multi_query {
|
||||
1
|
||||
} else {
|
||||
cfg.num_attention_heads
|
||||
};
|
||||
let kv_dim = kv_heads * head_dim;
|
||||
let c_attn = linear(hidden_size, hidden_size + 2 * kv_dim, true, vb.pp("c_attn"))?;
|
||||
let c_proj = linear(hidden_size, hidden_size, true, vb.pp("c_proj"))?;
|
||||
Ok(Self {
|
||||
c_proj,
|
||||
c_attn,
|
||||
embed_dim: hidden_size,
|
||||
kv_cache: None,
|
||||
use_cache: cfg.use_cache,
|
||||
kv_dim,
|
||||
head_dim,
|
||||
num_heads: cfg.num_attention_heads,
|
||||
multi_query: cfg.multi_query,
|
||||
})
|
||||
}
|
||||
|
||||
fn attn(
|
||||
&self,
|
||||
query: &Tensor,
|
||||
key: &Tensor,
|
||||
value: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
if query.dtype() != DType::F32 {
|
||||
// If we start supporting f16 models, we may need the upcasting scaling bits.
|
||||
// https://github.com/huggingface/transformers/blob/a0042379269bea9182c1f87e6b2eee4ba4c8cce8/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L133
|
||||
candle::bail!("upcasting is not supported {:?}", query.dtype())
|
||||
}
|
||||
let scale_factor = 1f64 / (self.head_dim as f64).sqrt();
|
||||
let initial_query_shape = query.shape();
|
||||
let key_len = key.dim(D::Minus1)?;
|
||||
let (query, key, attn_shape, attn_view) = if self.multi_query {
|
||||
let (b_sz, query_len, _) = query.dims3()?;
|
||||
let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?;
|
||||
let attn_shape = (b_sz, query_len, self.num_heads, key_len);
|
||||
let attn_view = (b_sz, query_len * self.num_heads, key_len);
|
||||
(query, key.clone(), attn_shape, attn_view)
|
||||
} else {
|
||||
let (b_sz, _num_heads, query_len, _head_dim) = query.dims4()?;
|
||||
let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?;
|
||||
let key = key.reshape((b_sz * self.num_heads, self.head_dim, key_len))?;
|
||||
let attn_shape = (b_sz, self.num_heads, query_len, key_len);
|
||||
let attn_view = (b_sz * self.num_heads, query_len, key_len);
|
||||
(query, key, attn_shape, attn_view)
|
||||
};
|
||||
|
||||
let attn_weights =
|
||||
(query.matmul(&key.contiguous()?)? * scale_factor)?.reshape(attn_shape)?;
|
||||
let attention_mask = attention_mask.broadcast_as(attn_shape)?;
|
||||
let mask_value =
|
||||
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
|
||||
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
let value = value.contiguous()?;
|
||||
let attn_output = if self.multi_query {
|
||||
attn_weights
|
||||
.reshape(attn_view)?
|
||||
.matmul(&value)?
|
||||
.reshape(initial_query_shape)?
|
||||
} else {
|
||||
attn_weights.matmul(&value)?
|
||||
};
|
||||
Ok(attn_output)
|
||||
}
|
||||
|
||||
fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let qkv = self.c_attn.forward(hidden_states)?;
|
||||
let (query, key_value) = if self.multi_query {
|
||||
let query = qkv.i((.., .., ..self.embed_dim))?;
|
||||
let key_value = qkv.i((.., .., self.embed_dim..self.embed_dim + 2 * self.kv_dim))?;
|
||||
(query, key_value)
|
||||
} else {
|
||||
let mut dims = qkv.dims().to_vec();
|
||||
dims.pop();
|
||||
dims.push(self.embed_dim);
|
||||
dims.push(self.head_dim * 3);
|
||||
let qkv = qkv.reshape(dims)?.transpose(1, 2)?;
|
||||
let query = qkv.i((.., .., .., ..self.head_dim))?;
|
||||
let key_value = qkv.i((.., .., .., self.head_dim..3 * self.head_dim))?;
|
||||
(query, key_value)
|
||||
};
|
||||
let mut key_value = key_value;
|
||||
if self.use_cache {
|
||||
if let Some(kv_cache) = &self.kv_cache {
|
||||
// TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for
|
||||
// arbitrarily large sizes.
|
||||
key_value = Tensor::cat(&[kv_cache, &key_value], D::Minus2)?.contiguous()?;
|
||||
}
|
||||
self.kv_cache = Some(key_value.clone())
|
||||
}
|
||||
|
||||
let key = key_value.narrow(D::Minus1, 0, self.head_dim)?;
|
||||
let value = key_value.narrow(D::Minus1, self.head_dim, self.head_dim)?;
|
||||
let attn_output = self.attn(&query, &key.t()?, &value, attention_mask)?;
|
||||
let attn_output = if self.multi_query {
|
||||
attn_output
|
||||
} else {
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape(hidden_states.shape())?
|
||||
};
|
||||
let attn_output = self.c_proj.forward(&attn_output)?;
|
||||
Ok(attn_output)
|
||||
}
|
||||
}
|
||||
|
||||
struct Mlp {
|
||||
c_fc: Linear,
|
||||
c_proj: Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn load(inner_dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let c_fc = linear(cfg.hidden_size, inner_dim, true, vb.pp("c_fc"))?;
|
||||
let c_proj = linear(inner_dim, cfg.hidden_size, true, vb.pp("c_proj"))?;
|
||||
Ok(Self { c_fc, c_proj })
|
||||
}
|
||||
|
||||
fn forward(&mut self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.c_fc.forward(hidden_states)?.gelu()?;
|
||||
let hidden_states = self.c_proj.forward(&hidden_states)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add cross-attention?
|
||||
struct Block {
|
||||
ln_1: LayerNorm,
|
||||
attn: Attention,
|
||||
ln_2: LayerNorm,
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let inner_dim = cfg.n_inner.unwrap_or(4 * hidden_size);
|
||||
let ln_1 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_1"))?;
|
||||
let attn = Attention::load(vb.pp("attn"), cfg)?;
|
||||
let ln_2 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_2"))?;
|
||||
let mlp = Mlp::load(inner_dim, vb.pp("mlp"), cfg)?;
|
||||
Ok(Self {
|
||||
ln_1,
|
||||
attn,
|
||||
ln_2,
|
||||
mlp,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let residual = hidden_states;
|
||||
let hidden_states = self.ln_1.forward(hidden_states)?;
|
||||
let attn_outputs = self.attn.forward(&hidden_states, attention_mask)?;
|
||||
let hidden_states = (&attn_outputs + residual)?;
|
||||
let residual = &hidden_states;
|
||||
let hidden_states = self.ln_2.forward(&hidden_states)?;
|
||||
let hidden_states = self.mlp.forward(&hidden_states)?;
|
||||
let hidden_states = (&hidden_states + residual)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GPTBigCode {
|
||||
wte: Embedding,
|
||||
wpe: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
ln_f: LayerNorm,
|
||||
lm_head: Linear,
|
||||
bias: Tensor,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl GPTBigCode {
|
||||
pub fn config(&self) -> &Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let vb_t = vb.pp("transformer");
|
||||
let wte = embedding(cfg.vocab_size, hidden_size, vb_t.pp("wte"))?;
|
||||
let wpe = embedding(cfg.max_position_embeddings, hidden_size, vb_t.pp("wpe"))?;
|
||||
let blocks = (0..cfg.num_hidden_layers)
|
||||
.map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?;
|
||||
let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?;
|
||||
let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?;
|
||||
Ok(Self {
|
||||
wte,
|
||||
wpe,
|
||||
blocks,
|
||||
lm_head,
|
||||
ln_f,
|
||||
bias,
|
||||
config: cfg,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, past_len: usize) -> Result<Tensor> {
|
||||
let dev = input_ids.device();
|
||||
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||
|
||||
let key_len = past_len + seq_len;
|
||||
let attention_mask = self.bias.i((past_len..key_len, ..key_len))?.unsqueeze(0)?;
|
||||
// MQA models: (batch_size, query_length, n_heads, key_length)
|
||||
// MHA models: (batch_size, n_heads, query_length, key_length)
|
||||
let seq_len_dim = if self.config.multi_query { 2 } else { 1 };
|
||||
let attention_mask = attention_mask.unsqueeze(seq_len_dim)?;
|
||||
|
||||
let position_ids = Tensor::arange(past_len as u32, (past_len + seq_len) as u32, dev)?;
|
||||
let position_ids = position_ids.unsqueeze(0)?.broadcast_as((b_sz, seq_len))?;
|
||||
let input_embeds = self.wte.forward(input_ids)?;
|
||||
let position_embeds = self.wpe.forward(&position_ids)?;
|
||||
|
||||
let mut hidden_states = (&input_embeds + &position_embeds)?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
hidden_states = block.forward(&hidden_states, &attention_mask)?;
|
||||
}
|
||||
let hidden_states = self.ln_f.forward(&hidden_states)?;
|
||||
let hidden_states = hidden_states
|
||||
.reshape((b_sz, seq_len, self.config.hidden_size))?
|
||||
.narrow(1, seq_len - 1, 1)?;
|
||||
let logits = self.lm_head.forward(&hidden_states)?.squeeze(1)?;
|
||||
Ok(logits)
|
||||
}
|
||||
}
|
2
candle-examples/examples/custom-ops/cuda_kernels.rs
Normal file
2
candle-examples/examples/custom-ops/cuda_kernels.rs
Normal file
@ -0,0 +1,2 @@
|
||||
#[rustfmt::skip]
|
||||
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));
|
@ -0,0 +1,35 @@
|
||||
#include <stdint.h>
|
||||
#include "reduction_utils.cuh"
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ void
|
||||
rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size]
|
||||
const scalar_t *__restrict__ input, // [num_tokens, hidden_size]
|
||||
const float epsilon, const uint32_t num_tokens,
|
||||
const uint32_t hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
variance += x * x;
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance));
|
||||
}
|
||||
}
|
||||
extern "C" __global__ void rms_f32(
|
||||
float *__restrict__ out, // [num_tokens, hidden_size]
|
||||
const float *__restrict__ input, // [num_tokens, hidden_size]
|
||||
const float epsilon, const uint32_t num_tokens,
|
||||
const uint32_t hidden_size) {
|
||||
rms_norm_kernel(out, input, epsilon, num_tokens, hidden_size);
|
||||
}
|
||||
|
@ -0,0 +1,46 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
template <typename T> __inline__ __device__ T warpReduceSum(T val) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(0xffffffff, val, mask, 32);
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the sum of all elements in a block */
|
||||
template <typename T> __inline__ __device__ T blockReduceSum(T val) {
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
val = warpReduceSum<T>(val);
|
||||
|
||||
if (lane == 0)
|
||||
shared[wid] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||
// blockDim.x is not divided by 32
|
||||
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
|
||||
val = warpReduceSum<T>(val);
|
||||
return val;
|
||||
}
|
95
candle-examples/examples/custom-ops/main.rs
Normal file
95
candle-examples/examples/custom-ops/main.rs
Normal file
@ -0,0 +1,95 @@
|
||||
// This example illustrates how to implement custom operations. These operations can provide their
|
||||
// own forward pass (CPU and GPU versions) as well as their backward pass.
|
||||
//
|
||||
// In this example we add the RMS normalization operation and implement it for f32.
|
||||
#![allow(dead_code)]
|
||||
#![allow(unused)]
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
mod cuda_kernels;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cpu_backend;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, Layout, Result, Shape, Tensor};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
}
|
||||
|
||||
struct LayerNorm {
|
||||
eps: f32,
|
||||
}
|
||||
|
||||
impl CustomOp1 for LayerNorm {
|
||||
fn name(&self) -> &'static str {
|
||||
"layer-norm"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
let (dim1, dim2) = layout.shape().dims2()?;
|
||||
let slice = storage.as_slice::<f32>()?;
|
||||
let src = match layout.contiguous_offsets() {
|
||||
None => candle::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => &slice[o1..o2],
|
||||
};
|
||||
let mut dst = Vec::with_capacity(dim1 * dim2);
|
||||
for idx1 in 0..dim1 {
|
||||
let src = &src[idx1 * dim2..(idx1 + 1) * dim2];
|
||||
let variance = src.iter().map(|x| x * x).sum::<f32>();
|
||||
let s_variance = 1f32 / (variance / dim2 as f32 + self.eps).sqrt();
|
||||
dst.extend(src.iter().map(|x| x * s_variance))
|
||||
}
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
storage: &candle::CudaStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::{cudarc, WrapErr};
|
||||
use cudarc::driver::{LaunchAsync, LaunchConfig};
|
||||
let (d1, d2) = layout.shape().dims2()?;
|
||||
let d1 = d1 as u32;
|
||||
let d2 = d2 as u32;
|
||||
let dev = storage.device().clone();
|
||||
let slice = storage.as_cuda_slice::<f32>()?;
|
||||
let slice = match layout.contiguous_offsets() {
|
||||
None => candle::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => slice.slice(o1..o2),
|
||||
};
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||
let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
|
||||
let params = (&dst, &slice, self.eps, d1, d2);
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (d1, 1, 1),
|
||||
block_dim: (d2, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?;
|
||||
println!("{t}");
|
||||
let t = t.custom_op1(LayerNorm { eps: 1e-5 })?;
|
||||
println!("{t}");
|
||||
Ok(())
|
||||
}
|
@ -5,10 +5,10 @@ extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_hub::{api::sync::Api, Repo, RepoType};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod model;
|
||||
@ -123,14 +123,18 @@ fn main() -> Result<()> {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::with_revision(args.model_id, RepoType::Model, args.revision);
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = api.get(&repo, rfilename)?;
|
||||
let filename = repo.get(rfilename)?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
|
@ -182,7 +182,7 @@ impl FalconRotaryEmbedding {
|
||||
key: &Tensor,
|
||||
past_kv_len: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_batch, seq_len, _head_dim) = query.shape().r3()?;
|
||||
let (_batch, seq_len, _head_dim) = query.dims3()?;
|
||||
let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
|
||||
let cos = cos.narrow(0, past_kv_len, seq_len)?;
|
||||
let sin = sin.narrow(0, past_kv_len, seq_len)?;
|
||||
@ -245,7 +245,7 @@ impl FalconAttention {
|
||||
}
|
||||
|
||||
fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||
let (b_sz, seq_len, _) = fused_qkv.shape().r3()?;
|
||||
let (b_sz, seq_len, _) = fused_qkv.dims3()?;
|
||||
if !self.multi_query {
|
||||
let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?;
|
||||
let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?;
|
||||
@ -267,7 +267,7 @@ impl FalconAttention {
|
||||
let fused_qkv = self.query_key_value.forward(x)?;
|
||||
let head_dim = self.head_dim;
|
||||
let (query, key, value) = self.split_heads(&fused_qkv)?;
|
||||
let (b_sz, seq_len, _, _) = query.shape().r4()?;
|
||||
let (b_sz, seq_len, _, _) = query.dims4()?;
|
||||
let query = query
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
|
||||
@ -309,11 +309,13 @@ impl FalconAttention {
|
||||
|
||||
// Only handle the case where alibi is None here, and non-flash attention.
|
||||
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
|
||||
let attention_scores = attention_scores
|
||||
.broadcast_add(&mask.squeeze(1)?)?
|
||||
.to_dtype(DType::F32)?
|
||||
.softmax(D::Minus1)?
|
||||
.to_dtype(x.dtype())?;
|
||||
let attention_scores = candle_nn::ops::softmax(
|
||||
&attention_scores
|
||||
.broadcast_add(&mask.squeeze(1)?)?
|
||||
.to_dtype(DType::F32)?,
|
||||
D::Minus1,
|
||||
)?
|
||||
.to_dtype(x.dtype())?;
|
||||
let attn_output = attention_scores
|
||||
.matmul(&value)?
|
||||
.reshape((b_sz, self.num_heads, seq_len, head_dim))?
|
||||
@ -422,7 +424,7 @@ pub struct Falcon {
|
||||
|
||||
fn make_causal_mask(t: usize) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
|
||||
Ok(mask)
|
||||
@ -465,7 +467,7 @@ impl Falcon {
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = input_ids.shape().r2()?;
|
||||
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||
let mut hidden_state = self.word_embeddings.forward(input_ids)?;
|
||||
let past_kv_len = match &self.blocks[0].self_attention.kv_cache {
|
||||
Some((k, _)) => k.dim(1)?,
|
||||
|
@ -15,10 +15,10 @@ extern crate intel_mkl_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor, D};
|
||||
use candle_hub::{api::sync::Api, Repo, RepoType};
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::api::sync::Api;
|
||||
|
||||
mod model;
|
||||
use model::{Config, Llama};
|
||||
@ -76,23 +76,6 @@ Whate'er it bodes, henceforward will I bear
|
||||
Upon my target three fair-shining suns.
|
||||
";
|
||||
|
||||
fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
|
||||
let n_elem = config.n_embd / config.n_head;
|
||||
let theta: Vec<_> = (0..n_elem)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let shape = [1, MAX_SEQ_LEN, n_elem / 2, 1];
|
||||
let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?;
|
||||
let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?;
|
||||
Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], D::Minus1)?)
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -127,17 +110,44 @@ struct Args {
|
||||
/// Use f32 computations rather than f16.
|
||||
#[arg(long)]
|
||||
use_f32: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
v1: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let config = Config::config_7b();
|
||||
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
|
||||
let config = if args.v1 {
|
||||
Config::config_7b_v1(args.use_flash_attn)
|
||||
} else {
|
||||
Config::config_7b_v2(args.use_flash_attn)
|
||||
};
|
||||
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
let (llama, tokenizer_filename) = match args.npy {
|
||||
Some(filename) => {
|
||||
let vb = VarBuilder::from_npz(filename, dtype, &device)?;
|
||||
@ -146,15 +156,22 @@ fn main() -> Result<()> {
|
||||
}
|
||||
None => {
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
||||
println!("loading the model weights");
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
|
||||
let model_id = args.model_id.unwrap_or_else(|| {
|
||||
if args.v1 {
|
||||
"Narsil/amall-7b".to_string()
|
||||
} else {
|
||||
"meta-llama/Llama-2-7b-hf".to_string()
|
||||
}
|
||||
});
|
||||
println!("loading the model weights from {model_id}");
|
||||
let api = api.model(model_id);
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = api.get(&repo, rfilename)?;
|
||||
let filename = api.get(rfilename)?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
|
||||
@ -180,8 +197,6 @@ fn main() -> Result<()> {
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
println!("pre-computing the positional embeddings");
|
||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
let mut new_tokens = vec![];
|
||||
@ -196,12 +211,7 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let freqs_cis = if cache.use_kv_cache {
|
||||
freqs_cis.narrow(1, index_pos, ctxt.len())?
|
||||
} else {
|
||||
freqs_cis.clone()
|
||||
};
|
||||
let logits = llama.forward(&input, &freqs_cis)?;
|
||||
let logits = llama.forward(&input, index_pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
index_pos += ctxt.len();
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, Linear, VarBuilder};
|
||||
use candle_nn::{Embedding, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
@ -12,10 +12,13 @@ pub struct Config {
|
||||
pub n_layer: usize,
|
||||
pub n_head: usize,
|
||||
pub n_embd: usize,
|
||||
pub n_key_value_head: usize,
|
||||
pub use_flash_attn: bool,
|
||||
pub rms_norm_eps: f64,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn config_7b() -> Self {
|
||||
pub fn config_7b_v1(use_flash_attn: bool) -> Self {
|
||||
Self {
|
||||
hidden_size: 4096,
|
||||
intermediate_size: 11008,
|
||||
@ -23,8 +26,40 @@ impl Config {
|
||||
n_layer: 32,
|
||||
n_head: 32,
|
||||
n_embd: 4096,
|
||||
n_key_value_head: 32,
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-6,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn config_7b_v2(use_flash_attn: bool) -> Self {
|
||||
Self {
|
||||
hidden_size: 4096,
|
||||
intermediate_size: 11008,
|
||||
vocab_size: 32000,
|
||||
n_layer: 32,
|
||||
n_head: 32,
|
||||
n_embd: 4096,
|
||||
n_key_value_head: 32,
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
||||
// model.
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -33,17 +68,37 @@ pub struct Cache {
|
||||
pub use_kv_cache: bool,
|
||||
#[allow(clippy::type_complexity)]
|
||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Self {
|
||||
Self {
|
||||
pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
|
||||
// precompute freqs_cis
|
||||
let n_elem = config.n_embd / config.n_head;
|
||||
let theta: Vec<_> = (0..n_elem)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
// This is different from the paper, see:
|
||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
||||
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||
Ok(Self {
|
||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||
use_kv_cache,
|
||||
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
||||
device: device.clone(),
|
||||
}
|
||||
cos,
|
||||
sin,
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&self, t: usize) -> Result<Tensor> {
|
||||
@ -51,9 +106,8 @@ impl Cache {
|
||||
if let Some(mask) = masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
// TODO: If we support bool or u8 tensors, this would be better.
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||
masks.insert(t, mask.clone());
|
||||
@ -67,8 +121,9 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
Ok(Linear::new(weight, None))
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
@ -78,27 +133,27 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
|
||||
struct RmsNorm {
|
||||
scale: Tensor,
|
||||
eps: f64,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn load(size: usize, vb: VarBuilder) -> Result<Self> {
|
||||
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let scale = vb.get(size, "weight")?;
|
||||
Ok(Self::new(scale))
|
||||
}
|
||||
|
||||
fn new(scale: Tensor) -> Self {
|
||||
Self { scale }
|
||||
Ok(Self { scale, eps, span })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let in_dtype = x.dtype();
|
||||
// This is a no-op if x's dtype is already f32.
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
|
||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||
let size = self.scale.shape().r1()?;
|
||||
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
|
||||
let size = self.scale.dims1()?;
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
@ -110,63 +165,69 @@ impl RmsNorm {
|
||||
}
|
||||
|
||||
struct CausalSelfAttention {
|
||||
c_attn: Linear,
|
||||
c_proj: Linear,
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
n_head: usize,
|
||||
n_key_value_head: usize,
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
use_flash_attn: bool,
|
||||
span: tracing::Span,
|
||||
span_rot: tracing::Span,
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self {
|
||||
Self {
|
||||
c_attn,
|
||||
c_proj,
|
||||
n_head,
|
||||
cache: cache.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||
let mut dims = x.dims().to_vec();
|
||||
let fcis_dims = freqs_cis.dims();
|
||||
let freqs_cis = if dims[2] < fcis_dims[1] {
|
||||
freqs_cis.narrow(1, 0, dims[2])?
|
||||
} else {
|
||||
freqs_cis.clone()
|
||||
};
|
||||
let v = dims.pop().unwrap();
|
||||
dims.push(v / 2);
|
||||
dims.push(2);
|
||||
let x = x.reshape(dims)?;
|
||||
let re_x = x.narrow(D::Minus1, 0, 1)?;
|
||||
let im_x = x.narrow(D::Minus1, 1, 1)?;
|
||||
let re_f = freqs_cis
|
||||
.narrow(D::Minus1, 0, 1)?
|
||||
.broadcast_as(re_x.shape())?;
|
||||
let im_f = freqs_cis
|
||||
.narrow(D::Minus1, 1, 1)?
|
||||
.broadcast_as(im_x.shape())?;
|
||||
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
|
||||
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
|
||||
let rope = Tensor::cat(&[&re, &im], D::Minus1)?;
|
||||
let rope = rope.flatten_from(D::Minus2)?;
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
|
||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
|
||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let (b_sz, seq_len, n_embd) = x.shape().r3()?;
|
||||
let qkv = self.c_attn.forward(x)?;
|
||||
let qkv = qkv.to_dtype(DType::F32)?;
|
||||
let q = qkv.narrow(D::Minus1, 0, n_embd)?;
|
||||
let k = qkv.narrow(D::Minus1, n_embd, n_embd)?;
|
||||
let v = qkv.narrow(D::Minus1, 2 * n_embd, n_embd)?;
|
||||
let target_dim = [b_sz, seq_len, self.n_head, n_embd / self.n_head];
|
||||
let k = k.reshape(target_dim.as_slice())?.transpose(1, 2)?;
|
||||
let q = q.reshape(target_dim.as_slice())?.transpose(1, 2)?;
|
||||
let mut v = v.reshape(target_dim.as_slice())?.transpose(1, 2)?;
|
||||
let q = self.apply_rotary_emb(&q, freqs_cis)?;
|
||||
let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let q = self.q_proj.forward(x)?;
|
||||
let k = self.k_proj.forward(x)?;
|
||||
let v = self.v_proj.forward(x)?;
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let mut v = v
|
||||
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
if self.cache.use_kv_cache {
|
||||
let mut cache = self.cache.kvs.lock().unwrap();
|
||||
@ -189,39 +250,70 @@ impl CausalSelfAttention {
|
||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
||||
}
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = att.softmax(D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let k = self.repeat_kv(k)?;
|
||||
let v = self.repeat_kv(v)?;
|
||||
|
||||
let y = if self.use_flash_attn {
|
||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||
let q = q.transpose(1, 2)?;
|
||||
let k = k.transpose(1, 2)?;
|
||||
let v = v.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?
|
||||
} else {
|
||||
let in_dtype = q.dtype();
|
||||
let q = q.to_dtype(DType::F32)?;
|
||||
let k = k.to_dtype(DType::F32)?;
|
||||
let v = v.to_dtype(DType::F32)?;
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
|
||||
};
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = y.to_dtype(x_dtype)?;
|
||||
let y = self.c_proj.forward(&y)?;
|
||||
let y = self.o_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.n_head / self.n_key_value_head;
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let size_in = cfg.hidden_size;
|
||||
let size = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
|
||||
let q_proj = vb.get((size_in, size), "q_proj.weight")?;
|
||||
let k_proj = vb.get((size_in, size), "k_proj.weight")?;
|
||||
let v_proj = vb.get((size_in, size), "v_proj.weight")?;
|
||||
// Invert the transformation from:
|
||||
// https://github.com/huggingface/transformers/blob/2642d8d04b14c18199ebe7b35f976da02df61752/src/transformers/models/llama/convert_llama_weights_to_hf.py#L101
|
||||
let n_head = cfg.n_head;
|
||||
let q_proj = q_proj
|
||||
.reshape((n_head, 2, size / n_head / 2, size_in))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((size_in, size))?;
|
||||
let k_proj = k_proj
|
||||
.reshape((n_head, 2, size / n_head / 2, size_in))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((size_in, size))?;
|
||||
let attn_weight = Tensor::cat(&[q_proj, k_proj, v_proj], 0)?;
|
||||
let c_attn = Linear::new(attn_weight, None);
|
||||
let o_proj = linear(size, size_in, vb.pp("o_proj"))?;
|
||||
Ok(Self::new(c_attn, o_proj, cfg.n_head, cache))
|
||||
let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
|
||||
let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
|
||||
let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
|
||||
let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
n_head: cfg.n_head,
|
||||
n_key_value_head: cfg.n_key_value_head,
|
||||
head_dim: cfg.hidden_size / cfg.n_head,
|
||||
cache: cache.clone(),
|
||||
use_flash_attn: cfg.use_flash_attn,
|
||||
span,
|
||||
span_rot,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -236,29 +328,29 @@ struct Mlp {
|
||||
c_fc1: Linear,
|
||||
c_fc2: Linear,
|
||||
c_proj: Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
|
||||
Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
self.c_proj.forward(&x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mlp");
|
||||
let h_size = cfg.hidden_size;
|
||||
let i_size = cfg.intermediate_size;
|
||||
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
|
||||
let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
|
||||
let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
|
||||
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
||||
Ok(Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -267,39 +359,37 @@ struct Block {
|
||||
attn: CausalSelfAttention,
|
||||
rms_2: RmsNorm,
|
||||
mlp: Mlp,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||
Self {
|
||||
rms_1,
|
||||
attn,
|
||||
rms_2,
|
||||
mlp,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
|
||||
let x = (self
|
||||
.attn
|
||||
.forward(&self.rms_1.forward(x)?, freqs_cis, block_idx)?
|
||||
+ x)?;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?;
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
let residual = &x;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "block");
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm =
|
||||
RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?;
|
||||
Ok(Self::new(
|
||||
input_layernorm,
|
||||
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let rms_2 = RmsNorm::load(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
rms_1,
|
||||
attn,
|
||||
post_attention_layernorm,
|
||||
rms_2,
|
||||
mlp,
|
||||
))
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -311,20 +401,11 @@ pub struct Llama {
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
|
||||
Self {
|
||||
wte,
|
||||
blocks,
|
||||
ln_f,
|
||||
lm_head,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.shape().r2()?;
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, freqs_cis, block_idx)?;
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let x = x.i((.., seq_len - 1, ..))?;
|
||||
@ -335,11 +416,16 @@ impl Llama {
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
|
||||
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layer)
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
|
||||
.collect();
|
||||
|
||||
Ok(Self::new(wte, blocks, norm, lm_head))
|
||||
Ok(Self {
|
||||
wte,
|
||||
blocks,
|
||||
ln_f,
|
||||
lm_head,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
289
candle-examples/examples/llama2-c/main.rs
Normal file
289
candle-examples/examples/llama2-c/main.rs
Normal file
@ -0,0 +1,289 @@
|
||||
// https://github.com/karpathy/llama2.c
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
mod model;
|
||||
mod training;
|
||||
mod weights;
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use candle::{IndexOp, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use model::{Config, Llama};
|
||||
use weights::TransformerWeights;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
struct InferenceCmd {
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
#[arg(long, default_value = "")]
|
||||
prompt: String,
|
||||
|
||||
/// Config file in binary or safetensors format.
|
||||
#[arg(long)]
|
||||
config: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "karpathy/tinyllamas")]
|
||||
model_id: String,
|
||||
|
||||
/// The model to be used when getting it from the hub. Possible
|
||||
/// values are 'stories15M.bin', 'stories42M.bin', see more at:
|
||||
/// https://huggingface.co/karpathy/tinyllamas/tree/main
|
||||
#[arg(long, default_value = "stories15M.bin")]
|
||||
which_model: String,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
struct EvaluationCmd {
|
||||
/// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
|
||||
/// script from llama2.c https://github.com/karpathy/llama2.c
|
||||
#[arg(long)]
|
||||
pretokenized_dir: Option<String>,
|
||||
|
||||
#[arg(long, default_value_t = 32)]
|
||||
batch_size: usize,
|
||||
|
||||
/// Config file in binary format.
|
||||
#[arg(long)]
|
||||
config: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "karpathy/tinyllamas")]
|
||||
model_id: String,
|
||||
|
||||
/// The model to be used when getting it from the hub. Possible
|
||||
/// values are 'stories15M.bin', 'stories42M.bin', see more at:
|
||||
/// https://huggingface.co/karpathy/tinyllamas/tree/main
|
||||
#[arg(long, default_value = "stories15M.bin")]
|
||||
which_model: String,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
pub struct TrainingCmd {
|
||||
/// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
|
||||
/// script from llama2.c https://github.com/karpathy/llama2.c
|
||||
#[arg(long)]
|
||||
pretokenized_dir: String,
|
||||
|
||||
#[arg(long, default_value_t = 32)]
|
||||
batch_size: usize,
|
||||
|
||||
#[arg(long, default_value_t = 0.001)]
|
||||
learning_rate: f64,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
enum Task {
|
||||
Inference(InferenceCmd),
|
||||
Eval(EvaluationCmd),
|
||||
Train(TrainingCmd),
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
/// The task to be performed, inference, training or evaluation.
|
||||
#[command(subcommand)]
|
||||
task: Option<Task>,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Tokenizer config file.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn tokenizer(&self) -> Result<Tokenizer> {
|
||||
let tokenizer_path = match &self.tokenizer {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(tokenizer_path).map_err(E::msg)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
match &args.task {
|
||||
None => {
|
||||
let cmd = InferenceCmd {
|
||||
temperature: None,
|
||||
prompt: "".to_string(),
|
||||
config: None,
|
||||
model_id: "karpathy/tinyllamas".to_string(),
|
||||
which_model: "stories15M.bin".to_string(),
|
||||
};
|
||||
run_inference(&cmd, &args)?
|
||||
}
|
||||
Some(Task::Inference(cmd)) => run_inference(cmd, &args)?,
|
||||
Some(Task::Eval(cmd)) => run_eval(cmd, &args)?,
|
||||
Some(Task::Train(cmd)) => training::run(cmd, &args)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
use std::io::BufRead;
|
||||
|
||||
let config_path = match &args.config {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
println!("loading the model weights from {}", args.model_id);
|
||||
let api = api.model(args.model_id.clone());
|
||||
api.get(&args.which_model)?
|
||||
}
|
||||
};
|
||||
|
||||
let tokenizer = common_args.tokenizer()?;
|
||||
|
||||
let device = candle_examples::device(common_args.cpu)?;
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
let tokens = match &args.pretokenized_dir {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let model_id = "roneneldan/TinyStories"; // TODO: Make this configurable.
|
||||
println!("loading the evaluation dataset from {}", model_id);
|
||||
let api = api.dataset(model_id.to_string());
|
||||
let dataset_path = api.get("TinyStories-valid.txt")?;
|
||||
let file = std::fs::File::open(dataset_path)?;
|
||||
let file = std::io::BufReader::new(file);
|
||||
let mut tokens = vec![];
|
||||
for line in file.lines() {
|
||||
let line = line?.replace("<|endoftext|>", "<s>");
|
||||
let line = tokenizer.encode(line, false).map_err(E::msg)?;
|
||||
tokens.push(line.get_ids().to_vec())
|
||||
}
|
||||
tokens.concat()
|
||||
}
|
||||
Some(pretokenized_dir) => {
|
||||
// Use shard 0 for the test split, similar to llama2.c
|
||||
// https://github.com/karpathy/llama2.c/blob/ce05cc28cf1e3560b873bb21837638a434520a67/tinystories.py#L121
|
||||
let path = std::path::PathBuf::from(pretokenized_dir).join("data00.bin");
|
||||
let bytes = std::fs::read(path)?;
|
||||
// Tokens are encoded as u16.
|
||||
let mut tokens = vec![0u16; bytes.len() / 2];
|
||||
std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens)?;
|
||||
tokens.into_iter().map(|u| u as u32).collect::<Vec<u32>>()
|
||||
}
|
||||
};
|
||||
println!("dataset loaded and encoded: {} tokens", tokens.len());
|
||||
|
||||
let seq_len = model.config.seq_len;
|
||||
let iter = (0..tokens.len()).step_by(seq_len).flat_map(|start_idx| {
|
||||
if start_idx + seq_len + 1 > tokens.len() {
|
||||
None
|
||||
} else {
|
||||
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
|
||||
let inputs = Tensor::new(&tokens[..seq_len], &device);
|
||||
let targets = Tensor::new(&tokens[1..], &device);
|
||||
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
|
||||
}
|
||||
});
|
||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
for inp_tgt in batch_iter {
|
||||
let (inp, tgt) = inp_tgt?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
println!("{}", loss.to_vec0::<f32>()?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let config_path = match &args.config {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
println!("loading the model weights from {}", args.model_id);
|
||||
let api = api.model(args.model_id.clone());
|
||||
api.get(&args.which_model)?
|
||||
}
|
||||
};
|
||||
|
||||
let tokenizer = common_args.tokenizer()?;
|
||||
|
||||
let device = candle_examples::device(common_args.cpu)?;
|
||||
|
||||
let is_safetensors = config_path
|
||||
.extension()
|
||||
.map_or(false, |v| v == "safetensors");
|
||||
let (vb, config) = if is_safetensors {
|
||||
let config = Config::tiny();
|
||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||
(vb, config)
|
||||
} else {
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
println!("{config:?}");
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
(vb, config)
|
||||
};
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
|
||||
let mut index_pos = 0;
|
||||
|
||||
print!("{}", args.prompt);
|
||||
let mut tokens = tokenizer
|
||||
.encode(args.prompt.clone(), true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0.. {
|
||||
if tokens.len() >= model.config.seq_len {
|
||||
break;
|
||||
}
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, index_pos)?;
|
||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||
// heuristics as it seems to work well enough for this example. See the following for more
|
||||
// details:
|
||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{} tokens generated ({:.2} token/s)\n",
|
||||
tokens.len(),
|
||||
tokens.len() as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
340
candle-examples/examples/llama2-c/model.rs
Normal file
340
candle-examples/examples/llama2-c/model.rs
Normal file
@ -0,0 +1,340 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::linear_no_bias as linear;
|
||||
use candle_nn::{embedding, Embedding, Linear, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub dim: usize, // transformer dimension
|
||||
pub hidden_dim: usize, // for ffn layers
|
||||
pub n_layers: usize, // number of layers
|
||||
pub n_heads: usize, // number of query heads
|
||||
pub n_kv_heads: usize, // number of key/value heads (can be < query heads because of multiquery)
|
||||
pub vocab_size: usize, // vocabulary size, usually 256 (byte-level)
|
||||
pub seq_len: usize, // max sequence length
|
||||
pub norm_eps: f64,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn tiny() -> Self {
|
||||
Self {
|
||||
dim: 288,
|
||||
hidden_dim: 768,
|
||||
n_layers: 6,
|
||||
n_heads: 6,
|
||||
n_kv_heads: 6,
|
||||
vocab_size: 32000,
|
||||
seq_len: 256,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
pub use_kv_cache: bool,
|
||||
#[allow(clippy::type_complexity)]
|
||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
pub fn new(use_kv_cache: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let n_elem = cfg.dim / cfg.n_heads;
|
||||
let theta: Vec<_> = (0..n_elem)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), vb.device())?;
|
||||
let idx_theta = Tensor::arange(0, cfg.seq_len as u32, vb.device())?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((cfg.seq_len, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let precomputed_cos = idx_theta.cos()?;
|
||||
let precomputed_sin = idx_theta.sin()?;
|
||||
|
||||
let freq_cis_real = vb
|
||||
.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real")
|
||||
.unwrap_or(precomputed_cos);
|
||||
let freq_cis_imag = vb
|
||||
.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag")
|
||||
.unwrap_or(precomputed_sin);
|
||||
let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
|
||||
let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
|
||||
Ok(Self {
|
||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||
use_kv_cache,
|
||||
kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])),
|
||||
cos,
|
||||
sin,
|
||||
device: vb.device().clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&self, t: usize) -> Result<Tensor> {
|
||||
let mut masks = self.masks.lock().unwrap();
|
||||
if let Some(mask) = masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||
masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
}
|
||||
|
||||
struct RmsNorm {
|
||||
scale: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let scale = vb.get_or_init(size, "weight", candle_nn::Init::Const(1.))?;
|
||||
Ok(Self { scale, eps })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
|
||||
let size = self.scale.dims1()?;
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_as((b_sz, seq_len, size))?;
|
||||
let x = (scale * x_normed)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
struct CausalSelfAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
n_head: usize,
|
||||
n_key_value_head: usize,
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
|
||||
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
|
||||
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
|
||||
let cos = cos.unsqueeze(1)?;
|
||||
let sin = sin.unsqueeze(1)?;
|
||||
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
||||
let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
||||
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
|
||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let q = self.q_proj.forward(x)?;
|
||||
let k = self.k_proj.forward(x)?;
|
||||
let v = self.v_proj.forward(x)?;
|
||||
|
||||
let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?;
|
||||
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
if self.cache.use_kv_cache {
|
||||
let mut cache = self.cache.kvs.lock().unwrap();
|
||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
||||
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
|
||||
}
|
||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
||||
}
|
||||
|
||||
let k = self.repeat_kv(k)?;
|
||||
let v = self.repeat_kv(v)?;
|
||||
|
||||
let q = q.transpose(1, 2)?.contiguous()?;
|
||||
let k = k.transpose(1, 2)?.contiguous()?;
|
||||
let v = v.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = self.o_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.n_head / self.n_key_value_head;
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?;
|
||||
let x = x
|
||||
.unsqueeze(3)?
|
||||
.expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))?
|
||||
.reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let size_in = cfg.dim;
|
||||
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
|
||||
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
|
||||
let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
|
||||
let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
n_head: cfg.n_heads,
|
||||
n_key_value_head: cfg.n_kv_heads,
|
||||
head_dim: cfg.dim / cfg.n_heads,
|
||||
cache: cache.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
struct Mlp {
|
||||
c_fc1: Linear,
|
||||
c_fc2: Linear,
|
||||
c_proj: Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
|
||||
Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
self.c_proj.forward(&x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h_size = cfg.dim;
|
||||
let i_size = cfg.hidden_dim;
|
||||
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
|
||||
let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
|
||||
let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
|
||||
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
||||
}
|
||||
}
|
||||
|
||||
struct Block {
|
||||
rms_1: RmsNorm,
|
||||
attn: CausalSelfAttention,
|
||||
rms_2: RmsNorm,
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||
Self {
|
||||
rms_1,
|
||||
attn,
|
||||
rms_2,
|
||||
mlp,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
let residual = &x;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||
let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm =
|
||||
RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
|
||||
Ok(Self::new(
|
||||
input_layernorm,
|
||||
attn,
|
||||
post_attention_layernorm,
|
||||
mlp,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Llama {
|
||||
wte: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
ln_f: RmsNorm,
|
||||
lm_head: Linear,
|
||||
pub config: Config,
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, _seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
logits.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
|
||||
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layers)
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
|
||||
.collect();
|
||||
Ok(Self {
|
||||
wte,
|
||||
blocks,
|
||||
ln_f,
|
||||
lm_head,
|
||||
config: cfg,
|
||||
})
|
||||
}
|
||||
}
|
175
candle-examples/examples/llama2-c/training.rs
Normal file
175
candle-examples/examples/llama2-c/training.rs
Normal file
@ -0,0 +1,175 @@
|
||||
#![allow(dead_code)]
|
||||
#![allow(unused)]
|
||||
use crate::model::{Cache, Config, Llama};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
|
||||
pub struct Dataset {
|
||||
valid_tokens: Vec<memmap2::Mmap>,
|
||||
train_tokens: Vec<memmap2::Mmap>,
|
||||
}
|
||||
|
||||
fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
|
||||
let file = std::fs::File::open(p)?;
|
||||
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
|
||||
Ok(mmap)
|
||||
}
|
||||
|
||||
impl Dataset {
|
||||
pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
|
||||
let dir = dir.as_ref();
|
||||
let mut bin_files = vec![];
|
||||
for file in std::fs::read_dir(dir)?.flatten() {
|
||||
let file = file.path();
|
||||
if let Some(extension) = file.extension() {
|
||||
if extension == "bin" {
|
||||
bin_files.push(file)
|
||||
}
|
||||
}
|
||||
}
|
||||
if bin_files.len() < 2 {
|
||||
candle::bail!("found less than two bin files in {:?}", dir)
|
||||
}
|
||||
bin_files.sort();
|
||||
let valid_tokens = mmap_file(&bin_files[0])?;
|
||||
let train_tokens = bin_files[1..]
|
||||
.iter()
|
||||
.map(mmap_file)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
valid_tokens: vec![valid_tokens],
|
||||
train_tokens,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct DatasetRandomIter<'a> {
|
||||
all_tokens: &'a [memmap2::Mmap],
|
||||
tokens: Vec<&'a memmap2::Mmap>,
|
||||
current_tokens: &'a memmap2::Mmap,
|
||||
indexes_in_bytes: Vec<usize>,
|
||||
seq_len: usize,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl<'a> DatasetRandomIter<'a> {
|
||||
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
let all_tokens = if valid {
|
||||
&ds.valid_tokens
|
||||
} else {
|
||||
&ds.train_tokens
|
||||
};
|
||||
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
|
||||
tokens.shuffle(&mut thread_rng());
|
||||
let current_tokens = tokens.pop().unwrap();
|
||||
let seq_len_in_bytes = seq_len * 2;
|
||||
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
|
||||
.step_by(seq_len_in_bytes)
|
||||
.collect::<Vec<_>>();
|
||||
indexes_in_bytes.shuffle(&mut thread_rng());
|
||||
Self {
|
||||
all_tokens,
|
||||
tokens,
|
||||
current_tokens,
|
||||
indexes_in_bytes,
|
||||
seq_len,
|
||||
device,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for DatasetRandomIter<'a> {
|
||||
type Item = Result<(Tensor, Tensor)>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
let seq_len = self.seq_len;
|
||||
if self.indexes_in_bytes.is_empty() {
|
||||
if self.tokens.is_empty() {
|
||||
self.tokens = self.all_tokens.iter().collect();
|
||||
self.tokens.shuffle(&mut thread_rng());
|
||||
}
|
||||
self.current_tokens = self.tokens.pop().unwrap();
|
||||
let seq_len_in_bytes = self.seq_len * 2;
|
||||
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
|
||||
.step_by(seq_len_in_bytes)
|
||||
.collect::<Vec<_>>();
|
||||
self.indexes_in_bytes.shuffle(&mut thread_rng());
|
||||
}
|
||||
let start_idx = self.indexes_in_bytes.pop().unwrap();
|
||||
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
|
||||
let mut tokens = vec![0u16; bytes.len() / 2];
|
||||
if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
|
||||
return Some(Err(err.into()));
|
||||
}
|
||||
let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
|
||||
let inputs = Tensor::new(&tokens[..seq_len], &self.device);
|
||||
let targets = Tensor::new(&tokens[1..], &self.device);
|
||||
Some(candle::error::zip(inputs, targets))
|
||||
}
|
||||
}
|
||||
|
||||
fn valid_loss(
|
||||
dataset: &Dataset,
|
||||
model: &Llama,
|
||||
args: &crate::TrainingCmd,
|
||||
device: &Device,
|
||||
) -> Result<f64> {
|
||||
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
|
||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
let mut sum_ce = 0f64;
|
||||
let mut cnt = 0usize;
|
||||
for inp_tgt in batch_iter.take(50) {
|
||||
let (inp, tgt) = inp_tgt?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
sum_ce += loss.to_vec0::<f32>()? as f64;
|
||||
cnt += 1;
|
||||
}
|
||||
Ok(sum_ce / cnt as f64)
|
||||
}
|
||||
|
||||
pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
let device = candle_examples::device(common_args.cpu)?;
|
||||
let dataset = Dataset::new(&args.pretokenized_dir)?;
|
||||
println!(
|
||||
"loaded dataset, train: {} files, valid: {} files",
|
||||
dataset.train_tokens.len(),
|
||||
dataset.valid_tokens.len()
|
||||
);
|
||||
let varmap = candle_nn::VarMap::new();
|
||||
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
||||
let config = Config::tiny();
|
||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
|
||||
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
let params = candle_nn::ParamsAdamW {
|
||||
lr: args.learning_rate,
|
||||
..Default::default()
|
||||
};
|
||||
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
|
||||
for (batch_index, batch) in batch_iter.enumerate() {
|
||||
let (inp, tgt) = batch?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
opt.backward_step(&loss)?;
|
||||
|
||||
if batch_index > 0 && batch_index % 100 == 0 {
|
||||
// TODO: Add a way to deactivate the backprop graph tracking when computing the
|
||||
// validation loss.
|
||||
let loss = valid_loss(&dataset, &model, args, &device)?;
|
||||
println!("{batch_index} {loss}");
|
||||
}
|
||||
if batch_index > 0 && batch_index % 1000 == 0 {
|
||||
varmap.save("checkpoint.safetensors")?
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
161
candle-examples/examples/llama2-c/weights.rs
Normal file
161
candle-examples/examples/llama2-c/weights.rs
Normal file
@ -0,0 +1,161 @@
|
||||
use anyhow::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use candle::{DType, Device, IndexOp, Shape, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
use crate::model::Config;
|
||||
|
||||
pub struct TransformerWeights {
|
||||
// token embedding table
|
||||
token_embedding_table: Tensor, // (vocab_size, dim)
|
||||
// weights for rmsnorms
|
||||
rms_att_weight: Tensor, // (layer, dim) rmsnorm weights
|
||||
rms_ffn_weight: Tensor, // (layer, dim)
|
||||
// weights for matmuls
|
||||
wq: Tensor, // (layer, dim, dim)
|
||||
wk: Tensor, // (layer, dim, dim)
|
||||
wv: Tensor, // (layer, dim, dim)
|
||||
wo: Tensor, // (layer, dim, dim)
|
||||
// weights for ffn
|
||||
w1: Tensor, // (layer, hidden_dim, dim)
|
||||
w2: Tensor, // (layer, dim, hidden_dim)
|
||||
w3: Tensor, // (layer, hidden_dim, dim)
|
||||
// final rmsnorm
|
||||
rms_final_weight: Tensor, // (dim,)
|
||||
// freq_cis for RoPE relatively positional embeddings
|
||||
freq_cis_real: Tensor, // (seq_len, head_size/2)
|
||||
freq_cis_imag: Tensor, // (seq_len, head_size/2)
|
||||
}
|
||||
|
||||
fn read_i32<R: std::io::Read>(r: &mut R) -> Result<i32> {
|
||||
let mut buf = [0u8; 4];
|
||||
r.read_exact(&mut buf)?;
|
||||
Ok(i32::from_le_bytes(buf))
|
||||
}
|
||||
|
||||
fn read_tensor<R: std::io::Read, S: Into<Shape>>(
|
||||
r: &mut R,
|
||||
shape: S,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let shape = shape.into();
|
||||
let mut data_t = vec![0f32; shape.elem_count()];
|
||||
r.read_f32_into::<LittleEndian>(&mut data_t)?;
|
||||
let tensor = Tensor::from_vec(data_t, shape, dev)?;
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn from_reader<R: std::io::Read>(r: &mut R) -> Result<Self> {
|
||||
let dim = read_i32(r)? as usize;
|
||||
let hidden_dim = read_i32(r)? as usize;
|
||||
let n_layers = read_i32(r)? as usize;
|
||||
let n_heads = read_i32(r)? as usize;
|
||||
let n_kv_heads = read_i32(r)? as usize;
|
||||
let vocab_size = read_i32(r)? as usize;
|
||||
let seq_len = read_i32(r)? as usize;
|
||||
Ok(Self {
|
||||
dim,
|
||||
hidden_dim,
|
||||
n_layers,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
vocab_size,
|
||||
seq_len,
|
||||
norm_eps: 1e-5,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn head_size(&self) -> usize {
|
||||
self.dim / self.n_heads
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerWeights {
|
||||
pub fn from_reader<R: std::io::Read>(r: &mut R, c: &Config, dev: &Device) -> Result<Self> {
|
||||
let token_embedding_table = read_tensor(r, (c.vocab_size, c.dim), dev)?;
|
||||
let rms_att_weight = read_tensor(r, (c.n_layers, c.dim), dev)?;
|
||||
let wq = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
|
||||
let wk = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
|
||||
let wv = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
|
||||
let wo = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
|
||||
let rms_ffn_weight = read_tensor(r, (c.n_layers, c.dim), dev)?;
|
||||
let w1 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?;
|
||||
let w2 = read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?;
|
||||
let w3 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?;
|
||||
let rms_final_weight = read_tensor(r, c.dim, dev)?;
|
||||
let head_size = c.head_size();
|
||||
let freq_cis_real = read_tensor(r, (c.seq_len, head_size / 2), dev)?;
|
||||
let freq_cis_imag = read_tensor(r, (c.seq_len, head_size / 2), dev)?;
|
||||
Ok(Self {
|
||||
token_embedding_table,
|
||||
rms_att_weight,
|
||||
wq,
|
||||
wk,
|
||||
wv,
|
||||
wo,
|
||||
rms_ffn_weight,
|
||||
w1,
|
||||
w2,
|
||||
w3,
|
||||
rms_final_weight,
|
||||
freq_cis_real,
|
||||
freq_cis_imag,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'static>> {
|
||||
let mut ws = std::collections::HashMap::new();
|
||||
let mut insert = |name: &str, t: Tensor| {
|
||||
ws.insert(name.to_string(), t);
|
||||
};
|
||||
insert("rot.freq_cis_real", self.freq_cis_real.clone());
|
||||
insert("rot.freq_cis_imag", self.freq_cis_imag.clone());
|
||||
insert(
|
||||
"model.embed_tokens.weight",
|
||||
self.token_embedding_table.clone(),
|
||||
);
|
||||
insert("lm_head.weight", self.token_embedding_table.clone());
|
||||
insert("model.norm.weight", self.rms_final_weight.clone());
|
||||
for layer in 0..cfg.n_layers {
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.self_attn.q_proj.weight"),
|
||||
self.wq.i(layer)?,
|
||||
);
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.self_attn.k_proj.weight"),
|
||||
self.wk.i(layer)?,
|
||||
);
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.self_attn.v_proj.weight"),
|
||||
self.wv.i(layer)?,
|
||||
);
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.self_attn.o_proj.weight"),
|
||||
self.wo.i(layer)?,
|
||||
);
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.mlp.gate_proj.weight"),
|
||||
self.w1.i(layer)?,
|
||||
);
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.mlp.down_proj.weight"),
|
||||
self.w2.i(layer)?,
|
||||
);
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.mlp.up_proj.weight"),
|
||||
self.w3.i(layer)?,
|
||||
);
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.input_layernorm.weight"),
|
||||
self.rms_att_weight.i(layer)?,
|
||||
);
|
||||
ws.insert(
|
||||
format!("model.layers.{layer}.post_attention_layernorm.weight"),
|
||||
self.rms_ffn_weight.i(layer)?,
|
||||
);
|
||||
}
|
||||
let vb = VarBuilder::from_tensors(ws, DType::F32, device);
|
||||
Ok(vb)
|
||||
}
|
||||
}
|
251
candle-examples/examples/llama_multiprocess/main.rs
Normal file
251
candle-examples/examples/llama_multiprocess/main.rs
Normal file
@ -0,0 +1,251 @@
|
||||
// An implementation of LLaMA https://github.com/facebookresearch/llama
|
||||
//
|
||||
// This is based on nanoGPT in a similar way to:
|
||||
// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py
|
||||
//
|
||||
// The tokenizer config can be retrieved from:
|
||||
// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json
|
||||
//
|
||||
// In order to convert the llama weights to a .npz file, run:
|
||||
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use cudarc::driver::safe::CudaDevice;
|
||||
use cudarc::nccl::safe::{Comm, Id};
|
||||
use hf_hub::api::sync::Api;
|
||||
use std::io::Write;
|
||||
use std::rc::Rc;
|
||||
|
||||
mod model;
|
||||
use model::{Config, Llama};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const DEFAULT_PROMPT: &str = r"
|
||||
EDWARD:
|
||||
I wonder how our princely father 'scaped,
|
||||
Or whether he be 'scaped away or no
|
||||
From Clifford's and Northumberland's pursuit:
|
||||
Had he been ta'en, we should have heard the news;
|
||||
Had he been slain, we should have heard the news;
|
||||
Or had he 'scaped, methinks we should have heard
|
||||
The happy tidings of his good escape.
|
||||
How fares my brother? why is he so sad?
|
||||
|
||||
RICHARD:
|
||||
I cannot joy, until I be resolved
|
||||
Where our right valiant father is become.
|
||||
I saw him in the battle range about;
|
||||
And watch'd him how he singled Clifford forth.
|
||||
Methought he bore him in the thickest troop
|
||||
As doth a lion in a herd of neat;
|
||||
Or as a bear, encompass'd round with dogs,
|
||||
Who having pinch'd a few and made them cry,
|
||||
The rest stand all aloof, and bark at him.
|
||||
So fared our father with his enemies;
|
||||
So fled his enemies my warlike father:
|
||||
Methinks, 'tis prize enough to be his son.
|
||||
See how the morning opes her golden gates,
|
||||
And takes her farewell of the glorious sun!
|
||||
How well resembles it the prime of youth,
|
||||
Trimm'd like a younker prancing to his love!
|
||||
|
||||
EDWARD:
|
||||
Dazzle mine eyes, or do I see three suns?
|
||||
|
||||
RICHARD:
|
||||
Three glorious suns, each one a perfect sun;
|
||||
Not separated with the racking clouds,
|
||||
But sever'd in a pale clear-shining sky.
|
||||
See, see! they join, embrace, and seem to kiss,
|
||||
As if they vow'd some league inviolable:
|
||||
Now are they but one lamp, one light, one sun.
|
||||
In this the heaven figures some event.
|
||||
|
||||
EDWARD:
|
||||
'Tis wondrous strange, the like yet never heard of.
|
||||
I think it cites us, brother, to the field,
|
||||
That we, the sons of brave Plantagenet,
|
||||
Each one already blazing by our meeds,
|
||||
Should notwithstanding join our lights together
|
||||
And over-shine the earth as this the world.
|
||||
Whate'er it bodes, henceforward will I bear
|
||||
Upon my target three fair-shining suns.
|
||||
";
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
num_shards: usize,
|
||||
|
||||
#[arg(long)]
|
||||
rank: Option<usize>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
/// Disable the key-value cache.
|
||||
#[arg(long)]
|
||||
no_kv_cache: bool,
|
||||
|
||||
/// The initial prompt.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let config = Config::config_7b();
|
||||
let dtype = DType::F16;
|
||||
|
||||
let api = Api::new()?;
|
||||
|
||||
let model_id = args
|
||||
.model_id
|
||||
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
|
||||
println!("loading the model weights from {model_id}");
|
||||
let api = api.model(model_id);
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = api.get(rfilename)?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
|
||||
if args.rank.is_none() {
|
||||
let children: Vec<_> = (0..args.num_shards)
|
||||
.map(|rank| {
|
||||
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
|
||||
args.push_back("--rank".to_string());
|
||||
args.push_back(format!("{rank}"));
|
||||
let name = args.pop_front().unwrap();
|
||||
std::process::Command::new(name).args(args).spawn().unwrap()
|
||||
})
|
||||
.collect();
|
||||
for mut child in children {
|
||||
child.wait().unwrap();
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let i = args.rank.unwrap();
|
||||
let num_shards = args.num_shards;
|
||||
let rank = i;
|
||||
// Primitive IPC
|
||||
let id = if rank == 0 {
|
||||
let id = Id::new().unwrap();
|
||||
std::fs::File::create("nccl_id.txt.tmp")?
|
||||
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())
|
||||
.unwrap();
|
||||
std::fs::rename("nccl_id.txt.tmp", "nccl_id.txt")?;
|
||||
id
|
||||
} else {
|
||||
let path = std::path::PathBuf::from("nccl_id.txt");
|
||||
while !path.exists() {
|
||||
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||
}
|
||||
let data = std::fs::read("nccl_id.txt")?;
|
||||
let internal: [i8; 128] = data
|
||||
.into_iter()
|
||||
.map(|i| i as i8)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let id: Id = Id::uninit(internal);
|
||||
id
|
||||
};
|
||||
let device = CudaDevice::new(i)?;
|
||||
let comm = Rc::new(Comm::from_rank(device, i, num_shards, id).unwrap());
|
||||
if rank == 0 {
|
||||
std::fs::remove_file("nccl_id.txt")?;
|
||||
}
|
||||
println!("Rank {rank:?} spawned");
|
||||
|
||||
let device = Device::new_cuda(i)?;
|
||||
let cache = model::Cache::new(&config, &device)?;
|
||||
|
||||
println!("building the model");
|
||||
let handles = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let tensors: Vec<_> = handles
|
||||
.iter()
|
||||
.map(|h| Ok(h.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
|
||||
let llama = Llama::load(vb, &cache, &config, comm)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
let mut new_tokens = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
for index in 0..args.sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = llama.forward(&input, index_pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
if rank == 0 {
|
||||
println!("> {:?}", start_gen.elapsed());
|
||||
println!(
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if rank == 0 {
|
||||
println!(
|
||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
args.sample_len,
|
||||
args.sample_len as f64 / dt.as_secs_f64(),
|
||||
tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
459
candle-examples/examples/llama_multiprocess/model.rs
Normal file
459
candle-examples/examples/llama_multiprocess/model.rs
Normal file
@ -0,0 +1,459 @@
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
||||
use candle_nn::{Embedding, Linear, VarBuilder};
|
||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||
use half::f16;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::MAX_SEQ_LEN;
|
||||
|
||||
struct TensorParallelColumnLinear {
|
||||
linear: Linear,
|
||||
}
|
||||
|
||||
impl TensorParallelColumnLinear {
|
||||
fn new(linear: Linear) -> Self {
|
||||
Self { linear }
|
||||
}
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.linear.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
struct TensorParallelRowLinear {
|
||||
linear: Linear,
|
||||
comm: Rc<Comm>,
|
||||
}
|
||||
|
||||
struct AllReduce {
|
||||
comm: Rc<Comm>,
|
||||
}
|
||||
|
||||
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
||||
/// But for this example purposes, this will work
|
||||
unsafe impl Sync for AllReduce {}
|
||||
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
||||
/// But for this example purposes, this will work
|
||||
unsafe impl Send for AllReduce {}
|
||||
|
||||
impl CustomOp1 for AllReduce {
|
||||
fn name(&self) -> &'static str {
|
||||
"allreduce"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
todo!("implement allreduce for cpu is not necessary for single node");
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
s: &candle::CudaStorage,
|
||||
l: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::WrapErr;
|
||||
let elem_count = l.shape().elem_count();
|
||||
let dev = s.device().clone();
|
||||
let s = s.as_cuda_slice::<f16>()?;
|
||||
// let s = match l.contiguous_offsets() {
|
||||
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||
// Some((o1, o2)) => s.slice(o1..o2),
|
||||
// };
|
||||
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
Ok((dst, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
|
||||
x.custom_op1(AllReduce { comm: comm.clone() })
|
||||
}
|
||||
|
||||
impl TensorParallelRowLinear {
|
||||
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
|
||||
Self { linear, comm }
|
||||
}
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = self.linear.forward(x)?;
|
||||
all_reduce_sum(&x, &self.comm)
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorParallelColumnLinear {
|
||||
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
|
||||
let rank = comm.rank();
|
||||
let size = comm.world_size();
|
||||
let weight = vb.get_sharded("weight", 0, rank, size)?;
|
||||
Ok(Self::new(Linear::new(weight, None)))
|
||||
}
|
||||
|
||||
fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc<Comm>) -> Result<Self> {
|
||||
let rank = comm.rank();
|
||||
let size = comm.world_size();
|
||||
let weights: Vec<_> = prefixes
|
||||
.iter()
|
||||
.map(|p| vb.pp(p).get_sharded("weight", 0, rank, size).unwrap())
|
||||
.collect();
|
||||
let weight = Tensor::cat(&weights, 0)?;
|
||||
Ok(Self::new(Linear::new(weight, None)))
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorParallelRowLinear {
|
||||
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
|
||||
let rank = comm.rank();
|
||||
let size = comm.world_size();
|
||||
let weight = vb.get_sharded("weight", 1, rank, size)?;
|
||||
Ok(Self::new(Linear::new(weight, None), comm))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub n_layer: usize,
|
||||
pub n_head: usize,
|
||||
pub n_embd: usize,
|
||||
pub n_key_value_head: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn config_7b() -> Self {
|
||||
Self {
|
||||
hidden_size: 4096,
|
||||
intermediate_size: 11008,
|
||||
vocab_size: 32000,
|
||||
n_layer: 32,
|
||||
n_head: 32,
|
||||
n_embd: 4096,
|
||||
n_key_value_head: 32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
#[allow(clippy::type_complexity)]
|
||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
pub fn new(config: &Config, device: &Device) -> Result<Self> {
|
||||
// precompute freqs_cis
|
||||
let n_elem = config.n_embd / config.n_head;
|
||||
let theta: Vec<_> = (0..n_elem)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
// This is different from the paper, see:
|
||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
||||
let cos = idx_theta.cos()?.to_dtype(DType::F16)?;
|
||||
let sin = idx_theta.sin()?.to_dtype(DType::F16)?;
|
||||
Ok(Self {
|
||||
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
||||
cos,
|
||||
sin,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
Ok(Linear::new(weight, None))
|
||||
}
|
||||
|
||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, cfg.hidden_size))
|
||||
}
|
||||
|
||||
struct RmsNorm {
|
||||
scale: Tensor,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn load(size: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let scale = vb.get(size, "weight")?;
|
||||
Ok(Self::new(scale))
|
||||
}
|
||||
|
||||
fn new(scale: Tensor) -> Self {
|
||||
Self { scale }
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let in_dtype = x.dtype();
|
||||
// This is a no-op if x's dtype is already f32.
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let (b_sz, seq_len, hidden_size) = x.shape().dims3()?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||
let size = self.scale.shape().dims1()?;
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_as((b_sz, seq_len, size))?;
|
||||
let x = (scale * x_normed)?;
|
||||
let x = x.to_dtype(in_dtype)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
struct CausalSelfAttention {
|
||||
qkv_proj: TensorParallelColumnLinear,
|
||||
o_proj: TensorParallelRowLinear,
|
||||
n_head: usize,
|
||||
n_key_value_head: usize,
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, _, seq_len, n_embd) = x.shape().dims4()?;
|
||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
|
||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, _) = x.shape().dims3()?;
|
||||
|
||||
let qkv = self.qkv_proj.forward(x)?;
|
||||
let n_embd = self.n_head * self.head_dim;
|
||||
|
||||
let q = qkv.i((.., .., ..self.n_head * self.head_dim))?;
|
||||
let k = qkv.i((
|
||||
..,
|
||||
..,
|
||||
self.n_head * self.head_dim
|
||||
..self.n_head * self.head_dim + self.n_key_value_head * self.head_dim,
|
||||
))?;
|
||||
let v = qkv.i((
|
||||
..,
|
||||
..,
|
||||
self.n_head * self.head_dim + self.n_key_value_head * self.head_dim..,
|
||||
))?;
|
||||
// todo!("Q {:?} K {:?} V {:?} - x {:?}", q.shape(), k.shape(), v.shape(), x.shape());
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let mut v = v
|
||||
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
let mut cache = self.cache.kvs.lock().unwrap();
|
||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
|
||||
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
|
||||
let k_seq_len = k.dims()[1];
|
||||
if k_seq_len > MAX_SEQ_LEN {
|
||||
k = k
|
||||
.narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||
.contiguous()?
|
||||
}
|
||||
let v_seq_len = v.dims()[1];
|
||||
if v_seq_len > 2 * MAX_SEQ_LEN {
|
||||
v = v
|
||||
.narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||
.contiguous()?
|
||||
}
|
||||
}
|
||||
cache[block_idx] = Some((k.clone(), v.clone()));
|
||||
|
||||
let k = self.repeat_kv(k)?;
|
||||
let v = self.repeat_kv(v)?;
|
||||
let q = q.transpose(1, 2)?;
|
||||
let k = k.transpose(1, 2)?;
|
||||
let v = v.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
|
||||
.transpose(1, 2)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = self.o_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.n_head / self.n_key_value_head;
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().dims4()?;
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let qkv_proj = TensorParallelColumnLinear::load_multi(
|
||||
vb.clone(),
|
||||
&["q_proj", "k_proj", "v_proj"],
|
||||
comm.clone(),
|
||||
)?;
|
||||
let o_proj = TensorParallelRowLinear::load(vb.pp("o_proj"), comm.clone())?;
|
||||
Ok(Self {
|
||||
qkv_proj,
|
||||
o_proj,
|
||||
n_head: cfg.n_head / comm.world_size(),
|
||||
n_key_value_head: cfg.n_key_value_head / comm.world_size(),
|
||||
head_dim: cfg.hidden_size / cfg.n_head,
|
||||
cache: cache.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct Mlp {
|
||||
c_fc1: TensorParallelColumnLinear,
|
||||
c_fc2: TensorParallelColumnLinear,
|
||||
c_proj: TensorParallelRowLinear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(
|
||||
c_fc1: TensorParallelColumnLinear,
|
||||
c_fc2: TensorParallelColumnLinear,
|
||||
c_proj: TensorParallelRowLinear,
|
||||
) -> Self {
|
||||
Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
self.c_proj.forward(&x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, _cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
|
||||
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
|
||||
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?;
|
||||
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
||||
}
|
||||
}
|
||||
|
||||
struct Block {
|
||||
rms_1: RmsNorm,
|
||||
attn: CausalSelfAttention,
|
||||
rms_2: RmsNorm,
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||
Self {
|
||||
rms_1,
|
||||
attn,
|
||||
rms_2,
|
||||
mlp,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
let residual = &x;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm)?;
|
||||
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm =
|
||||
RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?;
|
||||
Ok(Self::new(
|
||||
input_layernorm,
|
||||
attn,
|
||||
post_attention_layernorm,
|
||||
mlp,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Llama {
|
||||
wte: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
ln_f: RmsNorm,
|
||||
lm_head: Linear,
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
|
||||
Self {
|
||||
wte,
|
||||
blocks,
|
||||
ln_f,
|
||||
lm_head,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.shape().dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let x = x.i((.., seq_len - 1, ..))?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
logits.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layer)
|
||||
.map(|i| {
|
||||
Block::load(
|
||||
vb.pp(&format!("model.layers.{i}")),
|
||||
cache,
|
||||
cfg,
|
||||
comm.clone(),
|
||||
)
|
||||
.unwrap()
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Self::new(wte, blocks, norm, lm_head))
|
||||
}
|
||||
}
|
163
candle-examples/examples/mnist-training/main.rs
Normal file
163
candle-examples/examples/mnist-training/main.rs
Normal file
@ -0,0 +1,163 @@
|
||||
// This should reach 91.5% accuracy.
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{loss, ops, Linear, VarBuilder, VarMap};
|
||||
|
||||
const IMAGE_DIM: usize = 784;
|
||||
const LABELS: usize = 10;
|
||||
|
||||
fn linear_z(in_dim: usize, out_dim: usize, vs: VarBuilder) -> Result<Linear> {
|
||||
let ws = vs.get_or_init((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
|
||||
let bs = vs.get_or_init(out_dim, "bias", candle_nn::init::ZERO)?;
|
||||
Ok(Linear::new(ws, Some(bs)))
|
||||
}
|
||||
|
||||
trait Model: Sized {
|
||||
fn new(vs: VarBuilder) -> Result<Self>;
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
struct LinearModel {
|
||||
linear: Linear,
|
||||
}
|
||||
|
||||
impl Model for LinearModel {
|
||||
fn new(vs: VarBuilder) -> Result<Self> {
|
||||
let linear = linear_z(IMAGE_DIM, LABELS, vs)?;
|
||||
Ok(Self { linear })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.linear.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
struct Mlp {
|
||||
ln1: Linear,
|
||||
ln2: Linear,
|
||||
}
|
||||
|
||||
impl Model for Mlp {
|
||||
fn new(vs: VarBuilder) -> Result<Self> {
|
||||
let ln1 = candle_nn::linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
|
||||
let ln2 = candle_nn::linear(100, LABELS, vs.pp("ln2"))?;
|
||||
Ok(Self { ln1, ln2 })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.ln1.forward(xs)?;
|
||||
let xs = xs.relu()?;
|
||||
self.ln2.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
struct TrainingArgs {
|
||||
learning_rate: f64,
|
||||
load: Option<String>,
|
||||
save: Option<String>,
|
||||
epochs: usize,
|
||||
}
|
||||
|
||||
fn training_loop<M: Model>(
|
||||
m: candle_nn::vision::Dataset,
|
||||
args: &TrainingArgs,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = candle::Device::cuda_if_available(0)?;
|
||||
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images.to_device(&dev)?;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
let mut varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
|
||||
let model = M::new(vs.clone())?;
|
||||
|
||||
if let Some(load) = &args.load {
|
||||
println!("loading weights from {load}");
|
||||
varmap.load(load)?
|
||||
}
|
||||
|
||||
let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate);
|
||||
let test_images = m.test_images.to_device(&dev)?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
for epoch in 1..args.epochs {
|
||||
let logits = model.forward(&train_images)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
let loss = loss::nll(&log_sm, &train_labels)?;
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
let test_logits = model.forward(&test_images)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
100. * test_accuracy
|
||||
);
|
||||
}
|
||||
if let Some(save) = &args.save {
|
||||
println!("saving trained weights in {save}");
|
||||
varmap.save(save)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Clone)]
|
||||
enum WhichModel {
|
||||
Linear,
|
||||
Mlp,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[clap(value_enum, default_value_t = WhichModel::Linear)]
|
||||
model: WhichModel,
|
||||
|
||||
#[arg(long)]
|
||||
learning_rate: Option<f64>,
|
||||
|
||||
#[arg(long, default_value_t = 200)]
|
||||
epochs: usize,
|
||||
|
||||
/// The file where to save the trained weights, in safetensors format.
|
||||
#[arg(long)]
|
||||
save: Option<String>,
|
||||
|
||||
/// The file where to load the trained weights from, in safetensors format.
|
||||
#[arg(long)]
|
||||
load: Option<String>,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
// Load the dataset
|
||||
let m = candle_nn::vision::mnist::load_dir("data")?;
|
||||
println!("train-images: {:?}", m.train_images.shape());
|
||||
println!("train-labels: {:?}", m.train_labels.shape());
|
||||
println!("test-images: {:?}", m.test_images.shape());
|
||||
println!("test-labels: {:?}", m.test_labels.shape());
|
||||
|
||||
let default_learning_rate = match args.model {
|
||||
WhichModel::Linear => 1.,
|
||||
WhichModel::Mlp => 0.05,
|
||||
};
|
||||
let training_args = TrainingArgs {
|
||||
epochs: args.epochs,
|
||||
learning_rate: args.learning_rate.unwrap_or(default_learning_rate),
|
||||
load: args.load,
|
||||
save: args.save,
|
||||
};
|
||||
match args.model {
|
||||
WhichModel::Linear => training_loop::<LinearModel>(m, &training_args),
|
||||
WhichModel::Mlp => training_loop::<Mlp>(m, &training_args),
|
||||
}
|
||||
}
|
@ -142,7 +142,7 @@ impl EncodecEuclideanCodebook {
|
||||
}
|
||||
|
||||
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||
let quantize = Tensor::embedding(embed_ind, &self.embed)?;
|
||||
let quantize = self.embed.embedding(embed_ind)?;
|
||||
Ok(quantize)
|
||||
}
|
||||
}
|
||||
|
@ -123,7 +123,7 @@ impl MusicgenSinusoidalPositionalEmbedding {
|
||||
}
|
||||
|
||||
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, _codebooks, seq_len) = input_ids.shape().r3()?;
|
||||
let (_b_sz, _codebooks, seq_len) = input_ids.dims3()?;
|
||||
if seq_len > self.weights.dim(0)? {
|
||||
self.weights = get_embedding(seq_len, self.embedding_dim)?
|
||||
}
|
||||
@ -170,7 +170,7 @@ impl MusicgenAttention {
|
||||
kv_states: Option<&Tensor>,
|
||||
attention_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, tgt_len, _) = xs.shape().r3()?;
|
||||
let (b_sz, tgt_len, _) = xs.dims3()?;
|
||||
let query_states = (self.q_proj.forward(xs)? * self.scaling)?;
|
||||
|
||||
let kv_states = kv_states.unwrap_or(xs);
|
||||
@ -187,7 +187,7 @@ impl MusicgenAttention {
|
||||
let attn_weights = attn_weights
|
||||
.reshape((b_sz, self.num_heads, tgt_len, src_len))?
|
||||
.broadcast_add(attention_mask)?;
|
||||
let attn_weights = attn_weights.softmax(D::Minus1)?;
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
// TODO: layer_head_mask?
|
||||
let attn_output = attn_weights
|
||||
.matmul(&value_states)?
|
||||
@ -308,7 +308,7 @@ impl MusicgenDecoder {
|
||||
|
||||
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let dev = input_ids.device();
|
||||
let (b_sz_times_codebooks, seq_len) = input_ids.shape().r2()?;
|
||||
let (b_sz_times_codebooks, seq_len) = input_ids.dims2()?;
|
||||
let b_sz = b_sz_times_codebooks / self.num_codebooks;
|
||||
let input = input_ids.reshape((b_sz, self.num_codebooks, seq_len))?;
|
||||
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, dev)?;
|
||||
@ -352,7 +352,7 @@ impl MusicgenForCausalLM {
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = input_ids.shape().r2()?;
|
||||
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||
let hidden_states = self.decoder.forward(input_ids)?;
|
||||
let lm_logits = self
|
||||
.lm_heads
|
||||
|
@ -223,7 +223,7 @@ impl T5Attention {
|
||||
.transpose(1, 2)?;
|
||||
let scores = q.matmul(&k.t()?)?;
|
||||
// TODO: position_bias_masked
|
||||
let attn_weights = scores.softmax(D::Minus1)?;
|
||||
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
let attn_output = self.o.forward(&attn_output)?;
|
||||
Ok(attn_output)
|
||||
@ -338,7 +338,7 @@ impl T5Stack {
|
||||
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||
let (_b_sz, _seq_len) = input_embeds.shape().r2()?;
|
||||
let (_b_sz, _seq_len) = input_embeds.dims2()?;
|
||||
|
||||
let mut hidden_states = self.dropout.forward(&input_embeds)?;
|
||||
for block in self.block.iter() {
|
||||
|
@ -18,10 +18,8 @@ fn fft<T: Float>(inp: &[T]) -> Vec<T> {
|
||||
}
|
||||
let mut out = vec![zero; n * 2];
|
||||
|
||||
let mut even = vec![];
|
||||
even.reserve(n / 2);
|
||||
let mut odd = vec![];
|
||||
odd.reserve(n / 2);
|
||||
let mut even = Vec::with_capacity(n / 2);
|
||||
let mut odd = Vec::with_capacity(n / 2);
|
||||
|
||||
for (i, &inp) in inp.iter().enumerate() {
|
||||
if i % 2 == 0 {
|
||||
|
@ -11,9 +11,9 @@ extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{safetensors::Load, DType, Device, Tensor};
|
||||
use candle_hub::{api::sync::Api, Repo, RepoType};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
@ -120,19 +120,17 @@ impl Decoder {
|
||||
// Extract the no speech probability on the first iteration by looking at the first
|
||||
// token logits and the probability for the according token.
|
||||
if i == 0 {
|
||||
no_speech_prob = logits
|
||||
.get(0)?
|
||||
.softmax(0)?
|
||||
no_speech_prob = softmax(&logits.get(0)?, 0)?
|
||||
.get(NO_SPEECH_TOKEN as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
}
|
||||
|
||||
let (seq_len, _) = logits.shape().r2()?;
|
||||
let (seq_len, _) = logits.dims2()?;
|
||||
let logits = logits
|
||||
.get(seq_len - 1)?
|
||||
.broadcast_add(&self.suppress_tokens)?;
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = (&logits / t)?.softmax(0)?;
|
||||
let prs = softmax(&(&logits / t)?, 0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
distr.sample(&mut self.rng) as u32
|
||||
@ -146,8 +144,7 @@ impl Decoder {
|
||||
.unwrap()
|
||||
};
|
||||
tokens.push(next_token);
|
||||
let prob = logits
|
||||
.softmax(candle::D::Minus1)?
|
||||
let prob = softmax(&logits, candle::D::Minus1)?
|
||||
.get(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
|
||||
@ -195,7 +192,7 @@ impl Decoder {
|
||||
}
|
||||
|
||||
fn run(&mut self, mel: &Tensor) -> Result<Vec<Segment>> {
|
||||
let (_, _, content_frames) = mel.shape().r3()?;
|
||||
let (_, _, content_frames) = mel.dims3()?;
|
||||
let mut seek = 0;
|
||||
let mut segments = vec![];
|
||||
while seek < content_frames {
|
||||
@ -282,28 +279,23 @@ fn main() -> Result<()> {
|
||||
std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")),
|
||||
)
|
||||
} else {
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
let dataset = api.dataset("Narsil/candle-examples".to_string());
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let sample = if let Some(input) = args.input {
|
||||
if let Some(sample) = input.strip_prefix("sample:") {
|
||||
api.get(
|
||||
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
|
||||
&format!("samples_{sample}.wav"),
|
||||
)?
|
||||
dataset.get(&format!("samples_{sample}.wav"))?
|
||||
} else {
|
||||
std::path::PathBuf::from(input)
|
||||
}
|
||||
} else {
|
||||
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
|
||||
api.get(
|
||||
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
|
||||
"samples_jfk.wav",
|
||||
)?
|
||||
dataset.get("samples_jfk.wav")?
|
||||
};
|
||||
(
|
||||
api.get(&repo, "config.json")?,
|
||||
api.get(&repo, "tokenizer.json")?,
|
||||
api.get(&repo, "model.safetensors")?,
|
||||
repo.get("config.json")?,
|
||||
repo.get("tokenizer.json")?,
|
||||
repo.get("model.safetensors")?,
|
||||
sample,
|
||||
)
|
||||
};
|
||||
|
@ -2,7 +2,7 @@
|
||||
// back when using RUST_LIB_BACKTRACE=1.
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
// The names in comments correspond to the original implementation:
|
||||
@ -132,7 +132,7 @@ impl MultiHeadAttention {
|
||||
}
|
||||
|
||||
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (n_batch, n_ctx, n_state) = x.shape().r3()?;
|
||||
let (n_batch, n_ctx, n_state) = x.dims3()?;
|
||||
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
||||
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
|
||||
}
|
||||
@ -144,7 +144,7 @@ impl MultiHeadAttention {
|
||||
v: &Tensor,
|
||||
mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let (_, n_ctx, n_state) = q.shape().r3()?;
|
||||
let (_, n_ctx, n_state) = q.dims3()?;
|
||||
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
||||
let q = (self.reshape_head(q)? * scale)?;
|
||||
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
||||
@ -154,7 +154,7 @@ impl MultiHeadAttention {
|
||||
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
|
||||
qk = qk.broadcast_add(&mask)?
|
||||
}
|
||||
let w = qk.softmax(candle::D::Minus1)?;
|
||||
let w = softmax(&qk, candle::D::Minus1)?;
|
||||
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
|
||||
Ok(wv)
|
||||
}
|
||||
@ -270,7 +270,7 @@ impl AudioEncoder {
|
||||
let x = self.conv1.forward(x)?.gelu()?;
|
||||
let x = self.conv2.forward(&x)?.gelu()?;
|
||||
let x = x.transpose(1, 2)?;
|
||||
let (_bsize, seq_len, _hidden) = x.shape().r3()?;
|
||||
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
||||
let mut x = x.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter() {
|
||||
|
24
candle-flash-attn/Cargo.toml
Normal file
24
candle-flash-attn/Cargo.toml
Normal file
@ -0,0 +1,24 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT/Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.1.0", package = "candle-core" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
num_cpus = "1.15.0"
|
||||
rayon = "1.7.0"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.0", features = ["cuda"] }
|
1
candle-flash-attn/README.md
Normal file
1
candle-flash-attn/README.md
Normal file
@ -0,0 +1 @@
|
||||
# candle-flash-attn
|
252
candle-flash-attn/build.rs
Normal file
252
candle-flash-attn/build.rs
Normal file
@ -0,0 +1,252 @@
|
||||
// Build script to run nvcc and generate the C glue code for launching the flash-attention kernel.
|
||||
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
|
||||
// variable in order to cache the compiled artifacts and avoid recompiling too often.
|
||||
use anyhow::{Context, Result};
|
||||
use rayon::prelude::*;
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
|
||||
const KERNEL_FILES: [&str; 9] = [
|
||||
"flash_api.cu",
|
||||
"flash_fwd_hdim128_fp16_sm80.cu",
|
||||
"flash_fwd_hdim160_fp16_sm80.cu",
|
||||
"flash_fwd_hdim192_fp16_sm80.cu",
|
||||
"flash_fwd_hdim224_fp16_sm80.cu",
|
||||
"flash_fwd_hdim256_fp16_sm80.cu",
|
||||
"flash_fwd_hdim32_fp16_sm80.cu",
|
||||
"flash_fwd_hdim64_fp16_sm80.cu",
|
||||
"flash_fwd_hdim96_fp16_sm80.cu",
|
||||
// "flash_fwd_hdim128_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim160_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim192_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim224_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim256_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim32_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim64_bf16_sm80.cu",
|
||||
// "flash_fwd_hdim96_bf16_sm80.cu",
|
||||
];
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
|
||||
|_| num_cpus::get_physical(),
|
||||
|s| usize::from_str(&s).unwrap(),
|
||||
);
|
||||
|
||||
rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(num_cpus)
|
||||
.build_global()
|
||||
.unwrap();
|
||||
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
for kernel_file in KERNEL_FILES.iter() {
|
||||
println!("cargo:rerun-if-changed=kernels/{kernel_file}");
|
||||
}
|
||||
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
|
||||
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
|
||||
println!("cargo:rerun-if-changed=kernels/flash.h");
|
||||
println!("cargo:rerun-if-changed=kernels/philox.cuh");
|
||||
println!("cargo:rerun-if-changed=kernels/softmax.h");
|
||||
println!("cargo:rerun-if-changed=kernels/utils.h");
|
||||
println!("cargo:rerun-if-changed=kernels/kernel_traits.h");
|
||||
println!("cargo:rerun-if-changed=kernels/block_info.h");
|
||||
println!("cargo:rerun-if-changed=kernels/static_switch.h");
|
||||
let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?);
|
||||
let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
|
||||
Err(_) =>
|
||||
{
|
||||
#[allow(clippy::redundant_clone)]
|
||||
out_dir.clone()
|
||||
}
|
||||
Ok(build_dir) => PathBuf::from(build_dir),
|
||||
};
|
||||
set_cuda_include_dir()?;
|
||||
let compute_cap = compute_cap()?;
|
||||
|
||||
let out_file = build_dir.join("libflashattention.a");
|
||||
|
||||
let kernel_dir = PathBuf::from("kernels");
|
||||
let cu_files: Vec<_> = KERNEL_FILES
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let mut obj_file = out_dir.join(f);
|
||||
obj_file.set_extension("o");
|
||||
(kernel_dir.join(f), obj_file)
|
||||
})
|
||||
.collect();
|
||||
let should_compile = if out_file.exists() {
|
||||
cu_files.iter().any(|(cu_file, _)| {
|
||||
let out_modified = out_file.metadata().unwrap().modified().unwrap();
|
||||
let in_modified = cu_file.metadata().unwrap().modified().unwrap();
|
||||
in_modified.duration_since(out_modified).is_ok()
|
||||
})
|
||||
} else {
|
||||
true
|
||||
};
|
||||
if should_compile {
|
||||
cu_files
|
||||
.par_iter()
|
||||
.map(|(cu_file, obj_file)| {
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
command
|
||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("-c")
|
||||
.args(["-o", obj_file.to_str().unwrap()])
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.arg("-Icutlass/include")
|
||||
.arg("--expt-relaxed-constexpr")
|
||||
.arg(cu_file);
|
||||
let output = command
|
||||
.spawn()
|
||||
.context("failed spawning nvcc")?
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<()>>()?;
|
||||
let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
command
|
||||
.arg("--lib")
|
||||
.args(["-o", out_file.to_str().unwrap()])
|
||||
.args(obj_files);
|
||||
let output = command
|
||||
.spawn()
|
||||
.context("failed spawning nvcc")?
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while linking:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
}
|
||||
}
|
||||
println!("cargo:rustc-link-search={}", build_dir.display());
|
||||
println!("cargo:rustc-link-lib=flashattention");
|
||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||
|
||||
/* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
|
||||
finishing to run for some reason. Calling nvcc manually worked fine.
|
||||
cc::Build::new()
|
||||
.cuda(true)
|
||||
.include("cutlass/include")
|
||||
.flag("--expt-relaxed-constexpr")
|
||||
.flag("--default-stream")
|
||||
.flag("per-thread")
|
||||
.flag(&format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.file("kernels/flash_fwd_hdim32_fp16_sm80.cu")
|
||||
.compile("flashattn");
|
||||
*/
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_cuda_include_dir() -> Result<()> {
|
||||
// NOTE: copied from cudarc build.rs.
|
||||
let env_vars = [
|
||||
"CUDA_PATH",
|
||||
"CUDA_ROOT",
|
||||
"CUDA_TOOLKIT_ROOT_DIR",
|
||||
"CUDNN_LIB",
|
||||
];
|
||||
let env_vars = env_vars
|
||||
.into_iter()
|
||||
.map(std::env::var)
|
||||
.filter_map(Result::ok)
|
||||
.map(Into::<PathBuf>::into);
|
||||
|
||||
let roots = [
|
||||
"/usr",
|
||||
"/usr/local/cuda",
|
||||
"/opt/cuda",
|
||||
"/usr/lib/cuda",
|
||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||
"C:/CUDA",
|
||||
];
|
||||
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||
let root = env_vars
|
||||
.chain(roots)
|
||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||
.context("cannot find include/cuda.h")?;
|
||||
println!(
|
||||
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||
root.join("include").display()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn compute_cap() -> Result<usize> {
|
||||
// Grab compute code from nvidia-smi
|
||||
let mut compute_cap = {
|
||||
let out = std::process::Command::new("nvidia-smi")
|
||||
.arg("--query-gpu=compute_cap")
|
||||
.arg("--format=csv")
|
||||
.output()
|
||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||
let mut lines = out.lines();
|
||||
assert_eq!(
|
||||
lines.next().context("missing line in stdout")?,
|
||||
"compute_cap"
|
||||
);
|
||||
let cap = lines
|
||||
.next()
|
||||
.context("missing line in stdout")?
|
||||
.replace('.', "");
|
||||
cap.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as int {cap}"))?
|
||||
};
|
||||
|
||||
// Grab available GPU codes from nvcc and select the highest one
|
||||
let max_nvcc_code = {
|
||||
let out = std::process::Command::new("nvcc")
|
||||
.arg("--list-gpu-code")
|
||||
.output()
|
||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||
|
||||
let out = out.lines().collect::<Vec<&str>>();
|
||||
let mut codes = Vec::with_capacity(out.len());
|
||||
for code in out {
|
||||
let code = code.split('_').collect::<Vec<&str>>();
|
||||
if !code.is_empty() && code.contains(&"sm") {
|
||||
if let Ok(num) = code[1].parse::<usize>() {
|
||||
codes.push(num);
|
||||
}
|
||||
}
|
||||
}
|
||||
codes.sort();
|
||||
if !codes.contains(&compute_cap) {
|
||||
anyhow::bail!(
|
||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
|
||||
);
|
||||
}
|
||||
*codes.last().unwrap()
|
||||
};
|
||||
|
||||
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
|
||||
// then choose the highest gpu code in nvcc
|
||||
if compute_cap > max_nvcc_code {
|
||||
println!(
|
||||
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
|
||||
);
|
||||
compute_cap = max_nvcc_code;
|
||||
}
|
||||
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
compute_cap = compute_cap_str
|
||||
.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
|
||||
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
|
||||
}
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
||||
Ok(compute_cap)
|
||||
}
|
1
candle-flash-attn/cutlass
Submodule
1
candle-flash-attn/cutlass
Submodule
Submodule candle-flash-attn/cutlass added at c4f6b8c6bc
41
candle-flash-attn/kernels/block_info.h
Normal file
41
candle-flash-attn/kernels/block_info.h
Normal file
@ -0,0 +1,41 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace flash {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool Varlen=true>
|
||||
struct BlockInfo {
|
||||
|
||||
template<typename Params>
|
||||
__device__ BlockInfo(const Params ¶ms, const int bidb)
|
||||
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
|
||||
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
|
||||
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
||||
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
|
||||
}
|
||||
|
||||
const int sum_s_q;
|
||||
const int sum_s_k;
|
||||
const uint32_t actual_seqlen_q;
|
||||
const uint32_t actual_seqlen_k;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
141
candle-flash-attn/kernels/flash.h
Normal file
141
candle-flash-attn/kernels/flash.h
Normal file
@ -0,0 +1,141 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <vector>
|
||||
|
||||
// #ifdef OLD_GENERATOR_PATH
|
||||
// #include <ATen/CUDAGeneratorImpl.h>
|
||||
// #else
|
||||
// #include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
// #endif
|
||||
//
|
||||
// #include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||
|
||||
|
||||
constexpr int TOTAL_DIM = 0;
|
||||
constexpr int H_DIM = 1;
|
||||
constexpr int D_DIM = 2;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Qkv_params {
|
||||
using index_t = uint32_t;
|
||||
// The QKV matrices.
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
void *__restrict__ v_ptr;
|
||||
|
||||
// The stride between rows of the Q, K and V matrices.
|
||||
index_t q_batch_stride;
|
||||
index_t k_batch_stride;
|
||||
index_t v_batch_stride;
|
||||
index_t q_row_stride;
|
||||
index_t k_row_stride;
|
||||
index_t v_row_stride;
|
||||
index_t q_head_stride;
|
||||
index_t k_head_stride;
|
||||
index_t v_head_stride;
|
||||
|
||||
// The number of heads.
|
||||
int h, h_k;
|
||||
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
|
||||
// different from nheads (query).
|
||||
int h_h_k_ratio; // precompute h / h_k,
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Flash_fwd_params : public Qkv_params {
|
||||
|
||||
// The O matrix (output).
|
||||
void * __restrict__ o_ptr;
|
||||
|
||||
// The stride between rows of O.
|
||||
index_t o_batch_stride;
|
||||
index_t o_row_stride;
|
||||
index_t o_head_stride;
|
||||
|
||||
// The pointer to the P matrix.
|
||||
void * __restrict__ p_ptr;
|
||||
|
||||
// The pointer to the softmax sum.
|
||||
void * __restrict__ softmax_lse_ptr;
|
||||
|
||||
// The dimensions.
|
||||
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
|
||||
|
||||
// The scaling factors for the kernel.
|
||||
float scale_softmax;
|
||||
float scale_softmax_log2;
|
||||
|
||||
// array of length b+1 holding starting offset of each sequence.
|
||||
int * __restrict__ cu_seqlens_q;
|
||||
int * __restrict__ cu_seqlens_k;
|
||||
|
||||
int *__restrict__ blockmask;
|
||||
|
||||
// The dropout probability (probability of keeping an activation).
|
||||
float p_dropout;
|
||||
// uint32_t p_dropout_in_uint;
|
||||
// uint16_t p_dropout_in_uint16_t;
|
||||
uint8_t p_dropout_in_uint8_t;
|
||||
|
||||
// Scale factor of 1 / (1 - p_dropout).
|
||||
float rp_dropout;
|
||||
float scale_softmax_rp_dropout;
|
||||
|
||||
// Random state.
|
||||
// at::PhiloxCudaState philox_args;
|
||||
|
||||
bool is_bf16;
|
||||
bool is_causal;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Flash_bwd_params : public Flash_fwd_params {
|
||||
|
||||
// The dO and dQKV matrices.
|
||||
void *__restrict__ do_ptr;
|
||||
void *__restrict__ dq_ptr;
|
||||
void *__restrict__ dk_ptr;
|
||||
void *__restrict__ dv_ptr;
|
||||
|
||||
// To accumulate dQ
|
||||
void *__restrict__ dq_accum_ptr;
|
||||
void *__restrict__ dk_accum_ptr;
|
||||
void *__restrict__ dv_accum_ptr;
|
||||
|
||||
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
|
||||
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
|
||||
// dv_accum_ptr;
|
||||
|
||||
// The stride between rows of the dO, dQ, dK and dV matrices.
|
||||
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
|
||||
// The code probably won't work for arrays larger than 2GB.
|
||||
index_t do_batch_stride;
|
||||
index_t do_row_stride;
|
||||
index_t do_head_stride;
|
||||
index_t dq_batch_stride;
|
||||
index_t dk_batch_stride;
|
||||
index_t dv_batch_stride;
|
||||
index_t dq_row_stride;
|
||||
index_t dk_row_stride;
|
||||
index_t dv_row_stride;
|
||||
index_t dq_head_stride;
|
||||
index_t dk_head_stride;
|
||||
index_t dv_head_stride;
|
||||
|
||||
// The pointer to the softmax d sum.
|
||||
void *__restrict__ dsoftmax_sum;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure);
|
112
candle-flash-attn/kernels/flash_api.cu
Normal file
112
candle-flash-attn/kernels/flash_api.cu
Normal file
@ -0,0 +1,112 @@
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// TODO: Switch back to handling bf16.
|
||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
// void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// FP16_SWITCH(!params.is_bf16, [&] {
|
||||
// FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
// run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
// });
|
||||
// });
|
||||
// }
|
||||
|
||||
extern "C" void run_mha(
|
||||
void *q_ptr,
|
||||
void *k_ptr,
|
||||
void *v_ptr,
|
||||
void *o_ptr,
|
||||
void *softmax_lse_ptr,
|
||||
|
||||
int32_t *cu_seqlens_q_ptr,
|
||||
int32_t *cu_seqlens_k_ptr,
|
||||
|
||||
uint32_t q_batch_stride,
|
||||
uint32_t k_batch_stride,
|
||||
uint32_t v_batch_stride,
|
||||
uint32_t o_batch_stride,
|
||||
|
||||
uint32_t q_row_stride,
|
||||
uint32_t k_row_stride,
|
||||
uint32_t v_row_stride,
|
||||
uint32_t o_row_stride,
|
||||
|
||||
uint32_t q_head_stride,
|
||||
uint32_t k_head_stride,
|
||||
uint32_t v_head_stride,
|
||||
uint32_t o_head_stride,
|
||||
|
||||
uint32_t b,
|
||||
uint32_t h,
|
||||
uint32_t h_k,
|
||||
uint32_t d,
|
||||
uint32_t d_rounded,
|
||||
float softmax_scale,
|
||||
|
||||
uint32_t seqlen_q,
|
||||
uint32_t seqlen_k,
|
||||
uint32_t seqlen_q_rounded,
|
||||
uint32_t seqlen_k_rounded,
|
||||
|
||||
int is_causal
|
||||
) {
|
||||
Flash_fwd_params params;
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.q_ptr = q_ptr;
|
||||
params.k_ptr = k_ptr;
|
||||
params.v_ptr = v_ptr;
|
||||
params.o_ptr = o_ptr;
|
||||
|
||||
params.softmax_lse_ptr = softmax_lse_ptr;
|
||||
|
||||
// All stride are in elements, not bytes.
|
||||
params.q_batch_stride = q_batch_stride;
|
||||
params.k_batch_stride = k_batch_stride;
|
||||
params.v_batch_stride = v_batch_stride;
|
||||
params.o_batch_stride = o_batch_stride;
|
||||
|
||||
params.q_row_stride = q_row_stride;
|
||||
params.k_row_stride = k_row_stride;
|
||||
params.v_row_stride = v_row_stride;
|
||||
params.o_row_stride = o_row_stride;
|
||||
params.q_head_stride = q_head_stride;
|
||||
params.k_head_stride = k_head_stride;
|
||||
params.v_head_stride = v_head_stride;
|
||||
params.o_head_stride = o_head_stride;
|
||||
|
||||
// Set the dimensions.
|
||||
params.b = b;
|
||||
params.h = h;
|
||||
params.h_k = h_k;
|
||||
params.h_h_k_ratio = h / h_k;
|
||||
params.seqlen_q = seqlen_q;
|
||||
params.seqlen_k = seqlen_k;
|
||||
params.seqlen_q_rounded = seqlen_q_rounded;
|
||||
params.seqlen_k_rounded = seqlen_k_rounded;
|
||||
params.d = d;
|
||||
params.d_rounded = d_rounded;
|
||||
params.is_causal = is_causal;
|
||||
|
||||
// Set the different scale values.
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||
|
||||
params.p_dropout = 1.; // probability to keep
|
||||
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
|
||||
params.rp_dropout = 1.f / params.p_dropout;
|
||||
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
|
||||
params.is_bf16 = 0;
|
||||
params.cu_seqlens_q = cu_seqlens_q_ptr;
|
||||
params.cu_seqlens_k = cu_seqlens_k_ptr;
|
||||
params.p_ptr = nullptr; // used for `return_softmax`.
|
||||
|
||||
cudaStream_t stream = 0; // Use the default stream.
|
||||
run_mha_fwd(params, stream);
|
||||
}
|
19
candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
Normal file
19
candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
Normal file
@ -0,0 +1,19 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// if (params.p_dropout == 1.f) {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
|
||||
// } else {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
|
||||
// }
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
32
candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
Normal file
32
candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
Normal file
@ -0,0 +1,32 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// if (params.p_dropout == 1.f) {
|
||||
// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, false, elem_type>, false>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, true, elem_type>, false>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, true, elem_type>, false>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, false>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 64, 4, false, false, elem_type>, false>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 128, 4, false, false, elem_type>, false>(params, stream);
|
||||
// // 1st ones are good for H100, A100
|
||||
// // 2nd one is good for A6000 bc we get slightly better occupancy
|
||||
// } else {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, false, elem_type>, true>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, true, elem_type>, true>(params, stream);
|
||||
// // 1st one is good for H100, A100, A6000
|
||||
// }
|
||||
// }
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
|
||||
}
|
17
candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
Normal file
17
candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
Normal file
@ -0,0 +1,17 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// });
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
27
candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
Normal file
27
candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
Normal file
@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, true, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 128, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 8, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 128, 8, false, elem_type>>(params, stream);
|
||||
// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest.
|
||||
// // For A100, H100, 1st is fastest.
|
||||
// });
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
|
||||
}
|
16
candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
Normal file
16
candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
Normal file
@ -0,0 +1,16 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// });
|
||||
// }
|
||||
template<> void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
27
candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
Normal file
27
candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
Normal file
@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// // This one is slightly faster for causal?
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 8, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 128, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 128, 8, false, elem_type>>(params, stream);
|
||||
// });
|
||||
// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout
|
||||
// // For A6000, 1st is faster when causal, 3rd is faster when not causal
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
|
||||
}
|
9
candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
Normal file
9
candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
Normal file
@ -0,0 +1,9 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<> void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user