mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Compare commits
136 Commits
w-uncond
...
matmul-slo
Author | SHA1 | Date | |
---|---|---|---|
69c1fb1ee8 | |||
c55ebaf477 | |||
4c91dd2ff4 | |||
bc3351bce4 | |||
b34d7f0248 | |||
4d04ac83c7 | |||
392fe02fba | |||
59ab6d7832 | |||
783735cf22 | |||
9abeddd750 | |||
2e5fb0b251 | |||
823fe23f9b | |||
d833527fda | |||
a4967600d0 | |||
aa53368aeb | |||
955e00b2e8 | |||
d5f7267087 | |||
904bbdae65 | |||
b0442eff8a | |||
4631c48273 | |||
716883e9b0 | |||
47c25a567b | |||
7f7d95e2c3 | |||
f47bd9bab5 | |||
8f7973958c | |||
f0c619a4af | |||
b86ac0c507 | |||
27e70a5093 | |||
c18a856e76 | |||
3349c89252 | |||
11d3687cc6 | |||
dac73edb34 | |||
b4da19d1be | |||
ff513314fc | |||
043cc25766 | |||
7b06872f90 | |||
65825e7240 | |||
7670fe7d1f | |||
cddfc3944c | |||
089fc3b584 | |||
e04c789230 | |||
263a172202 | |||
638ccf9f46 | |||
0baf5a1e19 | |||
5130a7da32 | |||
41143db1af | |||
096dee7073 | |||
f6054e9d60 | |||
328167ec04 | |||
4e55aaa51f | |||
deee7612da | |||
06207332bc | |||
4021272875 | |||
87e3a4e175 | |||
6203ced495 | |||
34842fb234 | |||
d188d6a764 | |||
0ac2db577b | |||
fc59bc31bf | |||
03348e2e6f | |||
49fa184a35 | |||
6f17ef82be | |||
01b92cd959 | |||
53510ce427 | |||
23b3576c47 | |||
716ab2ccdc | |||
ada8851a23 | |||
c05a348e36 | |||
25657804ef | |||
5e1c595e00 | |||
8a49e01b9d | |||
9cb110c44c | |||
667f01c173 | |||
e59784e353 | |||
29bd6b2979 | |||
9571b200c9 | |||
ce0a4e3a85 | |||
4abc1ea34d | |||
2dd43d6cdd | |||
1fcac4afed | |||
a084f65f9a | |||
c798184c2b | |||
c78a294323 | |||
a36d883254 | |||
7f2bbcf746 | |||
dc47224ab9 | |||
1ce7fe2543 | |||
402ddcfcb4 | |||
f5069dd354 | |||
0007ae9c11 | |||
e15862cfdb | |||
4aeb449017 | |||
bcb0ed8f1c | |||
7edd755756 | |||
e32c89d90c | |||
bb3471ea31 | |||
890d069092 | |||
5dbe46b389 | |||
ccf352f3d1 | |||
402d207f0f | |||
7582937a32 | |||
b54acfa3d0 | |||
cda1786eed | |||
912a3d63b0 | |||
3ef328c53d | |||
0c8e983514 | |||
df6f5240ba | |||
a46b1b4657 | |||
19e52e5007 | |||
8601537e31 | |||
4ac6039a42 | |||
52a60ca3ad | |||
a96878f235 | |||
aa8ec06fd2 | |||
b43ca493f6 | |||
3b557765e8 | |||
2619c4307f | |||
c89b82b2d4 | |||
7b26e513f1 | |||
ab1d40ea97 | |||
3a0d3e05df | |||
9b24d89d2d | |||
fb1c2ac535 | |||
728e167334 | |||
7b1ddcff47 | |||
f685b2231c | |||
c0b49d5a50 | |||
098dd0d1e9 | |||
05626ef492 | |||
67a486d18d | |||
7ad82b87e4 | |||
8696f64bae | |||
d7e48234d4 | |||
34f2ecbc3b | |||
4f91c8e109 | |||
06e46d7c3b |
8
.gitignore
vendored
8
.gitignore
vendored
@ -23,14 +23,16 @@ flamegraph.svg
|
||||
*.dylib
|
||||
*.so
|
||||
*.swp
|
||||
*.swo
|
||||
trace-*.json
|
||||
|
||||
candle-wasm-examples/*/build
|
||||
candle-wasm-examples/*/*.bin
|
||||
candle-wasm-examples/*/*.jpeg
|
||||
candle-wasm-examples/*/*.wav
|
||||
candle-wasm-examples/*/*.safetensors
|
||||
candle-wasm-examples/*/audios/*.wav
|
||||
candle-wasm-examples/**/*.safetensors
|
||||
candle-wasm-examples/**/*.gguf
|
||||
candle-wasm-examples/*/package-lock.json
|
||||
|
||||
candle-wasm-examples/**/config*.json
|
||||
.DS_Store
|
||||
.idea/*
|
||||
|
11
.vscode/settings.json
vendored
Normal file
11
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.black-formatter"
|
||||
},
|
||||
"python.formatting.provider": "none",
|
||||
"python.testing.pytestArgs": [
|
||||
"candle-pyo3"
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
}
|
28
CHANGELOG.md
28
CHANGELOG.md
@ -1,12 +1,38 @@
|
||||
# Changelog
|
||||
This documents the main changes to the `candle` crate.
|
||||
|
||||
## v0.2.3 - Unreleased
|
||||
## v0.3.1 - Unreleased
|
||||
|
||||
### Added
|
||||
|
||||
### Modified
|
||||
|
||||
## v0.3.0 - 2023-10-01
|
||||
|
||||
### Added
|
||||
|
||||
- Added the Mistral 7b v0.1 model
|
||||
[983](https://github.com/huggingface/candle/pull/983).
|
||||
- Quantized version of the Mistral model
|
||||
[1009](https://github.com/huggingface/candle/pull/1009).
|
||||
- Add the gelu-erf op and activation function
|
||||
[969](https://github.com/huggingface/candle/pull/969).
|
||||
- Add the mixformer/phi-v1.5 model
|
||||
[930](https://github.com/huggingface/candle/pull/930).
|
||||
- Add the sclice-scatter op
|
||||
[927](https://github.com/huggingface/candle/pull/927).
|
||||
- Add the Wuerstchen diffusion model
|
||||
[911](https://github.com/huggingface/candle/pull/911).
|
||||
|
||||
### Modified
|
||||
|
||||
- Support for simd128 intrinsics in some quantized vecdots
|
||||
[982](https://github.com/huggingface/candle/pull/982).
|
||||
- Optimize the index-select cuda kernel
|
||||
[976](https://github.com/huggingface/candle/pull/976).
|
||||
- Self-contained safetensor wrappers
|
||||
[946](https://github.com/huggingface/candle/pull/946).
|
||||
|
||||
## v0.2.2 - 2023-09-18
|
||||
|
||||
### Added
|
||||
|
16
Cargo.toml
16
Cargo.toml
@ -11,15 +11,16 @@ members = [
|
||||
"candle-wasm-examples/segment-anything",
|
||||
"candle-wasm-examples/whisper",
|
||||
"candle-wasm-examples/yolo",
|
||||
"candle-wasm-examples/bert",
|
||||
"candle-wasm-examples/phi",
|
||||
"candle-wasm-examples/t5",
|
||||
"candle-wasm-tests",
|
||||
]
|
||||
exclude = [
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
]
|
||||
exclude = ["candle-flash-attn", "candle-kernels"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.2.3"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -42,9 +43,10 @@ imageproc = { version = "0.23.0", default-features = false }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
memmap2 = "0.7.1"
|
||||
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
parquet = { version = "45.0.0" }
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
rayon = "1.7.0"
|
||||
@ -58,8 +60,8 @@ tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
parquet = { version = "45.0.0" }
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
27
README.md
27
README.md
@ -8,6 +8,7 @@ Candle is a minimalist ML framework for Rust with a focus on performance (includ
|
||||
and ease of use. Try our online demos:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||
[yolo](https://huggingface.co/spaces/lmz/candle-yolo),
|
||||
[Segment
|
||||
Anything](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||
@ -52,14 +53,20 @@ These online demos run entirely in your browser:
|
||||
object recognition.
|
||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
|
||||
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
||||
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
||||
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
||||
|
||||
We also provide a some command line based examples using state of the art models:
|
||||
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
|
||||
generation.
|
||||
- [Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets.
|
||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||
performance larger than all publicly available 13b models as of 2023-09-28.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
||||
the LLaMA model using the same quantization techniques as
|
||||
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||
@ -71,6 +78,11 @@ We also provide a some command line based examples using state of the art models
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
|
||||
|
||||
- [Wuerstchen](./candle-examples/examples/wuerstchen/): another text to
|
||||
image generative model.
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/wuerstchen/assets/cat.jpg" width="200">
|
||||
|
||||
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
|
||||
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
|
||||
estimation models.
|
||||
@ -100,6 +112,8 @@ There are also some wasm examples for whisper and
|
||||
`trunk` or try them online:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||
[Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
|
||||
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||
|
||||
For LLaMA2, run the following command to retrieve the weight files and start a
|
||||
@ -138,10 +152,14 @@ If you have an addition to this list, please submit a pull request.
|
||||
- LLaMA v1 and v2.
|
||||
- Falcon.
|
||||
- StarCoder.
|
||||
- Phi v1.5.
|
||||
- Mistral 7b v0.1.
|
||||
- StableLM-3B-4E1T.
|
||||
- T5.
|
||||
- Bert.
|
||||
- Whisper (multi-lingual support).
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Wurstchen v2.
|
||||
- Computer Vision Models.
|
||||
- DINOv2.
|
||||
- EfficientNet.
|
||||
@ -306,6 +324,11 @@ mdbook test candle-book -L .\target\debug\deps\ `
|
||||
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib
|
||||
```
|
||||
|
||||
#### Extremely slow model load time with WSL
|
||||
|
||||
This may be caused by the models being loaded from `/mnt/c`, more details on
|
||||
[stackoverflow](https://stackoverflow.com/questions/68972448/why-is-wsl-extremely-slow-when-compared-with-native-windows-npm-yarn-processing).
|
||||
|
||||
#### Tracking down errors
|
||||
|
||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
@ -24,9 +24,10 @@ intel-mkl-src = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true, optional = true }
|
||||
anyhow = { workspace = true }
|
||||
tokio = "1.29.1"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
clap = { workspace = true }
|
||||
@ -38,7 +39,6 @@ tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.29.1"
|
||||
parquet = { workspace = true }
|
||||
image = { workspace = true }
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
- [Using the hub](inference/hub.md)
|
||||
- [Error management](error_manage.md)
|
||||
- [Training](training/training.md)
|
||||
- [Simplified](training/simplified.md)
|
||||
- [MNIST](training/mnist.md)
|
||||
- [Fine-tuning]()
|
||||
- [Serialization]()
|
||||
|
@ -1,3 +1,6 @@
|
||||
#[cfg(test)]
|
||||
pub mod simplified;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
|
196
candle-book/src/simplified.rs
Normal file
196
candle-book/src/simplified.rs
Normal file
@ -0,0 +1,196 @@
|
||||
//! #A simplified example in Rust of training a neural network and then using it based on the Candle Framework by Hugging Face.
|
||||
//! Author: Evgeny Igumnov 2023 igumnovnsk@gmail.com
|
||||
//! This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
|
||||
//!
|
||||
//! ##Basic moments:
|
||||
//!
|
||||
//! A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
|
||||
//! The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
|
||||
//! The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
|
||||
//! For training, samples with real data on the results of the first and second stages of different elections are used.
|
||||
//! The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
|
||||
//! Model parameters (weights of neurons) are initialized randomly, then optimized during training.
|
||||
//! After training, the model is tested on a deferred sample to evaluate the accuracy.
|
||||
//! If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
|
||||
//! Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
|
||||
|
||||
#[rustfmt::skip]
|
||||
mod tests {
|
||||
|
||||
use candle::{DType, Result, Tensor, D, Device};
|
||||
use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap, Optimizer};
|
||||
|
||||
// ANCHOR: book_training_simplified1
|
||||
const VOTE_DIM: usize = 2;
|
||||
const RESULTS: usize = 1;
|
||||
const EPOCHS: usize = 10;
|
||||
const LAYER1_OUT_SIZE: usize = 4;
|
||||
const LAYER2_OUT_SIZE: usize = 2;
|
||||
const LEARNING_RATE: f64 = 0.05;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Dataset {
|
||||
pub train_votes: Tensor,
|
||||
pub train_results: Tensor,
|
||||
pub test_votes: Tensor,
|
||||
pub test_results: Tensor,
|
||||
}
|
||||
|
||||
struct MultiLevelPerceptron {
|
||||
ln1: Linear,
|
||||
ln2: Linear,
|
||||
ln3: Linear,
|
||||
}
|
||||
|
||||
impl MultiLevelPerceptron {
|
||||
fn new(vs: VarBuilder) -> Result<Self> {
|
||||
let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?;
|
||||
let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?;
|
||||
let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?;
|
||||
Ok(Self { ln1, ln2, ln3 })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.ln1.forward(xs)?;
|
||||
let xs = xs.relu()?;
|
||||
let xs = self.ln2.forward(&xs)?;
|
||||
let xs = xs.relu()?;
|
||||
self.ln3.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
// ANCHOR_END: book_training_simplified1
|
||||
|
||||
|
||||
|
||||
// ANCHOR: book_training_simplified3
|
||||
#[tokio::test]
|
||||
async fn simplified() -> anyhow::Result<()> {
|
||||
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_votes_vec: Vec<u32> = vec![
|
||||
15, 10,
|
||||
10, 15,
|
||||
5, 12,
|
||||
30, 20,
|
||||
16, 12,
|
||||
13, 25,
|
||||
6, 14,
|
||||
31, 21,
|
||||
];
|
||||
let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||
|
||||
let train_results_vec: Vec<u32> = vec![
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
];
|
||||
let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &dev)?;
|
||||
|
||||
let test_votes_vec: Vec<u32> = vec![
|
||||
13, 9,
|
||||
8, 14,
|
||||
3, 10,
|
||||
];
|
||||
let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||
|
||||
let test_results_vec: Vec<u32> = vec![
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
];
|
||||
let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &dev)?;
|
||||
|
||||
let m = Dataset {
|
||||
train_votes: train_votes_tensor,
|
||||
train_results: train_results_tensor,
|
||||
test_votes: test_votes_tensor,
|
||||
test_results: test_results_tensor,
|
||||
};
|
||||
|
||||
let trained_model: MultiLevelPerceptron;
|
||||
loop {
|
||||
println!("Trying to train neural network.");
|
||||
match train(m.clone(), &dev) {
|
||||
Ok(model) => {
|
||||
trained_model = model;
|
||||
break;
|
||||
},
|
||||
Err(e) => {
|
||||
println!("Error: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
let real_world_votes: Vec<u32> = vec![
|
||||
13, 22,
|
||||
];
|
||||
|
||||
let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||
|
||||
let final_result = trained_model.forward(&tensor_test_votes)?;
|
||||
|
||||
let result = final_result
|
||||
.argmax(D::Minus1)?
|
||||
.to_dtype(DType::F32)?
|
||||
.get(0).map(|x| x.to_scalar::<f32>())??;
|
||||
println!("real_life_votes: {:?}", real_world_votes);
|
||||
println!("neural_network_prediction_result: {:?}", result);
|
||||
|
||||
Ok(())
|
||||
|
||||
}
|
||||
// ANCHOR_END: book_training_simplified3
|
||||
|
||||
// ANCHOR: book_training_simplified2
|
||||
fn train(m: Dataset, dev: &Device) -> anyhow::Result<MultiLevelPerceptron> {
|
||||
let train_results = m.train_results.to_device(dev)?;
|
||||
let train_votes = m.train_votes.to_device(dev)?;
|
||||
let varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, dev);
|
||||
let model = MultiLevelPerceptron::new(vs.clone())?;
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), LEARNING_RATE)?;
|
||||
let test_votes = m.test_votes.to_device(dev)?;
|
||||
let test_results = m.test_results.to_device(dev)?;
|
||||
let mut final_accuracy: f32 = 0.0;
|
||||
for epoch in 1..EPOCHS + 1 {
|
||||
let logits = model.forward(&train_votes)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
let loss = loss::nll(&log_sm, &train_results)?;
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
let test_logits = model.forward(&test_votes)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_results)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_results.dims1()? as f32;
|
||||
final_accuracy = 100. * test_accuracy;
|
||||
println!("Epoch: {epoch:3} Train loss: {:8.5} Test accuracy: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
final_accuracy
|
||||
);
|
||||
if final_accuracy == 100.0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if final_accuracy < 100.0 {
|
||||
Err(anyhow::Error::msg("The model is not trained well enough."))
|
||||
} else {
|
||||
Ok(model)
|
||||
}
|
||||
}
|
||||
// ANCHOR_END: book_training_simplified2
|
||||
|
||||
|
||||
}
|
45
candle-book/src/training/simplified.md
Normal file
45
candle-book/src/training/simplified.md
Normal file
@ -0,0 +1,45 @@
|
||||
# Simplified
|
||||
|
||||
## How its works
|
||||
|
||||
This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
|
||||
|
||||
Basic moments:
|
||||
|
||||
1. A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
|
||||
2. The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
|
||||
3. The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
|
||||
4. For training, samples with real data on the results of the first and second stages of different elections are used.
|
||||
5. The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
|
||||
6. Model parameters (weights of neurons) are initialized randomly, then optimized during training.
|
||||
7. After training, the model is tested on a deferred sample to evaluate the accuracy.
|
||||
8. If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
|
||||
|
||||
Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
|
||||
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../simplified.rs:book_training_simplified1}}
|
||||
```
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../simplified.rs:book_training_simplified2}}
|
||||
```
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../simplified.rs:book_training_simplified3}}
|
||||
```
|
||||
|
||||
|
||||
## Example output
|
||||
|
||||
```bash
|
||||
Trying to train neural network.
|
||||
Epoch: 1 Train loss: 4.42555 Test accuracy: 0.00%
|
||||
Epoch: 2 Train loss: 0.84677 Test accuracy: 33.33%
|
||||
Epoch: 3 Train loss: 2.54335 Test accuracy: 33.33%
|
||||
Epoch: 4 Train loss: 0.37806 Test accuracy: 33.33%
|
||||
Epoch: 5 Train loss: 0.36647 Test accuracy: 100.00%
|
||||
real_life_votes: [13, 22]
|
||||
neural_network_prediction_result: 0.0
|
||||
```
|
@ -12,7 +12,7 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.2.3", optional = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
half = { workspace = true }
|
||||
@ -26,6 +26,7 @@ rand_distr = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
yoke = { workspace = true }
|
||||
zip = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
|
@ -103,8 +103,10 @@ enum Command {
|
||||
|
||||
Quantize {
|
||||
/// The input file, in gguf format.
|
||||
in_file: std::path::PathBuf,
|
||||
in_file: Vec<std::path::PathBuf>,
|
||||
|
||||
/// The output file, in gguf format.
|
||||
#[arg(long)]
|
||||
out_file: std::path::PathBuf,
|
||||
|
||||
/// The quantization schema to apply.
|
||||
@ -150,8 +152,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
||||
}
|
||||
}
|
||||
Format::Safetensors => {
|
||||
let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? };
|
||||
let tensors = tensors.deserialize()?;
|
||||
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
|
||||
let mut tensors = tensors.tensors();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, view) in tensors.iter() {
|
||||
@ -218,15 +219,99 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize_safetensors(
|
||||
in_files: &[std::path::PathBuf],
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
) -> Result<()> {
|
||||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let mut tensors = std::collections::HashMap::new();
|
||||
for in_file in in_files.iter() {
|
||||
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
|
||||
tensors.extend(in_tensors)
|
||||
}
|
||||
println!("tensors: {}", tensors.len());
|
||||
|
||||
let quantize_fn = match q {
|
||||
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||
Quantization::F32 => QTensor::quantize::<f32>,
|
||||
};
|
||||
let block_size = match q {
|
||||
Quantization::Q4_0 => k_quants::QK4_0,
|
||||
Quantization::Q4_1 => k_quants::QK4_1,
|
||||
Quantization::Q5_0 => k_quants::QK5_0,
|
||||
Quantization::Q5_1 => k_quants::QK5_1,
|
||||
Quantization::Q8_0 => k_quants::QK8_0,
|
||||
Quantization::Q8_1 => k_quants::QK8_1,
|
||||
Quantization::Q2k
|
||||
| Quantization::Q3k
|
||||
| Quantization::Q4k
|
||||
| Quantization::Q5k
|
||||
| Quantization::Q6k
|
||||
| Quantization::Q8k => k_quants::QK_K,
|
||||
Quantization::F16 | Quantization::F32 => 1,
|
||||
};
|
||||
|
||||
let qtensors = tensors
|
||||
.into_par_iter()
|
||||
.map(|(name, tensor)| {
|
||||
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
|
||||
println!(" quantizing {name} {tensor:?} {should_quantize}");
|
||||
let tensor = if should_quantize {
|
||||
quantize_fn(&tensor)?
|
||||
} else {
|
||||
QTensor::quantize::<f32>(&tensor)?
|
||||
};
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let qtensors = qtensors
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v))
|
||||
.collect::<Vec<_>>();
|
||||
gguf_file::write(&mut out_file, &[], &qtensors)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize(
|
||||
in_file: std::path::PathBuf,
|
||||
in_files: &[std::path::PathBuf],
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
qmode: QuantizationMode,
|
||||
) -> Result<()> {
|
||||
if in_files.is_empty() {
|
||||
candle_core::bail!("no specified input files")
|
||||
}
|
||||
if let Some(extension) = out_file.extension() {
|
||||
if extension == "safetensors" {
|
||||
candle_core::bail!("the generated file cannot use the safetensors extension")
|
||||
}
|
||||
}
|
||||
if let Some(extension) = in_files[0].extension() {
|
||||
if extension == "safetensors" {
|
||||
return run_quantize_safetensors(in_files, out_file, q);
|
||||
}
|
||||
}
|
||||
|
||||
if in_files.len() != 1 {
|
||||
candle_core::bail!("only a single in-file can be used when quantizing gguf files")
|
||||
}
|
||||
|
||||
// Open the out file early so as to fail directly on missing directories etc.
|
||||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let mut in_ = std::fs::File::open(&in_file)?;
|
||||
let mut in_ = std::fs::File::open(&in_files[0])?;
|
||||
let content = gguf_file::Content::read(&mut in_)?;
|
||||
println!("tensors: {}", content.tensor_infos.len());
|
||||
|
||||
@ -252,7 +337,7 @@ fn run_quantize(
|
||||
.par_iter()
|
||||
.map(|(name, _)| {
|
||||
println!(" quantizing {name}");
|
||||
let mut in_file = std::fs::File::open(&in_file)?;
|
||||
let mut in_file = std::fs::File::open(&in_files[0])?;
|
||||
let tensor = content.tensor(&mut in_file, name)?;
|
||||
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
|
||||
Ok((name, tensor))
|
||||
@ -293,7 +378,7 @@ fn main() -> anyhow::Result<()> {
|
||||
out_file,
|
||||
quantization,
|
||||
mode,
|
||||
} => run_quantize(in_file, out_file, quantization, mode)?,
|
||||
} => run_quantize(&in_file, out_file, quantization, mode)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -111,4 +111,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()>;
|
||||
}
|
||||
|
@ -69,7 +69,8 @@ impl Tensor {
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::Gather(lhs, rhs, _)
|
||||
| Op::IndexSelect(lhs, rhs, _)
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
| Op::Matmul(lhs, rhs)
|
||||
| Op::SliceScatter0(lhs, rhs, _) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
||||
@ -90,6 +91,9 @@ impl Tensor {
|
||||
nodes
|
||||
}
|
||||
}
|
||||
Op::Unary(_node, UnaryOp::Ceil)
|
||||
| Op::Unary(_node, UnaryOp::Floor)
|
||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D(node)
|
||||
| Op::UpsampleNearest2D(node)
|
||||
@ -270,6 +274,15 @@ impl Tensor {
|
||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest2d",
|
||||
})?,
|
||||
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
|
||||
}
|
||||
Op::Gather(arg, indexes, dim) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||
@ -441,7 +454,18 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
||||
Op::Unary(_, UnaryOp::Floor) => {
|
||||
Err(Error::BackwardNotSupported { op: "floor" })?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Round) => {
|
||||
Err(Error::BackwardNotSupported { op: "round" })?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
||||
Op::Unary(_, UnaryOp::GeluErf) => {
|
||||
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Relu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
|
@ -25,6 +25,19 @@ impl ParamsConv1D {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum CudnnFwdAlgo {
|
||||
ImplicitGemm,
|
||||
ImplicitPrecompGemm,
|
||||
Gemm,
|
||||
Direct,
|
||||
Fft,
|
||||
FftTiling,
|
||||
Winograd,
|
||||
WinogradNonFused,
|
||||
Count,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConv2D {
|
||||
pub(crate) b_size: usize,
|
||||
@ -37,6 +50,7 @@ pub struct ParamsConv2D {
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
}
|
||||
|
||||
impl ParamsConv2D {
|
||||
@ -188,6 +202,7 @@ impl Tensor {
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv2d_single_group(kernel, ¶ms)
|
||||
|
763
candle-core/src/cpu/erf.rs
Normal file
763
candle-core/src/cpu/erf.rs
Normal file
@ -0,0 +1,763 @@
|
||||
#![allow(clippy::excessive_precision)]
|
||||
// Code taken from https://github.com/statrs-dev/statrs
|
||||
//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
|
||||
//! related functions
|
||||
|
||||
mod evaluate {
|
||||
//! Provides functions that don't have a numerical solution and must
|
||||
//! be solved computationally (e.g. evaluation of a polynomial)
|
||||
|
||||
/// evaluates a polynomial at `z` where `coeff` are the coeffecients
|
||||
/// to a polynomial of order `k` where `k` is the length of `coeff` and the
|
||||
/// coeffecient
|
||||
/// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
|
||||
/// `2z^2 - z + 3`
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// Returns 0 for a 0 length coefficient slice
|
||||
pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
|
||||
let n = coeff.len();
|
||||
if n == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut sum = *coeff.last().unwrap();
|
||||
for c in coeff[0..n - 1].iter().rev() {
|
||||
sum = *c + z * sum;
|
||||
}
|
||||
sum
|
||||
}
|
||||
}
|
||||
use std::f64;
|
||||
|
||||
/// `erf` calculates the error function at `x`.
|
||||
pub fn erf(x: f64) -> f64 {
|
||||
if x.is_nan() {
|
||||
f64::NAN
|
||||
} else if x >= 0.0 && x.is_infinite() {
|
||||
1.0
|
||||
} else if x <= 0.0 && x.is_infinite() {
|
||||
-1.0
|
||||
} else if x == 0. {
|
||||
0.0
|
||||
} else {
|
||||
erf_impl(x, false)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erf_inv` calculates the inverse error function
|
||||
/// at `x`.
|
||||
pub fn erf_inv(x: f64) -> f64 {
|
||||
if x == 0.0 {
|
||||
0.0
|
||||
} else if x >= 1.0 {
|
||||
f64::INFINITY
|
||||
} else if x <= -1.0 {
|
||||
f64::NEG_INFINITY
|
||||
} else if x < 0.0 {
|
||||
erf_inv_impl(-x, 1.0 + x, -1.0)
|
||||
} else {
|
||||
erf_inv_impl(x, 1.0 - x, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erfc` calculates the complementary error function
|
||||
/// at `x`.
|
||||
pub fn erfc(x: f64) -> f64 {
|
||||
if x.is_nan() {
|
||||
f64::NAN
|
||||
} else if x == f64::INFINITY {
|
||||
0.0
|
||||
} else if x == f64::NEG_INFINITY {
|
||||
2.0
|
||||
} else {
|
||||
erf_impl(x, true)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erfc_inv` calculates the complementary inverse
|
||||
/// error function at `x`.
|
||||
pub fn erfc_inv(x: f64) -> f64 {
|
||||
if x <= 0.0 {
|
||||
f64::INFINITY
|
||||
} else if x >= 2.0 {
|
||||
f64::NEG_INFINITY
|
||||
} else if x > 1.0 {
|
||||
erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
|
||||
} else {
|
||||
erf_inv_impl(1.0 - x, x, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
// **********************************************************
|
||||
// ********** Coefficients for erf_impl polynomial **********
|
||||
// **********************************************************
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_impl`
|
||||
/// in the interval [1e-10, 0.5].
|
||||
const ERF_IMPL_AN: &[f64] = &[
|
||||
0.00337916709551257388990745,
|
||||
-0.00073695653048167948530905,
|
||||
-0.374732337392919607868241,
|
||||
0.0817442448733587196071743,
|
||||
-0.0421089319936548595203468,
|
||||
0.0070165709512095756344528,
|
||||
-0.00495091255982435110337458,
|
||||
0.000871646599037922480317225,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_impl`
|
||||
/// in the interval [1e-10, 0.5]
|
||||
const ERF_IMPL_AD: &[f64] = &[
|
||||
1.0,
|
||||
-0.218088218087924645390535,
|
||||
0.412542972725442099083918,
|
||||
-0.0841891147873106755410271,
|
||||
0.0655338856400241519690695,
|
||||
-0.0120019604454941768171266,
|
||||
0.00408165558926174048329689,
|
||||
-0.000615900721557769691924509,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_IMPL_BN: &[f64] = &[
|
||||
-0.0361790390718262471360258,
|
||||
0.292251883444882683221149,
|
||||
0.281447041797604512774415,
|
||||
0.125610208862766947294894,
|
||||
0.0274135028268930549240776,
|
||||
0.00250839672168065762786937,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_IMPL_BD: &[f64] = &[
|
||||
1.0,
|
||||
1.8545005897903486499845,
|
||||
1.43575803037831418074962,
|
||||
0.582827658753036572454135,
|
||||
0.124810476932949746447682,
|
||||
0.0113724176546353285778481,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [0.75, 1.25].
|
||||
const ERF_IMPL_CN: &[f64] = &[
|
||||
-0.0397876892611136856954425,
|
||||
0.153165212467878293257683,
|
||||
0.191260295600936245503129,
|
||||
0.10276327061989304213645,
|
||||
0.029637090615738836726027,
|
||||
0.0046093486780275489468812,
|
||||
0.000307607820348680180548455,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [0.75, 1.25].
|
||||
const ERF_IMPL_CD: &[f64] = &[
|
||||
1.0,
|
||||
1.95520072987627704987886,
|
||||
1.64762317199384860109595,
|
||||
0.768238607022126250082483,
|
||||
0.209793185936509782784315,
|
||||
0.0319569316899913392596356,
|
||||
0.00213363160895785378615014,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [1.25, 2.25].
|
||||
const ERF_IMPL_DN: &[f64] = &[
|
||||
-0.0300838560557949717328341,
|
||||
0.0538578829844454508530552,
|
||||
0.0726211541651914182692959,
|
||||
0.0367628469888049348429018,
|
||||
0.00964629015572527529605267,
|
||||
0.00133453480075291076745275,
|
||||
0.778087599782504251917881e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [1.25, 2.25].
|
||||
const ERF_IMPL_DD: &[f64] = &[
|
||||
1.0,
|
||||
1.75967098147167528287343,
|
||||
1.32883571437961120556307,
|
||||
0.552528596508757581287907,
|
||||
0.133793056941332861912279,
|
||||
0.0179509645176280768640766,
|
||||
0.00104712440019937356634038,
|
||||
-0.106640381820357337177643e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [2.25, 3.5].
|
||||
const ERF_IMPL_EN: &[f64] = &[
|
||||
-0.0117907570137227847827732,
|
||||
0.014262132090538809896674,
|
||||
0.0202234435902960820020765,
|
||||
0.00930668299990432009042239,
|
||||
0.00213357802422065994322516,
|
||||
0.00025022987386460102395382,
|
||||
0.120534912219588189822126e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [2.25, 3.5].
|
||||
const ERF_IMPL_ED: &[f64] = &[
|
||||
1.0,
|
||||
1.50376225203620482047419,
|
||||
0.965397786204462896346934,
|
||||
0.339265230476796681555511,
|
||||
0.0689740649541569716897427,
|
||||
0.00771060262491768307365526,
|
||||
0.000371421101531069302990367,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [3.5, 5.25].
|
||||
const ERF_IMPL_FN: &[f64] = &[
|
||||
-0.00546954795538729307482955,
|
||||
0.00404190278731707110245394,
|
||||
0.0054963369553161170521356,
|
||||
0.00212616472603945399437862,
|
||||
0.000394984014495083900689956,
|
||||
0.365565477064442377259271e-4,
|
||||
0.135485897109932323253786e-5,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [3.5, 5.25].
|
||||
const ERF_IMPL_FD: &[f64] = &[
|
||||
1.0,
|
||||
1.21019697773630784832251,
|
||||
0.620914668221143886601045,
|
||||
0.173038430661142762569515,
|
||||
0.0276550813773432047594539,
|
||||
0.00240625974424309709745382,
|
||||
0.891811817251336577241006e-4,
|
||||
-0.465528836283382684461025e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [5.25, 8].
|
||||
const ERF_IMPL_GN: &[f64] = &[
|
||||
-0.00270722535905778347999196,
|
||||
0.0013187563425029400461378,
|
||||
0.00119925933261002333923989,
|
||||
0.00027849619811344664248235,
|
||||
0.267822988218331849989363e-4,
|
||||
0.923043672315028197865066e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [5.25, 8].
|
||||
const ERF_IMPL_GD: &[f64] = &[
|
||||
1.0,
|
||||
0.814632808543141591118279,
|
||||
0.268901665856299542168425,
|
||||
0.0449877216103041118694989,
|
||||
0.00381759663320248459168994,
|
||||
0.000131571897888596914350697,
|
||||
0.404815359675764138445257e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [8, 11.5].
|
||||
const ERF_IMPL_HN: &[f64] = &[
|
||||
-0.00109946720691742196814323,
|
||||
0.000406425442750422675169153,
|
||||
0.000274499489416900707787024,
|
||||
0.465293770646659383436343e-4,
|
||||
0.320955425395767463401993e-5,
|
||||
0.778286018145020892261936e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [8, 11.5].
|
||||
const ERF_IMPL_HD: &[f64] = &[
|
||||
1.0,
|
||||
0.588173710611846046373373,
|
||||
0.139363331289409746077541,
|
||||
0.0166329340417083678763028,
|
||||
0.00100023921310234908642639,
|
||||
0.24254837521587225125068e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [11.5, 17].
|
||||
const ERF_IMPL_IN: &[f64] = &[
|
||||
-0.00056907993601094962855594,
|
||||
0.000169498540373762264416984,
|
||||
0.518472354581100890120501e-4,
|
||||
0.382819312231928859704678e-5,
|
||||
0.824989931281894431781794e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [11.5, 17].
|
||||
const ERF_IMPL_ID: &[f64] = &[
|
||||
1.0,
|
||||
0.339637250051139347430323,
|
||||
0.043472647870310663055044,
|
||||
0.00248549335224637114641629,
|
||||
0.535633305337152900549536e-4,
|
||||
-0.117490944405459578783846e-12,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [17, 24].
|
||||
const ERF_IMPL_JN: &[f64] = &[
|
||||
-0.000241313599483991337479091,
|
||||
0.574224975202501512365975e-4,
|
||||
0.115998962927383778460557e-4,
|
||||
0.581762134402593739370875e-6,
|
||||
0.853971555085673614607418e-8,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [17, 24].
|
||||
const ERF_IMPL_JD: &[f64] = &[
|
||||
1.0,
|
||||
0.233044138299687841018015,
|
||||
0.0204186940546440312625597,
|
||||
0.000797185647564398289151125,
|
||||
0.117019281670172327758019e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [24, 38].
|
||||
const ERF_IMPL_KN: &[f64] = &[
|
||||
-0.000146674699277760365803642,
|
||||
0.162666552112280519955647e-4,
|
||||
0.269116248509165239294897e-5,
|
||||
0.979584479468091935086972e-7,
|
||||
0.101994647625723465722285e-8,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [24, 38].
|
||||
const ERF_IMPL_KD: &[f64] = &[
|
||||
1.0,
|
||||
0.165907812944847226546036,
|
||||
0.0103361716191505884359634,
|
||||
0.000286593026373868366935721,
|
||||
0.298401570840900340874568e-5,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [38, 60].
|
||||
const ERF_IMPL_LN: &[f64] = &[
|
||||
-0.583905797629771786720406e-4,
|
||||
0.412510325105496173512992e-5,
|
||||
0.431790922420250949096906e-6,
|
||||
0.993365155590013193345569e-8,
|
||||
0.653480510020104699270084e-10,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [38, 60].
|
||||
const ERF_IMPL_LD: &[f64] = &[
|
||||
1.0,
|
||||
0.105077086072039915406159,
|
||||
0.00414278428675475620830226,
|
||||
0.726338754644523769144108e-4,
|
||||
0.477818471047398785369849e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [60, 85].
|
||||
const ERF_IMPL_MN: &[f64] = &[
|
||||
-0.196457797609229579459841e-4,
|
||||
0.157243887666800692441195e-5,
|
||||
0.543902511192700878690335e-7,
|
||||
0.317472492369117710852685e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [60, 85].
|
||||
const ERF_IMPL_MD: &[f64] = &[
|
||||
1.0,
|
||||
0.052803989240957632204885,
|
||||
0.000926876069151753290378112,
|
||||
0.541011723226630257077328e-5,
|
||||
0.535093845803642394908747e-15,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [85, 110].
|
||||
const ERF_IMPL_NN: &[f64] = &[
|
||||
-0.789224703978722689089794e-5,
|
||||
0.622088451660986955124162e-6,
|
||||
0.145728445676882396797184e-7,
|
||||
0.603715505542715364529243e-10,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [85, 110].
|
||||
const ERF_IMPL_ND: &[f64] = &[
|
||||
1.0,
|
||||
0.0375328846356293715248719,
|
||||
0.000467919535974625308126054,
|
||||
0.193847039275845656900547e-5,
|
||||
];
|
||||
|
||||
// **********************************************************
|
||||
// ********** Coefficients for erf_inv_impl polynomial ******
|
||||
// **********************************************************
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0, 0.5].
|
||||
const ERF_INV_IMPL_AN: &[f64] = &[
|
||||
-0.000508781949658280665617,
|
||||
-0.00836874819741736770379,
|
||||
0.0334806625409744615033,
|
||||
-0.0126926147662974029034,
|
||||
-0.0365637971411762664006,
|
||||
0.0219878681111168899165,
|
||||
0.00822687874676915743155,
|
||||
-0.00538772965071242932965,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0, 0.5].
|
||||
const ERF_INV_IMPL_AD: &[f64] = &[
|
||||
1.0,
|
||||
-0.970005043303290640362,
|
||||
-1.56574558234175846809,
|
||||
1.56221558398423026363,
|
||||
0.662328840472002992063,
|
||||
-0.71228902341542847553,
|
||||
-0.0527396382340099713954,
|
||||
0.0795283687341571680018,
|
||||
-0.00233393759374190016776,
|
||||
0.000886216390456424707504,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_INV_IMPL_BN: &[f64] = &[
|
||||
-0.202433508355938759655,
|
||||
0.105264680699391713268,
|
||||
8.37050328343119927838,
|
||||
17.6447298408374015486,
|
||||
-18.8510648058714251895,
|
||||
-44.6382324441786960818,
|
||||
17.445385985570866523,
|
||||
21.1294655448340526258,
|
||||
-3.67192254707729348546,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_INV_IMPL_BD: &[f64] = &[
|
||||
1.0,
|
||||
6.24264124854247537712,
|
||||
3.9713437953343869095,
|
||||
-28.6608180499800029974,
|
||||
-20.1432634680485188801,
|
||||
48.5609213108739935468,
|
||||
10.8268667355460159008,
|
||||
-22.6436933413139721736,
|
||||
1.72114765761200282724,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x less than 3.
|
||||
const ERF_INV_IMPL_CN: &[f64] = &[
|
||||
-0.131102781679951906451,
|
||||
-0.163794047193317060787,
|
||||
0.117030156341995252019,
|
||||
0.387079738972604337464,
|
||||
0.337785538912035898924,
|
||||
0.142869534408157156766,
|
||||
0.0290157910005329060432,
|
||||
0.00214558995388805277169,
|
||||
-0.679465575181126350155e-6,
|
||||
0.285225331782217055858e-7,
|
||||
-0.681149956853776992068e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x less than 3.
|
||||
const ERF_INV_IMPL_CD: &[f64] = &[
|
||||
1.0,
|
||||
3.46625407242567245975,
|
||||
5.38168345707006855425,
|
||||
4.77846592945843778382,
|
||||
2.59301921623620271374,
|
||||
0.848854343457902036425,
|
||||
0.152264338295331783612,
|
||||
0.01105924229346489121,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 3 and 6.
|
||||
const ERF_INV_IMPL_DN: &[f64] = &[
|
||||
-0.0350353787183177984712,
|
||||
-0.00222426529213447927281,
|
||||
0.0185573306514231072324,
|
||||
0.00950804701325919603619,
|
||||
0.00187123492819559223345,
|
||||
0.000157544617424960554631,
|
||||
0.460469890584317994083e-5,
|
||||
-0.230404776911882601748e-9,
|
||||
0.266339227425782031962e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 3 and 6.
|
||||
const ERF_INV_IMPL_DD: &[f64] = &[
|
||||
1.0,
|
||||
1.3653349817554063097,
|
||||
0.762059164553623404043,
|
||||
0.220091105764131249824,
|
||||
0.0341589143670947727934,
|
||||
0.00263861676657015992959,
|
||||
0.764675292302794483503e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 6 and 18.
|
||||
const ERF_INV_IMPL_EN: &[f64] = &[
|
||||
-0.0167431005076633737133,
|
||||
-0.00112951438745580278863,
|
||||
0.00105628862152492910091,
|
||||
0.000209386317487588078668,
|
||||
0.149624783758342370182e-4,
|
||||
0.449696789927706453732e-6,
|
||||
0.462596163522878599135e-8,
|
||||
-0.281128735628831791805e-13,
|
||||
0.99055709973310326855e-16,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 6 and 18.
|
||||
const ERF_INV_IMPL_ED: &[f64] = &[
|
||||
1.0,
|
||||
0.591429344886417493481,
|
||||
0.138151865749083321638,
|
||||
0.0160746087093676504695,
|
||||
0.000964011807005165528527,
|
||||
0.275335474764726041141e-4,
|
||||
0.282243172016108031869e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 18 and 44.
|
||||
const ERF_INV_IMPL_FN: &[f64] = &[
|
||||
-0.0024978212791898131227,
|
||||
-0.779190719229053954292e-5,
|
||||
0.254723037413027451751e-4,
|
||||
0.162397777342510920873e-5,
|
||||
0.396341011304801168516e-7,
|
||||
0.411632831190944208473e-9,
|
||||
0.145596286718675035587e-11,
|
||||
-0.116765012397184275695e-17,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 18 and 44.
|
||||
const ERF_INV_IMPL_FD: &[f64] = &[
|
||||
1.0,
|
||||
0.207123112214422517181,
|
||||
0.0169410838120975906478,
|
||||
0.000690538265622684595676,
|
||||
0.145007359818232637924e-4,
|
||||
0.144437756628144157666e-6,
|
||||
0.509761276599778486139e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x greater than 44.
|
||||
const ERF_INV_IMPL_GN: &[f64] = &[
|
||||
-0.000539042911019078575891,
|
||||
-0.28398759004727721098e-6,
|
||||
0.899465114892291446442e-6,
|
||||
0.229345859265920864296e-7,
|
||||
0.225561444863500149219e-9,
|
||||
0.947846627503022684216e-12,
|
||||
0.135880130108924861008e-14,
|
||||
-0.348890393399948882918e-21,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x greater than 44.
|
||||
const ERF_INV_IMPL_GD: &[f64] = &[
|
||||
1.0,
|
||||
0.0845746234001899436914,
|
||||
0.00282092984726264681981,
|
||||
0.468292921940894236786e-4,
|
||||
0.399968812193862100054e-6,
|
||||
0.161809290887904476097e-8,
|
||||
0.231558608310259605225e-11,
|
||||
];
|
||||
|
||||
/// `erf_impl` computes the error function at `z`.
|
||||
/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
|
||||
fn erf_impl(z: f64, inv: bool) -> f64 {
|
||||
if z < 0.0 {
|
||||
if !inv {
|
||||
return -erf_impl(-z, false);
|
||||
}
|
||||
if z < -0.5 {
|
||||
return 2.0 - erf_impl(-z, true);
|
||||
}
|
||||
return 1.0 + erf_impl(-z, false);
|
||||
}
|
||||
|
||||
let result = if z < 0.5 {
|
||||
if z < 1e-10 {
|
||||
z * 1.125 + z * 0.003379167095512573896158903121545171688
|
||||
} else {
|
||||
z * 1.125
|
||||
+ z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
|
||||
}
|
||||
} else if z < 110.0 {
|
||||
let (r, b) = if z < 0.75 {
|
||||
(
|
||||
evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
|
||||
/ evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
|
||||
0.3440242112,
|
||||
)
|
||||
} else if z < 1.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
|
||||
/ evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
|
||||
0.419990927,
|
||||
)
|
||||
} else if z < 2.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
|
||||
/ evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
|
||||
0.4898625016,
|
||||
)
|
||||
} else if z < 3.5 {
|
||||
(
|
||||
evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
|
||||
/ evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
|
||||
0.5317370892,
|
||||
)
|
||||
} else if z < 5.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
|
||||
/ evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
|
||||
0.5489973426,
|
||||
)
|
||||
} else if z < 8.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
|
||||
/ evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
|
||||
0.5571740866,
|
||||
)
|
||||
} else if z < 11.5 {
|
||||
(
|
||||
evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
|
||||
/ evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
|
||||
0.5609807968,
|
||||
)
|
||||
} else if z < 17.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
|
||||
/ evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
|
||||
0.5626493692,
|
||||
)
|
||||
} else if z < 24.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
|
||||
/ evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
|
||||
0.5634598136,
|
||||
)
|
||||
} else if z < 38.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
|
||||
/ evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
|
||||
0.5638477802,
|
||||
)
|
||||
} else if z < 60.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
|
||||
/ evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
|
||||
0.5640528202,
|
||||
)
|
||||
} else if z < 85.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
|
||||
/ evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
|
||||
0.5641309023,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
|
||||
/ evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
|
||||
0.5641584396,
|
||||
)
|
||||
};
|
||||
let g = (-z * z).exp() / z;
|
||||
g * b + g * r
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
if inv && z >= 0.5 {
|
||||
result
|
||||
} else if z >= 0.5 || inv {
|
||||
1.0 - result
|
||||
} else {
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
// `erf_inv_impl` computes the inverse error function where
|
||||
// `p`,`q`, and `s` are the first, second, and third intermediate
|
||||
// parameters respectively
|
||||
fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
|
||||
let result = if p <= 0.5 {
|
||||
let y = 0.0891314744949340820313;
|
||||
let g = p * (p + 10.0);
|
||||
let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
|
||||
g * y + g * r
|
||||
} else if q >= 0.25 {
|
||||
let y = 2.249481201171875;
|
||||
let g = (-2.0 * q.ln()).sqrt();
|
||||
let xs = q - 0.25;
|
||||
let r =
|
||||
evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
|
||||
g / (y + r)
|
||||
} else {
|
||||
let x = (-q.ln()).sqrt();
|
||||
if x < 3.0 {
|
||||
let y = 0.807220458984375;
|
||||
let xs = x - 1.125;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_CD);
|
||||
y * x + r * x
|
||||
} else if x < 6.0 {
|
||||
let y = 0.93995571136474609375;
|
||||
let xs = x - 3.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_DD);
|
||||
y * x + r * x
|
||||
} else if x < 18.0 {
|
||||
let y = 0.98362827301025390625;
|
||||
let xs = x - 6.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_ED);
|
||||
y * x + r * x
|
||||
} else if x < 44.0 {
|
||||
let y = 0.99714565277099609375;
|
||||
let xs = x - 18.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_FD);
|
||||
y * x + r * x
|
||||
} else {
|
||||
let y = 0.99941349029541015625;
|
||||
let xs = x - 44.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_GD);
|
||||
y * x + r * x
|
||||
}
|
||||
};
|
||||
s * result
|
||||
}
|
@ -1,3 +1,4 @@
|
||||
pub mod erf;
|
||||
pub mod kernels;
|
||||
|
||||
trait Cpu<const ARR: usize> {
|
||||
|
@ -2603,6 +2603,10 @@ impl BackendDevice for CpuDevice {
|
||||
Ok(Self)
|
||||
}
|
||||
|
||||
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||
crate::bail!("cannot seed the CPU rng with set_seed")
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
|
||||
use rand::prelude::*;
|
||||
|
||||
|
@ -223,6 +223,12 @@ impl BackendDevice for CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
let mut curand = self.curand.lock().unwrap();
|
||||
curand.0.set_seed(seed).w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Cuda {
|
||||
gpu_id: self.device.ordinal(),
|
||||
@ -884,8 +890,6 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
};
|
||||
let ids_shape = ids_l.shape();
|
||||
let ids_dims = ids_shape.dims();
|
||||
let ids_el = ids_shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(ids_el as u32);
|
||||
let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
@ -893,19 +897,23 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
};
|
||||
let left_size: usize = src_l.dims()[..self.2].iter().product();
|
||||
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
|
||||
let dim_size = src_l.dims()[self.2];
|
||||
let src_dim_size = src_l.dims()[self.2];
|
||||
let ids_dim_size = ids_shape.elem_count();
|
||||
let dst_el = ids_shape.elem_count() * left_size * right_size;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(ids_el * left_size * right_size) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (
|
||||
ids_el,
|
||||
dst_el,
|
||||
ids_dims.len(),
|
||||
&ds,
|
||||
ids,
|
||||
&src,
|
||||
&out,
|
||||
left_size,
|
||||
dim_size,
|
||||
src_dim_size,
|
||||
ids_dim_size,
|
||||
right_size,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
|
@ -34,6 +34,9 @@ pub(crate) fn launch_conv2d<
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
dev: &crate::cuda_backend::CudaDevice,
|
||||
) -> crate::Result<()> {
|
||||
use crate::conv::CudnnFwdAlgo as CandleAlgo;
|
||||
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
|
||||
|
||||
let device_id = dev.id();
|
||||
let cudnn = CUDNN.with(|cudnn| {
|
||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||
@ -90,7 +93,20 @@ pub(crate) fn launch_conv2d<
|
||||
w: &w,
|
||||
y: &y,
|
||||
};
|
||||
let alg = conv2d.pick_algorithm()?;
|
||||
let alg = match params.cudnn_fwd_algo {
|
||||
None => conv2d.pick_algorithm()?,
|
||||
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
Some(CandleAlgo::ImplicitPrecompGemm) => {
|
||||
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
|
||||
}
|
||||
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||
};
|
||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
||||
unsafe {
|
||||
|
@ -67,6 +67,20 @@ impl DType {
|
||||
Self::F64 => 8,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_int(&self) -> bool {
|
||||
match self {
|
||||
Self::U8 | Self::U32 | Self::I64 => true,
|
||||
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_float(&self) -> bool {
|
||||
match self {
|
||||
Self::U8 | Self::U32 | Self::I64 => false,
|
||||
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WithDType:
|
||||
|
@ -167,6 +167,10 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
fail!()
|
||||
}
|
||||
|
@ -58,8 +58,13 @@ pub enum UnaryOp {
|
||||
Sqr,
|
||||
Sqrt,
|
||||
Gelu,
|
||||
GeluErf,
|
||||
Erf,
|
||||
Relu,
|
||||
Tanh,
|
||||
Floor,
|
||||
Ceil,
|
||||
Round,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -131,6 +136,7 @@ pub enum Op {
|
||||
Copy(Tensor),
|
||||
Broadcast(Tensor),
|
||||
Narrow(Tensor, usize, usize, usize),
|
||||
SliceScatter0(Tensor, Tensor, usize),
|
||||
Reshape(Tensor),
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
@ -325,8 +331,13 @@ pub(crate) struct Recip;
|
||||
pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
pub(crate) struct Gelu;
|
||||
pub(crate) struct GeluErf;
|
||||
pub(crate) struct Erf;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Tanh;
|
||||
pub(crate) struct Floor;
|
||||
pub(crate) struct Ceil;
|
||||
pub(crate) struct Round;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||
@ -621,6 +632,176 @@ impl UnaryOpT for Gelu {
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Erf {
|
||||
const NAME: &'static str = "erf";
|
||||
const KERNEL: &'static str = "uerf";
|
||||
const V: Self = Erf;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
Self::f64(v as f64) as f32
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
crate::cpu::erf::erf(v)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Ceil {
|
||||
const NAME: &'static str = "ceil";
|
||||
const KERNEL: &'static str = "uceil";
|
||||
const V: Self = Ceil;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Floor {
|
||||
const NAME: &'static str = "floor";
|
||||
const KERNEL: &'static str = "ufloor";
|
||||
const V: Self = Floor;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Round {
|
||||
const NAME: &'static str = "round";
|
||||
const KERNEL: &'static str = "uround";
|
||||
const V: Self = Round;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for GeluErf {
|
||||
const NAME: &'static str = "gelu_erf";
|
||||
const KERNEL: &'static str = "ugelu_erf";
|
||||
const V: Self = GeluErf;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
Self::f64(v as f64) as f32
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
(crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Relu {
|
||||
const NAME: &'static str = "relu";
|
||||
const KERNEL: &'static str = "urelu";
|
||||
|
@ -638,3 +638,35 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
||||
Ok(hsum_float_8(acc) + summs)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % qk != 0 {
|
||||
crate::bail!("vec_dot_q8k_8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
let x_qs = xs.qs.as_ptr();
|
||||
let y_qs = ys.qs.as_ptr();
|
||||
for j in (0..QK_K).step_by(32) {
|
||||
let xs = _mm256_loadu_si256(x_qs.add(j) as *const __m256i);
|
||||
let ys = _mm256_loadu_si256(y_qs.add(j) as *const __m256i);
|
||||
|
||||
let xs0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 0));
|
||||
let ys0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 0));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs0, ys0));
|
||||
|
||||
let xs1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 1));
|
||||
let ys1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 1));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs1, ys1));
|
||||
}
|
||||
let d = _mm256_set1_ps(xs.d * ys.d);
|
||||
acc = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
@ -135,7 +135,13 @@ pub fn qtensor_from_ggml(
|
||||
dims: Vec<usize>,
|
||||
) -> Result<super::QTensor> {
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||
let blck_size = ggml_dtype.blck_size();
|
||||
if tensor_elems % blck_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
|
||||
|
||||
match ggml_dtype {
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||
|
@ -59,8 +59,13 @@ impl TensorInfo {
|
||||
tensor_data_offset: u64,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_elems = self.shape.elem_count();
|
||||
let size_in_bytes =
|
||||
tensor_elems * self.ggml_dtype.type_size() / self.ggml_dtype.blck_size();
|
||||
let blck_size = self.ggml_dtype.blck_size();
|
||||
if tensor_elems % blck_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
|
@ -34,6 +34,9 @@ pub trait GgmlType: Sized + Clone + Send + Sync {
|
||||
/// Dot product used as a building block for quantized mat-mul.
|
||||
/// n is the number of elements to be considered.
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
|
||||
|
||||
/// Generic implementation of the dot product without simd optimizations.
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@ -225,6 +228,13 @@ impl GgmlType for BlockQ4_0 {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q4_0_q8_0(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q4_0_q8_0(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
@ -255,6 +265,10 @@ impl GgmlType for BlockQ4_1 {
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
// ggml_vec_dot_q4_1_q8_1
|
||||
let qk = QK8_1;
|
||||
if n % qk != 0 {
|
||||
@ -354,7 +368,10 @@ impl GgmlType for BlockQ5_0 {
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q5_0_q8_0: {n}, nb is not divisible by 2")
|
||||
}
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(_n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
// Generic implementation.
|
||||
let mut sumf = 0f32;
|
||||
|
||||
@ -445,6 +462,10 @@ impl GgmlType for BlockQ5_1 {
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = Self::BLCK_SIZE;
|
||||
if n % Self::BLCK_SIZE != 0 {
|
||||
crate::bail!("vec_dot_q5_1_q8_1: {n} is not divisible by {qk}")
|
||||
@ -606,6 +627,13 @@ impl GgmlType for BlockQ8_0 {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q8_0_q8_0(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q8_0_q8_0(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
@ -631,7 +659,11 @@ impl GgmlType for BlockQ8_1 {
|
||||
const BLCK_SIZE: usize = QK8_1;
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
unimplemented!("no support for vec-dot on Q8_1")
|
||||
}
|
||||
|
||||
@ -681,6 +713,13 @@ impl GgmlType for BlockQ2K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q2k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q2k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -701,18 +740,17 @@ impl GgmlType for BlockQ2K {
|
||||
|
||||
let mut isum = 0;
|
||||
let mut is = 0;
|
||||
let mut d;
|
||||
for _ in 0..(QK_K / 128) {
|
||||
let mut shift = 0;
|
||||
for _ in 0..4 {
|
||||
d = (sc[is] & 0xF) as i32;
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = 0;
|
||||
for l in 0..16 {
|
||||
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
|
||||
}
|
||||
isum += d * isuml;
|
||||
d = (sc[is] & 0xF) as i32;
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
isuml = 0;
|
||||
for l in 16..32 {
|
||||
@ -851,6 +889,10 @@ impl GgmlType for BlockQ3K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q3k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -1077,7 +1119,6 @@ impl GgmlType for BlockQ3K {
|
||||
let d_all = block.d.to_f32();
|
||||
let mut m = 1;
|
||||
let mut is = 0;
|
||||
let mut dl;
|
||||
|
||||
// Dequantize both 128 long blocks
|
||||
// 32 qs values per 128 long block
|
||||
@ -1088,7 +1129,7 @@ impl GgmlType for BlockQ3K {
|
||||
for (scale_index, scale_scoped_y) in
|
||||
shift_scoped_y.chunks_exact_mut(16).enumerate()
|
||||
{
|
||||
dl = d_all * (scales[is] as f32 - 32.0);
|
||||
let dl = d_all * (scales[is] as f32 - 32.0);
|
||||
for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() {
|
||||
let new_y = dl
|
||||
* (((qs[i + 16 * scale_index] >> shift) & 3) as i8
|
||||
@ -1126,6 +1167,13 @@ impl GgmlType for BlockQ4K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q4k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q4k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -1312,6 +1360,10 @@ impl GgmlType for BlockQ5K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q5k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -1529,6 +1581,13 @@ impl GgmlType for BlockQ6K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q6k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q6k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -1697,8 +1756,38 @@ impl GgmlType for BlockQ8K {
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
unreachable!()
|
||||
#[allow(unreachable_code)]
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
#[cfg(target_feature = "avx")]
|
||||
return super::avx::vec_dot_q8k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q8k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q8k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
// Generic implementation.
|
||||
let mut sumf = 0f32;
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
let sum_i = xs
|
||||
.qs
|
||||
.iter()
|
||||
.zip(ys.qs.iter())
|
||||
.map(|(&x, &y)| x as i32 * y as i32)
|
||||
.sum::<i32>();
|
||||
sumf += sum_i as f32 * xs.d * ys.d
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
@ -1804,6 +1893,10 @@ impl GgmlType for f32 {
|
||||
type VecDotType = f32;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if xs.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", xs.len())
|
||||
}
|
||||
@ -1838,6 +1931,10 @@ impl GgmlType for f16 {
|
||||
type VecDotType = f16;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if xs.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", xs.len())
|
||||
}
|
||||
|
@ -7,6 +7,8 @@ pub mod gguf_file;
|
||||
pub mod k_quants;
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub mod simd128;
|
||||
pub mod utils;
|
||||
|
||||
pub use k_quants::GgmlType;
|
||||
@ -229,20 +231,40 @@ impl QTensor {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct QMatMul(std::sync::Arc<QTensor>);
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum QMatMul {
|
||||
QTensor(std::sync::Arc<QTensor>),
|
||||
Tensor(Tensor),
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static DEQUANTIZE_ALL: bool = {
|
||||
match std::env::var("CANDLE_DEQUANTIZE_ALL") {
|
||||
Ok(s) => {
|
||||
!s.is_empty() && s != "0"
|
||||
},
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
|
||||
Self(qtensor)
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||
let dequantize = match qtensor.dtype() {
|
||||
GgmlDType::F32 | GgmlDType::F16 => true,
|
||||
_ => DEQUANTIZE_ALL.with(|b| *b),
|
||||
};
|
||||
let t = if dequantize {
|
||||
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||
Self::Tensor(tensor)
|
||||
} else {
|
||||
Self::QTensor(qtensor)
|
||||
};
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Self {
|
||||
Self(std::sync::Arc::new(qtensor))
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
|
||||
&self.0
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||
Self::from_arc(std::sync::Arc::new(qtensor))
|
||||
}
|
||||
}
|
||||
|
||||
@ -287,6 +309,16 @@ impl crate::CustomOp1 for QTensor {
|
||||
|
||||
impl QMatMul {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply_op1_no_bwd(self.0.as_ref())
|
||||
match self {
|
||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||
Self::Tensor(w) => {
|
||||
let w = match *xs.dims() {
|
||||
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||
_ => w.t()?,
|
||||
};
|
||||
xs.matmul(&w)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -148,6 +148,35 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
let mut sumf = 0f32;
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
unsafe {
|
||||
let mut sum_i = vdupq_n_s32(0);
|
||||
let scale = xs.d * ys.d;
|
||||
let xs = xs.qs.as_ptr();
|
||||
let ys = ys.qs.as_ptr();
|
||||
for i in (0..QK_K).step_by(16) {
|
||||
let xs = vld1q_s8(xs.add(i));
|
||||
let ys = vld1q_s8(ys.add(i));
|
||||
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
|
||||
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
|
||||
|
||||
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
|
||||
sum_i = vaddq_s32(sum_i, xy)
|
||||
}
|
||||
sumf += vaddvq_s32(sum_i) as f32 * scale
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
|
427
candle-core/src/quantized/simd128.rs
Normal file
427
candle-core/src/quantized/simd128.rs
Normal file
@ -0,0 +1,427 @@
|
||||
use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
|
||||
use crate::Result;
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use half::f16;
|
||||
|
||||
use core::arch::wasm32::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let x1234 = v128_load(x.qs.as_ptr() as *const v128);
|
||||
let x12 = v128_and(x1234, u8x16_splat(0x0F));
|
||||
let x12 = i8x16_sub(x12, i8x16_splat(8));
|
||||
let x34 = u8x16_shr(x1234, 4);
|
||||
let x34 = i8x16_sub(x34, i8x16_splat(8));
|
||||
|
||||
let x1 = i16x8_extend_low_i8x16(x12);
|
||||
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
|
||||
let sum_xy = i32x4_dot_i16x8(x1, y1);
|
||||
|
||||
let x2 = i16x8_extend_high_i8x16(x12);
|
||||
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
|
||||
|
||||
let x3 = i16x8_extend_low_i8x16(x34);
|
||||
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
|
||||
|
||||
let x4 = i16x8_extend_high_i8x16(x34);
|
||||
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
|
||||
|
||||
let sum_xy = f32x4_convert_i32x4(sum_xy);
|
||||
|
||||
// f32x4_relaxed_madd is nightly only.
|
||||
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let scaled = f32x4_mul(sum_xy, d);
|
||||
acc = f32x4_add(acc, scaled)
|
||||
}
|
||||
let res = f32x4_extract_lane::<0>(acc)
|
||||
+ f32x4_extract_lane::<1>(acc)
|
||||
+ f32x4_extract_lane::<2>(acc)
|
||||
+ f32x4_extract_lane::<3>(acc);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let x1 = i16x8_load_extend_i8x8(x.qs.as_ptr());
|
||||
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
|
||||
let sum_xy = i32x4_dot_i16x8(x1, y1);
|
||||
|
||||
let x2 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(8));
|
||||
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
|
||||
|
||||
let x3 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(16));
|
||||
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
|
||||
|
||||
let x4 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(24));
|
||||
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
|
||||
|
||||
let sum_xy = f32x4_convert_i32x4(sum_xy);
|
||||
|
||||
// f32x4_relaxed_madd is nightly only.
|
||||
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let scaled = f32x4_mul(sum_xy, d);
|
||||
acc = f32x4_add(acc, scaled)
|
||||
}
|
||||
let res = f32x4_extract_lane::<0>(acc)
|
||||
+ f32x4_extract_lane::<1>(acc)
|
||||
+ f32x4_extract_lane::<2>(acc)
|
||||
+ f32x4_extract_lane::<3>(acc);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
unsafe {
|
||||
let mut sumf = f32x4_splat(0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let mut q2: &[_] = &x.qs;
|
||||
let mut q8: &[_] = &y.qs;
|
||||
let sc = &x.scales;
|
||||
|
||||
let mut summs = i32x4_splat(0);
|
||||
for i in (0..(QK_K / 16)).step_by(4) {
|
||||
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(i));
|
||||
let scales = i32x4_shr(
|
||||
i32x4(
|
||||
sc[i] as i32,
|
||||
sc[i + 1] as i32,
|
||||
sc[i + 2] as i32,
|
||||
sc[i + 3] as i32,
|
||||
),
|
||||
4,
|
||||
);
|
||||
summs = i32x4_add(summs, i32x4_mul(bsums, scales))
|
||||
}
|
||||
let summs = f32x4_convert_i32x4(summs);
|
||||
|
||||
let dall = y.d * x.d.to_f32();
|
||||
let dmin = y.d * x.dmin.to_f32();
|
||||
|
||||
let mut isum = i32x4_splat(0);
|
||||
let mut is = 0;
|
||||
for _ in 0..(QK_K / 128) {
|
||||
let mut shift = 0;
|
||||
for _ in 0..4 {
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = i16x8_splat(0);
|
||||
for l in (0..16).step_by(8) {
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
|
||||
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
|
||||
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
|
||||
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
|
||||
}
|
||||
let dd = i32x4_splat(d);
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = i16x8_splat(0);
|
||||
for l in (16..32).step_by(8) {
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
|
||||
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
|
||||
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
|
||||
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
|
||||
}
|
||||
let dd = i32x4_splat(d);
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
|
||||
shift += 2;
|
||||
// adjust the indexing
|
||||
q8 = &q8[32..];
|
||||
}
|
||||
// adjust the indexing
|
||||
q2 = &q2[32..];
|
||||
}
|
||||
let isum = f32x4_convert_i32x4(isum);
|
||||
sumf = f32x4_add(
|
||||
sumf,
|
||||
f32x4_sub(
|
||||
f32x4_mul(isum, f32x4_splat(dall)),
|
||||
f32x4_mul(summs, f32x4_splat(dmin)),
|
||||
),
|
||||
);
|
||||
}
|
||||
let sumf = f32x4_extract_lane::<0>(sumf)
|
||||
+ f32x4_extract_lane::<1>(sumf)
|
||||
+ f32x4_extract_lane::<2>(sumf)
|
||||
+ f32x4_extract_lane::<3>(sumf);
|
||||
Ok(sumf)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
let mut utmp: [u32; 4] = [0; 4];
|
||||
let mut scales: [u8; 8] = [0; 8];
|
||||
let mut mins: [u8; 8] = [0; 8];
|
||||
|
||||
let mut aux8: [u8; QK_K] = [0; QK_K];
|
||||
let mut sums = f32x4_splat(0f32);
|
||||
unsafe {
|
||||
for (y, x) in ys.iter().zip(xs.iter()) {
|
||||
let q4 = &x.qs;
|
||||
let q8 = &y.qs;
|
||||
|
||||
for j in 0..QK_K / 64 {
|
||||
let q4_1 = v128_load(q4.as_ptr().add(32 * j) as *const v128);
|
||||
let q4_2 = v128_load(q4.as_ptr().add(32 * j + 16) as *const v128);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j) as *mut v128,
|
||||
v128_and(q4_1, u8x16_splat(0x0F)),
|
||||
);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j + 16) as *mut v128,
|
||||
v128_and(q4_2, u8x16_splat(0x0F)),
|
||||
);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j + 32) as *mut v128,
|
||||
u8x16_shr(q4_1, 4),
|
||||
);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j + 48) as *mut v128,
|
||||
u8x16_shr(q4_2, 4),
|
||||
);
|
||||
}
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
||||
let uaux = utmp[1] & KMASK1;
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
//extract scales and mins
|
||||
LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
|
||||
LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
|
||||
|
||||
let mut sumi = i32x4_splat(0);
|
||||
for j in (0..QK_K / 16).step_by(4) {
|
||||
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(j));
|
||||
let (m1, m2) = (mins[j / 2] as i32, mins[j / 2 + 1] as i32);
|
||||
let mins = i32x4(m1, m1, m2, m2);
|
||||
sumi = i32x4_add(sumi, i32x4_mul(bsums, mins));
|
||||
}
|
||||
|
||||
let mut aux32 = i32x4_splat(0i32);
|
||||
for (scale_i, scale) in scales.iter().enumerate() {
|
||||
let scale = i32x4_splat(*scale as i32);
|
||||
for j in 0..4 {
|
||||
let i = 32 * scale_i + 8 * j;
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(i));
|
||||
let aux8 = i16x8_load_extend_u8x8(aux8.as_ptr().add(i));
|
||||
let aux16 = i16x8_mul(q8, aux8);
|
||||
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_low_i16x8(aux16)));
|
||||
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_high_i16x8(aux16)));
|
||||
}
|
||||
}
|
||||
let aux32 = f32x4_convert_i32x4(aux32);
|
||||
let d = f32x4_splat(x.d.to_f32() * y.d);
|
||||
sums = f32x4_add(sums, f32x4_mul(aux32, d));
|
||||
let dmin = x.dmin.to_f32() * y.d;
|
||||
let dmin = f32x4_splat(dmin);
|
||||
let sumi = f32x4_convert_i32x4(sumi);
|
||||
sums = f32x4_sub(sums, f32x4_mul(sumi, dmin));
|
||||
}
|
||||
let sums = f32x4_extract_lane::<0>(sums)
|
||||
+ f32x4_extract_lane::<1>(sums)
|
||||
+ f32x4_extract_lane::<2>(sums)
|
||||
+ f32x4_extract_lane::<3>(sums);
|
||||
Ok(sums)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
let mut aux8 = [0i8; QK_K];
|
||||
unsafe {
|
||||
let mut sums = f32x4_splat(0f32);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let q4 = &x.ql;
|
||||
let qh = &x.qh;
|
||||
let q8 = &y.qs;
|
||||
let mut aux32 = f32x4_splat(0f32);
|
||||
|
||||
for j in (0..QK_K).step_by(128) {
|
||||
let aux8 = aux8.as_mut_ptr().add(j);
|
||||
let q4 = &q4.as_ptr().add(j / 2);
|
||||
let qh = &qh.as_ptr().add(j / 4);
|
||||
for l in (0..32).step_by(16) {
|
||||
// aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
v128_and(v128_load(q4.add(l) as *const v128), u8x16_splat(0xF)),
|
||||
u8x16_shl(
|
||||
v128_and(v128_load(qh.add(l) as *const v128), u8x16_splat(3)),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
|
||||
// aux8[l + 32] =
|
||||
// (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
v128_and(v128_load(q4.add(l + 32) as *const v128), u8x16_splat(0xF)),
|
||||
u8x16_shl(
|
||||
v128_and(
|
||||
u8x16_shr(v128_load(qh.add(l) as *const v128), 2),
|
||||
u8x16_splat(3),
|
||||
),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l + 32) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
|
||||
// aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
u8x16_shr(v128_load(q4.add(l) as *const v128), 4),
|
||||
u8x16_shl(
|
||||
v128_and(
|
||||
u8x16_shr(v128_load(qh.add(l) as *const v128), 4),
|
||||
u8x16_splat(3),
|
||||
),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l + 64) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
|
||||
// aux8[l + 96] =
|
||||
// (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
u8x16_shr(v128_load(q4.add(l + 32) as *const v128), 4),
|
||||
u8x16_shl(
|
||||
v128_and(
|
||||
u8x16_shr(v128_load(qh.add(l) as *const v128), 6),
|
||||
u8x16_splat(3),
|
||||
),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l + 96) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for (j, &scale) in x.scales.iter().enumerate() {
|
||||
let scale = f32x4_splat(scale as f32);
|
||||
for offset in [0, 8] {
|
||||
let aux16 = i16x8_mul(
|
||||
i16x8_load_extend_i8x8(q8.as_ptr().add(16 * j + offset)),
|
||||
i16x8_load_extend_i8x8(aux8.as_ptr().add(16 * j + offset)),
|
||||
);
|
||||
aux32 = f32x4_add(
|
||||
aux32,
|
||||
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_low_i16x8(aux16)), scale),
|
||||
);
|
||||
aux32 = f32x4_add(
|
||||
aux32,
|
||||
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_high_i16x8(aux16)), scale),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let d = f32x4_splat(x.d.to_f32() * y.d);
|
||||
sums = f32x4_add(sums, f32x4_mul(aux32, d));
|
||||
}
|
||||
let sums = f32x4_extract_lane::<0>(sums)
|
||||
+ f32x4_extract_lane::<1>(sums)
|
||||
+ f32x4_extract_lane::<2>(sums)
|
||||
+ f32x4_extract_lane::<3>(sums);
|
||||
Ok(sums)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
let x_qs = xs.qs.as_ptr();
|
||||
let y_qs = ys.qs.as_ptr();
|
||||
let mut sumi = i32x4_splat(0);
|
||||
for j in (0..QK_K).step_by(8) {
|
||||
let xs = i16x8_load_extend_i8x8(x_qs.add(j));
|
||||
let ys = i16x8_load_extend_i8x8(y_qs.add(j));
|
||||
let sum_xy = i32x4_dot_i16x8(xs, ys);
|
||||
sumi = i32x4_add(sumi, sum_xy)
|
||||
}
|
||||
let d = f32x4_splat(xs.d * ys.d);
|
||||
acc = f32x4_add(acc, f32x4_mul(f32x4_convert_i32x4(sumi), d))
|
||||
}
|
||||
let res = f32x4_extract_lane::<0>(acc)
|
||||
+ f32x4_extract_lane::<1>(acc)
|
||||
+ f32x4_extract_lane::<2>(acc)
|
||||
+ f32x4_extract_lane::<3>(acc);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
@ -17,7 +17,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
|
||||
let expected_blocks = xs.len() / block_size;
|
||||
let actual_blocks = ys.len();
|
||||
|
||||
//validate that the input is the right size
|
||||
// Validate that the input is the right size
|
||||
if expected_blocks != actual_blocks {
|
||||
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
|
||||
}
|
||||
@ -37,12 +37,12 @@ pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
|
||||
|
||||
let actual_output_len = ys.len();
|
||||
let expected_output_len = xs.len() * block_size;
|
||||
//validate that the output is the right size
|
||||
// Validate that the output is the right size
|
||||
if expected_output_len != actual_output_len {
|
||||
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
|
||||
}
|
||||
|
||||
//zip the blocks and outputs together
|
||||
// Zip the blocks and outputs together
|
||||
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
|
||||
}
|
||||
|
||||
|
@ -251,6 +251,134 @@ pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
|
||||
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
|
||||
}
|
||||
|
||||
#[derive(yoke::Yokeable)]
|
||||
struct SafeTensors_<'a>(SafeTensors<'a>);
|
||||
|
||||
pub struct MmapedSafetensors {
|
||||
safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
|
||||
routing: Option<HashMap<String, usize>>,
|
||||
}
|
||||
|
||||
impl MmapedSafetensors {
|
||||
/// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
|
||||
let p = p.as_ref();
|
||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||
let file = memmap2::MmapOptions::new()
|
||||
.map(&file)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
|
||||
file,
|
||||
|data: &[u8]| {
|
||||
let st = safetensors::SafeTensors::deserialize(data)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
Ok::<_, Error>(SafeTensors_(st))
|
||||
},
|
||||
)?;
|
||||
Ok(Self {
|
||||
safetensors: vec![safetensors],
|
||||
routing: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
|
||||
///
|
||||
/// If a tensor name appears in multiple files, the last entry is returned.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
|
||||
let mut routing = HashMap::new();
|
||||
let mut safetensors = vec![];
|
||||
for (index, p) in paths.iter().enumerate() {
|
||||
let p = p.as_ref();
|
||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||
let file = memmap2::MmapOptions::new()
|
||||
.map(&file)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
|
||||
file,
|
||||
|data: &[u8]| {
|
||||
let st = safetensors::SafeTensors::deserialize(data)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
Ok::<_, Error>(SafeTensors_(st))
|
||||
},
|
||||
)?;
|
||||
for k in data.get().0.names() {
|
||||
routing.insert(k.to_string(), index);
|
||||
}
|
||||
safetensors.push(data)
|
||||
}
|
||||
Ok(Self {
|
||||
safetensors,
|
||||
routing: Some(routing),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||
self.get(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
let mut tensors = vec![];
|
||||
for safetensors in self.safetensors.iter() {
|
||||
tensors.push(safetensors.get().0.tensors())
|
||||
}
|
||||
tensors.into_iter().flatten().collect()
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||
let index = match &self.routing {
|
||||
None => 0,
|
||||
Some(routing) => {
|
||||
let index = routing.get(name).ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: name.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?;
|
||||
*index
|
||||
}
|
||||
};
|
||||
Ok(self.safetensors[index].get().0.tensor(name)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BufferedSafetensors {
|
||||
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
||||
}
|
||||
|
||||
impl BufferedSafetensors {
|
||||
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
|
||||
pub fn new(buffer: Vec<u8>) -> Result<Self> {
|
||||
let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
|
||||
buffer,
|
||||
|data: &[u8]| {
|
||||
let st = safetensors::SafeTensors::deserialize(data)?;
|
||||
Ok::<_, Error>(SafeTensors_(st))
|
||||
},
|
||||
)?;
|
||||
Ok(Self { safetensors })
|
||||
}
|
||||
|
||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||
self.get(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
self.safetensors.get().0.tensors()
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||
Ok(self.safetensors.get().0.tensor(name)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MmapedFile {
|
||||
path: std::path::PathBuf,
|
||||
inner: memmap2::Mmap,
|
||||
|
@ -177,14 +177,9 @@ impl Tensor {
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let none = BackpropOp::none();
|
||||
if is_variable {
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
} else {
|
||||
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
|
||||
}
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with ones.
|
||||
@ -222,14 +217,9 @@ impl Tensor {
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let none = BackpropOp::none();
|
||||
if is_variable {
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
} else {
|
||||
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
|
||||
}
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with zeros.
|
||||
@ -489,7 +479,21 @@ impl Tensor {
|
||||
unary_op!(sqr, Sqr);
|
||||
unary_op!(sqrt, Sqrt);
|
||||
unary_op!(gelu, Gelu);
|
||||
unary_op!(gelu_erf, GeluErf);
|
||||
unary_op!(erf, Erf);
|
||||
unary_op!(relu, Relu);
|
||||
unary_op!(ceil, Ceil);
|
||||
unary_op!(floor, Floor);
|
||||
unary_op!(round, Round);
|
||||
|
||||
/// Round element of the input tensor to the nearest integer.
|
||||
///
|
||||
/// If the number of decimals is negative, it specifies the number of positions to the left of
|
||||
/// the decimal point.
|
||||
pub fn round_to(&self, decimals: i32) -> Result<Self> {
|
||||
let mult = 10f64.powi(decimals);
|
||||
(self * mult)?.round()? * (1f64 / mult)
|
||||
}
|
||||
|
||||
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
|
||||
/// dimensions, an error is returned instead.
|
||||
@ -1130,6 +1134,74 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
|
||||
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "slice-scatter")?;
|
||||
if dim == 0 {
|
||||
self.slice_scatter0(src, start)
|
||||
} else {
|
||||
// TODO: Maybe we want to add a more efficient implementation at some point.
|
||||
self.transpose(0, dim)?
|
||||
.slice_scatter0(&src.transpose(0, dim)?, start)?
|
||||
.transpose(0, dim)
|
||||
}
|
||||
}
|
||||
|
||||
/// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
|
||||
pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
|
||||
if self.dtype() != src.dtype() {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: src.dtype(),
|
||||
op: "slice-scatter",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.device().location() != src.device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: self.device().location(),
|
||||
rhs: src.device().location(),
|
||||
op: "slice-scatter",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.rank() != src.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: self.rank(),
|
||||
got: src.rank(),
|
||||
shape: src.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let shape_ok =
|
||||
self.dims()
|
||||
.iter()
|
||||
.zip(src.dims().iter())
|
||||
.enumerate()
|
||||
.all(|(dim_idx, (&d1, &d2))| {
|
||||
if 0 == dim_idx {
|
||||
d2 + start <= d1
|
||||
} else {
|
||||
d1 == d2
|
||||
}
|
||||
});
|
||||
if !shape_ok {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "slice-scatter (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: src.shape().clone(),
|
||||
})?
|
||||
}
|
||||
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let offset = start * src.dims()[1..].iter().product::<usize>();
|
||||
src.storage()
|
||||
.copy_strided_src(&mut storage, offset, src.layout())?;
|
||||
let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
|
||||
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-add")?;
|
||||
@ -1546,6 +1618,9 @@ impl Tensor {
|
||||
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
|
||||
let dim1 = dim1.to_index(self.shape(), "transpose")?;
|
||||
let dim2 = dim2.to_index(self.shape(), "transpose")?;
|
||||
if dim1 == dim2 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
@ -1907,6 +1982,34 @@ impl Tensor {
|
||||
for arg in args {
|
||||
arg.as_ref().check_dim(dim, "cat")?;
|
||||
}
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg0.rank() != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: arg0.rank(),
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
if dim == 0 {
|
||||
Self::cat0(args)
|
||||
} else {
|
||||
|
@ -218,6 +218,22 @@ fn binary_grad(device: &Device) -> Result<()> {
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
|
||||
|
||||
let x_var = Var::new(&[3f32, 1., -4., -1., 5., 9.], device)?;
|
||||
let x = x_var.as_tensor();
|
||||
let y_var = Var::new(&[2f32, 7., 1.], device)?;
|
||||
let y = y_var.as_tensor();
|
||||
|
||||
let ss = x
|
||||
.reshape((2, 3))?
|
||||
.slice_scatter0(&y.reshape((1, 3))?, 1)?
|
||||
.sqr()?;
|
||||
let grads = ss.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
let grad_y = grads.get(y).context("no grad for y")?;
|
||||
assert_eq!(ss.to_vec2::<f32>()?, [[9., 1., 16.], [4., 49., 1.]]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, -8.0, 0.0, 0.0, 0.0]);
|
||||
assert_eq!(grad_y.to_vec1::<f32>()?, [4.0, 14.0, 2.0]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -43,7 +43,7 @@ fn quantized_matmul() -> Result<()> {
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
@ -91,7 +91,7 @@ fn quantized_matmul_neg() -> Result<()> {
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
@ -491,6 +491,9 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
|
||||
GgmlDType::Q5_0 => 0.001353,
|
||||
GgmlDType::Q5_1 => 0.001363,
|
||||
GgmlDType::Q8_0 => 0.000092,
|
||||
|
||||
// Not from the ggml repo.
|
||||
GgmlDType::Q8K => 0.00065,
|
||||
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
||||
};
|
||||
Ok(err)
|
||||
@ -508,17 +511,22 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||
T::VecDotType::from_float(&b, &mut b_quant)?;
|
||||
|
||||
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
||||
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
|
||||
let reference_result = vec_dot_reference(&a, &b);
|
||||
|
||||
if (result - result_unopt).abs() / length as f32 > 1e-6 {
|
||||
candle_core::bail!(
|
||||
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
|
||||
)
|
||||
}
|
||||
|
||||
let error = (result - reference_result).abs() / length as f32;
|
||||
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
||||
|
||||
if error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
candle_core::bail!(
|
||||
"Dot product error {} exceeds max error {}",
|
||||
error,
|
||||
GGML_MAX_DOT_PRODUCT_ERROR
|
||||
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
|
||||
);
|
||||
}
|
||||
|
||||
@ -571,7 +579,7 @@ fn quantized_matmul_q2k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -597,7 +605,7 @@ fn quantized_matmul_q3k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -623,7 +631,7 @@ fn quantized_matmul_q4k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -649,7 +657,7 @@ fn quantized_matmul_q5k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -676,7 +684,7 @@ fn quantized_matmul_q6k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -687,3 +695,28 @@ fn quantized_matmul_q6k() -> Result<()> {
|
||||
ggml_matmul_error_test::<BlockQ6K>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q8k() -> Result<()> {
|
||||
use k_quants::BlockQ8K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.266, 1.504, -0.204, 1.7]);
|
||||
|
||||
ggml_matmul_error_test::<BlockQ8K>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn zeros(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||
@ -8,6 +8,31 @@ fn zeros(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ones(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::U8, device)?.to_vec2::<u8>()?,
|
||||
[[1, 1, 1], [1, 1, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::U32, device)?.to_vec2::<u32>()?,
|
||||
[[1, 1, 1], [1, 1, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::I64, device)?.to_vec2::<i64>()?,
|
||||
[[1, 1, 1], [1, 1, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn add_mul(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||
let dim1 = tensor.dims1()?;
|
||||
@ -44,6 +69,54 @@ fn clamp(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.gelu()?, 4)?,
|
||||
[
|
||||
[-0.0036, 0.8412, 3.9999, -0.046, 0.3457],
|
||||
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
||||
[
|
||||
[-0.004, 0.8413, 3.9999, -0.046, 0.3457],
|
||||
[2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.erf()?, 4)?,
|
||||
[
|
||||
[-1.0, 0.8427, 1.0, -0.1125, 0.5205],
|
||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.floor()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -1.0, 0.0], [2.0, -2.0, -1.0, 1.0, 2.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.round()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -2.0, -0.0, 2.0, 3.0]]
|
||||
);
|
||||
let tensor = Tensor::new(&[2997.9246, 314.15926f32], device)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&tensor.round_to(2)?, 4)?,
|
||||
[2997.92, 314.16]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
||||
[3000.0, 300.]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn binary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor1 = Tensor::new(data, device)?;
|
||||
@ -601,6 +674,30 @@ fn index_select(device: &Device) -> Result<()> {
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||
);
|
||||
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
]
|
||||
);
|
||||
|
||||
// Test when selecting dim > 0 with ids size different from elem count of
|
||||
// target dim in source/input.
|
||||
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -647,6 +744,48 @@ fn index_add(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn slice_scatter(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
let src = Tensor::arange(100f32, 106f32, device)?.reshape((2, 3))?;
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 0)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 1)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 2)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scatter_add(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
@ -897,6 +1036,7 @@ fn randn(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||
test_device!(ones, ones_cpu, ones_gpu);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||
@ -908,6 +1048,7 @@ test_device!(max, max_cpu, max_gpu);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||
@ -918,6 +1059,7 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||
test_device!(gather, gather_cpu, gather_gpu);
|
||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
||||
test_device!(randn, randn_cpu, randn_gpu);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu);
|
||||
|
||||
|
@ -11,8 +11,8 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true }
|
||||
@ -25,6 +25,7 @@ rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -35,7 +36,6 @@ imageproc = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rusttype = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
@ -51,7 +51,7 @@ default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
|
||||
|
@ -86,9 +86,8 @@ impl Args {
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
|
||||
let model = BertModel::load(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
|
@ -138,18 +138,9 @@ fn main() -> Result<()> {
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let weights = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|f| Ok(f.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let config = Config::starcoder_1b();
|
||||
let model = GPTBigCode::load(vb, config)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
@ -42,9 +42,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = dinov2::vit_small(vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
|
@ -68,9 +68,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let cfg = match args.which {
|
||||
Which::B0 => MBConvConfig::b0(),
|
||||
Which::B1 => MBConvConfig::b1(),
|
||||
|
@ -177,21 +177,12 @@ fn main() -> Result<()> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let weights = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|f| Ok(f.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let dtype = if args.use_f32 {
|
||||
DType::F32
|
||||
} else {
|
||||
DType::BF16
|
||||
};
|
||||
let vb = VarBuilder::from_safetensors(weights, dtype, &device);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let config = Config::falcon7b();
|
||||
config.validate()?;
|
||||
let model = Falcon::load(vb, config)?;
|
||||
|
@ -172,17 +172,9 @@ fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
println!("building the model");
|
||||
let handles = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let tensors: Vec<_> = handles
|
||||
.iter()
|
||||
.map(|h| Ok(h.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||
}
|
||||
};
|
||||
|
@ -89,6 +89,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
@ -201,16 +205,9 @@ fn main() -> Result<()> {
|
||||
let cache = model::Cache::new(dtype, &config, &device)?;
|
||||
|
||||
println!("building the model");
|
||||
let handles = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let tensors: Vec<_> = handles
|
||||
.iter()
|
||||
.map(|h| Ok(h.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device);
|
||||
let vb = unsafe {
|
||||
candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)?
|
||||
};
|
||||
let llama = Llama::load(vb, &cache, &config, comm)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
@ -222,7 +219,7 @@ fn main() -> Result<()> {
|
||||
.to_vec();
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||
let mut new_tokens = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
|
90
candle-examples/examples/mistral/README.md
Normal file
90
candle-examples/examples/mistral/README.md
Normal file
@ -0,0 +1,90 @@
|
||||
# candle-mistral: 7b LLM with Apache 2.0 licensed weights
|
||||
|
||||
Mistral-7B-v0.1 is a pretrained generative LLM with 7 billion parameters. It outperforms all the publicly available 13b models
|
||||
as of 2023-09-28. Weights (and the original Python model code) are released under the permissive Apache 2.0 license.
|
||||
|
||||
- [Blog post](https://mistral.ai/news/announcing-mistral-7b/) from Mistral announcing the model release.
|
||||
- [Model card](https://huggingface.co/mistralai/Mistral-7B-v0.1) on the
|
||||
HuggingFace Hub.
|
||||
This example supports the initial model as well as a quantized variant.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example mistral --release --features cuda -- --prompt 'Write helloworld code in Rust' --sample-len 150
|
||||
|
||||
Generated text:
|
||||
Write helloworld code in Rust
|
||||
=============================
|
||||
|
||||
This is a simple example of how to write "Hello, world!" program in Rust.
|
||||
|
||||
## Compile and run
|
||||
|
||||
``bash
|
||||
$ cargo build --release
|
||||
Compiling hello-world v0.1.0 (/home/user/rust/hello-world)
|
||||
Finished release [optimized] target(s) in 0.26s
|
||||
$ ./target/release/hello-world
|
||||
Hello, world!
|
||||
``
|
||||
|
||||
## Source code
|
||||
|
||||
``rust
|
||||
fn main() {
|
||||
println!("Hello, world!");
|
||||
}
|
||||
``
|
||||
|
||||
## License
|
||||
|
||||
This example is released under the terms
|
||||
```
|
||||
|
||||
## Running the quantized version of the model
|
||||
|
||||
```bash
|
||||
$ cargo run --example mistral --features accelerate --release -- \
|
||||
$ --prompt "Here is a sample quick sort implementation in rust " --quantized -n 400
|
||||
avx: false, neon: true, simd128: false, f16c: false
|
||||
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
retrieved the files in 562.292µs
|
||||
loaded the model in 1.100323667s
|
||||
Here is a sample quick sort implementation in rust
|
||||
|
||||
``rust
|
||||
fn quick_sort(arr: &mut [i32]) {
|
||||
if arr.len() <= 1 {
|
||||
return;
|
||||
}
|
||||
|
||||
let pivot = arr[0];
|
||||
let mut left = vec![];
|
||||
let mut right = vec![];
|
||||
|
||||
for i in 1..arr.len() {
|
||||
if arr[i] < pivot {
|
||||
left.push(arr[i]);
|
||||
} else {
|
||||
right.push(arr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
quick_sort(&mut left);
|
||||
quick_sort(&mut right);
|
||||
|
||||
let mut i = 0;
|
||||
for _ in &left {
|
||||
arr[i] = left.pop().unwrap();
|
||||
i += 1;
|
||||
}
|
||||
|
||||
for _ in &right {
|
||||
arr[i] = right.pop().unwrap();
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
``
|
||||
226 tokens generated (10.91 token/s)
|
||||
```
|
271
candle-examples/examples/mistral/main.rs
Normal file
271
candle-examples/examples/mistral/main.rs
Normal file
@ -0,0 +1,271 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::mistral::{Config, Model as Mistral};
|
||||
use candle_transformers::models::quantized_mistral::Model as QMistral;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
Mistral(Mistral),
|
||||
Quantized(QMistral),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("</s>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::Mistral(m) => m.forward(&input, start_pos)?,
|
||||
Model::Quantized(m) => m.forward(&input, start_pos)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "lmz/candle-mistral")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
vec![repo.get("model-q4k.gguf")?]
|
||||
} else {
|
||||
vec![
|
||||
repo.get("pytorch_model-00001-of-00002.safetensors")?,
|
||||
repo.get("pytorch_model-00002-of-00002.safetensors")?,
|
||||
]
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||
let model = QMistral::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Mistral::new(&config, vb)?;
|
||||
(Model::Mistral(model), device)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
use crate::nn::conv1d_weight_norm;
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, Module, VarBuilder};
|
||||
use candle::{DType, IndexOp, Module, Result, Tensor};
|
||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
||||
|
||||
// Encodec Model
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||
@ -199,25 +199,34 @@ impl EncodecResidualVectorQuantizer {
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
||||
#[derive(Debug)]
|
||||
struct EncodecLSTM {
|
||||
layers: Vec<(Tensor, Tensor, Tensor, Tensor)>,
|
||||
layers: Vec<candle_nn::LSTM>,
|
||||
}
|
||||
|
||||
impl EncodecLSTM {
|
||||
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = &vb.pp("lstm");
|
||||
let mut layers = vec![];
|
||||
for i in 0..cfg.num_lstm_layers {
|
||||
let w_hh = vb.get((4 * dim, dim), &format!("weight_hh_l{i}"))?;
|
||||
let w_ih = vb.get((4 * dim, dim), &format!("weight_ih_l{i}"))?;
|
||||
let b_hh = vb.get(4 * dim, &format!("bias_hh_l{i}"))?;
|
||||
let b_ih = vb.get(4 * dim, &format!("bias_ih_l{i}"))?;
|
||||
layers.push((w_hh, w_ih, b_hh, b_ih))
|
||||
for layer_idx in 0..cfg.num_lstm_layers {
|
||||
let config = candle_nn::LSTMConfig {
|
||||
layer_idx,
|
||||
..Default::default()
|
||||
};
|
||||
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
|
||||
layers.push(lstm)
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
impl Module for EncodecLSTM {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
use candle_nn::RNN;
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
let states = layer.seq(&xs)?;
|
||||
xs = layer.states_to_tensor(&states)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
@ -247,7 +256,9 @@ impl EncodecConvTranspose1d {
|
||||
bias,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConvTranspose1d {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
@ -299,7 +310,9 @@ impl EncodecConv1d {
|
||||
conv,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: padding, depending on causal.
|
||||
let xs = self.conv.forward(xs)?;
|
||||
@ -340,7 +353,9 @@ impl EncodecResnetBlock {
|
||||
shortcut,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecResnetBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs.clone();
|
||||
let xs = xs.elu(1.)?;
|
||||
@ -439,8 +454,17 @@ impl EncodecEncoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?;
|
||||
for (resnets, conv) in self.sampling_layers.iter() {
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?;
|
||||
}
|
||||
xs = xs.elu(1.0)?.apply(conv)?;
|
||||
}
|
||||
xs.apply(&self.final_lstm)?
|
||||
.elu(1.0)?
|
||||
.apply(&self.final_conv)
|
||||
}
|
||||
}
|
||||
|
||||
@ -507,8 +531,15 @@ impl EncodecDecoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
|
||||
for (conv, resnets) in self.sampling_layers.iter() {
|
||||
xs = xs.elu(1.)?.apply(conv)?;
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?
|
||||
}
|
||||
}
|
||||
xs.elu(1.)?.apply(&self.final_conv)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -73,9 +73,7 @@ fn main() -> Result<()> {
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let model = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let model = model.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DTYPE, &device)? };
|
||||
let config = GenConfig::small();
|
||||
let mut model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||
|
||||
|
@ -40,7 +40,7 @@ impl Default for Config {
|
||||
num_attention_heads: 16,
|
||||
layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
|
||||
activation_function: Activation::Gelu,
|
||||
hidden_size: 1024,
|
||||
dropout: 0.1,
|
||||
attention_dropout: 0.0,
|
||||
@ -66,7 +66,7 @@ impl Config {
|
||||
num_attention_heads: 16,
|
||||
layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
|
||||
activation_function: Activation::Gelu,
|
||||
hidden_size: 1024,
|
||||
dropout: 0.1,
|
||||
attention_dropout: 0.0,
|
||||
|
43
candle-examples/examples/phi/README.md
Normal file
43
candle-examples/examples/phi/README.md
Normal file
@ -0,0 +1,43 @@
|
||||
# candle-phi: 1.3b LLM with state of the art performance for <10b models.
|
||||
|
||||
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) is a language model using
|
||||
only 1.3 billion parameters but with state of the art performance compared to
|
||||
models with up to 10 billion parameters.
|
||||
|
||||
The candle implementation provides both the standard version as well as a
|
||||
quantized variant.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
$ cargo run --example phi --release -- --prompt "def print_prime(n): "
|
||||
|
||||
def print_prime(n):
|
||||
print("Printing prime numbers")
|
||||
for i in range(2, n+1):
|
||||
if is_prime(i):
|
||||
print(i)
|
||||
|
||||
def is_prime(n):
|
||||
if n <= 1:
|
||||
return False
|
||||
for i in range(2, int(math.sqrt(n))+1):
|
||||
if n % i == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
$ cargo run --example phi --release -- \
|
||||
--prompt "Explain how to find the median of an array and write the corresponding python function.\nAnswer:" \
|
||||
--quantized --sample-len 200
|
||||
|
||||
Explain how to find the median of an array and write the corresponding python function.
|
||||
Answer: The median is the middle value in an array. If the array has an even number of elements, the median is the average of the two middle values.
|
||||
|
||||
def median(arr):
|
||||
arr.sort()
|
||||
n = len(arr)
|
||||
if n % 2 == 0:
|
||||
return (arr[n//2 - 1] + arr[n//2]) / 2
|
||||
else:
|
||||
return arr[n//2]
|
||||
```
|
238
candle-examples/examples/phi/main.rs
Normal file
238
candle-examples/examples/phi/main.rs
Normal file
@ -0,0 +1,238 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
MixFormer(MixFormer),
|
||||
Quantized(QMixFormer),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||
Some(token) => *token,
|
||||
None => anyhow::bail!("cannot find the endoftext token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::MixFormer(m) => m.forward(&input)?,
|
||||
Model::Quantized(m) => m.forward(&input)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "microsoft/phi-1_5")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "refs/pr/18")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let filename = match args.weight_file {
|
||||
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
||||
None => {
|
||||
if args.quantized {
|
||||
api.model("lmz/candle-quantized-phi".to_string())
|
||||
.get("model-q4k.gguf")?
|
||||
} else {
|
||||
repo.get("model.safetensors")?
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::v1_5();
|
||||
let (model, device) = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
||||
let model = QMixFormer::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||
let model = MixFormer::new(&config, vb)?;
|
||||
(Model::MixFormer(model), device)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
42
candle-examples/examples/quantized-t5/README.md
Normal file
42
candle-examples/examples/quantized-t5/README.md
Normal file
@ -0,0 +1,42 @@
|
||||
# candle-quantized-t5
|
||||
|
||||
This example uses a quantized version of the t5 model.
|
||||
|
||||
```bash
|
||||
$ cargo run --example quantized-t5 --release -- --prompt "translate to German: A beautiful candle."
|
||||
...
|
||||
Eine schöne Kerze.
|
||||
```
|
||||
|
||||
The weight file is automatically retrieved from the hub. It is also possible to
|
||||
generate quantized weight files from the original safetensors file by using the
|
||||
`tensor-tools` command line utility via:
|
||||
|
||||
```bash
|
||||
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||
```
|
||||
|
||||
To use a different model, specify the `model-id`. For example, you can use
|
||||
quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
|
||||
|
||||
```bash
|
||||
$ cargo run --example quantized-t5 --release -- \
|
||||
--model-id "jbochi/candle-coedit-quantized" \
|
||||
--prompt "Make this text coherent: Their flight is weak. They run quickly through the tree canopy." \
|
||||
--temperature 0
|
||||
...
|
||||
Although their flight is weak, they run quickly through the tree canopy.
|
||||
|
||||
By default, it will look for `model.gguf` and `config.json`, but you can specify
|
||||
custom local or remote `weight-file` and `config-file`s:
|
||||
|
||||
```bash
|
||||
cargo run --example quantized-t5 --release -- \
|
||||
--model-id "jbochi/candle-coedit-quantized" \
|
||||
--weight-file "model-xl.gguf" \
|
||||
--config-file "config-xl.json" \
|
||||
--prompt "Rewrite to make this easier to understand: Note that a storm surge is what forecasters consider a hurricane's most treacherous aspect." \
|
||||
--temperature 0
|
||||
...
|
||||
Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
|
||||
```
|
228
candle-examples/examples/quantized-t5/main.rs
Normal file
228
candle-examples/examples/quantized-t5/main.rs
Normal file
@ -0,0 +1,228 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use candle_transformers::models::quantized_t5 as t5;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, api::sync::ApiRepo, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
T5Small,
|
||||
FlanT5Small,
|
||||
FlanT5Base,
|
||||
FlanT5Large,
|
||||
FlanT5Xl,
|
||||
FlanT5Xxl,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The model repository to use on the HuggingFace hub.
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
// Enable/disable decoding.
|
||||
#[arg(long, default_value = "false")]
|
||||
disable_cache: bool,
|
||||
|
||||
/// Use this prompt, otherwise compute sentence similarities.
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "t5-small")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
struct T5ModelBuilder {
|
||||
device: Device,
|
||||
config: t5::Config,
|
||||
weights_filename: PathBuf,
|
||||
}
|
||||
|
||||
impl T5ModelBuilder {
|
||||
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
||||
let device = Device::Cpu;
|
||||
let default_model = "lmz/candle-quantized-t5".to_string();
|
||||
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, "main".to_string()),
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config_filename = match &args.config_file {
|
||||
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
|
||||
None => match args.which {
|
||||
Which::T5Small => api.get("config.json")?,
|
||||
Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
|
||||
Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
|
||||
Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
|
||||
Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
|
||||
Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
|
||||
},
|
||||
};
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let weights_filename = match &args.weight_file {
|
||||
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
|
||||
None => match args.which {
|
||||
Which::T5Small => api.get("model.gguf")?,
|
||||
Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
|
||||
Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?,
|
||||
Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?,
|
||||
Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?,
|
||||
Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?,
|
||||
},
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||
config.use_cache = !args.disable_cache;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
Ok((
|
||||
Self {
|
||||
device,
|
||||
config,
|
||||
weights_filename,
|
||||
},
|
||||
tokenizer,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
|
||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||
}
|
||||
|
||||
fn get_local_or_remote_file(filename: &str, api: &ApiRepo) -> Result<PathBuf> {
|
||||
let local_filename = std::path::PathBuf::from(filename);
|
||||
if local_filename.exists() {
|
||||
Ok(local_filename)
|
||||
} else {
|
||||
Ok(api.get(filename)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
||||
let device = &builder.device;
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let mut model = builder.build_model()?;
|
||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||
let temperature = if args.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
Some(args.temperature)
|
||||
};
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
|
||||
let encoder_output = model.encode(&input_token_ids)?;
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for index in 0.. {
|
||||
if output_token_ids.len() > 512 {
|
||||
break;
|
||||
}
|
||||
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
|
||||
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
|
||||
} else {
|
||||
let last_token = *output_token_ids.last().unwrap();
|
||||
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
||||
};
|
||||
let logits = model
|
||||
.decode(&decoder_token_ids, &encoder_output)?
|
||||
.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&output_token_ids[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token_id = logits_processor.sample(&logits)?;
|
||||
if next_token_id as usize == builder.config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
output_token_ids.push(next_token_id);
|
||||
if let Some(text) = tokenizer.id_to_token(next_token_id) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start.elapsed();
|
||||
println!(
|
||||
"\n{} tokens generated ({:.2} token/s)\n",
|
||||
output_token_ids.len(),
|
||||
output_token_ids.len() as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
@ -44,6 +44,27 @@ enum Which {
|
||||
L13bCode,
|
||||
#[value(name = "32b-code")]
|
||||
L34bCode,
|
||||
#[value(name = "7b-mistral")]
|
||||
Mistral7b,
|
||||
#[value(name = "7b-mistral-instruct")]
|
||||
Mistral7bInstruct,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn is_mistral(&self) -> bool {
|
||||
match self {
|
||||
Self::L7b
|
||||
| Self::L13b
|
||||
| Self::L70b
|
||||
| Self::L7bChat
|
||||
| Self::L13bChat
|
||||
| Self::L70bChat
|
||||
| Self::L7bCode
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode => false,
|
||||
Self::Mistral7b | Self::Mistral7bInstruct => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -110,7 +131,12 @@ impl Args {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
|
||||
let repo = if self.which.is_mistral() {
|
||||
"mistralai/Mistral-7B-v0.1"
|
||||
} else {
|
||||
"hf-internal-testing/llama-tokenizer"
|
||||
};
|
||||
let api = api.model(repo.to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
@ -140,6 +166,14 @@ impl Args {
|
||||
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
|
||||
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
|
||||
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
|
||||
Which::Mistral7b => (
|
||||
"TheBloke/Mistral-7B-v0.1-GGUF",
|
||||
"mistral-7b-v0.1.Q4_K_S.gguf",
|
||||
),
|
||||
Which::Mistral7bInstruct => (
|
||||
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
||||
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
|
||||
),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(repo.to_string());
|
||||
@ -261,7 +295,7 @@ fn main() -> anyhow::Result<()> {
|
||||
| Which::L7bCode
|
||||
| Which::L13bCode
|
||||
| Which::L34bCode => 1,
|
||||
Which::L70b | Which::L70bChat => 8,
|
||||
Which::Mistral7b | Which::Mistral7bInstruct | Which::L70b | Which::L70bChat => 8,
|
||||
};
|
||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||
}
|
||||
@ -291,7 +325,11 @@ fn main() -> anyhow::Result<()> {
|
||||
prompt.pop();
|
||||
}
|
||||
}
|
||||
prompt
|
||||
if args.which.is_mistral() {
|
||||
format!("[INST] {prompt} [/INST]")
|
||||
} else {
|
||||
prompt
|
||||
}
|
||||
}
|
||||
};
|
||||
print!("{}", &prompt_str);
|
||||
@ -327,6 +365,8 @@ fn main() -> anyhow::Result<()> {
|
||||
all_tokens.push(next_token);
|
||||
print_token(next_token, &tokenizer);
|
||||
|
||||
let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
|
||||
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||
@ -345,6 +385,9 @@ fn main() -> anyhow::Result<()> {
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
print_token(next_token, &tokenizer);
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
};
|
||||
}
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
|
@ -16,25 +16,29 @@ based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
|
||||
cargo run --example segment-anything --release -- \
|
||||
--image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
--use-tiny
|
||||
--point-x 0.4
|
||||
--point-y 0.3
|
||||
--point 0.6,0.6 --point 0.6,0.55
|
||||
```
|
||||
|
||||
Running this command generates a `sam_merged.jpg` file containing the original
|
||||
image with a blue overlay of the selected mask. The red dot represents the prompt
|
||||
specified by `--point-x 0.4 --point-y 0.3`, this prompt is assumed to be part
|
||||
image with a blue overlay of the selected mask. The red dots represent the prompt
|
||||
specified by `--point 0.6,0.6 --point 0.6,0.55`, this prompt is assumed to be part
|
||||
of the target mask.
|
||||
|
||||
The values used for `--point-x` and `--point-y` should be between 0 and 1 and
|
||||
are proportional to the image dimension, i.e. use 0.5 for the image center.
|
||||
The values used for `--point` should be a comma delimited pair of float values.
|
||||
They are proportional to the image dimension, i.e. use 0.5 for the image center.
|
||||
|
||||
Original image:
|
||||

|
||||
|
||||

|
||||
Segment results by prompting with a single point `--point 0.6,0.55`:
|
||||

|
||||
|
||||
Segment results by prompting with multiple points `--point 0.6,0.6 --point 0.6,0.55`:
|
||||

|
||||
|
||||
### Command-line flags
|
||||
- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
|
||||
one.
|
||||
- `--point-x`, `--point-y`: specifies the location of the target point.
|
||||
- `--point`: specifies the location of the target points.
|
||||
- `--threshold`: sets the threshold value to be part of the mask, a negative
|
||||
value results in a larger mask and can be specified via `--threshold=-1.2`.
|
||||
|
Binary file not shown.
After Width: | Height: | Size: 158 KiB |
Binary file not shown.
After Width: | Height: | Size: 158 KiB |
@ -27,13 +27,15 @@ struct Args {
|
||||
#[arg(long)]
|
||||
generate_masks: bool,
|
||||
|
||||
/// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
point_x: f64,
|
||||
/// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image). These points
|
||||
/// should be part of the generated mask.
|
||||
#[arg(long)]
|
||||
point: Vec<String>,
|
||||
|
||||
/// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image).
|
||||
#[arg(long, default_value_t = 0.5)]
|
||||
point_y: f64,
|
||||
/// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image). These points
|
||||
/// should not be part of the generated mask and should be part of the background instead.
|
||||
#[arg(long)]
|
||||
neg_point: Vec<String>,
|
||||
|
||||
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
||||
/// mask, positive makes the mask more selective.
|
||||
@ -82,9 +84,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
api.get(filename)?
|
||||
}
|
||||
};
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let sam = if args.use_tiny {
|
||||
sam::Sam::new_tiny(vb)? // tiny vit_t
|
||||
} else {
|
||||
@ -113,15 +113,27 @@ pub fn main() -> anyhow::Result<()> {
|
||||
)?;
|
||||
}
|
||||
} else {
|
||||
let point = Some((args.point_x, args.point_y));
|
||||
let iter_points = args.point.iter().map(|p| (p, true));
|
||||
let iter_neg_points = args.neg_point.iter().map(|p| (p, false));
|
||||
let points = iter_points
|
||||
.chain(iter_neg_points)
|
||||
.map(|(point, b)| {
|
||||
use std::str::FromStr;
|
||||
let xy = point.split(',').collect::<Vec<_>>();
|
||||
if xy.len() != 2 {
|
||||
anyhow::bail!("expected format for points is 0.4,0.2")
|
||||
}
|
||||
Ok((f64::from_str(xy[0])?, f64::from_str(xy[1])?, b))
|
||||
})
|
||||
.collect::<anyhow::Result<Vec<_>>>()?;
|
||||
let start_time = std::time::Instant::now();
|
||||
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
|
||||
let (mask, iou_predictions) = sam.forward(&image, &points, false)?;
|
||||
println!(
|
||||
"mask generated in {:.2}s",
|
||||
start_time.elapsed().as_secs_f32()
|
||||
);
|
||||
println!("mask:\n{mask}");
|
||||
println!("iou_predictions: {iou_predictions:?}");
|
||||
println!("iou_predictions: {iou_predictions}");
|
||||
|
||||
let mask = (mask.ge(args.threshold)? * 255.)?;
|
||||
let (_one, h, w) = mask.dims3()?;
|
||||
@ -153,12 +165,17 @@ pub fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
}
|
||||
let (x, y) = (
|
||||
(args.point_x * img.width() as f64) as i32,
|
||||
(args.point_y * img.height() as f64) as i32,
|
||||
);
|
||||
imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200]))
|
||||
.save("sam_merged.jpg")?
|
||||
for (x, y, b) in points {
|
||||
let x = (x * img.width() as f64) as i32;
|
||||
let y = (y * img.height() as f64) as i32;
|
||||
let color = if b {
|
||||
image::Rgba([255, 0, 0, 200])
|
||||
} else {
|
||||
image::Rgba([0, 255, 0, 200])
|
||||
};
|
||||
imageproc::drawing::draw_filled_circle_mut(&mut img, (x, y), 3, color);
|
||||
}
|
||||
img.save("sam_merged.jpg")?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -97,14 +97,13 @@ struct Args {
|
||||
img2img_strength: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
|
||||
enum StableDiffusionVersion {
|
||||
V1_5,
|
||||
V2_1,
|
||||
Xl,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ModelFile {
|
||||
Tokenizer,
|
||||
@ -204,7 +203,18 @@ impl ModelFile {
|
||||
Self::Clip => (version.repo(), version.clip_file(use_f16)),
|
||||
Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
|
||||
Self::Unet => (version.repo(), version.unet_file(use_f16)),
|
||||
Self::Vae => (version.repo(), version.vae_file(use_f16)),
|
||||
Self::Vae => {
|
||||
// Override for SDXL when using f16 weights.
|
||||
// See https://github.com/huggingface/candle/issues/1060
|
||||
if version == StableDiffusionVersion::Xl && use_f16 {
|
||||
(
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
)
|
||||
} else {
|
||||
(version.repo(), version.vae_file(use_f16))
|
||||
}
|
||||
}
|
||||
};
|
||||
let filename = Api::new()?.model(repo.to_string()).get(path)?;
|
||||
Ok(filename)
|
||||
@ -484,9 +494,8 @@ fn run(args: Args) -> Result<()> {
|
||||
num_samples
|
||||
);
|
||||
let image = vae.decode(&(&latents / 0.18215)?)?;
|
||||
// TODO: Add the clamping between 0 and 1.
|
||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
||||
candle_examples::save_image(&image, image_filename)?
|
||||
}
|
||||
|
25
candle-examples/examples/stable-lm/README.md
Normal file
25
candle-examples/examples/stable-lm/README.md
Normal file
@ -0,0 +1,25 @@
|
||||
# candle-stable-lm
|
||||
|
||||
StableLM-3B-4E1T is a 3 billion parameter decoder-only language model
|
||||
pre-trained on 1 trillion tokens of diverse English and code datasets for 4
|
||||
epochs. See the [HuggingFace Hub Model
|
||||
Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t).
|
||||
|
||||
Note that this model is gated so you will have to request access on the Hub in
|
||||
order to be able to use it.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
$ cargo run --example stable-lm --release --features cuda -- --prompt 'What is the most efficient programming language in use?' --sample-len 150
|
||||
avx: true, neon: false, simd128: false, f16c: true
|
||||
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
retrieved the files in 126.593µs
|
||||
loaded the model in 3.474148965s
|
||||
What is the most efficient programming language in use?
|
||||
The answer to this question depends on what you mean by "efficient". If you're talking about speed, then C++ and Java are probably your best bets. But if you're talking about ease of development, then Python is probably the way to go.
|
||||
Python is a high-level, interpreted language that is easy to learn and use. It has a large community of developers who are always working on new features and improvements.
|
||||
C++ is a low-level, compiled language that can be used for both desktop applications and web development. It's more difficult to learn than Python but offers greater control over the code.
|
||||
Java is another high-level language that is popular with programmers because it runs on many different platforms (including Android phones
|
||||
150 tokens generated (37.61 token/s)
|
||||
```
|
268
candle-examples/examples/stable-lm/main.rs
Normal file
268
candle-examples/examples/stable-lm/main.rs
Normal file
@ -0,0 +1,268 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
|
||||
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
StableLM(StableLM),
|
||||
Quantized(QStableLM),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::StableLM(m) => m.forward(&input, start_pos)?,
|
||||
Model::Quantized(m) => m.forward(&input, start_pos)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
vec![repo.get("model-q4k.gguf")?]
|
||||
} else {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||
let model = QStableLM::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = StableLM::new(&config, vb)?;
|
||||
(Model::StableLM(model), device)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -8,12 +8,12 @@ use std::path::PathBuf;
|
||||
|
||||
use candle_transformers::models::t5;
|
||||
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
@ -25,10 +25,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Run offline (you must have the files already cached)
|
||||
#[arg(long)]
|
||||
offline: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
@ -46,12 +42,16 @@ struct Args {
|
||||
|
||||
// Enable/disable decoding.
|
||||
#[arg(long, default_value = "false")]
|
||||
use_cache: bool,
|
||||
disable_cache: bool,
|
||||
|
||||
/// Use this prompt, otherwise compute sentence similarities.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// If set along with --decode, will use this prompt to initialize the decoder.
|
||||
#[arg(long)]
|
||||
decoder_prompt: Option<String>,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
@ -76,7 +76,7 @@ struct Args {
|
||||
struct T5ModelBuilder {
|
||||
device: Device,
|
||||
config: t5::Config,
|
||||
weights_filename: PathBuf,
|
||||
weights_filename: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
impl T5ModelBuilder {
|
||||
@ -91,32 +91,25 @@ impl T5ModelBuilder {
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let (config_filename, tokenizer_filename, weights_filename) = if args.offline {
|
||||
let cache = Cache::default().repo(repo);
|
||||
(
|
||||
cache
|
||||
.get("config.json")
|
||||
.ok_or(anyhow!("Missing config file in cache"))?,
|
||||
cache
|
||||
.get("tokenizer.json")
|
||||
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
|
||||
cache
|
||||
.get("model.safetensors")
|
||||
.ok_or(anyhow!("Missing weights file in cache"))?,
|
||||
)
|
||||
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config_filename = api.get("config.json")?;
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let weights_filename = if model_id == "google/flan-t5-xxl" {
|
||||
vec![
|
||||
api.get("model-00001-of-00005.safetensors")?,
|
||||
api.get("model-00002-of-00005.safetensors")?,
|
||||
api.get("model-00003-of-00005.safetensors")?,
|
||||
api.get("model-00004-of-00005.safetensors")?,
|
||||
api.get("model-00005-of-00005.safetensors")?,
|
||||
]
|
||||
} else {
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
(
|
||||
api.get("config.json")?,
|
||||
api.get("tokenizer.json")?,
|
||||
api.get("model.safetensors")?,
|
||||
)
|
||||
vec![api.get("model.safetensors")?]
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||
config.use_cache = args.use_cache;
|
||||
config.use_cache = !args.disable_cache;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
Ok((
|
||||
Self {
|
||||
@ -129,24 +122,35 @@ impl T5ModelBuilder {
|
||||
}
|
||||
|
||||
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
|
||||
let weights =
|
||||
unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
|
||||
};
|
||||
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
|
||||
}
|
||||
|
||||
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||
let weights =
|
||||
unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
|
||||
};
|
||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
||||
let device = &builder.device;
|
||||
let tokenizer = tokenizer
|
||||
@ -170,6 +174,16 @@ fn main() -> Result<()> {
|
||||
} else {
|
||||
let mut model = builder.build_conditional_generation()?;
|
||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||
if let Some(decoder_prompt) = &args.decoder_prompt {
|
||||
print!("{decoder_prompt}");
|
||||
output_token_ids.extend(
|
||||
tokenizer
|
||||
.encode(decoder_prompt.to_string(), false)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec(),
|
||||
);
|
||||
}
|
||||
let temperature = if args.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
@ -195,11 +209,11 @@ fn main() -> Result<()> {
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
|
||||
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
&output_token_ids[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
@ -217,8 +231,8 @@ fn main() -> Result<()> {
|
||||
let dt = start.elapsed();
|
||||
println!(
|
||||
"\n{} tokens generated ({:.2} token/s)\n",
|
||||
tokens.len(),
|
||||
tokens.len() as f64 / dt.as_secs_f64(),
|
||||
output_token_ids.len(),
|
||||
output_token_ids.len() as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -18,8 +18,48 @@ use rand::{distributions::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod multilingual;
|
||||
use candle_transformers::models::whisper::{self as m, audio, model};
|
||||
use model::{Config, Whisper};
|
||||
use candle_transformers::models::whisper::{self as m, audio, Config};
|
||||
|
||||
pub enum Model {
|
||||
Normal(m::model::Whisper),
|
||||
Quantized(m::quantized_model::Whisper),
|
||||
}
|
||||
|
||||
// Maybe we should use some traits rather than doing the dispatch for all these.
|
||||
impl Model {
|
||||
pub fn config(&self) -> &Config {
|
||||
match self {
|
||||
Self::Normal(m) => &m.config,
|
||||
Self::Quantized(m) => &m.config,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.encoder.forward(x, flush),
|
||||
Self::Quantized(m) => m.encoder.forward(x, flush),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decoder_forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: &Tensor,
|
||||
flush: bool,
|
||||
) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.decoder.forward(x, xa, flush),
|
||||
Self::Quantized(m) => m.decoder.forward(x, xa, flush),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.decoder.final_linear(x),
|
||||
Self::Quantized(m) => m.decoder.final_linear(x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
@ -41,7 +81,7 @@ struct Segment {
|
||||
}
|
||||
|
||||
struct Decoder {
|
||||
model: Whisper,
|
||||
model: Model,
|
||||
rng: rand::rngs::StdRng,
|
||||
task: Option<Task>,
|
||||
timestamps: bool,
|
||||
@ -60,7 +100,7 @@ struct Decoder {
|
||||
impl Decoder {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Whisper,
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
device: &Device,
|
||||
@ -72,9 +112,9 @@ impl Decoder {
|
||||
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
|
||||
// Suppress the notimestamps token when in timestamps mode.
|
||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
|
||||
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
||||
let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
|
||||
.map(|i| {
|
||||
if model.config.suppress_tokens.contains(&i)
|
||||
if model.config().suppress_tokens.contains(&i)
|
||||
|| timestamps && i == no_timestamps_token
|
||||
{
|
||||
f32::NEG_INFINITY
|
||||
@ -109,11 +149,11 @@ impl Decoder {
|
||||
|
||||
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
||||
let model = &mut self.model;
|
||||
let audio_features = model.encoder.forward(mel, true)?;
|
||||
let audio_features = model.encoder_forward(mel, true)?;
|
||||
if self.verbose {
|
||||
println!("audio features: {:?}", audio_features.dims());
|
||||
}
|
||||
let sample_len = model.config.max_target_positions / 2;
|
||||
let sample_len = model.config().max_target_positions / 2;
|
||||
let mut sum_logprob = 0f64;
|
||||
let mut no_speech_prob = f64::NAN;
|
||||
let mut tokens = vec![self.sot_token];
|
||||
@ -133,12 +173,12 @@ impl Decoder {
|
||||
// The model expects a batch dim but this inference loop does not handle
|
||||
// it so we add it at this point.
|
||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||
let ys = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
|
||||
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
|
||||
|
||||
// Extract the no speech probability on the first iteration by looking at the first
|
||||
// token logits and the probability for the according token.
|
||||
if i == 0 {
|
||||
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
no_speech_prob = softmax(&logits, 0)?
|
||||
.i(self.no_speech_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
@ -146,8 +186,7 @@ impl Decoder {
|
||||
|
||||
let (_, seq_len, _) = ys.dims3()?;
|
||||
let logits = model
|
||||
.decoder
|
||||
.final_linear(&ys.i((..1, seq_len - 1..))?)?
|
||||
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
|
||||
.i(0)?
|
||||
.i(0)?;
|
||||
// TODO: Besides suppress tokens, we should apply the heuristics from
|
||||
@ -176,7 +215,7 @@ impl Decoder {
|
||||
let prob = softmax(&logits, candle::D::Minus1)?
|
||||
.i(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
|
||||
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
|
||||
break;
|
||||
}
|
||||
sum_logprob += prob.ln();
|
||||
@ -333,6 +372,7 @@ impl WhichModel {
|
||||
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn model_and_revision(&self) -> (&'static str, &'static str) {
|
||||
match self {
|
||||
Self::Tiny => ("openai/whisper-tiny", "main"),
|
||||
@ -382,6 +422,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Language.
|
||||
#[arg(long)]
|
||||
language: Option<String>,
|
||||
@ -413,10 +456,13 @@ fn main() -> Result<()> {
|
||||
None
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (default_model, default_revision) = args.model.model_and_revision();
|
||||
let (default_model, default_revision) = if args.quantized {
|
||||
("lmz/candle-whisper", "main")
|
||||
} else {
|
||||
args.model.model_and_revision()
|
||||
};
|
||||
let default_model = default_model.to_string();
|
||||
let default_revision = default_revision.to_string();
|
||||
let path = std::path::PathBuf::from(default_model.clone());
|
||||
let (model_id, revision) = match (args.model_id, args.revision) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
@ -424,20 +470,7 @@ fn main() -> Result<()> {
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let (config_filename, tokenizer_filename, weights_filename, input) = if path.exists() {
|
||||
let mut config_filename = path.clone();
|
||||
config_filename.push("config.json");
|
||||
let mut tokenizer_filename = path.clone();
|
||||
tokenizer_filename.push("tokenizer.json");
|
||||
let mut model_filename = path;
|
||||
model_filename.push("model.safetensors");
|
||||
(
|
||||
config_filename,
|
||||
tokenizer_filename,
|
||||
model_filename,
|
||||
std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")),
|
||||
)
|
||||
} else {
|
||||
let (config_filename, tokenizer_filename, weights_filename, input) = {
|
||||
let api = Api::new()?;
|
||||
let dataset = api.dataset("Narsil/candle-examples".to_string());
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
@ -451,12 +484,25 @@ fn main() -> Result<()> {
|
||||
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
|
||||
dataset.get("samples_jfk.wav")?
|
||||
};
|
||||
(
|
||||
repo.get("config.json")?,
|
||||
repo.get("tokenizer.json")?,
|
||||
repo.get("model.safetensors")?,
|
||||
sample,
|
||||
)
|
||||
let (config, tokenizer, model) = if args.quantized {
|
||||
let ext = match args.model {
|
||||
WhichModel::TinyEn => "tiny-en",
|
||||
WhichModel::Tiny => "tiny",
|
||||
_ => unimplemented!("no quantized support for {:?}", args.model),
|
||||
};
|
||||
(
|
||||
repo.get(&format!("config-{ext}.json"))?,
|
||||
repo.get(&format!("tokenizer-{ext}.json"))?,
|
||||
repo.get(&format!("model-{ext}-q80.gguf"))?,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
repo.get("config.json")?,
|
||||
repo.get("tokenizer.json")?,
|
||||
repo.get("model.safetensors")?,
|
||||
)
|
||||
};
|
||||
(config, tokenizer, model, sample)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
@ -481,11 +527,16 @@ fn main() -> Result<()> {
|
||||
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
||||
println!("loaded mel: {:?}", mel.dims());
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device);
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let mut model = Whisper::load(&vb, config)?;
|
||||
let mut model = if args.quantized {
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
|
||||
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
|
||||
} else {
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
|
||||
Model::Normal(m::model::Whisper::load(&vb, config)?)
|
||||
};
|
||||
|
||||
let language_token = match (args.model.is_multilingual(), args.language) {
|
||||
(true, None) => Some(multilingual::detect_language(&mut model, &tokenizer, &mel)?),
|
||||
|
@ -1,4 +1,3 @@
|
||||
use crate::Whisper;
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
@ -105,20 +104,28 @@ const LANGUAGES: [(&str, &str); 99] = [
|
||||
];
|
||||
|
||||
/// Returns the token id for the selected language.
|
||||
pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
|
||||
pub fn detect_language(
|
||||
model: &mut super::Model,
|
||||
tokenizer: &Tokenizer,
|
||||
mel: &Tensor,
|
||||
) -> Result<u32> {
|
||||
let (_bsize, _, seq_len) = mel.dims3()?;
|
||||
let mel = mel.narrow(2, 0, usize::min(seq_len, model.config.max_source_positions))?;
|
||||
let mel = mel.narrow(
|
||||
2,
|
||||
0,
|
||||
usize::min(seq_len, model.config().max_source_positions),
|
||||
)?;
|
||||
let device = mel.device();
|
||||
let language_token_ids = LANGUAGES
|
||||
.iter()
|
||||
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?;
|
||||
let audio_features = model.encoder.forward(&mel, true)?;
|
||||
let audio_features = model.encoder_forward(&mel, true)?;
|
||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||
let ys = model.decoder.forward(&tokens, &audio_features, true)?;
|
||||
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
|
||||
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
let logits = logits.index_select(&language_token_ids, 0)?;
|
||||
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
||||
let probs = probs.to_vec1::<f32>()?;
|
||||
|
27
candle-examples/examples/wuerstchen/README.md
Normal file
27
candle-examples/examples/wuerstchen/README.md
Normal file
@ -0,0 +1,27 @@
|
||||
# candle-wuerstchen: Efficient Pretraining of Text-to-Image Models
|
||||
|
||||

|
||||
|
||||
The `wuerstchen` example is a port of the [diffusers
|
||||
implementation](https://github.com/huggingface/diffusers/tree/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen) for Würstchen v2.
|
||||
The candle implementation reproduces the same structure/files for models and
|
||||
pipelines. Useful resources:
|
||||
|
||||
- [Official implementation](https://github.com/dome272/Wuerstchen).
|
||||
- [Arxiv paper](https://arxiv.org/abs/2306.00637).
|
||||
- Blog post: [Introducing Würstchen: Fast Diffusion for Image Generation](https://huggingface.co/blog/wuerstchen).
|
||||
|
||||
## Getting the weights
|
||||
|
||||
The weights are automatically downloaded for you from the [HuggingFace
|
||||
Hub](https://huggingface.co/) on the first run. There are various command line
|
||||
flags to use local files instead, run with `--help` to learn about them.
|
||||
|
||||
## Running some example.
|
||||
|
||||
```bash
|
||||
cargo run --example wuerstchen --release --features cuda,cudnn -- \
|
||||
--prompt "Anthropomorphic cat dressed as a fire fighter"
|
||||
```
|
||||
|
||||
The final image is named `sd_final.png` by default.
|
BIN
candle-examples/examples/wuerstchen/assets/cat.jpg
Normal file
BIN
candle-examples/examples/wuerstchen/assets/cat.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 38 KiB |
@ -1,5 +1,3 @@
|
||||
#![allow(unused)]
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
@ -10,11 +8,11 @@ use candle_transformers::models::stable_diffusion;
|
||||
use candle_transformers::models::wuerstchen;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, IndexOp, Module, Tensor, D};
|
||||
use candle::{DType, Device, IndexOp, Tensor};
|
||||
use clap::Parser;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const PRIOR_GUIDANCE_SCALE: f64 = 8.0;
|
||||
const PRIOR_GUIDANCE_SCALE: f64 = 4.0;
|
||||
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
||||
const LATENT_DIM_SCALE: f64 = 10.67;
|
||||
const PRIOR_CIN: usize = 16;
|
||||
@ -41,6 +39,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// The height in pixels of the generated image.
|
||||
#[arg(long)]
|
||||
height: Option<usize>,
|
||||
@ -77,14 +78,6 @@ struct Args {
|
||||
/// The file specifying the tokenizer to used for prior tokenization.
|
||||
prior_tokenizer: Option<String>,
|
||||
|
||||
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
|
||||
#[arg(long)]
|
||||
sliced_attention_size: Option<usize>,
|
||||
|
||||
/// The number of steps to run the diffusion for.
|
||||
#[arg(long, default_value_t = 30)]
|
||||
n_steps: usize,
|
||||
|
||||
/// The number of samples to generate.
|
||||
#[arg(long, default_value_t = 1)]
|
||||
num_samples: i64,
|
||||
@ -217,10 +210,8 @@ fn run(args: Args) -> Result<()> {
|
||||
cpu,
|
||||
height,
|
||||
width,
|
||||
n_steps,
|
||||
tokenizer,
|
||||
final_image,
|
||||
sliced_attention_size,
|
||||
num_samples,
|
||||
clip_weights,
|
||||
prior_weights,
|
||||
@ -284,23 +275,27 @@ fn run(args: Args) -> Result<()> {
|
||||
)?;
|
||||
|
||||
let prior = {
|
||||
let prior_weights = ModelFile::Prior.get(prior_weights)?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let file = ModelFile::Prior.get(prior_weights)?;
|
||||
let vb = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
|
||||
};
|
||||
wuerstchen::prior::WPrior::new(
|
||||
/* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280,
|
||||
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
|
||||
/* c_in */ PRIOR_CIN,
|
||||
/* c */ 1536,
|
||||
/* c_cond */ 1280,
|
||||
/* c_r */ 64,
|
||||
/* depth */ 32,
|
||||
/* nhead */ 24,
|
||||
args.use_flash_attn,
|
||||
vb,
|
||||
)?
|
||||
};
|
||||
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||
let timesteps = prior_scheduler.timesteps();
|
||||
let timesteps = ×teps[..timesteps.len() - 1];
|
||||
println!("prior denoising");
|
||||
for (index, &t) in timesteps.iter().enumerate() {
|
||||
let start_time = std::time::Instant::now();
|
||||
if index == timesteps.len() - 1 {
|
||||
continue;
|
||||
}
|
||||
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
||||
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
|
||||
@ -317,10 +312,10 @@ fn run(args: Args) -> Result<()> {
|
||||
|
||||
println!("Building the vqgan.");
|
||||
let vqgan = {
|
||||
let vqgan_weights = ModelFile::VqGan.get(vqgan_weights)?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(vqgan_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let file = ModelFile::VqGan.get(vqgan_weights)?;
|
||||
let vb = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
|
||||
};
|
||||
wuerstchen::paella_vq::PaellaVQ::new(vb)?
|
||||
};
|
||||
|
||||
@ -328,10 +323,10 @@ fn run(args: Args) -> Result<()> {
|
||||
|
||||
// https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json
|
||||
let decoder = {
|
||||
let decoder_weights = ModelFile::Decoder.get(decoder_weights)?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(decoder_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let file = ModelFile::Decoder.get(decoder_weights)?;
|
||||
let vb = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
|
||||
};
|
||||
wuerstchen::diffnext::WDiffNeXt::new(
|
||||
/* c_in */ DECODER_CIN,
|
||||
/* c_out */ DECODER_CIN,
|
||||
@ -339,6 +334,7 @@ fn run(args: Args) -> Result<()> {
|
||||
/* c_cond */ 1024,
|
||||
/* clip_embd */ 1024,
|
||||
/* patch_size */ 2,
|
||||
args.use_flash_attn,
|
||||
vb,
|
||||
)?
|
||||
};
|
||||
@ -356,13 +352,11 @@ fn run(args: Args) -> Result<()> {
|
||||
)?;
|
||||
|
||||
println!("diffusion process with prior {image_embeddings:?}");
|
||||
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(12, Default::default())?;
|
||||
let timesteps = scheduler.timesteps();
|
||||
let timesteps = ×teps[..timesteps.len() - 1];
|
||||
for (index, &t) in timesteps.iter().enumerate() {
|
||||
let start_time = std::time::Instant::now();
|
||||
if index == timesteps.len() - 1 {
|
||||
continue;
|
||||
}
|
||||
let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?;
|
||||
let noise_pred =
|
||||
decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?;
|
||||
@ -376,9 +370,9 @@ fn run(args: Args) -> Result<()> {
|
||||
num_samples
|
||||
);
|
||||
let image = vqgan.decode(&(&latents * 0.3764)?)?;
|
||||
// TODO: Add the clamping between 0 and 1.
|
||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image = (image.clamp(0f32, 1f32)? * 255.)?
|
||||
.to_dtype(DType::U8)?
|
||||
.i(0)?;
|
||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
||||
candle_examples::save_image(&image, image_filename)?
|
||||
}
|
||||
|
@ -146,9 +146,7 @@ pub fn main() -> Result<()> {
|
||||
|
||||
// Create the model and load the weights from the file.
|
||||
let model = args.model()?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &Device::Cpu)? };
|
||||
let config = args.config()?;
|
||||
let darknet = darknet::parse_config(config)?;
|
||||
let model = darknet.build_model(vb)?;
|
||||
|
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
||||
mod model;
|
||||
use model::{Multiples, YoloV8, YoloV8Pose};
|
||||
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
||||
use clap::{Parser, ValueEnum};
|
||||
@ -253,6 +253,14 @@ enum YoloTask {
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Model weights, in safetensors format.
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
@ -363,6 +371,7 @@ impl Task for YoloV8Pose {
|
||||
}
|
||||
|
||||
pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
// Create the model and load the weights from the file.
|
||||
let multiples = match args.which {
|
||||
Which::N => Multiples::n(),
|
||||
@ -372,9 +381,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
||||
Which::X => Multiples::x(),
|
||||
};
|
||||
let model = args.model()?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let model = T::load(vb, multiples)?;
|
||||
println!("model loaded");
|
||||
for image_name in args.images.iter() {
|
||||
@ -405,7 +412,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
||||
Tensor::from_vec(
|
||||
data,
|
||||
(img.height() as usize, img.width() as usize, 3),
|
||||
&Device::Cpu,
|
||||
&device,
|
||||
)?
|
||||
.permute((2, 0, 1))?
|
||||
};
|
||||
@ -430,7 +437,19 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
match args.task {
|
||||
YoloTask::Detect => run::<YoloV8>(args)?,
|
||||
YoloTask::Pose => run::<YoloV8Pose>(args)?,
|
||||
|
@ -77,6 +77,7 @@ impl Module for Upsample {
|
||||
struct ConvBlock {
|
||||
conv: Conv2d,
|
||||
bn: BatchNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ConvBlock {
|
||||
@ -97,12 +98,17 @@ impl ConvBlock {
|
||||
};
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||
Ok(Self { conv, bn })
|
||||
Ok(Self {
|
||||
conv,
|
||||
bn,
|
||||
span: tracing::span!(tracing::Level::TRACE, "conv-block"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ConvBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = self.conv.forward(xs)?;
|
||||
let xs = self.bn.forward(&xs)?;
|
||||
candle_nn::ops::silu(&xs)
|
||||
@ -114,6 +120,7 @@ struct Bottleneck {
|
||||
cv1: ConvBlock,
|
||||
cv2: ConvBlock,
|
||||
residual: bool,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Bottleneck {
|
||||
@ -123,12 +130,18 @@ impl Bottleneck {
|
||||
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 3, 1, None)?;
|
||||
let cv2 = ConvBlock::load(vb.pp("cv2"), c_, c2, 3, 1, None)?;
|
||||
let residual = c1 == c2 && shortcut;
|
||||
Ok(Self { cv1, cv2, residual })
|
||||
Ok(Self {
|
||||
cv1,
|
||||
cv2,
|
||||
residual,
|
||||
span: tracing::span!(tracing::Level::TRACE, "bottleneck"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Bottleneck {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let ys = self.cv2.forward(&self.cv1.forward(xs)?)?;
|
||||
if self.residual {
|
||||
xs + ys
|
||||
@ -143,6 +156,7 @@ struct C2f {
|
||||
cv1: ConvBlock,
|
||||
cv2: ConvBlock,
|
||||
bottleneck: Vec<Bottleneck>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl C2f {
|
||||
@ -159,12 +173,14 @@ impl C2f {
|
||||
cv1,
|
||||
cv2,
|
||||
bottleneck,
|
||||
span: tracing::span!(tracing::Level::TRACE, "c2f"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for C2f {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let ys = self.cv1.forward(xs)?;
|
||||
let mut ys = ys.chunk(2, 1)?;
|
||||
for m in self.bottleneck.iter() {
|
||||
@ -180,6 +196,7 @@ struct Sppf {
|
||||
cv1: ConvBlock,
|
||||
cv2: ConvBlock,
|
||||
k: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Sppf {
|
||||
@ -187,12 +204,18 @@ impl Sppf {
|
||||
let c_ = c1 / 2;
|
||||
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 1, 1, None)?;
|
||||
let cv2 = ConvBlock::load(vb.pp("cv2"), c_ * 4, c2, 1, 1, None)?;
|
||||
Ok(Self { cv1, cv2, k })
|
||||
Ok(Self {
|
||||
cv1,
|
||||
cv2,
|
||||
k,
|
||||
span: tracing::span!(tracing::Level::TRACE, "sppf"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Sppf {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_, _, _, _) = xs.dims4()?;
|
||||
let xs = self.cv1.forward(xs)?;
|
||||
let xs2 = xs
|
||||
@ -215,17 +238,23 @@ impl Module for Sppf {
|
||||
struct Dfl {
|
||||
conv: Conv2d,
|
||||
num_classes: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Dfl {
|
||||
fn load(vb: VarBuilder, num_classes: usize) -> Result<Self> {
|
||||
let conv = conv2d_no_bias(num_classes, 1, 1, Default::default(), vb.pp("conv"))?;
|
||||
Ok(Self { conv, num_classes })
|
||||
Ok(Self {
|
||||
conv,
|
||||
num_classes,
|
||||
span: tracing::span!(tracing::Level::TRACE, "dfl"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Dfl {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_sz, _channels, anchors) = xs.dims3()?;
|
||||
let xs = xs
|
||||
.reshape((b_sz, 4, self.num_classes, anchors))?
|
||||
@ -247,6 +276,7 @@ struct DarkNet {
|
||||
b4_0: ConvBlock,
|
||||
b4_1: C2f,
|
||||
b5: Sppf,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl DarkNet {
|
||||
@ -330,10 +360,12 @@ impl DarkNet {
|
||||
b4_0,
|
||||
b4_1,
|
||||
b5,
|
||||
span: tracing::span!(tracing::Level::TRACE, "darknet"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||
let _enter = self.span.enter();
|
||||
let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?;
|
||||
let x2 = self
|
||||
.b2_2
|
||||
@ -354,6 +386,7 @@ struct YoloV8Neck {
|
||||
n4: C2f,
|
||||
n5: ConvBlock,
|
||||
n6: C2f,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl YoloV8Neck {
|
||||
@ -413,10 +446,12 @@ impl YoloV8Neck {
|
||||
n4,
|
||||
n5,
|
||||
n6,
|
||||
span: tracing::span!(tracing::Level::TRACE, "neck"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||
let _enter = self.span.enter();
|
||||
let x = self
|
||||
.n1
|
||||
.forward(&Tensor::cat(&[&self.up.forward(p5)?, p4], 1)?)?;
|
||||
@ -440,6 +475,7 @@ struct DetectionHead {
|
||||
cv3: [(ConvBlock, ConvBlock, Conv2d); 3],
|
||||
ch: usize,
|
||||
no: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -447,6 +483,7 @@ struct PoseHead {
|
||||
detect: DetectionHead,
|
||||
cv4: [(ConvBlock, ConvBlock, Conv2d); 3],
|
||||
kpt: (usize, usize),
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
fn make_anchors(
|
||||
@ -519,6 +556,7 @@ impl DetectionHead {
|
||||
cv3,
|
||||
ch,
|
||||
no,
|
||||
span: tracing::span!(tracing::Level::TRACE, "detection-head"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -547,6 +585,7 @@ impl DetectionHead {
|
||||
}
|
||||
|
||||
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<DetectionHeadOut> {
|
||||
let _enter = self.span.enter();
|
||||
let forward_cv = |xs, i: usize| {
|
||||
let xs_2 = self.cv2[i].0.forward(xs)?;
|
||||
let xs_2 = self.cv2[i].1.forward(&xs_2)?;
|
||||
@ -606,7 +645,12 @@ impl PoseHead {
|
||||
Self::load_cv4(vb.pp("cv4.1"), c4, nk, filters.1)?,
|
||||
Self::load_cv4(vb.pp("cv4.2"), c4, nk, filters.2)?,
|
||||
];
|
||||
Ok(Self { detect, cv4, kpt })
|
||||
Ok(Self {
|
||||
detect,
|
||||
cv4,
|
||||
kpt,
|
||||
span: tracing::span!(tracing::Level::TRACE, "pose-head"),
|
||||
})
|
||||
}
|
||||
|
||||
fn load_cv4(
|
||||
@ -622,6 +666,7 @@ impl PoseHead {
|
||||
}
|
||||
|
||||
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let d = self.detect.forward(xs0, xs1, xs2)?;
|
||||
let forward_cv = |xs: &Tensor, i: usize| {
|
||||
let (b_sz, _, h, w) = xs.dims4()?;
|
||||
@ -650,6 +695,7 @@ pub struct YoloV8 {
|
||||
net: DarkNet,
|
||||
fpn: YoloV8Neck,
|
||||
head: DetectionHead,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl YoloV8 {
|
||||
@ -657,12 +703,18 @@ impl YoloV8 {
|
||||
let net = DarkNet::load(vb.pp("net"), m)?;
|
||||
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
|
||||
let head = DetectionHead::load(vb.pp("head"), num_classes, m.filters())?;
|
||||
Ok(Self { net, fpn, head })
|
||||
Ok(Self {
|
||||
net,
|
||||
fpn,
|
||||
head,
|
||||
span: tracing::span!(tracing::Level::TRACE, "yolo-v8"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for YoloV8 {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (xs1, xs2, xs3) = self.net.forward(xs)?;
|
||||
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
|
||||
Ok(self.head.forward(&xs1, &xs2, &xs3)?.pred)
|
||||
@ -674,6 +726,7 @@ pub struct YoloV8Pose {
|
||||
net: DarkNet,
|
||||
fpn: YoloV8Neck,
|
||||
head: PoseHead,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl YoloV8Pose {
|
||||
@ -686,12 +739,18 @@ impl YoloV8Pose {
|
||||
let net = DarkNet::load(vb.pp("net"), m)?;
|
||||
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
|
||||
let head = PoseHead::load(vb.pp("head"), num_classes, kpt, m.filters())?;
|
||||
Ok(Self { net, fpn, head })
|
||||
Ok(Self {
|
||||
net,
|
||||
fpn,
|
||||
head,
|
||||
span: tracing::span!(tracing::Level::TRACE, "yolo-v8-pose"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for YoloV8Pose {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (xs1, xs2, xs3) = self.net.forward(xs)?;
|
||||
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
|
||||
self.head.forward(&xs1, &xs2, &xs3)
|
||||
|
@ -1,5 +1,6 @@
|
||||
pub mod coco_classes;
|
||||
pub mod imagenet;
|
||||
pub mod token_output_stream;
|
||||
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
|
86
candle-examples/src/token_output_stream.rs
Normal file
86
candle-examples/src/token_output_stream.rs
Normal file
@ -0,0 +1,86 @@
|
||||
use candle::Result;
|
||||
|
||||
/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a
|
||||
/// streaming way rather than having to wait for the full decoding.
|
||||
pub struct TokenOutputStream {
|
||||
tokenizer: tokenizers::Tokenizer,
|
||||
tokens: Vec<u32>,
|
||||
prev_index: usize,
|
||||
current_index: usize,
|
||||
}
|
||||
|
||||
impl TokenOutputStream {
|
||||
pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
|
||||
Self {
|
||||
tokenizer,
|
||||
tokens: Vec::new(),
|
||||
prev_index: 0,
|
||||
current_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> tokenizers::Tokenizer {
|
||||
self.tokenizer
|
||||
}
|
||||
|
||||
fn decode(&self, tokens: &[u32]) -> Result<String> {
|
||||
match self.tokenizer.decode(tokens, true) {
|
||||
Ok(str) => Ok(str),
|
||||
Err(err) => candle::bail!("cannot decode: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68
|
||||
pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
|
||||
let prev_text = if self.tokens.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let tokens = &self.tokens[self.prev_index..self.current_index];
|
||||
self.decode(tokens)?
|
||||
};
|
||||
self.tokens.push(token);
|
||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_ascii() {
|
||||
let text = text.split_at(prev_text.len());
|
||||
self.prev_index = self.current_index;
|
||||
self.current_index = self.tokens.len();
|
||||
Ok(Some(text.1.to_string()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_rest(&self) -> Result<Option<String>> {
|
||||
let prev_text = if self.tokens.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let tokens = &self.tokens[self.prev_index..self.current_index];
|
||||
self.decode(tokens)?
|
||||
};
|
||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||
if text.len() > prev_text.len() {
|
||||
let text = text.split_at(prev_text.len());
|
||||
Ok(Some(text.1.to_string()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_all(&self) -> Result<String> {
|
||||
self.decode(&self.tokens)
|
||||
}
|
||||
|
||||
pub fn get_token(&self, token_s: &str) -> Option<u32> {
|
||||
self.tokenizer.get_vocab(true).get(token_s).copied()
|
||||
}
|
||||
|
||||
pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
|
||||
&self.tokenizer
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.tokens.clear();
|
||||
self.prev_index = 0;
|
||||
self.current_index = 0;
|
||||
}
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.2.3"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.2.3", package = "candle-core" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.0", package = "candle-core" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
@ -21,4 +21,4 @@ rayon = "1.7.0"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3", features = ["cuda"] }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0", features = ["cuda"] }
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.2.3"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -77,20 +77,30 @@ CAST_OP(double, __half, cast_f64_f16)
|
||||
|
||||
CAST_OP(uint32_t, uint32_t, cast_u32_u32)
|
||||
CAST_OP(uint32_t, uint8_t, cast_u32_u8 )
|
||||
CAST_OP(uint32_t, int64_t, cast_u32_i64 )
|
||||
CAST_OP(uint32_t, float, cast_u32_f32)
|
||||
CAST_OP(uint32_t, double, cast_u32_f64)
|
||||
|
||||
CAST_OP(uint8_t, uint32_t, cast_u8_u32)
|
||||
CAST_OP(uint8_t, uint8_t, cast_u8_u8 )
|
||||
CAST_OP(uint8_t, int64_t, cast_u8_i64 )
|
||||
CAST_OP(uint8_t, float, cast_u8_f32)
|
||||
CAST_OP(uint8_t, double, cast_u8_f64)
|
||||
|
||||
CAST_OP(int64_t, uint32_t, cast_i64_u32)
|
||||
CAST_OP(int64_t, uint8_t, cast_i64_u8 )
|
||||
CAST_OP(int64_t, int64_t, cast_i64_i64 )
|
||||
CAST_OP(int64_t, float, cast_i64_f32)
|
||||
CAST_OP(int64_t, double, cast_i64_f64)
|
||||
|
||||
CAST_OP(float, uint8_t, cast_f32_u8 )
|
||||
CAST_OP(float, uint32_t, cast_f32_u32)
|
||||
CAST_OP(float, int64_t, cast_f32_i64 )
|
||||
CAST_OP(float, float, cast_f32_f32)
|
||||
CAST_OP(float, double, cast_f32_f64)
|
||||
|
||||
CAST_OP(double, uint8_t, cast_f64_u8 )
|
||||
CAST_OP(double, uint32_t, cast_f64_u32)
|
||||
CAST_OP(double, int64_t, cast_f64_i64 )
|
||||
CAST_OP(double, float, cast_f64_f32)
|
||||
CAST_OP(double, double, cast_f64_f64)
|
||||
|
@ -129,6 +129,16 @@ __device__ __forceinline__ float powg(float a, float b) { return powf(a, b); }
|
||||
__device__ __forceinline__ double powg(double a, double b) { return pow(a, b); }
|
||||
__device__ __forceinline__ float tanhg(float a) { return tanhf(a); }
|
||||
__device__ __forceinline__ double tanhg(double a) { return tanh(a); }
|
||||
__device__ __forceinline__ float erfg(float a) { return erff(a); }
|
||||
__device__ __forceinline__ double erfg(double a) { return erf(a); }
|
||||
__device__ __forceinline__ float ceilg(float a) { return ceilf(a); }
|
||||
__device__ __forceinline__ double ceilg(double a) { return ceil(a); }
|
||||
__device__ __forceinline__ float floorg(float a) { return floorf(a); }
|
||||
__device__ __forceinline__ double floorg(double a) { return floor(a); }
|
||||
__device__ __forceinline__ float roundg(float a) { return roundf(a); }
|
||||
__device__ __forceinline__ double roundg(double a) { return round(a); }
|
||||
__device__ __forceinline__ float normcdfg(float a) { return normcdff(a); }
|
||||
__device__ __forceinline__ double normcdfg(double a) { return normcdf(a); }
|
||||
__device__ __forceinline__ float maxg(float a, float b) { return fmaxf(a, b); }
|
||||
__device__ __forceinline__ double maxg(double a, double b) { return fmax(a, b); }
|
||||
__device__ __forceinline__ float ming(float a, float b) { return fminf(a, b); }
|
||||
@ -157,6 +167,11 @@ __device__ __forceinline__ __half sing(__half a) { return hsin(a); }
|
||||
__device__ __forceinline__ __half recipg(__half a) { __half one = 1.0; return one / a; }
|
||||
__device__ __forceinline__ __half maxg(__half a, __half b) { return __hmax_nan(a, b); }
|
||||
__device__ __forceinline__ __half tanhg(__half a) { return __float2half(tanhf(__half2float(a))); }
|
||||
__device__ __forceinline__ __half erfg(__half a) { return __float2half(erff(__half2float(a))); }
|
||||
__device__ __forceinline__ __half ceilg(__half a) { return __float2half(ceilf(__half2float(a))); }
|
||||
__device__ __forceinline__ __half floorg(__half a) { return __float2half(floorf(__half2float(a))); }
|
||||
__device__ __forceinline__ __half roundg(__half a) { return __float2half(roundf(__half2float(a))); }
|
||||
__device__ __forceinline__ __half normcdfg(__half a) { return __float2half(normcdff(__half2float(a))); }
|
||||
__device__ __forceinline__ __half ming(__half a, __half b) { return __hmin_nan(a, b); }
|
||||
__device__ __forceinline__ __half logg(__half a) { return hlog(a); }
|
||||
__device__ __forceinline__ __half expg(__half a) { return hexp(a); }
|
||||
@ -173,6 +188,11 @@ __device__ __forceinline__ __nv_bfloat16 sing(__nv_bfloat16 a) { return hsin(a);
|
||||
__device__ __forceinline__ __nv_bfloat16 recipg(__nv_bfloat16 a) { __nv_bfloat16 one = 1.0; return one / a; }
|
||||
__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); }
|
||||
__device__ __forceinline__ __nv_bfloat16 tanhg(__nv_bfloat16 a) { return __float2bfloat16(tanhf(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 erfg(__nv_bfloat16 a) { return __float2bfloat16(erff(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 ceilg(__nv_bfloat16 a) { return __float2bfloat16(ceilf(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 floorg(__nv_bfloat16 a) { return __float2bfloat16(floorf(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 roundg(__nv_bfloat16 a) { return __float2bfloat16(roundf(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 normcdfg(__nv_bfloat16 a) { return __float2bfloat16(normcdff(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); }
|
||||
__device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); }
|
||||
__device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); }
|
||||
|
@ -1,3 +1,4 @@
|
||||
#include<stdint.h>
|
||||
#include "cuda_fp16.h"
|
||||
|
||||
template<typename T>
|
||||
@ -6,6 +7,14 @@ __device__ void fill_with(T *buf, T value, const size_t numel) {
|
||||
buf[i] = value;
|
||||
}
|
||||
}
|
||||
extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
#endif
|
||||
|
@ -12,25 +12,20 @@ __device__ void index_select(
|
||||
const T *inp,
|
||||
T *out,
|
||||
const size_t left_size,
|
||||
const size_t dim_size,
|
||||
const size_t src_dim_size,
|
||||
const size_t ids_dim_size,
|
||||
const size_t right_size
|
||||
) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
if (is_contiguous(num_dims, dims, strides)) {
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
for (unsigned int j = 0; j < left_size; ++j) {
|
||||
memcpy(&out[(i + j * numel) * right_size], &inp[(j * dim_size + ids[i]) * right_size], right_size * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||
for (unsigned int j = 0; j < left_size; ++j) {
|
||||
memcpy(&out[(i + j * numel) * right_size], &inp[(j * dim_size + ids[strided_i]) * right_size], right_size * sizeof(T));
|
||||
}
|
||||
}
|
||||
bool b = is_contiguous(num_dims, dims, strides);
|
||||
for (unsigned int dst_i = blockIdx.x * blockDim.x + threadIdx.x; dst_i < numel; dst_i += blockDim.x * gridDim.x) {
|
||||
unsigned int left_i = dst_i / (ids_dim_size * right_size);
|
||||
unsigned int id_i = dst_i / right_size % ids_dim_size;
|
||||
unsigned int right_i = dst_i % right_size;
|
||||
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
|
||||
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
|
||||
out[dst_i] = inp[strided_i];
|
||||
}
|
||||
}
|
||||
|
||||
@ -43,9 +38,10 @@ extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out, \
|
||||
const size_t left_size, \
|
||||
const size_t dim_size, \
|
||||
const size_t src_dim_size, \
|
||||
const size_t ids_dim_size, \
|
||||
const size_t right_size \
|
||||
) { index_select(numel, num_dims, info, ids, inp, out, left_size, dim_size, right_size); } \
|
||||
) { index_select(numel, num_dims, info, ids, inp, out, left_size, src_dim_size, ids_dim_size, right_size); } \
|
||||
|
||||
template<typename T, typename I>
|
||||
__device__ void gather(
|
||||
|
@ -28,6 +28,11 @@ extern "C" __global__ void FN_NAME( \
|
||||
} \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_erf_fwd(T x) {
|
||||
return x * normcdfg(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_fwd(T x) {
|
||||
T x_sq = x * x;
|
||||
@ -86,10 +91,16 @@ UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x))
|
||||
UNARY_OP(__nv_bfloat16, usin_bf16, sing(x))
|
||||
UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x))
|
||||
UNARY_OP(__nv_bfloat16, utanh_bf16, tanhg(x))
|
||||
UNARY_OP(__nv_bfloat16, uerf_bf16, erfg(x))
|
||||
UNARY_OP(__nv_bfloat16, uceil_bf16, ceilg(x))
|
||||
UNARY_OP(__nv_bfloat16, ufloor_bf16, floorg(x))
|
||||
UNARY_OP(__nv_bfloat16, uround_bf16, roundg(x))
|
||||
UNARY_OP(__nv_bfloat16, unormcdf_bf16, normcdfg(x))
|
||||
UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x))
|
||||
UNARY_OP(__nv_bfloat16, usqr_bf16, x*x)
|
||||
UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x))
|
||||
UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
|
||||
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
|
||||
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
||||
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
||||
@ -104,10 +115,16 @@ UNARY_OP(__half, ulog_f16, logg(x))
|
||||
UNARY_OP(__half, usin_f16, sing(x))
|
||||
UNARY_OP(__half, ucos_f16, cosg(x))
|
||||
UNARY_OP(__half, utanh_f16, tanhg(x))
|
||||
UNARY_OP(__half, uerf_f16, erfg(x))
|
||||
UNARY_OP(__half, uceil_f16, ceilg(x))
|
||||
UNARY_OP(__half, ufloor_f16, floorg(x))
|
||||
UNARY_OP(__half, uround_f16, roundg(x))
|
||||
UNARY_OP(__half, unormcdf_f16, normcdfg(x))
|
||||
UNARY_OP(__half, uabs_f16, absg(x))
|
||||
UNARY_OP(__half, usqr_f16, x*x)
|
||||
UNARY_OP(__half, usqrt_f16, sqrtg(x))
|
||||
UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
|
||||
UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x))
|
||||
UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
||||
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
||||
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
||||
@ -115,6 +132,7 @@ UNARY_OP1(__half, upowf_f16, powg(x, param))
|
||||
|
||||
UNARY_OP(uint8_t, ucopy_u8, x)
|
||||
UNARY_OP(uint32_t, ucopy_u32, x)
|
||||
UNARY_OP(int64_t, ucopy_i64, x)
|
||||
UNARY_OP(float, ucopy_f32, x)
|
||||
UNARY_OP(double, ucopy_f64, x)
|
||||
UNARY_OP(float, uneg_f32, -x)
|
||||
@ -131,6 +149,16 @@ UNARY_OP(float, ucos_f32, cosg(x))
|
||||
UNARY_OP(double, ucos_f64, cosg(x))
|
||||
UNARY_OP(float, utanh_f32, tanhg(x))
|
||||
UNARY_OP(double, utanh_f64, tanhg(x))
|
||||
UNARY_OP(float, uerf_f32, erfg(x))
|
||||
UNARY_OP(double, uerf_f64, erfg(x))
|
||||
UNARY_OP(float, uceil_f32, ceilg(x))
|
||||
UNARY_OP(double, uceil_f64, ceilg(x))
|
||||
UNARY_OP(float, ufloor_f32, floorg(x))
|
||||
UNARY_OP(double, ufloor_f64, floorg(x))
|
||||
UNARY_OP(float, uround_f32, roundg(x))
|
||||
UNARY_OP(double, uround_f64, roundg(x))
|
||||
UNARY_OP(float, unormcdf_f32, normcdfg(x))
|
||||
UNARY_OP(double, unormcdf_f64, normcdfg(x))
|
||||
UNARY_OP(float, uabs_f32, absg(x))
|
||||
UNARY_OP(double, uabs_f64, absg(x))
|
||||
UNARY_OP(float, usqr_f32, x*x)
|
||||
@ -139,6 +167,8 @@ UNARY_OP(float, usqrt_f32, sqrtg(x))
|
||||
UNARY_OP(double, usqrt_f64, sqrtg(x))
|
||||
UNARY_OP(float, ugelu_f32, gelu_fwd(x))
|
||||
UNARY_OP(double, ugelu_f64, gelu_fwd(x))
|
||||
UNARY_OP(float, ugelu_erf_f32, gelu_erf_fwd(x))
|
||||
UNARY_OP(double, ugelu_erf_f64, gelu_erf_fwd(x))
|
||||
UNARY_OP(float, urelu_f32, relu_fwd(x))
|
||||
UNARY_OP(double, urelu_f64, relu_fwd(x))
|
||||
UNARY_OP1(float, uelu_f32, elu_fwd(x, param))
|
||||
|
@ -11,7 +11,7 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
half = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
|
@ -185,8 +185,8 @@ impl Benchmark for Matmul {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
|
||||
let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
|
||||
let lhs = Tensor::randn(0f32, 1., (1024 * 4, 1024 * 4), &Device::Cpu)?;
|
||||
let rhs = Tensor::randn(0f32, 1., (1024 * 4, 1), &Device::Cpu)?;
|
||||
Ok((lhs, rhs))
|
||||
}
|
||||
|
||||
@ -206,7 +206,7 @@ impl Benchmark for QMatMul {
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
|
||||
let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
|
||||
let mm = candle::quantized::QMatMul::from_qtensor(mm);
|
||||
let mm = candle::quantized::QMatMul::from_qtensor(mm)?;
|
||||
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
|
||||
Ok((mm, arg))
|
||||
}
|
||||
|
@ -9,6 +9,8 @@ pub enum Activation {
|
||||
#[serde(rename = "gated-gelu")]
|
||||
NewGelu,
|
||||
Relu,
|
||||
Silu,
|
||||
Sigmoid,
|
||||
Elu(f64),
|
||||
LeakyRelu(f64),
|
||||
}
|
||||
@ -16,12 +18,12 @@ pub enum Activation {
|
||||
impl super::Module for Activation {
|
||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Gelu => xs.gelu(),
|
||||
// TODO: This is "gelu_new", not the original "gelu".
|
||||
// There's some small numerical difference:
|
||||
Self::Gelu => xs.gelu_erf(),
|
||||
// https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
|
||||
Self::NewGelu => xs.gelu(),
|
||||
Self::Relu => xs.relu(),
|
||||
Self::Silu => crate::ops::silu(xs),
|
||||
Self::Sigmoid => crate::ops::sigmoid(xs),
|
||||
&Self::Elu(alpha) => xs.elu(alpha),
|
||||
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
|
||||
}
|
||||
|
@ -38,7 +38,7 @@ impl From<f64> for BatchNormConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct BatchNorm {
|
||||
running_mean: Tensor,
|
||||
running_var: Tensor,
|
||||
|
@ -20,7 +20,7 @@ impl Default for Conv1dConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Conv1d {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
@ -88,7 +88,7 @@ impl Default for Conv2dConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Conv2d {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
@ -157,7 +157,7 @@ impl Default for ConvTranspose2dConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ConvTranspose2d {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
|
@ -1,7 +1,7 @@
|
||||
//! Embedding Layer.
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Embedding {
|
||||
embeddings: Tensor,
|
||||
hidden_size: usize,
|
||||
|
@ -4,7 +4,7 @@
|
||||
use candle::{DType, Result, Tensor};
|
||||
|
||||
// This group norm version handles both weight and bias so removes the mean.
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GroupNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
|
@ -60,7 +60,7 @@ impl From<f64> for LayerNormConfig {
|
||||
}
|
||||
|
||||
// This layer norm version handles both weight and bias so removes the mean.
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct LayerNorm {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
@ -143,7 +143,7 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
|
||||
}
|
||||
|
||||
/// RmsNorm is a specialized version of the LayerNorm module.
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RmsNorm(LayerNorm);
|
||||
|
||||
impl RmsNorm {
|
||||
|
@ -19,7 +19,7 @@
|
||||
//! ```
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
|
@ -45,7 +45,8 @@ pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||
}
|
||||
|
||||
pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
|
||||
xs.relu()?.minimum(&(xs * negative_slope)?)
|
||||
let zeros = xs.zeros_like()?;
|
||||
xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope
|
||||
}
|
||||
|
||||
pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
|
||||
|
@ -41,6 +41,10 @@ impl Optimizer for SGD {
|
||||
type Config = f64;
|
||||
|
||||
fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
|
||||
let vars = vars
|
||||
.into_iter()
|
||||
.filter(|var| var.dtype().is_float())
|
||||
.collect();
|
||||
Ok(Self {
|
||||
vars,
|
||||
learning_rate,
|
||||
@ -116,6 +120,7 @@ impl Optimizer for AdamW {
|
||||
fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
|
||||
let vars = vars
|
||||
.into_iter()
|
||||
.filter(|var| var.dtype().is_float())
|
||||
.map(|var| {
|
||||
let dtype = var.dtype();
|
||||
let shape = var.shape();
|
||||
|
@ -4,7 +4,7 @@ use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
/// Trait for Recurrent Neural Networks.
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
pub trait RNN {
|
||||
type State;
|
||||
type State: Clone;
|
||||
|
||||
/// A zero state from which the recurrent network is usually initialized.
|
||||
fn zero_state(&self, batch_dim: usize) -> Result<Self::State>;
|
||||
@ -18,7 +18,7 @@ pub trait RNN {
|
||||
///
|
||||
/// The input should have dimensions [batch_size, seq_len, features].
|
||||
/// The initial state is the result of applying zero_state.
|
||||
fn seq(&self, input: &Tensor) -> Result<(Tensor, Self::State)> {
|
||||
fn seq(&self, input: &Tensor) -> Result<Vec<Self::State>> {
|
||||
let batch_dim = input.dim(0)?;
|
||||
let state = self.zero_state(batch_dim)?;
|
||||
self.seq_init(input, &state)
|
||||
@ -27,7 +27,23 @@ pub trait RNN {
|
||||
/// Applies multiple steps of the recurrent network.
|
||||
///
|
||||
/// The input should have dimensions [batch_size, seq_len, features].
|
||||
fn seq_init(&self, input: &Tensor, state: &Self::State) -> Result<(Tensor, Self::State)>;
|
||||
fn seq_init(&self, input: &Tensor, init_state: &Self::State) -> Result<Vec<Self::State>> {
|
||||
let (_b_size, seq_len, _features) = input.dims3()?;
|
||||
let mut output = Vec::with_capacity(seq_len);
|
||||
for seq_index in 0..seq_len {
|
||||
let input = input.i((.., seq_index, ..))?;
|
||||
let state = if seq_index == 0 {
|
||||
self.step(&input, init_state)?
|
||||
} else {
|
||||
self.step(&input, &output[seq_index - 1])?
|
||||
};
|
||||
output.push(state);
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Converts a sequence of state to a tensor.
|
||||
fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
/// The state for a LSTM network, this contains two tensors.
|
||||
@ -57,6 +73,7 @@ pub struct LSTMConfig {
|
||||
pub w_hh_init: super::Init,
|
||||
pub b_ih_init: Option<super::Init>,
|
||||
pub b_hh_init: Option<super::Init>,
|
||||
pub layer_idx: usize,
|
||||
}
|
||||
|
||||
impl Default for LSTMConfig {
|
||||
@ -66,6 +83,7 @@ impl Default for LSTMConfig {
|
||||
w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
|
||||
b_ih_init: Some(super::Init::Const(0.)),
|
||||
b_hh_init: Some(super::Init::Const(0.)),
|
||||
layer_idx: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -77,6 +95,7 @@ impl LSTMConfig {
|
||||
w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
|
||||
b_ih_init: None,
|
||||
b_hh_init: None,
|
||||
layer_idx: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -85,7 +104,7 @@ impl LSTMConfig {
|
||||
///
|
||||
/// <https://en.wikipedia.org/wiki/Long_short-term_memory>
|
||||
#[allow(clippy::upper_case_acronyms, unused)]
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct LSTM {
|
||||
w_ih: Tensor,
|
||||
w_hh: Tensor,
|
||||
@ -104,22 +123,27 @@ pub fn lstm(
|
||||
config: LSTMConfig,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<LSTM> {
|
||||
let layer_idx = config.layer_idx;
|
||||
let w_ih = vb.get_with_hints(
|
||||
(4 * hidden_dim, in_dim),
|
||||
"weight_ih_l0", // Only a single layer is supported.
|
||||
&format!("weight_ih_l{layer_idx}"), // Only a single layer is supported.
|
||||
config.w_ih_init,
|
||||
)?;
|
||||
let w_hh = vb.get_with_hints(
|
||||
(4 * hidden_dim, hidden_dim),
|
||||
"weight_hh_l0", // Only a single layer is supported.
|
||||
&format!("weight_hh_l{layer_idx}"), // Only a single layer is supported.
|
||||
config.w_hh_init,
|
||||
)?;
|
||||
let b_ih = match config.b_ih_init {
|
||||
Some(init) => Some(vb.get_with_hints(4 * hidden_dim, "bias_ih_l0", init)?),
|
||||
Some(init) => {
|
||||
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_ih_l{layer_idx}"), init)?)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
let b_hh = match config.b_hh_init {
|
||||
Some(init) => Some(vb.get_with_hints(4 * hidden_dim, "bias_hh_l0", init)?),
|
||||
Some(init) => {
|
||||
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_hh_l{layer_idx}"), init)?)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
Ok(LSTM {
|
||||
@ -171,18 +195,9 @@ impl RNN for LSTM {
|
||||
})
|
||||
}
|
||||
|
||||
/// The input should have dimensions [batch_size, seq_len, features].
|
||||
fn seq_init(&self, input: &Tensor, in_state: &Self::State) -> Result<(Tensor, Self::State)> {
|
||||
let (_b_size, seq_len, _features) = input.dims3()?;
|
||||
let mut state = in_state.clone();
|
||||
let mut output: Vec<Tensor> = Vec::with_capacity(seq_len);
|
||||
for seq_index in 0..seq_len {
|
||||
let input = input.i((.., seq_index, ..))?;
|
||||
state = self.step(&input, &state)?;
|
||||
output.push(state.h.clone());
|
||||
}
|
||||
let output = Tensor::cat(&output, 1)?;
|
||||
Ok((output, state))
|
||||
fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> {
|
||||
let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>();
|
||||
Tensor::cat(&states, 1)
|
||||
}
|
||||
}
|
||||
|
||||
@ -235,7 +250,7 @@ impl GRUConfig {
|
||||
///
|
||||
/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
|
||||
#[allow(clippy::upper_case_acronyms, unused)]
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GRU {
|
||||
w_ih: Tensor,
|
||||
w_hh: Tensor,
|
||||
@ -314,17 +329,8 @@ impl RNN for GRU {
|
||||
Ok(GRUState { h: next_h })
|
||||
}
|
||||
|
||||
/// The input should have dimensions [batch_size, seq_len, features].
|
||||
fn seq_init(&self, input: &Tensor, in_state: &Self::State) -> Result<(Tensor, Self::State)> {
|
||||
let (_b_size, seq_len, _features) = input.dims3()?;
|
||||
let mut state = in_state.clone();
|
||||
let mut output: Vec<Tensor> = Vec::with_capacity(seq_len);
|
||||
for seq_index in 0..seq_len {
|
||||
let input = input.i((.., seq_index, ..))?;
|
||||
state = self.step(&input, &state)?;
|
||||
output.push(state.h.clone());
|
||||
}
|
||||
let output = Tensor::cat(&output, 1)?;
|
||||
Ok((output, state))
|
||||
fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> {
|
||||
let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>();
|
||||
Tensor::cat(&states, 1)
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
//! A `VarBuilder` is used to retrieve variables used by a model. These variables can either come
|
||||
//! from a pre-trained checkpoint, e.g. using `VarBuilder::from_safetensors`, or initialized for
|
||||
//! training, e.g. using `VarBuilder::from_varmap`.
|
||||
//! from a pre-trained checkpoint, e.g. using `VarBuilder::from_mmaped_safetensors`, or initialized
|
||||
//! for training, e.g. using `VarBuilder::from_varmap`.
|
||||
use crate::VarMap;
|
||||
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
|
||||
use safetensors::{slice::IndexOp, tensor::SafeTensors};
|
||||
@ -325,6 +325,58 @@ impl SimpleBackend for candle::npy::NpzTensors {
|
||||
}
|
||||
}
|
||||
|
||||
impl SimpleBackend for candle::safetensors::MmapedSafetensors {
|
||||
fn get(
|
||||
&self,
|
||||
s: Shape,
|
||||
name: &str,
|
||||
_: crate::Init,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
|
||||
if tensor.shape() != &s {
|
||||
Err(candle::Error::UnexpectedShape {
|
||||
msg: format!("shape mismatch for {name}"),
|
||||
expected: s,
|
||||
got: tensor.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
fn contains_tensor(&self, name: &str) -> bool {
|
||||
self.get(name).is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
impl SimpleBackend for candle::safetensors::BufferedSafetensors {
|
||||
fn get(
|
||||
&self,
|
||||
s: Shape,
|
||||
name: &str,
|
||||
_: crate::Init,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
|
||||
if tensor.shape() != &s {
|
||||
Err(candle::Error::UnexpectedShape {
|
||||
msg: format!("shape mismatch for {name}"),
|
||||
expected: s,
|
||||
got: tensor.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
fn contains_tensor(&self, name: &str) -> bool {
|
||||
self.get(name).is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> VarBuilder<'a> {
|
||||
fn new(backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device) -> Self {
|
||||
let data = TensorData {
|
||||
@ -362,18 +414,23 @@ impl<'a> VarBuilder<'a> {
|
||||
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
|
||||
/// files.
|
||||
pub fn from_safetensors(safetensors: Vec<SafeTensors<'a>>, dtype: DType, dev: &Device) -> Self {
|
||||
let mut routing = HashMap::new();
|
||||
for (index, sf) in safetensors.iter().enumerate() {
|
||||
for k in sf.names() {
|
||||
routing.insert(k.to_string(), index);
|
||||
}
|
||||
}
|
||||
let tensors = SafeTensorWithRouting {
|
||||
routing,
|
||||
safetensors,
|
||||
};
|
||||
Self::new(Box::new(tensors), dtype, dev.clone())
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn from_mmaped_safetensors<P: AsRef<std::path::Path>>(
|
||||
paths: &[P],
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Self> {
|
||||
let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
|
||||
Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` from a binary builder in the safetensor format.
|
||||
pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
|
||||
Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
|
||||
@ -383,27 +440,24 @@ impl<'a> VarBuilder<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ShardedSafeTensors<'a>(SafeTensorWithRouting<'a>);
|
||||
pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors<'a>>;
|
||||
pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors);
|
||||
pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors>;
|
||||
|
||||
impl<'a> ShardedSafeTensors<'a> {
|
||||
pub fn var_builder(
|
||||
safetensors: Vec<SafeTensors<'a>>,
|
||||
impl ShardedSafeTensors {
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
|
||||
/// files and make them usable in a sharded way.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn var_builder<P: AsRef<std::path::Path>>(
|
||||
paths: &[P],
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> ShardedVarBuilder<'a> {
|
||||
let mut routing = HashMap::new();
|
||||
for (index, sf) in safetensors.iter().enumerate() {
|
||||
for k in sf.names() {
|
||||
routing.insert(k.to_string(), index);
|
||||
}
|
||||
}
|
||||
let tensors = SafeTensorWithRouting {
|
||||
routing,
|
||||
safetensors,
|
||||
};
|
||||
) -> Result<ShardedVarBuilder<'static>> {
|
||||
let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
|
||||
let backend = ShardedSafeTensors(tensors);
|
||||
VarBuilderArgs::new_with_args(backend, dtype, dev)
|
||||
Ok(VarBuilderArgs::new_with_args(backend, dtype, dev))
|
||||
}
|
||||
}
|
||||
|
||||
@ -435,7 +489,7 @@ impl Default for Shard {
|
||||
/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))`
|
||||
/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))`
|
||||
/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))`
|
||||
impl<'a> Backend for ShardedSafeTensors<'a> {
|
||||
impl Backend for ShardedSafeTensors {
|
||||
type Hints = Shard;
|
||||
|
||||
fn get(
|
||||
@ -451,18 +505,7 @@ impl<'a> Backend for ShardedSafeTensors<'a> {
|
||||
rank,
|
||||
world_size,
|
||||
} = h;
|
||||
let SafeTensorWithRouting {
|
||||
routing,
|
||||
safetensors,
|
||||
} = &self.0;
|
||||
let index = routing.get(path).ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?;
|
||||
|
||||
let view = safetensors[*index].tensor(path)?;
|
||||
let view = self.0.get(path)?;
|
||||
let view_dtype = view.dtype();
|
||||
let mut shape = view.shape().to_vec();
|
||||
let size = shape[dim];
|
||||
@ -505,6 +548,6 @@ impl<'a> Backend for ShardedSafeTensors<'a> {
|
||||
}
|
||||
|
||||
fn contains_tensor(&self, name: &str) -> bool {
|
||||
self.0.routing.contains_key(name)
|
||||
self.0.get(name).is_ok()
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use candle::{safetensors::Load, DType, Device, Result, Shape, Tensor, Var};
|
||||
use candle::{DType, Device, Result, Shape, Tensor, Var};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
@ -40,18 +40,50 @@ impl VarMap {
|
||||
/// Note that values for variables that are currently not in the map are not kept.
|
||||
pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
|
||||
let path = path.as_ref();
|
||||
let data = unsafe { candle::safetensors::MmapedFile::new(path)? };
|
||||
let data = data.deserialize()?;
|
||||
let data = unsafe { candle::safetensors::MmapedSafetensors::new(path)? };
|
||||
let mut tensor_data = self.data.lock().unwrap();
|
||||
for (name, var) in tensor_data.iter_mut() {
|
||||
match data.tensor(name) {
|
||||
Ok(data) => {
|
||||
let data: Tensor = data.load(var.device())?;
|
||||
if let Err(err) = var.set(&data) {
|
||||
candle::bail!("error setting {name} using data from {path:?}: {err}",)
|
||||
let data = data.load(name, var.device())?;
|
||||
if let Err(err) = var.set(&data) {
|
||||
candle::bail!("error setting {name} using data from {path:?}: {err}",)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set a named variable to some value.
|
||||
pub fn set_one<K: AsRef<str>, V: AsRef<Tensor>>(&mut self, name: K, value: V) -> Result<()> {
|
||||
let tensor_data = self.data.lock().unwrap();
|
||||
let name = name.as_ref();
|
||||
match tensor_data.get(name) {
|
||||
None => candle::bail!("cannot find {name} in VarMap"),
|
||||
Some(var) => {
|
||||
if let Err(err) = var.set(value.as_ref()) {
|
||||
candle::bail!("error setting {name}: {err}",)
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set some named variables to some values.
|
||||
///
|
||||
/// If an error is returned, some of the variables might have already been set to their new
|
||||
/// values.
|
||||
pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<String>, V: AsRef<Tensor>>(
|
||||
&mut self,
|
||||
iter: I,
|
||||
) -> Result<()> {
|
||||
let tensor_data = self.data.lock().unwrap();
|
||||
for (name, value) in iter {
|
||||
let name = name.as_ref();
|
||||
match tensor_data.get(name) {
|
||||
None => candle::bail!("cannot find {name} in VarMap"),
|
||||
Some(var) => {
|
||||
if let Err(err) = var.set(value.as_ref()) {
|
||||
candle::bail!("error setting {name}: {err}",)
|
||||
}
|
||||
}
|
||||
Err(_) => candle::bail!("cannot find tensor for {name}"),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
1
candle-pyo3/.gitignore
vendored
1
candle-pyo3/.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
tests/_workdir
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
@ -14,8 +14,8 @@ name = "candle"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
half = { workspace = true }
|
||||
pyo3 = { version = "0.19.0", features = ["extension-module"] }
|
||||
|
||||
|
104
candle-pyo3/e5.py
Normal file
104
candle-pyo3/e5.py
Normal file
@ -0,0 +1,104 @@
|
||||
from candle.utils import load_safetensors, save_gguf, load_gguf
|
||||
from candle.models.bert import BertModel, Config
|
||||
import json
|
||||
from candle import Tensor
|
||||
from tqdm import tqdm
|
||||
from dataclasses import fields
|
||||
import os
|
||||
import time
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import BertTokenizer, AutoModel
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_name = "intfloat/e5-small-v2"
|
||||
model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors")
|
||||
config_file = hf_hub_download(repo_id=model_name, filename="config.json")
|
||||
|
||||
tensors = load_safetensors(model_file)
|
||||
config = Config()
|
||||
with open(config_file, "r") as f:
|
||||
raw_config = json.load(f)
|
||||
for field in fields(config):
|
||||
if field.name in raw_config:
|
||||
setattr(config, field.name, raw_config[field.name])
|
||||
|
||||
# Load the model
|
||||
model = BertModel(config)
|
||||
model.load_state_dict(tensors)
|
||||
|
||||
hf_model = AutoModel.from_pretrained(model_name)
|
||||
tokenizer = BertTokenizer.from_pretrained(model_name)
|
||||
|
||||
sentences = [
|
||||
"The cat sits outside",
|
||||
"A man is playing guitar",
|
||||
"I love pasta",
|
||||
"The new movie is awesome",
|
||||
"The cat plays in the garden",
|
||||
"A woman watches TV",
|
||||
"The new movie is so great",
|
||||
"Do you like pizza?",
|
||||
]
|
||||
|
||||
def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor):
|
||||
"""Average the hidden states according to the attention mask"""
|
||||
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||
|
||||
tokenized = tokenizer(sentences, padding=True)
|
||||
tokens = Tensor(tokenized["input_ids"])
|
||||
token_type_ids = Tensor(tokenized["token_type_ids"])
|
||||
encoder_out, _ = model.forward(tokens, token_type_ids)
|
||||
|
||||
hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt")
|
||||
hf_result = hf_model(**hf_tokenized)["last_hidden_state"]
|
||||
|
||||
hf_pooled = average_pool(hf_result, hf_tokenized["attention_mask"])
|
||||
candle_pooled = average_pool(torch.tensor(encoder_out.values()), hf_tokenized["attention_mask"])
|
||||
|
||||
loss = torch.nn.L1Loss()
|
||||
error = loss(hf_pooled, candle_pooled).mean().item()
|
||||
print(f"Mean error between torch-referenze and candle: {error}")
|
||||
|
||||
# Quantize all attention 'weights'
|
||||
quantized_tensors = {}
|
||||
for name, tensor in tqdm(tensors.items(), desc="Quantizing tensors to 5-Bit"):
|
||||
if name.endswith("weight") and ("attention" in name or "intermediate" in name or "output" in name):
|
||||
# check if the tensor is k-quantizable
|
||||
if tensor.shape[-1] % 256 == 0:
|
||||
new_tensor = tensor.quantize("q4k")
|
||||
else:
|
||||
new_tensor = tensor.quantize("q5_0")
|
||||
quantized_tensors[name] = new_tensor
|
||||
else:
|
||||
quantized_tensors[name] = tensor.quantize("q8_0")
|
||||
|
||||
print(f"Saving quantized tensors")
|
||||
# Remove all None values from the config
|
||||
config_to_save = {k: v for k, v in config.__dict__.items() if v is not None}
|
||||
# Save the model
|
||||
quantized_model_file = "e5_small.gguf"
|
||||
save_gguf(quantized_model_file, quantized_tensors, config_to_save)
|
||||
|
||||
file_size_mb = os.path.getsize(model_file) / 1024 / 1024
|
||||
file_size_mb_compressed = os.path.getsize(quantized_model_file) / 1024 / 1024
|
||||
print(f"Compressed model from {file_size_mb:.2f} MB to {file_size_mb_compressed:.2f} MB")
|
||||
# Load the model from the gguf
|
||||
tensors, raw_config = load_gguf(quantized_model_file)
|
||||
config = Config()
|
||||
for field in fields(config):
|
||||
if field.name in raw_config:
|
||||
setattr(config, field.name, raw_config[field.name])
|
||||
model = BertModel(config)
|
||||
# "embeddings.position_ids" is missing in the gguf as it is i64
|
||||
model.load_state_dict(tensors, strict=False)
|
||||
|
||||
# Run the model again
|
||||
encoder_out_2, pooled_output_2 = model.forward(tokens, token_type_ids)
|
||||
encoder_out_2, pooled_output_2 = encoder_out_2.to_device("cpu"), pooled_output_2.to_device("cpu")
|
||||
|
||||
candle_pooled_2 = average_pool(torch.tensor(encoder_out_2.values()), hf_tokenized["attention_mask"])
|
||||
error = loss(hf_pooled, candle_pooled_2).mean().item()
|
||||
print(f"Mean error between torch-referenze and quantized-candle: {error}")
|
@ -1,5 +1,30 @@
|
||||
from .candle import *
|
||||
import logging
|
||||
|
||||
try:
|
||||
from .candle import *
|
||||
except ImportError as e:
|
||||
# If we are in development mode, or we did not bundle the CUDA DLLs, we try to locate them here
|
||||
logging.warning("CUDA DLLs were not bundled with this package. Trying to locate them...")
|
||||
import os
|
||||
import platform
|
||||
|
||||
# Try to locate CUDA_PATH environment variable
|
||||
cuda_path = os.environ.get("CUDA_PATH", None)
|
||||
if cuda_path:
|
||||
logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}")
|
||||
if platform.system() == "Windows":
|
||||
cuda_path = os.path.join(cuda_path, "bin")
|
||||
else:
|
||||
cuda_path = os.path.join(cuda_path, "lib64")
|
||||
|
||||
logging.warning(f"Adding {cuda_path} to DLL search path...")
|
||||
os.add_dll_directory(cuda_path)
|
||||
|
||||
try:
|
||||
from .candle import *
|
||||
except ImportError as inner_e:
|
||||
raise ImportError("Could not locate CUDA DLLs. Please check the documentation for more information.")
|
||||
|
||||
__doc__ = candle.__doc__
|
||||
if hasattr(candle, "__all__"):
|
||||
__all__ = candle.__all__
|
||||
__all__ = candle.__all__
|
||||
|
8
candle-pyo3/py_src/candle/functional/__init__.py
Normal file
8
candle-pyo3/py_src/candle/functional/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Generated content DO NOT EDIT
|
||||
from .. import functional
|
||||
|
||||
gelu = functional.gelu
|
||||
relu = functional.relu
|
||||
silu = functional.silu
|
||||
softmax = functional.softmax
|
||||
tanh = functional.tanh
|
@ -4,6 +4,20 @@ from os import PathLike
|
||||
from candle.typing import _ArrayLike, Device
|
||||
from candle import Tensor, DType, QTensor
|
||||
|
||||
@staticmethod
|
||||
def gelu(tensor: Tensor) -> Tensor:
|
||||
"""
|
||||
Applies the Gaussian Error Linear Unit (GELU) function to a given tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def relu(tensor: Tensor) -> Tensor:
|
||||
"""
|
||||
Applies the Rectified Linear Unit (ReLU) function to a given tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def silu(tensor: Tensor) -> Tensor:
|
||||
"""
|
||||
@ -17,3 +31,10 @@ def softmax(tensor: Tensor, dim: int) -> Tensor:
|
||||
Applies the Softmax function to a given tensor.#
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def tanh(tensor: Tensor) -> Tensor:
|
||||
"""
|
||||
Applies the tanh function to a given tensor.
|
||||
"""
|
||||
pass
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user