mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Compare commits
8 Commits
tmp_broken
...
w-uncond
Author | SHA1 | Date | |
---|---|---|---|
4114872aae | |||
f2a648f313 | |||
ec895453cd | |||
3769d8bf71 | |||
5d8e214dfe | |||
576bf7c21f | |||
49a4fa44bb | |||
b936e32e11 |
BIN
.github/workflows/maturin.yml
vendored
BIN
.github/workflows/maturin.yml
vendored
Binary file not shown.
62
.github/workflows/python.yml
vendored
62
.github/workflows/python.yml
vendored
@ -1,62 +0,0 @@
|
||||
name: PyO3-CI
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- candle-pyo3/**
|
||||
pull_request:
|
||||
paths:
|
||||
- candle-pyo3/**
|
||||
|
||||
jobs:
|
||||
build_and_test:
|
||||
name: Check everything builds & tests
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest] # For now, only test on Linux
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Install Rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.11
|
||||
architecture: "x64"
|
||||
|
||||
- name: Cache Cargo Registry
|
||||
uses: actions/cache@v1
|
||||
with:
|
||||
path: ~/.cargo/registry
|
||||
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- name: Install
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
python -m venv .env
|
||||
source .env/bin/activate
|
||||
pip install -U pip
|
||||
pip install pytest maturin black
|
||||
python -m maturin develop -r
|
||||
|
||||
- name: Check style
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
python stub.py --check
|
||||
black --check .
|
||||
|
||||
- name: Run tests
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
python -m pytest -s -v tests
|
8
.gitignore
vendored
8
.gitignore
vendored
@ -23,16 +23,14 @@ flamegraph.svg
|
||||
*.dylib
|
||||
*.so
|
||||
*.swp
|
||||
*.swo
|
||||
trace-*.json
|
||||
|
||||
candle-wasm-examples/*/build
|
||||
candle-wasm-examples/*/*.bin
|
||||
candle-wasm-examples/*/*.jpeg
|
||||
candle-wasm-examples/*/audios/*.wav
|
||||
candle-wasm-examples/**/*.safetensors
|
||||
candle-wasm-examples/**/*.gguf
|
||||
candle-wasm-examples/*/*.wav
|
||||
candle-wasm-examples/*/*.safetensors
|
||||
candle-wasm-examples/*/package-lock.json
|
||||
candle-wasm-examples/**/config*.json
|
||||
|
||||
.DS_Store
|
||||
.idea/*
|
||||
|
11
.vscode/settings.json
vendored
11
.vscode/settings.json
vendored
@ -1,11 +0,0 @@
|
||||
{
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.black-formatter"
|
||||
},
|
||||
"python.formatting.provider": "none",
|
||||
"python.testing.pytestArgs": [
|
||||
"candle-pyo3"
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
}
|
28
CHANGELOG.md
28
CHANGELOG.md
@ -1,38 +1,12 @@
|
||||
# Changelog
|
||||
This documents the main changes to the `candle` crate.
|
||||
|
||||
## v0.3.1 - Unreleased
|
||||
## v0.2.3 - Unreleased
|
||||
|
||||
### Added
|
||||
|
||||
### Modified
|
||||
|
||||
## v0.3.0 - 2023-10-01
|
||||
|
||||
### Added
|
||||
|
||||
- Added the Mistral 7b v0.1 model
|
||||
[983](https://github.com/huggingface/candle/pull/983).
|
||||
- Quantized version of the Mistral model
|
||||
[1009](https://github.com/huggingface/candle/pull/1009).
|
||||
- Add the gelu-erf op and activation function
|
||||
[969](https://github.com/huggingface/candle/pull/969).
|
||||
- Add the mixformer/phi-v1.5 model
|
||||
[930](https://github.com/huggingface/candle/pull/930).
|
||||
- Add the sclice-scatter op
|
||||
[927](https://github.com/huggingface/candle/pull/927).
|
||||
- Add the Wuerstchen diffusion model
|
||||
[911](https://github.com/huggingface/candle/pull/911).
|
||||
|
||||
### Modified
|
||||
|
||||
- Support for simd128 intrinsics in some quantized vecdots
|
||||
[982](https://github.com/huggingface/candle/pull/982).
|
||||
- Optimize the index-select cuda kernel
|
||||
[976](https://github.com/huggingface/candle/pull/976).
|
||||
- Self-contained safetensor wrappers
|
||||
[946](https://github.com/huggingface/candle/pull/946).
|
||||
|
||||
## v0.2.2 - 2023-09-18
|
||||
|
||||
### Added
|
||||
|
22
Cargo.toml
22
Cargo.toml
@ -7,14 +7,19 @@ members = [
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
"candle-transformers",
|
||||
"candle-wasm-examples/*",
|
||||
"candle-wasm-tests",
|
||||
"candle-wasm-examples/llama2-c",
|
||||
"candle-wasm-examples/segment-anything",
|
||||
"candle-wasm-examples/whisper",
|
||||
"candle-wasm-examples/yolo",
|
||||
]
|
||||
exclude = [
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
]
|
||||
exclude = ["candle-flash-attn", "candle-kernels"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.3.0"
|
||||
version = "0.2.3"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -28,7 +33,8 @@ anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
cudarc = { version = "0.9.14", features = ["f16"] }
|
||||
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
|
||||
# TODO: Switch back to the official gemm implementation once it has caught up.
|
||||
gemm = { version = "0.16.0", package = "candle-gemm" }
|
||||
hf-hub = "0.3.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||
@ -36,10 +42,9 @@ imageproc = { version = "0.23.0", default-features = false }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
|
||||
memmap2 = "0.7.1"
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
parquet = { version = "45.0.0" }
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
rayon = "1.7.0"
|
||||
@ -53,9 +58,8 @@ tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
parquet = { version = "45.0.0" }
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
59
README.md
59
README.md
@ -8,7 +8,6 @@ Candle is a minimalist ML framework for Rust with a focus on performance (includ
|
||||
and ease of use. Try our online demos:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||
[yolo](https://huggingface.co/spaces/lmz/candle-yolo),
|
||||
[Segment
|
||||
Anything](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||
@ -53,22 +52,14 @@ These online demos run entirely in your browser:
|
||||
object recognition.
|
||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
|
||||
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
||||
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
||||
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
||||
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
|
||||
|
||||
We also provide a some command line based examples using state of the art models:
|
||||
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Phi-v1 and Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets.
|
||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||
performance larger than all publicly available 13b models as of 2023-09-28.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
|
||||
generation.
|
||||
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
||||
the LLaMA model using the same quantization techniques as
|
||||
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||
@ -80,11 +71,6 @@ We also provide a some command line based examples using state of the art models
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
|
||||
|
||||
- [Wuerstchen](./candle-examples/examples/wuerstchen/): another text to
|
||||
image generative model.
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/wuerstchen/assets/cat.jpg" width="200">
|
||||
|
||||
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
|
||||
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
|
||||
estimation models.
|
||||
@ -96,15 +82,10 @@ We also provide a some command line based examples using state of the art models
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||
|
||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
||||
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
using self-supervision (can be used for imagenet classification, depth
|
||||
evaluation, segmentation).
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
generate captions for an image.
|
||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||
model, generates the translated text from the input text.
|
||||
|
||||
Run them using commands like:
|
||||
```
|
||||
@ -119,8 +100,6 @@ There are also some wasm examples for whisper and
|
||||
`trunk` or try them online:
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||
[Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
|
||||
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||
|
||||
For LLaMA2, run the following command to retrieve the weight files and start a
|
||||
@ -136,13 +115,8 @@ And then head over to
|
||||
|
||||
<!--- ANCHOR: useful_libraries --->
|
||||
|
||||
## Useful External Resources
|
||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
|
||||
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
||||
- [`optimisers`](https://github.com/KGrewal1/optimisers): a collection of optimisers
|
||||
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
|
||||
that conforms to the official `peft` implementation.
|
||||
## Useful Libraries
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora) provides a LoRA implementation that conforms to the official `peft` implementation.
|
||||
|
||||
If you have an addition to this list, please submit a pull request.
|
||||
|
||||
@ -164,23 +138,15 @@ If you have an addition to this list, please submit a pull request.
|
||||
- LLaMA v1 and v2.
|
||||
- Falcon.
|
||||
- StarCoder.
|
||||
- Phi v1.5.
|
||||
- Mistral 7b v0.1.
|
||||
- StableLM-3B-4E1T.
|
||||
- Replit-code-v1.5-3B.
|
||||
- T5.
|
||||
- Bert.
|
||||
- Whisper (multi-lingual support).
|
||||
- Text to image.
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Wurstchen v2.
|
||||
- Image to text.
|
||||
- BLIP.
|
||||
- Text to text.
|
||||
- Marian MT (Machine Translation).
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
||||
- yolo-v3, yolo-v8.
|
||||
- DINOv2.
|
||||
- EfficientNet.
|
||||
- yolo-v3.
|
||||
- yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||
- Serverless (on CPU), small and fast deployments.
|
||||
@ -340,11 +306,6 @@ mdbook test candle-book -L .\target\debug\deps\ `
|
||||
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib
|
||||
```
|
||||
|
||||
#### Extremely slow model load time with WSL
|
||||
|
||||
This may be caused by the models being loaded from `/mnt/c`, more details on
|
||||
[stackoverflow](https://stackoverflow.com/questions/68972448/why-is-wsl-extremely-slow-when-compared-with-native-windows-npm-yarn-processing).
|
||||
|
||||
#### Tracking down errors
|
||||
|
||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
@ -24,10 +24,9 @@ intel-mkl-src = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true, optional = true }
|
||||
anyhow = { workspace = true }
|
||||
tokio = "1.29.1"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
clap = { workspace = true }
|
||||
@ -39,6 +38,7 @@ tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.29.1"
|
||||
parquet = { workspace = true }
|
||||
image = { workspace = true }
|
||||
|
||||
|
@ -14,7 +14,6 @@
|
||||
- [Using the hub](inference/hub.md)
|
||||
- [Error management](error_manage.md)
|
||||
- [Training](training/training.md)
|
||||
- [Simplified](training/simplified.md)
|
||||
- [MNIST](training/mnist.md)
|
||||
- [Fine-tuning]()
|
||||
- [Serialization]()
|
||||
|
@ -12,9 +12,6 @@ compute_cap
|
||||
8.9
|
||||
```
|
||||
|
||||
You can also compile the Cuda kernels for a specific compute cap using the
|
||||
`CUDA_COMPUTE_CAP=<compute cap>` environment variable.
|
||||
|
||||
If any of the above commands errors out, please make sure to update your Cuda version.
|
||||
|
||||
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
|
||||
|
@ -1,6 +1,3 @@
|
||||
#[cfg(test)]
|
||||
pub mod simplified;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
|
@ -1,196 +0,0 @@
|
||||
//! #A simplified example in Rust of training a neural network and then using it based on the Candle Framework by Hugging Face.
|
||||
//! Author: Evgeny Igumnov 2023 igumnovnsk@gmail.com
|
||||
//! This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
|
||||
//!
|
||||
//! ##Basic moments:
|
||||
//!
|
||||
//! A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
|
||||
//! The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
|
||||
//! The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
|
||||
//! For training, samples with real data on the results of the first and second stages of different elections are used.
|
||||
//! The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
|
||||
//! Model parameters (weights of neurons) are initialized randomly, then optimized during training.
|
||||
//! After training, the model is tested on a deferred sample to evaluate the accuracy.
|
||||
//! If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
|
||||
//! Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
|
||||
|
||||
#[rustfmt::skip]
|
||||
mod tests {
|
||||
|
||||
use candle::{DType, Result, Tensor, D, Device};
|
||||
use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap, Optimizer};
|
||||
|
||||
// ANCHOR: book_training_simplified1
|
||||
const VOTE_DIM: usize = 2;
|
||||
const RESULTS: usize = 1;
|
||||
const EPOCHS: usize = 10;
|
||||
const LAYER1_OUT_SIZE: usize = 4;
|
||||
const LAYER2_OUT_SIZE: usize = 2;
|
||||
const LEARNING_RATE: f64 = 0.05;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Dataset {
|
||||
pub train_votes: Tensor,
|
||||
pub train_results: Tensor,
|
||||
pub test_votes: Tensor,
|
||||
pub test_results: Tensor,
|
||||
}
|
||||
|
||||
struct MultiLevelPerceptron {
|
||||
ln1: Linear,
|
||||
ln2: Linear,
|
||||
ln3: Linear,
|
||||
}
|
||||
|
||||
impl MultiLevelPerceptron {
|
||||
fn new(vs: VarBuilder) -> Result<Self> {
|
||||
let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?;
|
||||
let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?;
|
||||
let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?;
|
||||
Ok(Self { ln1, ln2, ln3 })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.ln1.forward(xs)?;
|
||||
let xs = xs.relu()?;
|
||||
let xs = self.ln2.forward(&xs)?;
|
||||
let xs = xs.relu()?;
|
||||
self.ln3.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
// ANCHOR_END: book_training_simplified1
|
||||
|
||||
|
||||
|
||||
// ANCHOR: book_training_simplified3
|
||||
#[tokio::test]
|
||||
async fn simplified() -> anyhow::Result<()> {
|
||||
|
||||
let dev = Device::cuda_if_available(0)?;
|
||||
|
||||
let train_votes_vec: Vec<u32> = vec![
|
||||
15, 10,
|
||||
10, 15,
|
||||
5, 12,
|
||||
30, 20,
|
||||
16, 12,
|
||||
13, 25,
|
||||
6, 14,
|
||||
31, 21,
|
||||
];
|
||||
let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||
|
||||
let train_results_vec: Vec<u32> = vec![
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
];
|
||||
let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &dev)?;
|
||||
|
||||
let test_votes_vec: Vec<u32> = vec![
|
||||
13, 9,
|
||||
8, 14,
|
||||
3, 10,
|
||||
];
|
||||
let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||
|
||||
let test_results_vec: Vec<u32> = vec![
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
];
|
||||
let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &dev)?;
|
||||
|
||||
let m = Dataset {
|
||||
train_votes: train_votes_tensor,
|
||||
train_results: train_results_tensor,
|
||||
test_votes: test_votes_tensor,
|
||||
test_results: test_results_tensor,
|
||||
};
|
||||
|
||||
let trained_model: MultiLevelPerceptron;
|
||||
loop {
|
||||
println!("Trying to train neural network.");
|
||||
match train(m.clone(), &dev) {
|
||||
Ok(model) => {
|
||||
trained_model = model;
|
||||
break;
|
||||
},
|
||||
Err(e) => {
|
||||
println!("Error: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
let real_world_votes: Vec<u32> = vec![
|
||||
13, 22,
|
||||
];
|
||||
|
||||
let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
||||
|
||||
let final_result = trained_model.forward(&tensor_test_votes)?;
|
||||
|
||||
let result = final_result
|
||||
.argmax(D::Minus1)?
|
||||
.to_dtype(DType::F32)?
|
||||
.get(0).map(|x| x.to_scalar::<f32>())??;
|
||||
println!("real_life_votes: {:?}", real_world_votes);
|
||||
println!("neural_network_prediction_result: {:?}", result);
|
||||
|
||||
Ok(())
|
||||
|
||||
}
|
||||
// ANCHOR_END: book_training_simplified3
|
||||
|
||||
// ANCHOR: book_training_simplified2
|
||||
fn train(m: Dataset, dev: &Device) -> anyhow::Result<MultiLevelPerceptron> {
|
||||
let train_results = m.train_results.to_device(dev)?;
|
||||
let train_votes = m.train_votes.to_device(dev)?;
|
||||
let varmap = VarMap::new();
|
||||
let vs = VarBuilder::from_varmap(&varmap, DType::F32, dev);
|
||||
let model = MultiLevelPerceptron::new(vs.clone())?;
|
||||
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), LEARNING_RATE)?;
|
||||
let test_votes = m.test_votes.to_device(dev)?;
|
||||
let test_results = m.test_results.to_device(dev)?;
|
||||
let mut final_accuracy: f32 = 0.0;
|
||||
for epoch in 1..EPOCHS + 1 {
|
||||
let logits = model.forward(&train_votes)?;
|
||||
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
||||
let loss = loss::nll(&log_sm, &train_results)?;
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
let test_logits = model.forward(&test_votes)?;
|
||||
let sum_ok = test_logits
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_results)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_results.dims1()? as f32;
|
||||
final_accuracy = 100. * test_accuracy;
|
||||
println!("Epoch: {epoch:3} Train loss: {:8.5} Test accuracy: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
final_accuracy
|
||||
);
|
||||
if final_accuracy == 100.0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if final_accuracy < 100.0 {
|
||||
Err(anyhow::Error::msg("The model is not trained well enough."))
|
||||
} else {
|
||||
Ok(model)
|
||||
}
|
||||
}
|
||||
// ANCHOR_END: book_training_simplified2
|
||||
|
||||
|
||||
}
|
@ -1,45 +0,0 @@
|
||||
# Simplified
|
||||
|
||||
## How its works
|
||||
|
||||
This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
|
||||
|
||||
Basic moments:
|
||||
|
||||
1. A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
|
||||
2. The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
|
||||
3. The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
|
||||
4. For training, samples with real data on the results of the first and second stages of different elections are used.
|
||||
5. The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
|
||||
6. Model parameters (weights of neurons) are initialized randomly, then optimized during training.
|
||||
7. After training, the model is tested on a deferred sample to evaluate the accuracy.
|
||||
8. If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
|
||||
|
||||
Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
|
||||
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../simplified.rs:book_training_simplified1}}
|
||||
```
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../simplified.rs:book_training_simplified2}}
|
||||
```
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../simplified.rs:book_training_simplified3}}
|
||||
```
|
||||
|
||||
|
||||
## Example output
|
||||
|
||||
```bash
|
||||
Trying to train neural network.
|
||||
Epoch: 1 Train loss: 4.42555 Test accuracy: 0.00%
|
||||
Epoch: 2 Train loss: 0.84677 Test accuracy: 33.33%
|
||||
Epoch: 3 Train loss: 2.54335 Test accuracy: 33.33%
|
||||
Epoch: 4 Train loss: 0.37806 Test accuracy: 33.33%
|
||||
Epoch: 5 Train loss: 0.36647 Test accuracy: 100.00%
|
||||
real_life_votes: [13, 22]
|
||||
neural_network_prediction_result: 0.0
|
||||
```
|
@ -12,10 +12,7 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
metal = { workspace = true, optional = true}
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.2.3", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
half = { workspace = true }
|
||||
@ -29,7 +26,6 @@ rand_distr = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
yoke = { workspace = true }
|
||||
zip = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
@ -42,4 +38,3 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
||||
cudnn = ["cuda", "cudarc/cudnn"]
|
||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||
metal = ["dep:candle-metal-kernels", "dep:metal"]
|
||||
|
@ -103,10 +103,8 @@ enum Command {
|
||||
|
||||
Quantize {
|
||||
/// The input file, in gguf format.
|
||||
in_file: Vec<std::path::PathBuf>,
|
||||
|
||||
in_file: std::path::PathBuf,
|
||||
/// The output file, in gguf format.
|
||||
#[arg(long)]
|
||||
out_file: std::path::PathBuf,
|
||||
|
||||
/// The quantization schema to apply.
|
||||
@ -152,7 +150,8 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
||||
}
|
||||
}
|
||||
Format::Safetensors => {
|
||||
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
|
||||
let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? };
|
||||
let tensors = tensors.deserialize()?;
|
||||
let mut tensors = tensors.tensors();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, view) in tensors.iter() {
|
||||
@ -219,99 +218,15 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize_safetensors(
|
||||
in_files: &[std::path::PathBuf],
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
) -> Result<()> {
|
||||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let mut tensors = std::collections::HashMap::new();
|
||||
for in_file in in_files.iter() {
|
||||
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
|
||||
tensors.extend(in_tensors)
|
||||
}
|
||||
println!("tensors: {}", tensors.len());
|
||||
|
||||
let quantize_fn = match q {
|
||||
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||
Quantization::F32 => QTensor::quantize::<f32>,
|
||||
};
|
||||
let block_size = match q {
|
||||
Quantization::Q4_0 => k_quants::QK4_0,
|
||||
Quantization::Q4_1 => k_quants::QK4_1,
|
||||
Quantization::Q5_0 => k_quants::QK5_0,
|
||||
Quantization::Q5_1 => k_quants::QK5_1,
|
||||
Quantization::Q8_0 => k_quants::QK8_0,
|
||||
Quantization::Q8_1 => k_quants::QK8_1,
|
||||
Quantization::Q2k
|
||||
| Quantization::Q3k
|
||||
| Quantization::Q4k
|
||||
| Quantization::Q5k
|
||||
| Quantization::Q6k
|
||||
| Quantization::Q8k => k_quants::QK_K,
|
||||
Quantization::F16 | Quantization::F32 => 1,
|
||||
};
|
||||
|
||||
let qtensors = tensors
|
||||
.into_par_iter()
|
||||
.map(|(name, tensor)| {
|
||||
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
|
||||
println!(" quantizing {name} {tensor:?} {should_quantize}");
|
||||
let tensor = if should_quantize {
|
||||
quantize_fn(&tensor)?
|
||||
} else {
|
||||
QTensor::quantize::<f32>(&tensor)?
|
||||
};
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let qtensors = qtensors
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v))
|
||||
.collect::<Vec<_>>();
|
||||
gguf_file::write(&mut out_file, &[], &qtensors)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize(
|
||||
in_files: &[std::path::PathBuf],
|
||||
in_file: std::path::PathBuf,
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
qmode: QuantizationMode,
|
||||
) -> Result<()> {
|
||||
if in_files.is_empty() {
|
||||
candle_core::bail!("no specified input files")
|
||||
}
|
||||
if let Some(extension) = out_file.extension() {
|
||||
if extension == "safetensors" {
|
||||
candle_core::bail!("the generated file cannot use the safetensors extension")
|
||||
}
|
||||
}
|
||||
if let Some(extension) = in_files[0].extension() {
|
||||
if extension == "safetensors" {
|
||||
return run_quantize_safetensors(in_files, out_file, q);
|
||||
}
|
||||
}
|
||||
|
||||
if in_files.len() != 1 {
|
||||
candle_core::bail!("only a single in-file can be used when quantizing gguf files")
|
||||
}
|
||||
|
||||
// Open the out file early so as to fail directly on missing directories etc.
|
||||
let mut out_file = std::fs::File::create(out_file)?;
|
||||
let mut in_ = std::fs::File::open(&in_files[0])?;
|
||||
let mut in_ = std::fs::File::open(&in_file)?;
|
||||
let content = gguf_file::Content::read(&mut in_)?;
|
||||
println!("tensors: {}", content.tensor_infos.len());
|
||||
|
||||
@ -337,7 +252,7 @@ fn run_quantize(
|
||||
.par_iter()
|
||||
.map(|(name, _)| {
|
||||
println!(" quantizing {name}");
|
||||
let mut in_file = std::fs::File::open(&in_files[0])?;
|
||||
let mut in_file = std::fs::File::open(&in_file)?;
|
||||
let tensor = content.tensor(&mut in_file, name)?;
|
||||
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
|
||||
Ok((name, tensor))
|
||||
@ -378,7 +293,7 @@ fn main() -> anyhow::Result<()> {
|
||||
out_file,
|
||||
quantization,
|
||||
mode,
|
||||
} => run_quantize(&in_file, out_file, quantization, mode)?,
|
||||
} => run_quantize(in_file, out_file, quantization, mode)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -111,6 +111,4 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()>;
|
||||
}
|
||||
|
@ -36,8 +36,6 @@ impl Tensor {
|
||||
// Do not call recursively on the "leaf" nodes.
|
||||
track_grad = true;
|
||||
nodes
|
||||
} else if node.dtype().is_int() {
|
||||
nodes
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::IndexAdd(t1, t2, t3, _)
|
||||
@ -71,8 +69,7 @@ impl Tensor {
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::Gather(lhs, rhs, _)
|
||||
| Op::IndexSelect(lhs, rhs, _)
|
||||
| Op::Matmul(lhs, rhs)
|
||||
| Op::SliceScatter0(lhs, rhs, _) => {
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
||||
@ -93,9 +90,6 @@ impl Tensor {
|
||||
nodes
|
||||
}
|
||||
}
|
||||
Op::Unary(_node, UnaryOp::Ceil)
|
||||
| Op::Unary(_node, UnaryOp::Floor)
|
||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D(node)
|
||||
| Op::UpsampleNearest2D(node)
|
||||
@ -105,6 +99,7 @@ impl Tensor {
|
||||
| Op::Broadcast(node)
|
||||
| Op::Cmp(node, _)
|
||||
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Permute(node, _)
|
||||
@ -117,15 +112,6 @@ impl Tensor {
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
Op::ToDType(node) => {
|
||||
if node.dtype().is_float() {
|
||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
} else {
|
||||
nodes
|
||||
}
|
||||
}
|
||||
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
|
||||
}
|
||||
} else {
|
||||
@ -238,13 +224,6 @@ impl Tensor {
|
||||
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||
.transpose(0, 1)?;
|
||||
let sum_grad = grads.or_insert(kernel)?;
|
||||
let (_, _, k0, k1) = kernel.dims4()?;
|
||||
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
|
||||
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
|
||||
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
|
||||
} else {
|
||||
grad_kernel
|
||||
};
|
||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||
}
|
||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||
@ -291,15 +270,6 @@ impl Tensor {
|
||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest2d",
|
||||
})?,
|
||||
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
|
||||
}
|
||||
Op::Gather(arg, indexes, dim) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||
@ -391,7 +361,7 @@ impl Tensor {
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||
}
|
||||
Op::Copy(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
@ -471,26 +441,7 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
||||
Op::Unary(_, UnaryOp::Floor) => {
|
||||
Err(Error::BackwardNotSupported { op: "floor" })?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Round) => {
|
||||
Err(Error::BackwardNotSupported { op: "round" })?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Gelu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let cube = arg.powf(3.)?;
|
||||
let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
|
||||
let gelu_grad = (((0.5 * &tanh)?
|
||||
+ (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
|
||||
+ 0.5)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
||||
Op::Unary(_, UnaryOp::GeluErf) => {
|
||||
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||
Op::Unary(arg, UnaryOp::Relu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
|
@ -25,19 +25,6 @@ impl ParamsConv1D {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum CudnnFwdAlgo {
|
||||
ImplicitGemm,
|
||||
ImplicitPrecompGemm,
|
||||
Gemm,
|
||||
Direct,
|
||||
Fft,
|
||||
FftTiling,
|
||||
Winograd,
|
||||
WinogradNonFused,
|
||||
Count,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConv2D {
|
||||
pub(crate) b_size: usize,
|
||||
@ -50,7 +37,6 @@ pub struct ParamsConv2D {
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||
}
|
||||
|
||||
impl ParamsConv2D {
|
||||
@ -202,7 +188,6 @@ impl Tensor {
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
cudnn_fwd_algo: None,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv2d_single_group(kernel, ¶ms)
|
||||
|
@ -1,763 +0,0 @@
|
||||
#![allow(clippy::excessive_precision)]
|
||||
// Code taken from https://github.com/statrs-dev/statrs
|
||||
//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
|
||||
//! related functions
|
||||
|
||||
mod evaluate {
|
||||
//! Provides functions that don't have a numerical solution and must
|
||||
//! be solved computationally (e.g. evaluation of a polynomial)
|
||||
|
||||
/// evaluates a polynomial at `z` where `coeff` are the coeffecients
|
||||
/// to a polynomial of order `k` where `k` is the length of `coeff` and the
|
||||
/// coeffecient
|
||||
/// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
|
||||
/// `2z^2 - z + 3`
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// Returns 0 for a 0 length coefficient slice
|
||||
pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
|
||||
let n = coeff.len();
|
||||
if n == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut sum = *coeff.last().unwrap();
|
||||
for c in coeff[0..n - 1].iter().rev() {
|
||||
sum = *c + z * sum;
|
||||
}
|
||||
sum
|
||||
}
|
||||
}
|
||||
use std::f64;
|
||||
|
||||
/// `erf` calculates the error function at `x`.
|
||||
pub fn erf(x: f64) -> f64 {
|
||||
if x.is_nan() {
|
||||
f64::NAN
|
||||
} else if x >= 0.0 && x.is_infinite() {
|
||||
1.0
|
||||
} else if x <= 0.0 && x.is_infinite() {
|
||||
-1.0
|
||||
} else if x == 0. {
|
||||
0.0
|
||||
} else {
|
||||
erf_impl(x, false)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erf_inv` calculates the inverse error function
|
||||
/// at `x`.
|
||||
pub fn erf_inv(x: f64) -> f64 {
|
||||
if x == 0.0 {
|
||||
0.0
|
||||
} else if x >= 1.0 {
|
||||
f64::INFINITY
|
||||
} else if x <= -1.0 {
|
||||
f64::NEG_INFINITY
|
||||
} else if x < 0.0 {
|
||||
erf_inv_impl(-x, 1.0 + x, -1.0)
|
||||
} else {
|
||||
erf_inv_impl(x, 1.0 - x, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erfc` calculates the complementary error function
|
||||
/// at `x`.
|
||||
pub fn erfc(x: f64) -> f64 {
|
||||
if x.is_nan() {
|
||||
f64::NAN
|
||||
} else if x == f64::INFINITY {
|
||||
0.0
|
||||
} else if x == f64::NEG_INFINITY {
|
||||
2.0
|
||||
} else {
|
||||
erf_impl(x, true)
|
||||
}
|
||||
}
|
||||
|
||||
/// `erfc_inv` calculates the complementary inverse
|
||||
/// error function at `x`.
|
||||
pub fn erfc_inv(x: f64) -> f64 {
|
||||
if x <= 0.0 {
|
||||
f64::INFINITY
|
||||
} else if x >= 2.0 {
|
||||
f64::NEG_INFINITY
|
||||
} else if x > 1.0 {
|
||||
erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
|
||||
} else {
|
||||
erf_inv_impl(1.0 - x, x, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
// **********************************************************
|
||||
// ********** Coefficients for erf_impl polynomial **********
|
||||
// **********************************************************
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_impl`
|
||||
/// in the interval [1e-10, 0.5].
|
||||
const ERF_IMPL_AN: &[f64] = &[
|
||||
0.00337916709551257388990745,
|
||||
-0.00073695653048167948530905,
|
||||
-0.374732337392919607868241,
|
||||
0.0817442448733587196071743,
|
||||
-0.0421089319936548595203468,
|
||||
0.0070165709512095756344528,
|
||||
-0.00495091255982435110337458,
|
||||
0.000871646599037922480317225,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_impl`
|
||||
/// in the interval [1e-10, 0.5]
|
||||
const ERF_IMPL_AD: &[f64] = &[
|
||||
1.0,
|
||||
-0.218088218087924645390535,
|
||||
0.412542972725442099083918,
|
||||
-0.0841891147873106755410271,
|
||||
0.0655338856400241519690695,
|
||||
-0.0120019604454941768171266,
|
||||
0.00408165558926174048329689,
|
||||
-0.000615900721557769691924509,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_IMPL_BN: &[f64] = &[
|
||||
-0.0361790390718262471360258,
|
||||
0.292251883444882683221149,
|
||||
0.281447041797604512774415,
|
||||
0.125610208862766947294894,
|
||||
0.0274135028268930549240776,
|
||||
0.00250839672168065762786937,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_IMPL_BD: &[f64] = &[
|
||||
1.0,
|
||||
1.8545005897903486499845,
|
||||
1.43575803037831418074962,
|
||||
0.582827658753036572454135,
|
||||
0.124810476932949746447682,
|
||||
0.0113724176546353285778481,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [0.75, 1.25].
|
||||
const ERF_IMPL_CN: &[f64] = &[
|
||||
-0.0397876892611136856954425,
|
||||
0.153165212467878293257683,
|
||||
0.191260295600936245503129,
|
||||
0.10276327061989304213645,
|
||||
0.029637090615738836726027,
|
||||
0.0046093486780275489468812,
|
||||
0.000307607820348680180548455,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [0.75, 1.25].
|
||||
const ERF_IMPL_CD: &[f64] = &[
|
||||
1.0,
|
||||
1.95520072987627704987886,
|
||||
1.64762317199384860109595,
|
||||
0.768238607022126250082483,
|
||||
0.209793185936509782784315,
|
||||
0.0319569316899913392596356,
|
||||
0.00213363160895785378615014,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [1.25, 2.25].
|
||||
const ERF_IMPL_DN: &[f64] = &[
|
||||
-0.0300838560557949717328341,
|
||||
0.0538578829844454508530552,
|
||||
0.0726211541651914182692959,
|
||||
0.0367628469888049348429018,
|
||||
0.00964629015572527529605267,
|
||||
0.00133453480075291076745275,
|
||||
0.778087599782504251917881e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [1.25, 2.25].
|
||||
const ERF_IMPL_DD: &[f64] = &[
|
||||
1.0,
|
||||
1.75967098147167528287343,
|
||||
1.32883571437961120556307,
|
||||
0.552528596508757581287907,
|
||||
0.133793056941332861912279,
|
||||
0.0179509645176280768640766,
|
||||
0.00104712440019937356634038,
|
||||
-0.106640381820357337177643e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [2.25, 3.5].
|
||||
const ERF_IMPL_EN: &[f64] = &[
|
||||
-0.0117907570137227847827732,
|
||||
0.014262132090538809896674,
|
||||
0.0202234435902960820020765,
|
||||
0.00930668299990432009042239,
|
||||
0.00213357802422065994322516,
|
||||
0.00025022987386460102395382,
|
||||
0.120534912219588189822126e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [2.25, 3.5].
|
||||
const ERF_IMPL_ED: &[f64] = &[
|
||||
1.0,
|
||||
1.50376225203620482047419,
|
||||
0.965397786204462896346934,
|
||||
0.339265230476796681555511,
|
||||
0.0689740649541569716897427,
|
||||
0.00771060262491768307365526,
|
||||
0.000371421101531069302990367,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [3.5, 5.25].
|
||||
const ERF_IMPL_FN: &[f64] = &[
|
||||
-0.00546954795538729307482955,
|
||||
0.00404190278731707110245394,
|
||||
0.0054963369553161170521356,
|
||||
0.00212616472603945399437862,
|
||||
0.000394984014495083900689956,
|
||||
0.365565477064442377259271e-4,
|
||||
0.135485897109932323253786e-5,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [3.5, 5.25].
|
||||
const ERF_IMPL_FD: &[f64] = &[
|
||||
1.0,
|
||||
1.21019697773630784832251,
|
||||
0.620914668221143886601045,
|
||||
0.173038430661142762569515,
|
||||
0.0276550813773432047594539,
|
||||
0.00240625974424309709745382,
|
||||
0.891811817251336577241006e-4,
|
||||
-0.465528836283382684461025e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [5.25, 8].
|
||||
const ERF_IMPL_GN: &[f64] = &[
|
||||
-0.00270722535905778347999196,
|
||||
0.0013187563425029400461378,
|
||||
0.00119925933261002333923989,
|
||||
0.00027849619811344664248235,
|
||||
0.267822988218331849989363e-4,
|
||||
0.923043672315028197865066e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [5.25, 8].
|
||||
const ERF_IMPL_GD: &[f64] = &[
|
||||
1.0,
|
||||
0.814632808543141591118279,
|
||||
0.268901665856299542168425,
|
||||
0.0449877216103041118694989,
|
||||
0.00381759663320248459168994,
|
||||
0.000131571897888596914350697,
|
||||
0.404815359675764138445257e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [8, 11.5].
|
||||
const ERF_IMPL_HN: &[f64] = &[
|
||||
-0.00109946720691742196814323,
|
||||
0.000406425442750422675169153,
|
||||
0.000274499489416900707787024,
|
||||
0.465293770646659383436343e-4,
|
||||
0.320955425395767463401993e-5,
|
||||
0.778286018145020892261936e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [8, 11.5].
|
||||
const ERF_IMPL_HD: &[f64] = &[
|
||||
1.0,
|
||||
0.588173710611846046373373,
|
||||
0.139363331289409746077541,
|
||||
0.0166329340417083678763028,
|
||||
0.00100023921310234908642639,
|
||||
0.24254837521587225125068e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [11.5, 17].
|
||||
const ERF_IMPL_IN: &[f64] = &[
|
||||
-0.00056907993601094962855594,
|
||||
0.000169498540373762264416984,
|
||||
0.518472354581100890120501e-4,
|
||||
0.382819312231928859704678e-5,
|
||||
0.824989931281894431781794e-7,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [11.5, 17].
|
||||
const ERF_IMPL_ID: &[f64] = &[
|
||||
1.0,
|
||||
0.339637250051139347430323,
|
||||
0.043472647870310663055044,
|
||||
0.00248549335224637114641629,
|
||||
0.535633305337152900549536e-4,
|
||||
-0.117490944405459578783846e-12,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [17, 24].
|
||||
const ERF_IMPL_JN: &[f64] = &[
|
||||
-0.000241313599483991337479091,
|
||||
0.574224975202501512365975e-4,
|
||||
0.115998962927383778460557e-4,
|
||||
0.581762134402593739370875e-6,
|
||||
0.853971555085673614607418e-8,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [17, 24].
|
||||
const ERF_IMPL_JD: &[f64] = &[
|
||||
1.0,
|
||||
0.233044138299687841018015,
|
||||
0.0204186940546440312625597,
|
||||
0.000797185647564398289151125,
|
||||
0.117019281670172327758019e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [24, 38].
|
||||
const ERF_IMPL_KN: &[f64] = &[
|
||||
-0.000146674699277760365803642,
|
||||
0.162666552112280519955647e-4,
|
||||
0.269116248509165239294897e-5,
|
||||
0.979584479468091935086972e-7,
|
||||
0.101994647625723465722285e-8,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [24, 38].
|
||||
const ERF_IMPL_KD: &[f64] = &[
|
||||
1.0,
|
||||
0.165907812944847226546036,
|
||||
0.0103361716191505884359634,
|
||||
0.000286593026373868366935721,
|
||||
0.298401570840900340874568e-5,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [38, 60].
|
||||
const ERF_IMPL_LN: &[f64] = &[
|
||||
-0.583905797629771786720406e-4,
|
||||
0.412510325105496173512992e-5,
|
||||
0.431790922420250949096906e-6,
|
||||
0.993365155590013193345569e-8,
|
||||
0.653480510020104699270084e-10,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [38, 60].
|
||||
const ERF_IMPL_LD: &[f64] = &[
|
||||
1.0,
|
||||
0.105077086072039915406159,
|
||||
0.00414278428675475620830226,
|
||||
0.726338754644523769144108e-4,
|
||||
0.477818471047398785369849e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [60, 85].
|
||||
const ERF_IMPL_MN: &[f64] = &[
|
||||
-0.196457797609229579459841e-4,
|
||||
0.157243887666800692441195e-5,
|
||||
0.543902511192700878690335e-7,
|
||||
0.317472492369117710852685e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [60, 85].
|
||||
const ERF_IMPL_MD: &[f64] = &[
|
||||
1.0,
|
||||
0.052803989240957632204885,
|
||||
0.000926876069151753290378112,
|
||||
0.541011723226630257077328e-5,
|
||||
0.535093845803642394908747e-15,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||
/// in the interval [85, 110].
|
||||
const ERF_IMPL_NN: &[f64] = &[
|
||||
-0.789224703978722689089794e-5,
|
||||
0.622088451660986955124162e-6,
|
||||
0.145728445676882396797184e-7,
|
||||
0.603715505542715364529243e-10,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||
/// in the interval [85, 110].
|
||||
const ERF_IMPL_ND: &[f64] = &[
|
||||
1.0,
|
||||
0.0375328846356293715248719,
|
||||
0.000467919535974625308126054,
|
||||
0.193847039275845656900547e-5,
|
||||
];
|
||||
|
||||
// **********************************************************
|
||||
// ********** Coefficients for erf_inv_impl polynomial ******
|
||||
// **********************************************************
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0, 0.5].
|
||||
const ERF_INV_IMPL_AN: &[f64] = &[
|
||||
-0.000508781949658280665617,
|
||||
-0.00836874819741736770379,
|
||||
0.0334806625409744615033,
|
||||
-0.0126926147662974029034,
|
||||
-0.0365637971411762664006,
|
||||
0.0219878681111168899165,
|
||||
0.00822687874676915743155,
|
||||
-0.00538772965071242932965,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0, 0.5].
|
||||
const ERF_INV_IMPL_AD: &[f64] = &[
|
||||
1.0,
|
||||
-0.970005043303290640362,
|
||||
-1.56574558234175846809,
|
||||
1.56221558398423026363,
|
||||
0.662328840472002992063,
|
||||
-0.71228902341542847553,
|
||||
-0.0527396382340099713954,
|
||||
0.0795283687341571680018,
|
||||
-0.00233393759374190016776,
|
||||
0.000886216390456424707504,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_INV_IMPL_BN: &[f64] = &[
|
||||
-0.202433508355938759655,
|
||||
0.105264680699391713268,
|
||||
8.37050328343119927838,
|
||||
17.6447298408374015486,
|
||||
-18.8510648058714251895,
|
||||
-44.6382324441786960818,
|
||||
17.445385985570866523,
|
||||
21.1294655448340526258,
|
||||
-3.67192254707729348546,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.5, 0.75].
|
||||
const ERF_INV_IMPL_BD: &[f64] = &[
|
||||
1.0,
|
||||
6.24264124854247537712,
|
||||
3.9713437953343869095,
|
||||
-28.6608180499800029974,
|
||||
-20.1432634680485188801,
|
||||
48.5609213108739935468,
|
||||
10.8268667355460159008,
|
||||
-22.6436933413139721736,
|
||||
1.72114765761200282724,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x less than 3.
|
||||
const ERF_INV_IMPL_CN: &[f64] = &[
|
||||
-0.131102781679951906451,
|
||||
-0.163794047193317060787,
|
||||
0.117030156341995252019,
|
||||
0.387079738972604337464,
|
||||
0.337785538912035898924,
|
||||
0.142869534408157156766,
|
||||
0.0290157910005329060432,
|
||||
0.00214558995388805277169,
|
||||
-0.679465575181126350155e-6,
|
||||
0.285225331782217055858e-7,
|
||||
-0.681149956853776992068e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x less than 3.
|
||||
const ERF_INV_IMPL_CD: &[f64] = &[
|
||||
1.0,
|
||||
3.46625407242567245975,
|
||||
5.38168345707006855425,
|
||||
4.77846592945843778382,
|
||||
2.59301921623620271374,
|
||||
0.848854343457902036425,
|
||||
0.152264338295331783612,
|
||||
0.01105924229346489121,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 3 and 6.
|
||||
const ERF_INV_IMPL_DN: &[f64] = &[
|
||||
-0.0350353787183177984712,
|
||||
-0.00222426529213447927281,
|
||||
0.0185573306514231072324,
|
||||
0.00950804701325919603619,
|
||||
0.00187123492819559223345,
|
||||
0.000157544617424960554631,
|
||||
0.460469890584317994083e-5,
|
||||
-0.230404776911882601748e-9,
|
||||
0.266339227425782031962e-11,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 3 and 6.
|
||||
const ERF_INV_IMPL_DD: &[f64] = &[
|
||||
1.0,
|
||||
1.3653349817554063097,
|
||||
0.762059164553623404043,
|
||||
0.220091105764131249824,
|
||||
0.0341589143670947727934,
|
||||
0.00263861676657015992959,
|
||||
0.764675292302794483503e-4,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 6 and 18.
|
||||
const ERF_INV_IMPL_EN: &[f64] = &[
|
||||
-0.0167431005076633737133,
|
||||
-0.00112951438745580278863,
|
||||
0.00105628862152492910091,
|
||||
0.000209386317487588078668,
|
||||
0.149624783758342370182e-4,
|
||||
0.449696789927706453732e-6,
|
||||
0.462596163522878599135e-8,
|
||||
-0.281128735628831791805e-13,
|
||||
0.99055709973310326855e-16,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 6 and 18.
|
||||
const ERF_INV_IMPL_ED: &[f64] = &[
|
||||
1.0,
|
||||
0.591429344886417493481,
|
||||
0.138151865749083321638,
|
||||
0.0160746087093676504695,
|
||||
0.000964011807005165528527,
|
||||
0.275335474764726041141e-4,
|
||||
0.282243172016108031869e-6,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 18 and 44.
|
||||
const ERF_INV_IMPL_FN: &[f64] = &[
|
||||
-0.0024978212791898131227,
|
||||
-0.779190719229053954292e-5,
|
||||
0.254723037413027451751e-4,
|
||||
0.162397777342510920873e-5,
|
||||
0.396341011304801168516e-7,
|
||||
0.411632831190944208473e-9,
|
||||
0.145596286718675035587e-11,
|
||||
-0.116765012397184275695e-17,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x between 18 and 44.
|
||||
const ERF_INV_IMPL_FD: &[f64] = &[
|
||||
1.0,
|
||||
0.207123112214422517181,
|
||||
0.0169410838120975906478,
|
||||
0.000690538265622684595676,
|
||||
0.145007359818232637924e-4,
|
||||
0.144437756628144157666e-6,
|
||||
0.509761276599778486139e-9,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a numerator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x greater than 44.
|
||||
const ERF_INV_IMPL_GN: &[f64] = &[
|
||||
-0.000539042911019078575891,
|
||||
-0.28398759004727721098e-6,
|
||||
0.899465114892291446442e-6,
|
||||
0.229345859265920864296e-7,
|
||||
0.225561444863500149219e-9,
|
||||
0.947846627503022684216e-12,
|
||||
0.135880130108924861008e-14,
|
||||
-0.348890393399948882918e-21,
|
||||
];
|
||||
|
||||
/// Polynomial coefficients for a denominator of `erf_inv_impl`
|
||||
/// in the interval [0.75, 1] with x greater than 44.
|
||||
const ERF_INV_IMPL_GD: &[f64] = &[
|
||||
1.0,
|
||||
0.0845746234001899436914,
|
||||
0.00282092984726264681981,
|
||||
0.468292921940894236786e-4,
|
||||
0.399968812193862100054e-6,
|
||||
0.161809290887904476097e-8,
|
||||
0.231558608310259605225e-11,
|
||||
];
|
||||
|
||||
/// `erf_impl` computes the error function at `z`.
|
||||
/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
|
||||
fn erf_impl(z: f64, inv: bool) -> f64 {
|
||||
if z < 0.0 {
|
||||
if !inv {
|
||||
return -erf_impl(-z, false);
|
||||
}
|
||||
if z < -0.5 {
|
||||
return 2.0 - erf_impl(-z, true);
|
||||
}
|
||||
return 1.0 + erf_impl(-z, false);
|
||||
}
|
||||
|
||||
let result = if z < 0.5 {
|
||||
if z < 1e-10 {
|
||||
z * 1.125 + z * 0.003379167095512573896158903121545171688
|
||||
} else {
|
||||
z * 1.125
|
||||
+ z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
|
||||
}
|
||||
} else if z < 110.0 {
|
||||
let (r, b) = if z < 0.75 {
|
||||
(
|
||||
evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
|
||||
/ evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
|
||||
0.3440242112,
|
||||
)
|
||||
} else if z < 1.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
|
||||
/ evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
|
||||
0.419990927,
|
||||
)
|
||||
} else if z < 2.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
|
||||
/ evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
|
||||
0.4898625016,
|
||||
)
|
||||
} else if z < 3.5 {
|
||||
(
|
||||
evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
|
||||
/ evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
|
||||
0.5317370892,
|
||||
)
|
||||
} else if z < 5.25 {
|
||||
(
|
||||
evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
|
||||
/ evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
|
||||
0.5489973426,
|
||||
)
|
||||
} else if z < 8.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
|
||||
/ evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
|
||||
0.5571740866,
|
||||
)
|
||||
} else if z < 11.5 {
|
||||
(
|
||||
evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
|
||||
/ evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
|
||||
0.5609807968,
|
||||
)
|
||||
} else if z < 17.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
|
||||
/ evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
|
||||
0.5626493692,
|
||||
)
|
||||
} else if z < 24.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
|
||||
/ evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
|
||||
0.5634598136,
|
||||
)
|
||||
} else if z < 38.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
|
||||
/ evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
|
||||
0.5638477802,
|
||||
)
|
||||
} else if z < 60.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
|
||||
/ evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
|
||||
0.5640528202,
|
||||
)
|
||||
} else if z < 85.0 {
|
||||
(
|
||||
evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
|
||||
/ evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
|
||||
0.5641309023,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
|
||||
/ evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
|
||||
0.5641584396,
|
||||
)
|
||||
};
|
||||
let g = (-z * z).exp() / z;
|
||||
g * b + g * r
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
if inv && z >= 0.5 {
|
||||
result
|
||||
} else if z >= 0.5 || inv {
|
||||
1.0 - result
|
||||
} else {
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
// `erf_inv_impl` computes the inverse error function where
|
||||
// `p`,`q`, and `s` are the first, second, and third intermediate
|
||||
// parameters respectively
|
||||
fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
|
||||
let result = if p <= 0.5 {
|
||||
let y = 0.0891314744949340820313;
|
||||
let g = p * (p + 10.0);
|
||||
let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
|
||||
g * y + g * r
|
||||
} else if q >= 0.25 {
|
||||
let y = 2.249481201171875;
|
||||
let g = (-2.0 * q.ln()).sqrt();
|
||||
let xs = q - 0.25;
|
||||
let r =
|
||||
evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
|
||||
g / (y + r)
|
||||
} else {
|
||||
let x = (-q.ln()).sqrt();
|
||||
if x < 3.0 {
|
||||
let y = 0.807220458984375;
|
||||
let xs = x - 1.125;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_CD);
|
||||
y * x + r * x
|
||||
} else if x < 6.0 {
|
||||
let y = 0.93995571136474609375;
|
||||
let xs = x - 3.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_DD);
|
||||
y * x + r * x
|
||||
} else if x < 18.0 {
|
||||
let y = 0.98362827301025390625;
|
||||
let xs = x - 6.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_ED);
|
||||
y * x + r * x
|
||||
} else if x < 44.0 {
|
||||
let y = 0.99714565277099609375;
|
||||
let xs = x - 18.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_FD);
|
||||
y * x + r * x
|
||||
} else {
|
||||
let y = 0.99941349029541015625;
|
||||
let xs = x - 44.0;
|
||||
let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
|
||||
/ evaluate::polynomial(xs, ERF_INV_IMPL_GD);
|
||||
y * x + r * x
|
||||
}
|
||||
};
|
||||
s * result
|
||||
}
|
@ -1,4 +1,3 @@
|
||||
pub mod erf;
|
||||
pub mod kernels;
|
||||
|
||||
trait Cpu<const ARR: usize> {
|
||||
|
@ -804,11 +804,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
};
|
||||
let dim = self.dim;
|
||||
let ids_dims = self.ids_l.dims();
|
||||
@ -857,7 +857,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-select" })?,
|
||||
};
|
||||
let dim = self.dim;
|
||||
let n_ids = match self.ids_l.dims() {
|
||||
@ -913,7 +913,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
|
||||
@ -929,7 +929,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
||||
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
};
|
||||
for left_i in 0..ids_left_len {
|
||||
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
|
||||
@ -971,7 +971,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
let dim = self.dim;
|
||||
@ -2539,25 +2539,25 @@ impl BackendStorage for CpuStorage {
|
||||
Self::U8(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
Self::U32(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
Self::I64(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
|
||||
}
|
||||
}
|
||||
|
||||
@ -2603,10 +2603,6 @@ impl BackendDevice for CpuDevice {
|
||||
Ok(Self)
|
||||
}
|
||||
|
||||
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||
crate::bail!("cannot seed the CPU rng with set_seed")
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
|
||||
use rand::prelude::*;
|
||||
|
||||
|
@ -223,14 +223,6 @@ impl BackendDevice for CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||
// state will be identical and the same random numbers will be generated.
|
||||
let mut curand = self.curand.lock().unwrap();
|
||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Cuda {
|
||||
gpu_id: self.device.ordinal(),
|
||||
@ -892,6 +884,8 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
};
|
||||
let ids_shape = ids_l.shape();
|
||||
let ids_dims = ids_shape.dims();
|
||||
let ids_el = ids_shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(ids_el as u32);
|
||||
let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
@ -899,23 +893,19 @@ impl<'a> Map1 for IndexSelect<'a> {
|
||||
};
|
||||
let left_size: usize = src_l.dims()[..self.2].iter().product();
|
||||
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
|
||||
let src_dim_size = src_l.dims()[self.2];
|
||||
let ids_dim_size = ids_shape.elem_count();
|
||||
let dst_el = ids_shape.elem_count() * left_size * right_size;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let dim_size = src_l.dims()[self.2];
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(ids_el * left_size * right_size) }.w()?;
|
||||
let params = (
|
||||
dst_el,
|
||||
ids_el,
|
||||
ids_dims.len(),
|
||||
&ds,
|
||||
ids,
|
||||
&src,
|
||||
&out,
|
||||
left_size,
|
||||
src_dim_size,
|
||||
ids_dim_size,
|
||||
dim_size,
|
||||
right_size,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
@ -2171,7 +2161,7 @@ impl BackendStorage for CudaStorage {
|
||||
if src_l.is_contiguous() {
|
||||
dev.dtod_copy(&src, &mut dst).w()?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_f64", kernels::UNARY)?;
|
||||
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
|
@ -34,9 +34,6 @@ pub(crate) fn launch_conv2d<
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
dev: &crate::cuda_backend::CudaDevice,
|
||||
) -> crate::Result<()> {
|
||||
use crate::conv::CudnnFwdAlgo as CandleAlgo;
|
||||
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
|
||||
|
||||
let device_id = dev.id();
|
||||
let cudnn = CUDNN.with(|cudnn| {
|
||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||
@ -93,20 +90,7 @@ pub(crate) fn launch_conv2d<
|
||||
w: &w,
|
||||
y: &y,
|
||||
};
|
||||
let alg = match params.cudnn_fwd_algo {
|
||||
None => conv2d.pick_algorithm()?,
|
||||
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
Some(CandleAlgo::ImplicitPrecompGemm) => {
|
||||
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
|
||||
}
|
||||
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||
};
|
||||
let alg = conv2d.pick_algorithm()?;
|
||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
||||
unsafe {
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::backend::BackendDevice;
|
||||
use crate::cpu_backend::CpuDevice;
|
||||
use crate::{bail, CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
|
||||
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
|
||||
/// can live on the same location (typically for cuda devices).
|
||||
@ -8,14 +8,12 @@ use crate::{bail, CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
pub enum DeviceLocation {
|
||||
Cpu,
|
||||
Cuda { gpu_id: usize },
|
||||
Metal,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Device {
|
||||
Cpu,
|
||||
Cuda(crate::CudaDevice),
|
||||
Metal(crate::MetalDevice),
|
||||
}
|
||||
|
||||
pub trait NdArray {
|
||||
@ -105,14 +103,14 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
|
||||
impl<S: NdArray> NdArray for Vec<S> {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
if self.is_empty() {
|
||||
bail!("empty array")
|
||||
crate::bail!("empty array")
|
||||
}
|
||||
let shape0 = self[0].shape()?;
|
||||
let n = self.len();
|
||||
for v in self.iter() {
|
||||
let shape = v.shape()?;
|
||||
if shape != shape0 {
|
||||
bail!("two elements have different shapes {shape:?} {shape0:?}")
|
||||
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
|
||||
}
|
||||
}
|
||||
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
|
||||
@ -130,18 +128,6 @@ impl Device {
|
||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||
}
|
||||
|
||||
pub fn new_metal(ordinal: usize) -> Result<Self> {
|
||||
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
|
||||
}
|
||||
|
||||
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
match self {
|
||||
Self::Cpu => CpuDevice.set_seed(seed),
|
||||
Self::Cuda(c) => c.set_seed(seed),
|
||||
Self::Metal(m) => m.set_seed(seed),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn same_device(&self, rhs: &Self) -> bool {
|
||||
match (self, rhs) {
|
||||
(Self::Cpu, Self::Cpu) => true,
|
||||
@ -154,16 +140,21 @@ impl Device {
|
||||
match self {
|
||||
Self::Cpu => DeviceLocation::Cpu,
|
||||
Self::Cuda(device) => device.location(),
|
||||
Device::Metal(device) => device.location(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cpu(&self) -> bool {
|
||||
matches!(self, Self::Cpu)
|
||||
match self {
|
||||
Self::Cpu => true,
|
||||
Self::Cuda(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cuda(&self) -> bool {
|
||||
matches!(self, Self::Cuda(_))
|
||||
match self {
|
||||
Self::Cpu => false,
|
||||
Self::Cuda(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||
@ -190,11 +181,6 @@ impl Device {
|
||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(_device) => {
|
||||
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
// Ok(Storage::Metal(storage))
|
||||
bail!("Metal rand_uniform not implemented")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -223,10 +209,6 @@ impl Device {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -249,10 +231,6 @@ impl Device {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -266,10 +244,6 @@ impl Device {
|
||||
let storage = device.zeros_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.zeros_impl(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -281,11 +255,6 @@ impl Device {
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = array.to_cpu_storage();
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -297,11 +266,6 @@ impl Device {
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = S::to_cpu_storage_owned(data);
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14,7 +14,6 @@ impl Tensor {
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
_ => todo!(),
|
||||
};
|
||||
|
||||
write!(f, "Tensor[")?;
|
||||
@ -477,7 +476,6 @@ impl std::fmt::Display for Tensor {
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
crate::DeviceLocation::Metal => todo!(),
|
||||
};
|
||||
|
||||
write!(
|
||||
|
@ -67,20 +67,6 @@ impl DType {
|
||||
Self::F64 => 8,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_int(&self) -> bool {
|
||||
match self {
|
||||
Self::U8 | Self::U32 | Self::I64 => true,
|
||||
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_float(&self) -> bool {
|
||||
match self {
|
||||
Self::U8 | Self::U32 | Self::I64 => false,
|
||||
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WithDType:
|
||||
|
@ -167,10 +167,6 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
fail!()
|
||||
}
|
||||
|
@ -1,201 +0,0 @@
|
||||
#![allow(dead_code)]
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalDevice;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MetalStorage;
|
||||
|
||||
macro_rules! fail {
|
||||
() => {
|
||||
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
|
||||
};
|
||||
}
|
||||
|
||||
impl crate::backend::BackendStorage for MetalStorage {
|
||||
type Device = MetalDevice;
|
||||
|
||||
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn dtype(&self) -> DType {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn device(&self) -> &Self::Device {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: (usize, usize, usize, usize),
|
||||
_: &Layout,
|
||||
_: &Layout,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::backend::BackendDevice for MetalDevice {
|
||||
type Storage = MetalStorage;
|
||||
fn new(_: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn same_device(&self, _: &Self) -> bool {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
use crate::{metal_backend, DType, DeviceLocation, Layout, Shape};
|
||||
use crate::{DType, DeviceLocation, Layout, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatMulUnexpectedStriding {
|
||||
@ -142,9 +142,6 @@ pub enum Error {
|
||||
#[error("{op} expects at least one tensor")]
|
||||
OpRequiresAtLeastOneTensor { op: &'static str },
|
||||
|
||||
#[error("{op} expects at least two tensors")]
|
||||
OpRequiresAtLeastTwoTensors { op: &'static str },
|
||||
|
||||
#[error("backward is not supported for {op}")]
|
||||
BackwardNotSupported { op: &'static str },
|
||||
|
||||
@ -152,9 +149,6 @@ pub enum Error {
|
||||
#[error("the candle crate has not been built with cuda support")]
|
||||
NotCompiledWithCudaSupport,
|
||||
|
||||
#[error("the candle crate has not been built with metal support")]
|
||||
NotCompiledWithMetalSupport,
|
||||
|
||||
#[error("cannot find tensor {path}")]
|
||||
CannotFindTensor { path: String },
|
||||
|
||||
@ -162,9 +156,6 @@ pub enum Error {
|
||||
#[error(transparent)]
|
||||
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
#[error("Metal error {0}")]
|
||||
Metal(#[from] metal_backend::MetalError),
|
||||
|
||||
#[error(transparent)]
|
||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||
|
||||
|
@ -52,10 +52,6 @@ mod dummy_cuda_backend;
|
||||
pub mod error;
|
||||
mod indexer;
|
||||
pub mod layout;
|
||||
#[cfg(feature = "metal")]
|
||||
pub mod metal_backend;
|
||||
#[cfg(feature = "accelerate")]
|
||||
mod metal_backend;
|
||||
#[cfg(feature = "mkl")]
|
||||
mod mkl;
|
||||
pub mod npy;
|
||||
@ -91,12 +87,6 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub use metal_backend::{MetalDevice, MetalStorage};
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
pub use dummy_metal_backend::{MetalDevice, MetalStorage};
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
@ -135,15 +125,3 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||
self(xs)
|
||||
}
|
||||
}
|
||||
|
||||
// A trait defining a module with forward method using a single tensor argument and a flag to
|
||||
// separate the training and evaluation behaviors.
|
||||
pub trait ModuleT {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
impl<M: Module> ModuleT for M {
|
||||
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
|
||||
self.forward(xs)
|
||||
}
|
||||
}
|
||||
|
@ -1,806 +0,0 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels;
|
||||
use candle_metal_kernels::{void_ptr, Kernels, Source};
|
||||
use core::mem;
|
||||
use half::{bf16, f16};
|
||||
use metal;
|
||||
use metal::mps::matrix::encode_gemm;
|
||||
use metal::mps::Float32;
|
||||
use metal::{Buffer, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger};
|
||||
use std::sync::Arc;
|
||||
use tracing::debug;
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum MetalError {
|
||||
#[error("{0}")]
|
||||
Message(String),
|
||||
#[error(transparent)]
|
||||
KernelError(#[from] candle_metal_kernels::MetalKernelError),
|
||||
}
|
||||
|
||||
impl From<String> for MetalError {
|
||||
fn from(e: String) -> Self {
|
||||
MetalError::Message(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalError {
|
||||
fn msg<S: AsRef<str>>(msg: S) -> Self {
|
||||
MetalError::Message(msg.as_ref().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalDevice {
|
||||
device: metal::Device,
|
||||
command_queue: metal::CommandQueue,
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "MetalDevice({:?})", self.device.registry_id())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for MetalDevice {
|
||||
type Target = metal::DeviceRef;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.device
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalDevice {
|
||||
// pub fn metal_device(&self) -> &metal::DeviceRef {
|
||||
// self.device.as_ref()
|
||||
// }
|
||||
|
||||
pub fn id(&self) -> u64 {
|
||||
self.registry_id()
|
||||
}
|
||||
|
||||
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||
let size = (element_count * dtype.size_in_bytes()) as u64;
|
||||
// debug!("Allocate 1 - buffer size {size}");
|
||||
self.device
|
||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalStorage {
|
||||
buffer: metal::Buffer,
|
||||
device: MetalDevice,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl BackendStorage for MetalStorage {
|
||||
type Device = MetalDevice;
|
||||
|
||||
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||
Ok(self.clone())
|
||||
}
|
||||
|
||||
fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
fn device(&self) -> &Self::Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
match self.dtype {
|
||||
DType::F32 => Ok(CpuStorage::F32(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
|
||||
)),
|
||||
dtype => todo!("Unsupported dtype {dtype:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let dtype = self.dtype;
|
||||
|
||||
assert!(layout.is_contiguous());
|
||||
assert_eq!(dtype, DType::F32);
|
||||
|
||||
let mut buffer = device.new_buffer(el, self.dtype);
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
candle_metal_kernels::call_affine(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
el,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
return Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
});
|
||||
}
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
debug!("TODO reduce_op {op:?}");
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
let src_el: usize = src_dims.iter().product();
|
||||
// Source dims and strides with the sum dims at the end.
|
||||
let mut dims = vec![];
|
||||
let mut stride = vec![];
|
||||
let mut dst_el: usize = 1;
|
||||
for (dim_idx, &d) in src_dims.iter().enumerate() {
|
||||
if !sum_dims.contains(&dim_idx) {
|
||||
dst_el *= d;
|
||||
dims.push(d);
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
}
|
||||
for &dim_idx in sum_dims.iter() {
|
||||
dims.push(src_dims[dim_idx]);
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
// let el_to_sum_per_block = src_el / dst_el;
|
||||
// // The reduction loop requires the shared array to be properly initialized and for
|
||||
// // this we want the number of threads to be a power of two.
|
||||
// let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two();
|
||||
// let cfg = LaunchConfig {
|
||||
// // TODO: Maybe use grid_y if the output is too large?
|
||||
// // TODO: Specialized implementation when reducing on no or all dimensions or when
|
||||
// // reducing only aggregate a small number of elements together.
|
||||
// grid_dim: (dst_el as u32, 1, 1),
|
||||
// block_dim: (block_dim as u32, 1, 1),
|
||||
// shared_mem_bytes: 0,
|
||||
// };
|
||||
// let ds = dev
|
||||
// .htod_copy([dims.as_slice(), stride.as_slice()].concat())
|
||||
// .w()?;
|
||||
// let src = &src.slice(layout.start_offset()..);
|
||||
// let (name, check_empty, return_index) = match self.1 {
|
||||
// ReduceOp::Sum => ("fast_sum", false, false),
|
||||
// ReduceOp::Min => ("fast_min", true, false),
|
||||
// ReduceOp::Max => ("fast_max", true, false),
|
||||
// ReduceOp::ArgMin => ("fast_argmin", true, true),
|
||||
// ReduceOp::ArgMax => ("fast_argmax", true, true),
|
||||
// };
|
||||
// if check_empty && layout.shape().elem_count() == 0 {
|
||||
// Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
// }
|
||||
// let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
|
||||
// if return_index {
|
||||
// // SAFETY: filled in by the follow up kernel.
|
||||
// let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
|
||||
// let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||
// // SAFETY: ffi.
|
||||
// unsafe { func.launch(cfg, params) }.w()?;
|
||||
// Ok(S::U32(out))
|
||||
// } else {
|
||||
// // SAFETY: filled in by the follow up kernel.
|
||||
// let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
// let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||
// // SAFETY: ffi.
|
||||
// unsafe { func.launch(cfg, params) }.w()?;
|
||||
// Ok(wrap(out))
|
||||
// }
|
||||
// Ok(self.clone())
|
||||
// todo!()
|
||||
let dtype = self.dtype;
|
||||
let device = self.device();
|
||||
let buffer = device.new_buffer(dst_el, dtype);
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
if layout.is_contiguous() {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(left, right) => todo!("to dtype {left:?} - {right:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_cast_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
todo!(
|
||||
"TODO Implement the kernel calling cast {:?}-{:?}",
|
||||
self.dtype,
|
||||
dtype
|
||||
);
|
||||
}
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
debug!(
|
||||
"cast {:?} - {:?} - {:?} - {:?}",
|
||||
dtype,
|
||||
start.elapsed(),
|
||||
self.buffer.length(),
|
||||
buffer.length()
|
||||
);
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let dtype = self.dtype;
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
// TODO remove
|
||||
// return Ok(Self {
|
||||
// buffer,
|
||||
// device: device.clone(),
|
||||
// dtype,
|
||||
// });
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
if layout.is_contiguous() {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
||||
("usin", DType::F32) => contiguous::sin::FLOAT,
|
||||
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
||||
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
todo!("TODO Implement the kernel calling {}", B::KERNEL);
|
||||
}
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
debug!(
|
||||
"Unary {:?} - {:?} - {:?} - {:?}",
|
||||
B::KERNEL,
|
||||
start.elapsed(),
|
||||
self.buffer.length(),
|
||||
buffer.length()
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOpT>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let dtype = self.dtype;
|
||||
let shape = lhs_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
||||
use candle_metal_kernels::binary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("add", DType::F32) => contiguous::add::FLOAT,
|
||||
("badd", DType::F32) => contiguous::add::FLOAT,
|
||||
("sub", DType::F32) => contiguous::sub::FLOAT,
|
||||
("bsub", DType::F32) => contiguous::sub::FLOAT,
|
||||
("mul", DType::F32) => contiguous::mul::FLOAT,
|
||||
("bmul", DType::F32) => contiguous::mul::FLOAT,
|
||||
("div", DType::F32) => contiguous::div::FLOAT,
|
||||
("bdiv", DType::F32) => contiguous::div::FLOAT,
|
||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
use candle_metal_kernels::binary::strided;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("badd", DType::F32) => strided::add::FLOAT,
|
||||
("bsub", DType::F32) => strided::sub::FLOAT,
|
||||
("bmul", DType::F32) => strided::mul::FLOAT,
|
||||
("bdiv", DType::F32) => strided::div::FLOAT,
|
||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
lhs_l.dims(),
|
||||
&self.buffer,
|
||||
&lhs_l.stride(),
|
||||
lhs_l.start_offset(),
|
||||
&rhs.buffer,
|
||||
&rhs_l.stride(),
|
||||
rhs_l.start_offset(),
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
debug!(
|
||||
"Binary {:?} - {:?} - {:?} - {:?}",
|
||||
B::KERNEL,
|
||||
start.elapsed(),
|
||||
self.buffer.length(),
|
||||
buffer.length()
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn where_cond(&self, _: &Layout, rhs: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
debug!("TODO where_cond");
|
||||
Ok(rhs.clone())
|
||||
// todo!()
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
debug!(
|
||||
"TODO Index select {:?} {:?} {src_l:?} {ids_l:?} {dim:?}",
|
||||
self.buffer.length(),
|
||||
ids.buffer.length(),
|
||||
);
|
||||
let src = self;
|
||||
let ids_shape = ids_l.shape();
|
||||
let ids_dims = ids_shape.dims();
|
||||
// let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
|
||||
// let src = match src_l.contiguous_offsets() {
|
||||
// Some((o1, o2)) => src.slice(o1..o2),
|
||||
// None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
|
||||
// };
|
||||
let left_size: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_size = src_l.dims()[dim];
|
||||
let ids_dim_size = ids_shape.elem_count();
|
||||
let dst_el = ids_shape.elem_count() * left_size * right_size;
|
||||
let dtype = self.dtype;
|
||||
let device = self.device();
|
||||
let buffer = device.new_buffer(dst_el, dtype);
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
// todo!()
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let transpose_left = false;
|
||||
let transpose_right = false;
|
||||
let alpha = 1.0;
|
||||
let beta = 0.0;
|
||||
self.matmul_generic(
|
||||
rhs,
|
||||
(b, m, n, k),
|
||||
lhs_l,
|
||||
rhs_l,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let src_shape = src_l.shape();
|
||||
let dims = src_shape.dims();
|
||||
let el_count = src_shape.elem_count();
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
if src_l.is_contiguous() {
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
let blip = command_buffer.new_blit_command_encoder();
|
||||
blip.copy_from_buffer(
|
||||
&self.buffer,
|
||||
src_l.start_offset() as u64,
|
||||
&dst.buffer,
|
||||
dst_offset as u64,
|
||||
self.buffer.length(),
|
||||
);
|
||||
} else {
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
let kernel_name = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
||||
dtype => todo!("copy_strided not implemented for {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
kernel_name,
|
||||
src_l.dims(),
|
||||
&self.buffer,
|
||||
&src_l.stride(),
|
||||
src_l.start_offset(),
|
||||
&mut dst.buffer,
|
||||
dst_offset,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalStorage {
|
||||
pub(crate) fn matmul_t(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let transpose_left = false;
|
||||
let transpose_right = true;
|
||||
let alpha = 1.0;
|
||||
let beta = 0.0;
|
||||
self.matmul_generic(
|
||||
rhs,
|
||||
(b, m, n, k),
|
||||
lhs_l,
|
||||
rhs_l,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
}
|
||||
pub(crate) fn matmul_generic(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
transpose_left: bool,
|
||||
transpose_right: bool,
|
||||
alpha: f64,
|
||||
beta: f64,
|
||||
) -> Result<Self> {
|
||||
let elem_count = b * m * n;
|
||||
match (self.dtype, rhs.dtype) {
|
||||
(DType::F32, DType::F32) => {
|
||||
let mut out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||
if b != 1 {
|
||||
debug!("TODO implement batched matmul for B={b}");
|
||||
// bail!("Didn't implemented strided matmul yet");
|
||||
return Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
});
|
||||
}
|
||||
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
|
||||
debug!(
|
||||
"TODO non contiguous matmul yet {:?} {:?}",
|
||||
lhs_l.is_contiguous(),
|
||||
rhs_l.is_contiguous()
|
||||
);
|
||||
return Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
});
|
||||
}
|
||||
|
||||
debug!("GEMM");
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
encode_gemm::<Float32, Float32, Float32>(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
&mut out_buffer,
|
||||
m as NSUInteger,
|
||||
n as NSUInteger,
|
||||
k as NSUInteger,
|
||||
alpha as f32,
|
||||
beta as f32,
|
||||
Some(b as NSUInteger),
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
|
||||
Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
})
|
||||
}
|
||||
_ => todo!("Unimplemented matmul for this pair"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for MetalDevice {
|
||||
type Storage = MetalStorage;
|
||||
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
|
||||
// let capture = metal::CaptureManager::shared();
|
||||
// let descriptor = metal::CaptureDescriptor::new();
|
||||
// descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||
// descriptor.set_capture_device(&device);
|
||||
// let mut dir = std::env::current_dir()?;
|
||||
// dir.push("out.gputrace");
|
||||
// descriptor.set_output_url(dir);
|
||||
|
||||
// capture
|
||||
// .start_capture(&descriptor)
|
||||
// .map_err(MetalError::from)?;
|
||||
let command_queue = device.new_command_queue();
|
||||
// let command_buffer = _command_queue.new_owned_command_buffer();
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
Ok(Self {
|
||||
device,
|
||||
command_queue,
|
||||
// command_buffer,
|
||||
kernels,
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||
todo!("set_seed")
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Metal
|
||||
}
|
||||
|
||||
fn same_device(&self, rhs: &Self) -> bool {
|
||||
self.device.registry_id() == rhs.device.registry_id()
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
// TODO Is there a faster way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
// TODO Is there a faster way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||
let option = metal::MTLResourceOptions::StorageModeManaged;
|
||||
let buffer = match storage {
|
||||
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<u8>()) as u64,
|
||||
option,
|
||||
),
|
||||
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<u32>()) as u64,
|
||||
option,
|
||||
),
|
||||
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<i64>()) as u64,
|
||||
option,
|
||||
),
|
||||
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<bf16>()) as u64,
|
||||
option,
|
||||
),
|
||||
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f16>()) as u64,
|
||||
option,
|
||||
),
|
||||
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f32>()) as u64,
|
||||
option,
|
||||
),
|
||||
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
|
||||
storage.as_ptr() as *const core::ffi::c_void,
|
||||
(storage.len() * mem::size_of::<f64>()) as u64,
|
||||
option,
|
||||
),
|
||||
};
|
||||
// debug!("Allocate 2 - buffer size {}", buffer.length());
|
||||
Ok(Self::Storage {
|
||||
buffer,
|
||||
device: self.clone(),
|
||||
dtype: storage.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_uniform(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn rand_normal(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
}
|
@ -250,6 +250,8 @@ impl Tensor {
|
||||
if header.fortran_order {
|
||||
return Err(Error::Npy("fortran order not supported".to_string()));
|
||||
}
|
||||
let mut data: Vec<u8> = vec![];
|
||||
reader.read_to_end(&mut data)?;
|
||||
Self::from_reader(header.shape(), header.descr, &mut reader)
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
|
||||
use half::{bf16, f16};
|
||||
use num_traits::float::Float;
|
||||
|
||||
@ -58,13 +58,8 @@ pub enum UnaryOp {
|
||||
Sqr,
|
||||
Sqrt,
|
||||
Gelu,
|
||||
GeluErf,
|
||||
Erf,
|
||||
Relu,
|
||||
Tanh,
|
||||
Floor,
|
||||
Ceil,
|
||||
Round,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -136,7 +131,6 @@ pub enum Op {
|
||||
Copy(Tensor),
|
||||
Broadcast(Tensor),
|
||||
Narrow(Tensor, usize, usize, usize),
|
||||
SliceScatter0(Tensor, Tensor, usize),
|
||||
Reshape(Tensor),
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
@ -174,18 +168,6 @@ pub trait CustomOp1 {
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_storage: &MetalStorage,
|
||||
_layout: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||
/// The function should return the gradient of the argument.
|
||||
@ -221,20 +203,6 @@ pub trait CustomOp2 {
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
@ -277,22 +245,6 @@ pub trait CustomOp3 {
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
@ -373,13 +325,8 @@ pub(crate) struct Recip;
|
||||
pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
pub(crate) struct Gelu;
|
||||
pub(crate) struct GeluErf;
|
||||
pub(crate) struct Erf;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Tanh;
|
||||
pub(crate) struct Floor;
|
||||
pub(crate) struct Ceil;
|
||||
pub(crate) struct Round;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||
@ -578,6 +525,7 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
|
||||
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
||||
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
||||
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
|
||||
unary_op!(Abs, "abs", v, v.abs());
|
||||
unary_op!(Neg, "neg", v, -v);
|
||||
unary_op!(Recip, "recip", v, v.recip());
|
||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||
@ -673,210 +621,6 @@ impl UnaryOpT for Gelu {
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Erf {
|
||||
const NAME: &'static str = "erf";
|
||||
const KERNEL: &'static str = "uerf";
|
||||
const V: Self = Erf;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
Self::f64(v as f64) as f32
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
crate::cpu::erf::erf(v)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Abs {
|
||||
const NAME: &'static str = "abs";
|
||||
const KERNEL: &'static str = "uabs";
|
||||
const V: Self = Abs;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v.abs()
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Ceil {
|
||||
const NAME: &'static str = "ceil";
|
||||
const KERNEL: &'static str = "uceil";
|
||||
const V: Self = Ceil;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Floor {
|
||||
const NAME: &'static str = "floor";
|
||||
const KERNEL: &'static str = "ufloor";
|
||||
const V: Self = Floor;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Round {
|
||||
const NAME: &'static str = "round";
|
||||
const KERNEL: &'static str = "uround";
|
||||
const V: Self = Round;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for GeluErf {
|
||||
const NAME: &'static str = "gelu_erf";
|
||||
const KERNEL: &'static str = "ugelu_erf";
|
||||
const V: Self = GeluErf;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from_f64(Self::f64(v.to_f64()))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
Self::f64(v as f64) as f32
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
(crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Relu {
|
||||
const NAME: &'static str = "relu";
|
||||
const KERNEL: &'static str = "urelu";
|
||||
|
@ -193,50 +193,6 @@ impl Object {
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_tensor_info(
|
||||
self,
|
||||
name: Self,
|
||||
dir_name: &std::path::Path,
|
||||
) -> Result<Option<TensorInfo>> {
|
||||
let name = match name.unicode() {
|
||||
Ok(name) => name,
|
||||
Err(_) => return Ok(None),
|
||||
};
|
||||
let (callable, args) = match self.reduce() {
|
||||
Ok(callable_args) => callable_args,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
let (callable, args) = match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
|
||||
let mut args = args.tuple()?;
|
||||
let callable = args.remove(0);
|
||||
let args = args.remove(1);
|
||||
(callable, args)
|
||||
}
|
||||
_ => (callable, args),
|
||||
};
|
||||
match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
|
||||
_ => return Ok(None),
|
||||
};
|
||||
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
||||
let mut path = dir_name.to_path_buf();
|
||||
path.push(file_path);
|
||||
Ok(Some(TensorInfo {
|
||||
name,
|
||||
dtype,
|
||||
layout,
|
||||
path: path.to_string_lossy().into_owned(),
|
||||
storage_size,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Object> for String {
|
||||
@ -609,7 +565,6 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
||||
"HalfStorage" => DType::F16,
|
||||
"BFloat16Storage" => DType::BF16,
|
||||
"ByteStorage" => DType::U8,
|
||||
"LongStorage" => DType::I64,
|
||||
other => {
|
||||
crate::bail!("unsupported storage type {other}")
|
||||
}
|
||||
@ -668,10 +623,50 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
};
|
||||
if let Object::Dict(key_values) = obj {
|
||||
for (name, value) in key_values.into_iter() {
|
||||
match value.into_tensor_info(name, &dir_name) {
|
||||
Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
|
||||
Ok(None) => {}
|
||||
Err(err) => eprintln!("skipping: {err:?}"),
|
||||
let name = match name.unicode() {
|
||||
Ok(name) => name,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let (callable, args) = match value.reduce() {
|
||||
Ok(callable_args) => callable_args,
|
||||
_ => continue,
|
||||
};
|
||||
let (callable, args) = match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._tensor"
|
||||
&& class_name == "_rebuild_from_type_v2" =>
|
||||
{
|
||||
let mut args = args.tuple()?;
|
||||
let callable = args.remove(0);
|
||||
let args = args.remove(1);
|
||||
(callable, args)
|
||||
}
|
||||
_ => (callable, args),
|
||||
};
|
||||
match callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
|
||||
_ => continue,
|
||||
};
|
||||
match rebuild_args(args) {
|
||||
Ok((layout, dtype, file_path, storage_size)) => {
|
||||
let mut path = dir_name.clone();
|
||||
path.push(file_path);
|
||||
tensor_infos.push(TensorInfo {
|
||||
name,
|
||||
dtype,
|
||||
layout,
|
||||
path: path.to_string_lossy().into_owned(),
|
||||
storage_size,
|
||||
})
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("skipping {name}: {err:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -728,16 +723,3 @@ impl PthTensors {
|
||||
Ok(Some(tensor))
|
||||
}
|
||||
}
|
||||
|
||||
/// Read all the tensors from a PyTorch pth file.
|
||||
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
||||
let pth = PthTensors::new(path)?;
|
||||
let tensor_names = pth.tensor_infos.keys();
|
||||
let mut tensors = Vec::with_capacity(tensor_names.len());
|
||||
for name in tensor_names {
|
||||
if let Some(tensor) = pth.get(name)? {
|
||||
tensors.push((name.to_string(), tensor))
|
||||
}
|
||||
}
|
||||
Ok(tensors)
|
||||
}
|
||||
|
@ -50,9 +50,14 @@ pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
@ -633,35 +638,3 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
||||
Ok(hsum_float_8(acc) + summs)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % qk != 0 {
|
||||
crate::bail!("vec_dot_q8k_8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
let mut sumi = _mm256_setzero_si256();
|
||||
let x_qs = xs.qs.as_ptr();
|
||||
let y_qs = ys.qs.as_ptr();
|
||||
for j in (0..QK_K).step_by(32) {
|
||||
let xs = _mm256_loadu_si256(x_qs.add(j) as *const __m256i);
|
||||
let ys = _mm256_loadu_si256(y_qs.add(j) as *const __m256i);
|
||||
|
||||
let xs0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 0));
|
||||
let ys0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 0));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs0, ys0));
|
||||
|
||||
let xs1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 1));
|
||||
let ys1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 1));
|
||||
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs1, ys1));
|
||||
}
|
||||
let d = _mm256_set1_ps(xs.d * ys.d);
|
||||
acc = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi), acc);
|
||||
}
|
||||
Ok(hsum_float_8(acc))
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
//! Support for the GGML file format.
|
||||
|
||||
use super::{k_quants, GgmlDType};
|
||||
use crate::{Device, Result};
|
||||
use crate::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -121,12 +121,11 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
||||
raw_data: &[u8],
|
||||
size_in_bytes: usize,
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<super::QTensor> {
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
super::QTensor::new(data.to_vec(), dims, device)
|
||||
super::QTensor::new(data.to_vec(), dims)
|
||||
}
|
||||
|
||||
/// Creates a [Tensor] from a raw GGML tensor.
|
||||
@ -134,50 +133,23 @@ pub fn qtensor_from_ggml(
|
||||
ggml_dtype: GgmlDType,
|
||||
raw_data: &[u8],
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<super::QTensor> {
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let blck_size = ggml_dtype.blck_size();
|
||||
if tensor_elems % blck_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||
|
||||
match ggml_dtype {
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
|
||||
GgmlDType::Q4_0 => {
|
||||
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
|
||||
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
||||
}
|
||||
}
|
||||
@ -185,7 +157,6 @@ pub fn qtensor_from_ggml(
|
||||
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
magic: VersionedMagic,
|
||||
device: &Device,
|
||||
) -> Result<(String, super::QTensor)> {
|
||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||
@ -210,7 +181,7 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
// TODO: Mmap version to avoid copying the data around?
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
|
||||
Ok(tensor) => Ok((name, tensor)),
|
||||
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
||||
}
|
||||
@ -224,10 +195,7 @@ pub struct Content {
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Content> {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||
@ -237,7 +205,7 @@ impl Content {
|
||||
let mut tensors = HashMap::new();
|
||||
|
||||
while reader.stream_position()? != last_position {
|
||||
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||
let (name, tensor) = read_one_tensor(reader, magic)?;
|
||||
tensors.insert(name, tensor);
|
||||
}
|
||||
Ok(Self {
|
||||
|
@ -3,7 +3,7 @@
|
||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||
|
||||
use super::{GgmlDType, QTensor};
|
||||
use crate::{Device, Result};
|
||||
use crate::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -57,25 +57,14 @@ impl TensorInfo {
|
||||
&self,
|
||||
reader: &mut R,
|
||||
tensor_data_offset: u64,
|
||||
device: &Device,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_elems = self.shape.elem_count();
|
||||
let blck_size = self.ggml_dtype.blck_size();
|
||||
if tensor_elems % blck_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
|
||||
let size_in_bytes =
|
||||
tensor_elems * self.ggml_dtype.type_size() / self.ggml_dtype.blck_size();
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
super::ggml_file::qtensor_from_ggml(
|
||||
self.ggml_dtype,
|
||||
&raw_data,
|
||||
self.shape.dims().to_vec(),
|
||||
device,
|
||||
)
|
||||
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
@ -456,13 +445,12 @@ impl Content {
|
||||
&self,
|
||||
reader: &mut R,
|
||||
name: &str,
|
||||
device: &Device,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
Some(tensor_info) => tensor_info,
|
||||
None => crate::bail!("cannot find tensor-infor for {name}"),
|
||||
};
|
||||
tensor_info.read(reader, self.tensor_data_offset, device)
|
||||
tensor_info.read(reader, self.tensor_data_offset)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -34,9 +34,6 @@ pub trait GgmlType: Sized + Clone + Send + Sync {
|
||||
/// Dot product used as a building block for quantized mat-mul.
|
||||
/// n is the number of elements to be considered.
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
|
||||
|
||||
/// Generic implementation of the dot product without simd optimizations.
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@ -228,17 +225,15 @@ impl GgmlType for BlockQ4_0 {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q4_0_q8_0(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q4_0_q8_0(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
// Generic implementation.
|
||||
let mut sumf = 0f32;
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
@ -260,10 +255,6 @@ impl GgmlType for BlockQ4_1 {
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
// ggml_vec_dot_q4_1_q8_1
|
||||
let qk = QK8_1;
|
||||
if n % qk != 0 {
|
||||
@ -363,10 +354,7 @@ impl GgmlType for BlockQ5_0 {
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q5_0_q8_0: {n}, nb is not divisible by 2")
|
||||
}
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(_n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
// Generic implementation.
|
||||
let mut sumf = 0f32;
|
||||
|
||||
@ -457,10 +445,6 @@ impl GgmlType for BlockQ5_1 {
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = Self::BLCK_SIZE;
|
||||
if n % Self::BLCK_SIZE != 0 {
|
||||
crate::bail!("vec_dot_q5_1_q8_1: {n} is not divisible by {qk}")
|
||||
@ -622,13 +606,6 @@ impl GgmlType for BlockQ8_0 {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q8_0_q8_0(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q8_0_q8_0(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
@ -654,11 +631,7 @@ impl GgmlType for BlockQ8_1 {
|
||||
const BLCK_SIZE: usize = QK8_1;
|
||||
type VecDotType = BlockQ8_1;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
unimplemented!("no support for vec-dot on Q8_1")
|
||||
}
|
||||
|
||||
@ -708,13 +681,6 @@ impl GgmlType for BlockQ2K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q2k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q2k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -735,17 +701,18 @@ impl GgmlType for BlockQ2K {
|
||||
|
||||
let mut isum = 0;
|
||||
let mut is = 0;
|
||||
let mut d;
|
||||
for _ in 0..(QK_K / 128) {
|
||||
let mut shift = 0;
|
||||
for _ in 0..4 {
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = 0;
|
||||
for l in 0..16 {
|
||||
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
|
||||
}
|
||||
isum += d * isuml;
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
isuml = 0;
|
||||
for l in 16..32 {
|
||||
@ -884,10 +851,6 @@ impl GgmlType for BlockQ3K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q3k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -1114,6 +1077,7 @@ impl GgmlType for BlockQ3K {
|
||||
let d_all = block.d.to_f32();
|
||||
let mut m = 1;
|
||||
let mut is = 0;
|
||||
let mut dl;
|
||||
|
||||
// Dequantize both 128 long blocks
|
||||
// 32 qs values per 128 long block
|
||||
@ -1124,7 +1088,7 @@ impl GgmlType for BlockQ3K {
|
||||
for (scale_index, scale_scoped_y) in
|
||||
shift_scoped_y.chunks_exact_mut(16).enumerate()
|
||||
{
|
||||
let dl = d_all * (scales[is] as f32 - 32.0);
|
||||
dl = d_all * (scales[is] as f32 - 32.0);
|
||||
for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() {
|
||||
let new_y = dl
|
||||
* (((qs[i + 16 * scale_index] >> shift) & 3) as i8
|
||||
@ -1162,13 +1126,6 @@ impl GgmlType for BlockQ4K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q4k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q4k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -1355,10 +1312,6 @@ impl GgmlType for BlockQ5K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q5k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -1576,13 +1529,6 @@ impl GgmlType for BlockQ6K {
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q6k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q6k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
@ -1751,38 +1697,8 @@ impl GgmlType for BlockQ8K {
|
||||
const BLCK_SIZE: usize = QK_K;
|
||||
type VecDotType = BlockQ8K;
|
||||
|
||||
#[allow(unreachable_code)]
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
#[cfg(target_feature = "avx")]
|
||||
return super::avx::vec_dot_q8k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q8k_q8k(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q8k_q8k(n, xs, ys);
|
||||
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
// Generic implementation.
|
||||
let mut sumf = 0f32;
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
let sum_i = xs
|
||||
.qs
|
||||
.iter()
|
||||
.zip(ys.qs.iter())
|
||||
.map(|(&x, &y)| x as i32 * y as i32)
|
||||
.sum::<i32>();
|
||||
sumf += sum_i as f32 * xs.d * ys.d
|
||||
}
|
||||
Ok(sumf)
|
||||
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
@ -1888,10 +1804,6 @@ impl GgmlType for f32 {
|
||||
type VecDotType = f32;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if xs.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", xs.len())
|
||||
}
|
||||
@ -1926,10 +1838,6 @@ impl GgmlType for f16 {
|
||||
type VecDotType = f16;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
Self::vec_dot_unopt(n, xs, ys)
|
||||
}
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if xs.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", xs.len())
|
||||
}
|
||||
|
@ -1,5 +1,4 @@
|
||||
use crate::{Device, Result, Shape, Tensor};
|
||||
use tracing::debug;
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
@ -8,14 +7,11 @@ pub mod gguf_file;
|
||||
pub mod k_quants;
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub mod simd128;
|
||||
pub mod utils;
|
||||
|
||||
pub use k_quants::GgmlType;
|
||||
|
||||
pub struct QTensor {
|
||||
device: Device,
|
||||
data: Box<dyn QuantizedType>,
|
||||
shape: Shape,
|
||||
}
|
||||
@ -172,20 +168,17 @@ impl QTensor {
|
||||
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||
data: Vec<T>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
check_shape::<T>(&shape)?;
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
shape,
|
||||
device: device.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
|
||||
let shape = src.shape();
|
||||
let device = src.device();
|
||||
check_shape::<T>(shape)?;
|
||||
let src = src
|
||||
.to_dtype(crate::DType::F32)?
|
||||
@ -202,7 +195,6 @@ impl QTensor {
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
shape: shape.clone(),
|
||||
device: device.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -218,12 +210,7 @@ impl QTensor {
|
||||
&self.shape
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||
// TODO Skip the CPU part on metal
|
||||
let mut f32_data = vec![0f32; self.shape.elem_count()];
|
||||
self.data.to_float(&mut f32_data)?;
|
||||
Tensor::from_vec(f32_data, &self.shape, device)
|
||||
@ -242,40 +229,20 @@ impl QTensor {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum QMatMul {
|
||||
QTensor(std::sync::Arc<QTensor>),
|
||||
Tensor(Tensor),
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static DEQUANTIZE_ALL: bool = {
|
||||
match std::env::var("CANDLE_DEQUANTIZE_ALL") {
|
||||
Ok(s) => {
|
||||
!s.is_empty() && s != "0"
|
||||
},
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
#[derive(Debug)]
|
||||
pub struct QMatMul(std::sync::Arc<QTensor>);
|
||||
|
||||
impl QMatMul {
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||
let dequantize = match qtensor.dtype() {
|
||||
GgmlDType::F32 | GgmlDType::F16 => true,
|
||||
_ => DEQUANTIZE_ALL.with(|b| *b),
|
||||
};
|
||||
let t = if dequantize {
|
||||
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||
Self::Tensor(tensor)
|
||||
} else {
|
||||
Self::QTensor(qtensor)
|
||||
};
|
||||
Ok(t)
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
|
||||
Self(qtensor)
|
||||
}
|
||||
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||
Self::from_arc(std::sync::Arc::new(qtensor))
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Self {
|
||||
Self(std::sync::Arc::new(qtensor))
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
@ -316,60 +283,10 @@ impl crate::CustomOp1 for QTensor {
|
||||
)?;
|
||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||
}
|
||||
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &crate::MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::MetalStorage, Shape)> {
|
||||
debug!("TODO qmatmul");
|
||||
if !layout.is_contiguous() {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
// self is transposed so n is first then k.
|
||||
let (n, k) = self.shape.dims2()?;
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
// let storage = storage.as_slice::<f32>()?;
|
||||
// let storage =
|
||||
// &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
let dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||
// self.matmul_t(
|
||||
// (dst_shape.elem_count() / n, k, n),
|
||||
// storage,
|
||||
// &mut dst_storage,
|
||||
// )?;
|
||||
let cpu_storage = crate::CpuStorage::F32(dst_storage);
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
if let Device::Metal(device) = &self.device {
|
||||
Ok((device.storage_from_cpu_storage(&cpu_storage)?, dst_shape))
|
||||
} else {
|
||||
crate::bail!("qtensor not on metal device")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||
Self::Tensor(w) => {
|
||||
let w = match *xs.dims() {
|
||||
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||
_ => w.t()?,
|
||||
};
|
||||
xs.matmul(&w)
|
||||
}
|
||||
}
|
||||
xs.apply_op1_no_bwd(self.0.as_ref())
|
||||
}
|
||||
}
|
||||
|
@ -19,29 +19,42 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||
for i in 0..nb {
|
||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
||||
for i in (0..nb).step_by(2) {
|
||||
let x0 = &xs[i];
|
||||
let x1 = &xs[i + 1];
|
||||
let y0 = &ys[i];
|
||||
let y1 = &ys[i + 1];
|
||||
|
||||
let m4b = vdupq_n_u8(0x0F);
|
||||
let s8b = vdupq_n_s8(0x8);
|
||||
|
||||
let v0_0 = vld1q_u8(x0.qs.as_ptr());
|
||||
let v0_1 = vld1q_u8(x1.qs.as_ptr());
|
||||
|
||||
// 4-bit -> 8-bit
|
||||
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
||||
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
||||
let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
||||
let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
||||
|
||||
// sub 8
|
||||
let v0_0ls = vsubq_s8(v0_0l, s8b);
|
||||
let v0_0hs = vsubq_s8(v0_0h, s8b);
|
||||
let v0_1ls = vsubq_s8(v0_1l, s8b);
|
||||
let v0_1hs = vsubq_s8(v0_1h, s8b);
|
||||
|
||||
// load y
|
||||
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
||||
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
let v1_1l = vld1q_s8(y1.qs.as_ptr());
|
||||
let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16));
|
||||
|
||||
// TODO: Support dotprod when it's available outside of nightly.
|
||||
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
|
||||
@ -49,16 +62,28 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
||||
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
|
||||
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
||||
|
||||
let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l));
|
||||
let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
|
||||
let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h));
|
||||
let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
|
||||
|
||||
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||
let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
||||
let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
||||
x0.d.to_f32() * y0.d.to_f32(),
|
||||
);
|
||||
sumv1 = vmlaq_n_f32(
|
||||
sumv1,
|
||||
vcvtq_f32_s32(vaddq_s32(pl1, ph1)),
|
||||
x1.d.to_f32() * y1.d.to_f32(),
|
||||
);
|
||||
}
|
||||
Ok(vaddvq_f32(sumv0))
|
||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,18 +94,28 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||
for i in 0..nb {
|
||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
||||
for i in (0..nb).step_by(2) {
|
||||
let x0 = &xs[i];
|
||||
let x1 = &xs[i + 1];
|
||||
let y0 = &ys[i];
|
||||
let y1 = &ys[i + 1];
|
||||
|
||||
let x0_0 = vld1q_s8(x0.qs.as_ptr());
|
||||
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
|
||||
let x1_0 = vld1q_s8(x1.qs.as_ptr());
|
||||
let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));
|
||||
|
||||
// load y
|
||||
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
||||
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
let y1_0 = vld1q_s8(y1.qs.as_ptr());
|
||||
let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));
|
||||
|
||||
// TODO dotprod once this is the intrinsics are.
|
||||
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
||||
@ -88,48 +123,31 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
||||
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
||||
|
||||
let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
|
||||
let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
||||
let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
|
||||
let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
||||
|
||||
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
||||
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
||||
let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
||||
let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(p0, p1)),
|
||||
x0.d.to_f32() * y0.d.to_f32(),
|
||||
);
|
||||
sumv1 = vmlaq_n_f32(
|
||||
sumv1,
|
||||
vcvtq_f32_s32(vaddq_s32(p2, p3)),
|
||||
x1.d.to_f32() * y1.d.to_f32(),
|
||||
);
|
||||
}
|
||||
Ok(vaddvq_f32(sumv0))
|
||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
let mut sumf = 0f32;
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
unsafe {
|
||||
let mut sum_i = vdupq_n_s32(0);
|
||||
let scale = xs.d * ys.d;
|
||||
let xs = xs.qs.as_ptr();
|
||||
let ys = ys.qs.as_ptr();
|
||||
for i in (0..QK_K).step_by(16) {
|
||||
let xs = vld1q_s8(xs.add(i));
|
||||
let ys = vld1q_s8(ys.add(i));
|
||||
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
|
||||
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
|
||||
|
||||
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
|
||||
sum_i = vaddq_s32(sum_i, xy)
|
||||
}
|
||||
sumf += vaddvq_s32(sum_i) as f32 * scale
|
||||
}
|
||||
}
|
||||
Ok(sumf)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
|
@ -1,419 +0,0 @@
|
||||
use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
|
||||
use crate::Result;
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use half::f16;
|
||||
|
||||
use core::arch::wasm32::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let x1234 = v128_load(x.qs.as_ptr() as *const v128);
|
||||
let x12 = v128_and(x1234, u8x16_splat(0x0F));
|
||||
let x12 = i8x16_sub(x12, i8x16_splat(8));
|
||||
let x34 = u8x16_shr(x1234, 4);
|
||||
let x34 = i8x16_sub(x34, i8x16_splat(8));
|
||||
|
||||
let x1 = i16x8_extend_low_i8x16(x12);
|
||||
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
|
||||
let sum_xy = i32x4_dot_i16x8(x1, y1);
|
||||
|
||||
let x2 = i16x8_extend_high_i8x16(x12);
|
||||
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
|
||||
|
||||
let x3 = i16x8_extend_low_i8x16(x34);
|
||||
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
|
||||
|
||||
let x4 = i16x8_extend_high_i8x16(x34);
|
||||
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
|
||||
|
||||
let sum_xy = f32x4_convert_i32x4(sum_xy);
|
||||
|
||||
// f32x4_relaxed_madd is nightly only.
|
||||
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let scaled = f32x4_mul(sum_xy, d);
|
||||
acc = f32x4_add(acc, scaled)
|
||||
}
|
||||
let res = f32x4_extract_lane::<0>(acc)
|
||||
+ f32x4_extract_lane::<1>(acc)
|
||||
+ f32x4_extract_lane::<2>(acc)
|
||||
+ f32x4_extract_lane::<3>(acc);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let x1 = i16x8_load_extend_i8x8(x.qs.as_ptr());
|
||||
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
|
||||
let sum_xy = i32x4_dot_i16x8(x1, y1);
|
||||
|
||||
let x2 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(8));
|
||||
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
|
||||
|
||||
let x3 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(16));
|
||||
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
|
||||
|
||||
let x4 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(24));
|
||||
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
|
||||
|
||||
let sum_xy = f32x4_convert_i32x4(sum_xy);
|
||||
|
||||
// f32x4_relaxed_madd is nightly only.
|
||||
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let scaled = f32x4_mul(sum_xy, d);
|
||||
acc = f32x4_add(acc, scaled)
|
||||
}
|
||||
let res = f32x4_extract_lane::<0>(acc)
|
||||
+ f32x4_extract_lane::<1>(acc)
|
||||
+ f32x4_extract_lane::<2>(acc)
|
||||
+ f32x4_extract_lane::<3>(acc);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
unsafe {
|
||||
let mut sumf = f32x4_splat(0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let mut q2: &[_] = &x.qs;
|
||||
let mut q8: &[_] = &y.qs;
|
||||
let sc = &x.scales;
|
||||
|
||||
let mut summs = i32x4_splat(0);
|
||||
for i in (0..(QK_K / 16)).step_by(4) {
|
||||
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(i));
|
||||
let scales = i32x4_shr(
|
||||
i32x4(
|
||||
sc[i] as i32,
|
||||
sc[i + 1] as i32,
|
||||
sc[i + 2] as i32,
|
||||
sc[i + 3] as i32,
|
||||
),
|
||||
4,
|
||||
);
|
||||
summs = i32x4_add(summs, i32x4_mul(bsums, scales))
|
||||
}
|
||||
let summs = f32x4_convert_i32x4(summs);
|
||||
|
||||
let dall = y.d * x.d.to_f32();
|
||||
let dmin = y.d * x.dmin.to_f32();
|
||||
|
||||
let mut isum = i32x4_splat(0);
|
||||
let mut is = 0;
|
||||
for _ in 0..(QK_K / 128) {
|
||||
let mut shift = 0;
|
||||
for _ in 0..4 {
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = i16x8_splat(0);
|
||||
for l in (0..16).step_by(8) {
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
|
||||
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
|
||||
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
|
||||
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
|
||||
}
|
||||
let dd = i32x4_splat(d);
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
|
||||
let d = (sc[is] & 0xF) as i32;
|
||||
is += 1;
|
||||
let mut isuml = i16x8_splat(0);
|
||||
for l in (16..32).step_by(8) {
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
|
||||
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
|
||||
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
|
||||
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
|
||||
}
|
||||
let dd = i32x4_splat(d);
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
|
||||
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
|
||||
shift += 2;
|
||||
// adjust the indexing
|
||||
q8 = &q8[32..];
|
||||
}
|
||||
// adjust the indexing
|
||||
q2 = &q2[32..];
|
||||
}
|
||||
let isum = f32x4_convert_i32x4(isum);
|
||||
sumf = f32x4_add(
|
||||
sumf,
|
||||
f32x4_sub(
|
||||
f32x4_mul(isum, f32x4_splat(dall)),
|
||||
f32x4_mul(summs, f32x4_splat(dmin)),
|
||||
),
|
||||
);
|
||||
}
|
||||
let sumf = f32x4_extract_lane::<0>(sumf)
|
||||
+ f32x4_extract_lane::<1>(sumf)
|
||||
+ f32x4_extract_lane::<2>(sumf)
|
||||
+ f32x4_extract_lane::<3>(sumf);
|
||||
Ok(sumf)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
const KMASK1: u32 = 0x3f3f3f3f;
|
||||
const KMASK2: u32 = 0x0f0f0f0f;
|
||||
const KMASK3: u32 = 0x03030303;
|
||||
|
||||
let mut utmp: [u32; 4] = [0; 4];
|
||||
let mut scales: [u8; 8] = [0; 8];
|
||||
let mut mins: [u8; 8] = [0; 8];
|
||||
|
||||
let mut aux8: [u8; QK_K] = [0; QK_K];
|
||||
let mut sums = f32x4_splat(0f32);
|
||||
unsafe {
|
||||
for (y, x) in ys.iter().zip(xs.iter()) {
|
||||
let q4 = &x.qs;
|
||||
let q8 = &y.qs;
|
||||
|
||||
for j in 0..QK_K / 64 {
|
||||
let q4_1 = v128_load(q4.as_ptr().add(32 * j) as *const v128);
|
||||
let q4_2 = v128_load(q4.as_ptr().add(32 * j + 16) as *const v128);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j) as *mut v128,
|
||||
v128_and(q4_1, u8x16_splat(0x0F)),
|
||||
);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j + 16) as *mut v128,
|
||||
v128_and(q4_2, u8x16_splat(0x0F)),
|
||||
);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j + 32) as *mut v128,
|
||||
u8x16_shr(q4_1, 4),
|
||||
);
|
||||
v128_store(
|
||||
aux8.as_mut_ptr().add(64 * j + 48) as *mut v128,
|
||||
u8x16_shr(q4_2, 4),
|
||||
);
|
||||
}
|
||||
|
||||
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
|
||||
|
||||
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
|
||||
let uaux = utmp[1] & KMASK1;
|
||||
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= KMASK1;
|
||||
|
||||
//extract scales and mins
|
||||
LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
|
||||
LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
|
||||
|
||||
let mut sumi = i32x4_splat(0);
|
||||
for j in (0..QK_K / 16).step_by(4) {
|
||||
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(j));
|
||||
let (m1, m2) = (mins[j / 2] as i32, mins[j / 2 + 1] as i32);
|
||||
let mins = i32x4(m1, m1, m2, m2);
|
||||
sumi = i32x4_add(sumi, i32x4_mul(bsums, mins));
|
||||
}
|
||||
|
||||
let mut aux32 = i32x4_splat(0i32);
|
||||
for (scale_i, scale) in scales.iter().enumerate() {
|
||||
let scale = i32x4_splat(*scale as i32);
|
||||
for j in 0..4 {
|
||||
let i = 32 * scale_i + 8 * j;
|
||||
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(i));
|
||||
let aux8 = i16x8_load_extend_u8x8(aux8.as_ptr().add(i));
|
||||
let aux16 = i16x8_mul(q8, aux8);
|
||||
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_low_i16x8(aux16)));
|
||||
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_high_i16x8(aux16)));
|
||||
}
|
||||
}
|
||||
let aux32 = f32x4_convert_i32x4(aux32);
|
||||
let d = f32x4_splat(x.d.to_f32() * y.d);
|
||||
sums = f32x4_add(sums, f32x4_mul(aux32, d));
|
||||
let dmin = x.dmin.to_f32() * y.d;
|
||||
let dmin = f32x4_splat(dmin);
|
||||
let sumi = f32x4_convert_i32x4(sumi);
|
||||
sums = f32x4_sub(sums, f32x4_mul(sumi, dmin));
|
||||
}
|
||||
let sums = f32x4_extract_lane::<0>(sums)
|
||||
+ f32x4_extract_lane::<1>(sums)
|
||||
+ f32x4_extract_lane::<2>(sums)
|
||||
+ f32x4_extract_lane::<3>(sums);
|
||||
Ok(sums)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
|
||||
}
|
||||
|
||||
let mut aux8 = [0i8; QK_K];
|
||||
unsafe {
|
||||
let mut sums = f32x4_splat(0f32);
|
||||
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let q4 = &x.ql;
|
||||
let qh = &x.qh;
|
||||
let q8 = &y.qs;
|
||||
let mut aux32 = f32x4_splat(0f32);
|
||||
|
||||
for j in (0..QK_K).step_by(128) {
|
||||
let aux8 = aux8.as_mut_ptr().add(j);
|
||||
let q4 = &q4.as_ptr().add(j / 2);
|
||||
let qh = &qh.as_ptr().add(j / 4);
|
||||
for l in (0..32).step_by(16) {
|
||||
// aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
v128_and(v128_load(q4.add(l) as *const v128), u8x16_splat(0xF)),
|
||||
u8x16_shl(
|
||||
v128_and(v128_load(qh.add(l) as *const v128), u8x16_splat(3)),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
|
||||
// aux8[l + 32] =
|
||||
// (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
v128_and(v128_load(q4.add(l + 32) as *const v128), u8x16_splat(0xF)),
|
||||
u8x16_shl(
|
||||
v128_and(
|
||||
u8x16_shr(v128_load(qh.add(l) as *const v128), 2),
|
||||
u8x16_splat(3),
|
||||
),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l + 32) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
|
||||
// aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
u8x16_shr(v128_load(q4.add(l) as *const v128), 4),
|
||||
u8x16_shl(
|
||||
v128_and(
|
||||
u8x16_shr(v128_load(qh.add(l) as *const v128), 4),
|
||||
u8x16_splat(3),
|
||||
),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l + 64) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
|
||||
// aux8[l + 96] =
|
||||
// (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8;
|
||||
let a8 = v128_or(
|
||||
u8x16_shr(v128_load(q4.add(l + 32) as *const v128), 4),
|
||||
u8x16_shl(
|
||||
v128_and(
|
||||
u8x16_shr(v128_load(qh.add(l) as *const v128), 6),
|
||||
u8x16_splat(3),
|
||||
),
|
||||
4,
|
||||
),
|
||||
);
|
||||
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
|
||||
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
|
||||
v128_store(
|
||||
aux8.add(l + 96) as *mut v128,
|
||||
i8x16_narrow_i16x8(a8_low, a8_high),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for (j, &scale) in x.scales.iter().enumerate() {
|
||||
let scale = f32x4_splat(scale as f32);
|
||||
for offset in [0, 8] {
|
||||
let aux16 = i16x8_mul(
|
||||
i16x8_load_extend_i8x8(q8.as_ptr().add(16 * j + offset)),
|
||||
i16x8_load_extend_i8x8(aux8.as_ptr().add(16 * j + offset)),
|
||||
);
|
||||
aux32 = f32x4_add(
|
||||
aux32,
|
||||
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_low_i16x8(aux16)), scale),
|
||||
);
|
||||
aux32 = f32x4_add(
|
||||
aux32,
|
||||
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_high_i16x8(aux16)), scale),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let d = f32x4_splat(x.d.to_f32() * y.d);
|
||||
sums = f32x4_add(sums, f32x4_mul(aux32, d));
|
||||
}
|
||||
let sums = f32x4_extract_lane::<0>(sums)
|
||||
+ f32x4_extract_lane::<1>(sums)
|
||||
+ f32x4_extract_lane::<2>(sums)
|
||||
+ f32x4_extract_lane::<3>(sums);
|
||||
Ok(sums)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
|
||||
let qk = QK_K;
|
||||
if n % QK_K != 0 {
|
||||
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
let x_qs = xs.qs.as_ptr();
|
||||
let y_qs = ys.qs.as_ptr();
|
||||
let mut sumi = i32x4_splat(0);
|
||||
for j in (0..QK_K).step_by(8) {
|
||||
let xs = i16x8_load_extend_i8x8(x_qs.add(j));
|
||||
let ys = i16x8_load_extend_i8x8(y_qs.add(j));
|
||||
let sum_xy = i32x4_dot_i16x8(xs, ys);
|
||||
sumi = i32x4_add(sumi, sum_xy)
|
||||
}
|
||||
let d = f32x4_splat(xs.d * ys.d);
|
||||
acc = f32x4_add(acc, f32x4_mul(f32x4_convert_i32x4(sumi), d))
|
||||
}
|
||||
let res = f32x4_extract_lane::<0>(acc)
|
||||
+ f32x4_extract_lane::<1>(acc)
|
||||
+ f32x4_extract_lane::<2>(acc)
|
||||
+ f32x4_extract_lane::<3>(acc);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
@ -17,7 +17,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
|
||||
let expected_blocks = xs.len() / block_size;
|
||||
let actual_blocks = ys.len();
|
||||
|
||||
// Validate that the input is the right size
|
||||
//validate that the input is the right size
|
||||
if expected_blocks != actual_blocks {
|
||||
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
|
||||
}
|
||||
@ -37,12 +37,12 @@ pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
|
||||
|
||||
let actual_output_len = ys.len();
|
||||
let expected_output_len = xs.len() * block_size;
|
||||
// Validate that the output is the right size
|
||||
//validate that the output is the right size
|
||||
if expected_output_len != actual_output_len {
|
||||
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
|
||||
}
|
||||
|
||||
// Zip the blocks and outputs together
|
||||
//zip the blocks and outputs together
|
||||
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
|
||||
}
|
||||
|
||||
|
@ -251,134 +251,6 @@ pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
|
||||
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
|
||||
}
|
||||
|
||||
#[derive(yoke::Yokeable)]
|
||||
struct SafeTensors_<'a>(SafeTensors<'a>);
|
||||
|
||||
pub struct MmapedSafetensors {
|
||||
safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
|
||||
routing: Option<HashMap<String, usize>>,
|
||||
}
|
||||
|
||||
impl MmapedSafetensors {
|
||||
/// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
|
||||
let p = p.as_ref();
|
||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||
let file = memmap2::MmapOptions::new()
|
||||
.map(&file)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
|
||||
file,
|
||||
|data: &[u8]| {
|
||||
let st = safetensors::SafeTensors::deserialize(data)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
Ok::<_, Error>(SafeTensors_(st))
|
||||
},
|
||||
)?;
|
||||
Ok(Self {
|
||||
safetensors: vec![safetensors],
|
||||
routing: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
|
||||
///
|
||||
/// If a tensor name appears in multiple files, the last entry is returned.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
|
||||
let mut routing = HashMap::new();
|
||||
let mut safetensors = vec![];
|
||||
for (index, p) in paths.iter().enumerate() {
|
||||
let p = p.as_ref();
|
||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||
let file = memmap2::MmapOptions::new()
|
||||
.map(&file)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
|
||||
file,
|
||||
|data: &[u8]| {
|
||||
let st = safetensors::SafeTensors::deserialize(data)
|
||||
.map_err(|e| Error::from(e).with_path(p))?;
|
||||
Ok::<_, Error>(SafeTensors_(st))
|
||||
},
|
||||
)?;
|
||||
for k in data.get().0.names() {
|
||||
routing.insert(k.to_string(), index);
|
||||
}
|
||||
safetensors.push(data)
|
||||
}
|
||||
Ok(Self {
|
||||
safetensors,
|
||||
routing: Some(routing),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||
self.get(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
let mut tensors = vec![];
|
||||
for safetensors in self.safetensors.iter() {
|
||||
tensors.push(safetensors.get().0.tensors())
|
||||
}
|
||||
tensors.into_iter().flatten().collect()
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||
let index = match &self.routing {
|
||||
None => 0,
|
||||
Some(routing) => {
|
||||
let index = routing.get(name).ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: name.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?;
|
||||
*index
|
||||
}
|
||||
};
|
||||
Ok(self.safetensors[index].get().0.tensor(name)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BufferedSafetensors {
|
||||
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
||||
}
|
||||
|
||||
impl BufferedSafetensors {
|
||||
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
|
||||
pub fn new(buffer: Vec<u8>) -> Result<Self> {
|
||||
let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
|
||||
buffer,
|
||||
|data: &[u8]| {
|
||||
let st = safetensors::SafeTensors::deserialize(data)?;
|
||||
Ok::<_, Error>(SafeTensors_(st))
|
||||
},
|
||||
)?;
|
||||
Ok(Self { safetensors })
|
||||
}
|
||||
|
||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||
self.get(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
self.safetensors.get().0.tensors()
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||
Ok(self.safetensors.get().0.tensor(name)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MmapedFile {
|
||||
path: std::path::PathBuf,
|
||||
inner: memmap2::Mmap,
|
||||
|
@ -203,7 +203,7 @@ impl Shape {
|
||||
|
||||
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
|
||||
/// broadcasted shape. This is to be used for binary pointwise ops.
|
||||
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||
let lhs = self;
|
||||
let lhs_dims = lhs.dims();
|
||||
let rhs_dims = rhs.dims();
|
||||
@ -511,119 +511,154 @@ impl ShapeWithOneHole for ((),) {
|
||||
}
|
||||
}
|
||||
|
||||
fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
|
||||
if prod_d == 0 {
|
||||
crate::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
|
||||
}
|
||||
if el_count % prod_d != 0 {
|
||||
crate::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
|
||||
}
|
||||
Ok(el_count / prod_d)
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1) = self;
|
||||
Ok((hole_size(el_count, d1, &self)?, d1).into())
|
||||
if el_count % d1 != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
|
||||
}
|
||||
Ok((el_count / d1, d1).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, ()) = self;
|
||||
Ok((d1, hole_size(el_count, d1, &self)?).into())
|
||||
if el_count % d1 != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
|
||||
}
|
||||
Ok((d1, el_count / d1).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2) = self;
|
||||
Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into())
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2) = self;
|
||||
Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into())
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, ()) = self;
|
||||
Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into())
|
||||
let d = d1 * d2;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2, d3) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||
Ok((d, d1, d2, d3).into())
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2, d3) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||
Ok((d1, d, d2, d3).into())
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, (), d3) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||
Ok((d1, d2, d, d3).into())
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d, d3).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, ()) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
|
||||
Ok((d1, d2, d3, d).into())
|
||||
let d = d1 * d2 * d3;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, el_count / d).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let ((), d1, d2, d3, d4) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d, d1, d2, d3, d4).into())
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((el_count / d, d1, d2, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, (), d2, d3, d4) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d1, d, d2, d3, d4).into())
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, el_count / d, d2, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, (), d3, d4) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d1, d2, d, d3, d4).into())
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, el_count / d, d3, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, (), d4) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d1, d2, d3, d, d4).into())
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, el_count / d, d4).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape> {
|
||||
let (d1, d2, d3, d4, ()) = self;
|
||||
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
|
||||
Ok((d1, d2, d3, d4, d).into())
|
||||
let d = d1 * d2 * d3 * d4;
|
||||
if el_count % d != 0 {
|
||||
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
|
||||
}
|
||||
Ok((d1, d2, d3, d4, el_count / d).into())
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
// out of memory. Instead try_clone should be used.
|
||||
@ -8,7 +8,6 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage,
|
||||
pub enum Storage {
|
||||
Cpu(CpuStorage),
|
||||
Cuda(CudaStorage),
|
||||
Metal(MetalStorage),
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
@ -19,10 +18,6 @@ impl Storage {
|
||||
let storage = storage.try_clone(layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.try_clone(layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -30,7 +25,6 @@ impl Storage {
|
||||
match self {
|
||||
Self::Cpu(_) => Device::Cpu,
|
||||
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||
Self::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,7 +32,6 @@ impl Storage {
|
||||
match self {
|
||||
Self::Cpu(storage) => storage.dtype(),
|
||||
Self::Cuda(storage) => storage.dtype(),
|
||||
Self::Metal(storage) => storage.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,10 +65,6 @@ impl Storage {
|
||||
let storage = storage.affine(layout, mul, add)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.affine(layout, mul, add)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -89,10 +78,6 @@ impl Storage {
|
||||
let storage = storage.powf(layout, alpha)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.powf(layout, alpha)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,10 +91,6 @@ impl Storage {
|
||||
let storage = storage.elu(layout, alpha)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.elu(layout, alpha)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -131,10 +112,6 @@ impl Storage {
|
||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
@ -158,10 +135,6 @@ impl Storage {
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -175,10 +148,6 @@ impl Storage {
|
||||
let storage = storage.to_dtype(layout, dtype)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.to_dtype(layout, dtype)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -192,10 +161,6 @@ impl Storage {
|
||||
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
||||
Ok((Self::Cuda(storage), shape))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let (storage, shape) = c.metal_fwd(storage, l)?;
|
||||
Ok((Self::Metal(storage), shape))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -216,10 +181,6 @@ impl Storage {
|
||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
||||
Ok((Self::Cuda(s), shape))
|
||||
}
|
||||
(Self::Metal(s1), Self::Metal(s2)) => {
|
||||
let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
|
||||
Ok((Self::Metal(s), shape))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -244,10 +205,6 @@ impl Storage {
|
||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||
Ok((Self::Cuda(s), shape))
|
||||
}
|
||||
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
|
||||
let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||
Ok((Self::Metal(s), shape))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -262,10 +219,6 @@ impl Storage {
|
||||
let storage = storage.unary_impl::<B>(layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.unary_impl::<B>(layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -286,10 +239,6 @@ impl Storage {
|
||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
@ -321,10 +270,6 @@ impl Storage {
|
||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -352,10 +297,6 @@ impl Storage {
|
||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -383,10 +324,6 @@ impl Storage {
|
||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -411,10 +348,6 @@ impl Storage {
|
||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -433,10 +366,6 @@ impl Storage {
|
||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -450,10 +379,6 @@ impl Storage {
|
||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -467,10 +392,6 @@ impl Storage {
|
||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -494,10 +415,6 @@ impl Storage {
|
||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
|
||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -524,10 +441,6 @@ impl Storage {
|
||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes)) => {
|
||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -552,10 +465,6 @@ impl Storage {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -580,10 +489,6 @@ impl Storage {
|
||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -605,10 +510,6 @@ impl Storage {
|
||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -636,10 +537,6 @@ impl Storage {
|
||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -659,9 +556,6 @@ impl Storage {
|
||||
match (self, dst) {
|
||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
||||
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
||||
(Self::Metal(src), Self::Metal(dst)) => {
|
||||
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
|
@ -6,7 +6,7 @@ use crate::op::{
|
||||
};
|
||||
use crate::scalar::TensorOrScalar;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// Unique identifier for tensors.
|
||||
@ -177,9 +177,14 @@ impl Tensor {
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let none = BackpropOp::none();
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
if is_variable {
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
} else {
|
||||
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with ones.
|
||||
@ -217,9 +222,14 @@ impl Tensor {
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let none = BackpropOp::none();
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
if is_variable {
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
} else {
|
||||
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with zeros.
|
||||
@ -385,21 +395,11 @@ impl Tensor {
|
||||
step: D,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
if D::is_zero(&step) {
|
||||
crate::bail!("step cannot be zero")
|
||||
}
|
||||
let mut data = vec![];
|
||||
let mut current = start;
|
||||
if step >= D::zero() {
|
||||
while current < end {
|
||||
data.push(current);
|
||||
current += step;
|
||||
}
|
||||
} else {
|
||||
while current > end {
|
||||
data.push(current);
|
||||
current += step;
|
||||
}
|
||||
while current < end {
|
||||
data.push(current);
|
||||
current += step;
|
||||
}
|
||||
let len = data.len();
|
||||
Self::from_vec_impl(data, len, device, false)
|
||||
@ -459,7 +459,7 @@ impl Tensor {
|
||||
|
||||
/// Returns true if the computation graph should track this op, that is if it is
|
||||
/// a variable or if it has some variable as dependencies.
|
||||
pub fn track_op(&self) -> bool {
|
||||
pub(crate) fn track_op(&self) -> bool {
|
||||
self.is_variable || self.op.is_some()
|
||||
}
|
||||
|
||||
@ -489,21 +489,7 @@ impl Tensor {
|
||||
unary_op!(sqr, Sqr);
|
||||
unary_op!(sqrt, Sqrt);
|
||||
unary_op!(gelu, Gelu);
|
||||
unary_op!(gelu_erf, GeluErf);
|
||||
unary_op!(erf, Erf);
|
||||
unary_op!(relu, Relu);
|
||||
unary_op!(ceil, Ceil);
|
||||
unary_op!(floor, Floor);
|
||||
unary_op!(round, Round);
|
||||
|
||||
/// Round element of the input tensor to the nearest integer.
|
||||
///
|
||||
/// If the number of decimals is negative, it specifies the number of positions to the left of
|
||||
/// the decimal point.
|
||||
pub fn round_to(&self, decimals: i32) -> Result<Self> {
|
||||
let mult = 10f64.powi(decimals);
|
||||
(self * mult)?.round()? * (1f64 / mult)
|
||||
}
|
||||
|
||||
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
|
||||
/// dimensions, an error is returned instead.
|
||||
@ -523,7 +509,6 @@ impl Tensor {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -551,73 +536,6 @@ impl Tensor {
|
||||
Ok(inp)
|
||||
}
|
||||
|
||||
/// Creates grids of coordinates specified by the 1D inputs.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `args` - A slice of 1D tensors.
|
||||
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
|
||||
/// first dimension corresponds to the cardinality of the second input and the second
|
||||
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
|
||||
/// dimensions are in the same order as the cardinality of the inputs.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device, Shape};
|
||||
/// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
|
||||
/// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
|
||||
///
|
||||
/// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
|
||||
///
|
||||
/// assert_eq!(grids_xy.len(), 2);
|
||||
/// assert_eq!(grids_xy[0].dims(), &[3, 3]);
|
||||
///
|
||||
/// assert_eq!(grids_xy[0].to_vec2::<f32>()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]);
|
||||
/// assert_eq!(grids_xy[1].to_vec2::<f32>()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
|
||||
///
|
||||
/// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
|
||||
///
|
||||
/// assert_eq!(grids_ij[0].to_vec2::<f32>()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]);
|
||||
/// assert_eq!(grids_ij[1].to_vec2::<f32>()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// * Will return `Err` if `args` contains less than 2 tensors.
|
||||
///
|
||||
pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
|
||||
if args.len() <= 1 {
|
||||
Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
|
||||
}
|
||||
let args: Vec<_> = if xy_indexing {
|
||||
args.iter().rev().collect()
|
||||
} else {
|
||||
args.iter().collect()
|
||||
};
|
||||
|
||||
let mut shape = Vec::with_capacity(args.len());
|
||||
for arg in args.iter() {
|
||||
shape.push(arg.as_ref().dims1()?)
|
||||
}
|
||||
|
||||
let mut grids = Vec::with_capacity(args.len());
|
||||
for idx in 0..args.len() {
|
||||
let mut ones = vec![1usize; args.len()];
|
||||
ones[idx] = shape[idx];
|
||||
let arg = args[idx].as_ref().reshape(ones)?;
|
||||
let mut repeats = shape.clone();
|
||||
repeats[idx] = 1;
|
||||
let repeated_tensor = arg.repeat(repeats)?;
|
||||
grids.push(repeated_tensor);
|
||||
}
|
||||
if xy_indexing {
|
||||
grids.reverse();
|
||||
}
|
||||
Ok(grids)
|
||||
}
|
||||
|
||||
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
|
||||
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
|
||||
/// be performed.
|
||||
@ -693,23 +611,15 @@ impl Tensor {
|
||||
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
|
||||
let dims = self.dims();
|
||||
let dim = dim.to_index(self.shape(), "narrow")?;
|
||||
let err = |msg| {
|
||||
Err::<(), _>(
|
||||
Error::NarrowInvalidArgs {
|
||||
shape: self.shape().clone(),
|
||||
dim,
|
||||
start,
|
||||
len,
|
||||
msg,
|
||||
}
|
||||
.bt(),
|
||||
)
|
||||
};
|
||||
if start > dims[dim] {
|
||||
err("start > dim_len")?
|
||||
}
|
||||
if start.saturating_add(len) > dims[dim] {
|
||||
err("start + len > dim_len")?
|
||||
if start + len > dims[dim] {
|
||||
Err(Error::NarrowInvalidArgs {
|
||||
shape: self.shape().clone(),
|
||||
dim,
|
||||
start,
|
||||
len,
|
||||
msg: "start + len > dim_len",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if start == 0 && dims[dim] == len {
|
||||
Ok(self.clone())
|
||||
@ -1197,16 +1107,14 @@ impl Tensor {
|
||||
op: "scatter-add (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
})?
|
||||
}
|
||||
if indexes.dims() != source.dims() {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (indexes, src)",
|
||||
lhs: indexes.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
})?
|
||||
}
|
||||
let storage = self.storage().scatter_add(
|
||||
self.layout(),
|
||||
@ -1222,75 +1130,6 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
|
||||
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "slice-scatter")?;
|
||||
if dim == 0 {
|
||||
self.slice_scatter0(src, start)
|
||||
} else {
|
||||
// TODO: Maybe we want to add a more efficient implementation at some point.
|
||||
self.transpose(0, dim)?
|
||||
.slice_scatter0(&src.transpose(0, dim)?, start)?
|
||||
.transpose(0, dim)
|
||||
}
|
||||
}
|
||||
|
||||
/// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
|
||||
pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
|
||||
if self.dtype() != src.dtype() {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: src.dtype(),
|
||||
op: "slice-scatter",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.device().location() != src.device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: self.device().location(),
|
||||
rhs: src.device().location(),
|
||||
op: "slice-scatter",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.rank() != src.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: self.rank(),
|
||||
got: src.rank(),
|
||||
shape: src.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let shape_ok =
|
||||
self.dims()
|
||||
.iter()
|
||||
.zip(src.dims().iter())
|
||||
.enumerate()
|
||||
.all(|(dim_idx, (&d1, &d2))| {
|
||||
if 0 == dim_idx {
|
||||
d2 + start <= d1
|
||||
} else {
|
||||
d1 == d2
|
||||
}
|
||||
});
|
||||
if !shape_ok {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "slice-scatter (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: src.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let offset = start * src.dims()[1..].iter().product::<usize>();
|
||||
src.storage()
|
||||
.copy_strided_src(&mut storage, offset, src.layout())?;
|
||||
let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
|
||||
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-add")?;
|
||||
@ -1313,8 +1152,7 @@ impl Tensor {
|
||||
op: "index-add (self, source)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
})?
|
||||
}
|
||||
// The number of element in indexes must match the dimension on which the add is
|
||||
// performed on the source tensor (and the index values from `indexes` are taken from
|
||||
@ -1325,8 +1163,7 @@ impl Tensor {
|
||||
op: "index-add (ids, source))",
|
||||
lhs: indexes.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
})?
|
||||
}
|
||||
let storage = self.storage().index_add(
|
||||
self.layout(),
|
||||
@ -1374,8 +1211,7 @@ impl Tensor {
|
||||
op: "gather",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: indexes.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
})?
|
||||
}
|
||||
let storage =
|
||||
self.storage()
|
||||
@ -1449,7 +1285,6 @@ impl Tensor {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1480,7 +1315,6 @@ impl Tensor {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1521,7 +1355,6 @@ impl Tensor {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1685,24 +1518,6 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the sub-tensor fixing the index at `index` on the dimension `dim`.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let t = tensor.get_on_dim(1, 0)?;
|
||||
/// assert_eq!(t.to_vec1::<f32>()?, &[0., 2., 4.]);
|
||||
/// let t = tensor.get_on_dim(1, 1)?;
|
||||
/// assert_eq!(t.to_vec1::<f32>()?, &[1., 3., 5.]);
|
||||
/// let t = tensor.get_on_dim(0, 1)?;
|
||||
/// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
|
||||
let dim = dim.to_index(self.shape(), "get_on_dim")?;
|
||||
self.narrow(dim, index, 1)?.squeeze(dim)
|
||||
}
|
||||
|
||||
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
|
||||
/// input are swapped.
|
||||
///
|
||||
@ -1731,9 +1546,6 @@ impl Tensor {
|
||||
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
|
||||
let dim1 = dim1.to_index(self.shape(), "transpose")?;
|
||||
let dim2 = dim2.to_index(self.shape(), "transpose")?;
|
||||
if dim1 == dim2 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
@ -1841,9 +1653,6 @@ impl Tensor {
|
||||
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
|
||||
}
|
||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||
_ => {
|
||||
bail!("not implemented yet")
|
||||
}
|
||||
};
|
||||
let op = BackpropOp::new1(self, Op::ToDevice);
|
||||
let tensor_ = Tensor_ {
|
||||
@ -2098,34 +1907,6 @@ impl Tensor {
|
||||
for arg in args {
|
||||
arg.as_ref().check_dim(dim, "cat")?;
|
||||
}
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg0.rank() != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: arg0.rank(),
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
if dim == 0 {
|
||||
Self::cat0(args)
|
||||
} else {
|
||||
@ -2243,56 +2024,11 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Pad the input tensor using same values along dimension `dim`. This adds `left` elements before the
|
||||
/// input tensor values and `right` elements after.
|
||||
pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
||||
if left == 0 && right == 0 {
|
||||
Ok(self.clone())
|
||||
} else if self.elem_count() == 0 {
|
||||
crate::bail!("cannot use pad_with_same on an empty tensor")
|
||||
} else if left == 0 {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
|
||||
let mut v = vec![self];
|
||||
for _ in 0..right {
|
||||
v.push(&r)
|
||||
}
|
||||
Tensor::cat(&v, dim)
|
||||
} else if right == 0 {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||
let l = self.narrow(dim, 0, 1)?;
|
||||
let mut v = vec![];
|
||||
for _ in 0..left {
|
||||
v.push(&l)
|
||||
}
|
||||
v.push(self);
|
||||
Tensor::cat(&v, dim)
|
||||
} else {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||
let l = self.narrow(dim, 0, 1)?;
|
||||
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
|
||||
let mut v = vec![];
|
||||
for _ in 0..left {
|
||||
v.push(&l)
|
||||
}
|
||||
v.push(self);
|
||||
for _ in 0..right {
|
||||
v.push(&r)
|
||||
}
|
||||
Tensor::cat(&v, dim)
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the `forward` method of `m` on `self`.
|
||||
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
|
||||
m.forward(self)
|
||||
}
|
||||
|
||||
/// Run the `forward` method of `m` on `self`.
|
||||
pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
|
||||
m.forward_t(self, train)
|
||||
}
|
||||
|
||||
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
self.storage.read().unwrap()
|
||||
}
|
||||
|
@ -23,10 +23,6 @@ pub fn cuda_is_available() -> bool {
|
||||
cfg!(feature = "cuda")
|
||||
}
|
||||
|
||||
pub fn metal_is_available() -> bool {
|
||||
cfg!(feature = "metal")
|
||||
}
|
||||
|
||||
pub fn with_avx() -> bool {
|
||||
cfg!(target_feature = "avx")
|
||||
}
|
||||
|
@ -479,71 +479,6 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
// Replicate the issue from https://github.com/huggingface/candle/issues/1212
|
||||
let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32);
|
||||
let grads = loss.backward()?;
|
||||
let grad_t = grads.get(&t).unwrap();
|
||||
let grad_w = grads.get(&w).unwrap();
|
||||
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[9.29, -7.03, 7.87, 0.0, 0.0],
|
||||
[-1.8, -7.82, 5.9, 0.0, 0.0],
|
||||
[-3.12, 4.49, 5.52, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[21.73, 3.39, 4.77, 0.0, 0.0],
|
||||
[8.25, 3.73, 27.61, 0.0, 0.0],
|
||||
[-20.55, -5.61, -2.77, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[-8.98, 9.91, -7.15, 0.0, 0.0],
|
||||
[4.93, -0.33, 4.56, 0.0, 0.0],
|
||||
[-6.7, -5.76, -8.05, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[23.54, 6.98, -10.0, 0.0, 0.0],
|
||||
[9.65, 6.18, 18.72, 0.0, 0.0],
|
||||
[3.29, -5.27, 0.79, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[-3.47, 7.44, 0.66],
|
||||
[12.89, -3.4, -9.29],
|
||||
[-14.16, -0.83, 7.14]
|
||||
],
|
||||
[
|
||||
[-3.23, 5.37, -3.02],
|
||||
[-2.12, -11.24, 1.94],
|
||||
[6.97, 7.2, 2.99]
|
||||
],
|
||||
[
|
||||
[-4.04, -3.31, 4.87],
|
||||
[-6.68, -5.68, 1.73],
|
||||
[-5.54, 4.32, 0.52]
|
||||
],
|
||||
[[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -192,19 +192,6 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(grad_x, 2)?,
|
||||
[0.01, 0.42, 0.0, 0.98],
|
||||
);
|
||||
|
||||
// testing compared to pytorch nn.GELU(approximate = 'tanh')
|
||||
let y = x.gelu()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.9964, 0.8412, 3.9999, 0.0839]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0116, 1.0830, 1.0003, 0.6188],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -231,22 +218,6 @@ fn binary_grad(device: &Device) -> Result<()> {
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
|
||||
|
||||
let x_var = Var::new(&[3f32, 1., -4., -1., 5., 9.], device)?;
|
||||
let x = x_var.as_tensor();
|
||||
let y_var = Var::new(&[2f32, 7., 1.], device)?;
|
||||
let y = y_var.as_tensor();
|
||||
|
||||
let ss = x
|
||||
.reshape((2, 3))?
|
||||
.slice_scatter0(&y.reshape((1, 3))?, 1)?
|
||||
.sqr()?;
|
||||
let grads = ss.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
let grad_y = grads.get(y).context("no grad for y")?;
|
||||
assert_eq!(ss.to_vec2::<f32>()?, [[9., 1., 16.], [4., 49., 1.]]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, -8.0, 0.0, 0.0, 0.0]);
|
||||
assert_eq!(grad_y.to_vec1::<f32>()?, [4.0, 14.0, 2.0]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -1,9 +0,0 @@
|
||||
import numpy as np
|
||||
x = np.arange(10)
|
||||
|
||||
# Write a npy file.
|
||||
np.save("test.npy", x)
|
||||
|
||||
# Write multiple values to a npz file.
|
||||
values = { "x": x, "x_plus_one": x + 1 }
|
||||
np.savez("test.npz", **values)
|
@ -43,7 +43,7 @@ fn quantized_matmul() -> Result<()> {
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
@ -91,7 +91,7 @@ fn quantized_matmul_neg() -> Result<()> {
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
@ -491,9 +491,6 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
|
||||
GgmlDType::Q5_0 => 0.001353,
|
||||
GgmlDType::Q5_1 => 0.001363,
|
||||
GgmlDType::Q8_0 => 0.000092,
|
||||
|
||||
// Not from the ggml repo.
|
||||
GgmlDType::Q8K => 0.00065,
|
||||
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
||||
};
|
||||
Ok(err)
|
||||
@ -511,22 +508,17 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||
T::VecDotType::from_float(&b, &mut b_quant)?;
|
||||
|
||||
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
||||
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
|
||||
let reference_result = vec_dot_reference(&a, &b);
|
||||
|
||||
if (result - result_unopt).abs() / length as f32 > 1e-6 {
|
||||
candle_core::bail!(
|
||||
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
|
||||
)
|
||||
}
|
||||
|
||||
let error = (result - reference_result).abs() / length as f32;
|
||||
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
||||
|
||||
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
if error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
candle_core::bail!(
|
||||
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
|
||||
"Dot product error {} exceeds max error {}",
|
||||
error,
|
||||
GGML_MAX_DOT_PRODUCT_ERROR
|
||||
);
|
||||
}
|
||||
|
||||
@ -579,7 +571,7 @@ fn quantized_matmul_q2k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -605,7 +597,7 @@ fn quantized_matmul_q3k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -631,7 +623,7 @@ fn quantized_matmul_q4k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -657,7 +649,7 @@ fn quantized_matmul_q5k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -684,7 +676,7 @@ fn quantized_matmul_q6k() -> Result<()> {
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs);
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
@ -695,28 +687,3 @@ fn quantized_matmul_q6k() -> Result<()> {
|
||||
ggml_matmul_error_test::<BlockQ6K>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q8k() -> Result<()> {
|
||||
use k_quants::BlockQ8K;
|
||||
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (11, 512, 21);
|
||||
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
assert_eq!(mm.dims(), [m, n]);
|
||||
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.266, 1.504, -0.204, 1.7]);
|
||||
|
||||
ggml_matmul_error_test::<BlockQ8K>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,24 +0,0 @@
|
||||
use candle_core::{DType, Result, Tensor};
|
||||
|
||||
#[test]
|
||||
fn npy() -> Result<()> {
|
||||
let npy = Tensor::read_npy("tests/test.npy")?;
|
||||
assert_eq!(
|
||||
npy.to_dtype(DType::U8)?.to_vec1::<u8>()?,
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npz() -> Result<()> {
|
||||
let npz = Tensor::read_npz("tests/test.npz")?;
|
||||
assert_eq!(npz.len(), 2);
|
||||
assert_eq!(npz[0].0, "x");
|
||||
assert_eq!(npz[1].0, "x_plus_one");
|
||||
assert_eq!(
|
||||
npz[1].1.to_dtype(DType::U8)?.to_vec1::<u8>()?,
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
);
|
||||
Ok(())
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn zeros(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||
@ -8,50 +8,6 @@ fn zeros(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ones(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::U8, device)?.to_vec2::<u8>()?,
|
||||
[[1, 1, 1], [1, 1, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::U32, device)?.to_vec2::<u32>()?,
|
||||
[[1, 1, 1], [1, 1, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::I64, device)?.to_vec2::<i64>()?,
|
||||
[[1, 1, 1], [1, 1, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn arange(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
|
||||
[0, 1, 2, 3, 4],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::<u8>()?,
|
||||
[0, 2, 4],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::<u8>()?,
|
||||
[0, 3],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::<i64>()?,
|
||||
[5, 4, 3, 2, 1],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn add_mul(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||
let dim1 = tensor.dims1()?;
|
||||
@ -88,54 +44,6 @@ fn clamp(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.gelu()?, 4)?,
|
||||
[
|
||||
[-0.0036, 0.8412, 3.9999, -0.046, 0.3457],
|
||||
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
||||
[
|
||||
[-0.004, 0.8413, 3.9999, -0.046, 0.3457],
|
||||
[2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.erf()?, 4)?,
|
||||
[
|
||||
[-1.0, 0.8427, 1.0, -0.1125, 0.5205],
|
||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.floor()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -1.0, 0.0], [2.0, -2.0, -1.0, 1.0, 2.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.round()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -2.0, -0.0, 2.0, 3.0]]
|
||||
);
|
||||
let tensor = Tensor::new(&[2997.9246, 314.15926f32], device)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&tensor.round_to(2)?, 4)?,
|
||||
[2997.92, 314.16]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
||||
[3000.0, 300.]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn binary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor1 = Tensor::new(data, device)?;
|
||||
@ -693,30 +601,6 @@ fn index_select(device: &Device) -> Result<()> {
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||
);
|
||||
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
]
|
||||
);
|
||||
|
||||
// Test when selecting dim > 0 with ids size different from elem count of
|
||||
// target dim in source/input.
|
||||
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -763,48 +647,6 @@ fn index_add(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn slice_scatter(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
let src = Tensor::arange(100f32, 106f32, device)?.reshape((2, 3))?;
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 0)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 1)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 2)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scatter_add(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
@ -1055,8 +897,6 @@ fn randn(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||
test_device!(ones, ones_cpu, ones_gpu);
|
||||
test_device!(arange, arange_cpu, arange_gpu);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||
@ -1068,7 +908,6 @@ test_device!(max, max_cpu, max_gpu);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||
@ -1079,7 +918,6 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||
test_device!(gather, gather_cpu, gather_gpu);
|
||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
||||
test_device!(randn, randn_cpu, randn_gpu);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu);
|
||||
|
||||
@ -1093,27 +931,3 @@ fn randn_hasneg() -> Result<()> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pad_with_same() -> Result<()> {
|
||||
let t = Tensor::arange(1f32, 5f32, &Device::Cpu)?.reshape((2, 2))?;
|
||||
let t0 = t.pad_with_same(0, 1, 2)?;
|
||||
assert_eq!(
|
||||
t0.to_vec2::<f32>()?,
|
||||
[[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]]
|
||||
);
|
||||
let t1 = t.pad_with_same(1, 1, 2)?;
|
||||
assert_eq!(
|
||||
t1.to_vec2::<f32>()?,
|
||||
[[1.0, 1.0, 2.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0, 4.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn i64_abs() -> Result<()> {
|
||||
let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?;
|
||||
let t = t.abs()?;
|
||||
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
||||
Ok(())
|
||||
}
|
||||
|
Binary file not shown.
Binary file not shown.
@ -11,8 +11,8 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
|
@ -11,22 +11,20 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -37,6 +35,7 @@ imageproc = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rusttype = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
@ -51,16 +50,11 @@ anyhow = { workspace = true }
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
required-features = ["cuda", "nccl", "flash-attn"]
|
||||
|
||||
[[example]]
|
||||
name = "reinforcement-learning"
|
||||
required-features = ["pyo3"]
|
||||
|
@ -5,11 +5,11 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use candle::Tensor;
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -19,6 +19,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Run offline (you must have the files already cached)
|
||||
#[arg(long)]
|
||||
offline: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
@ -34,10 +38,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// Use the pytorch weights rather than the safetensors ones
|
||||
#[arg(long)]
|
||||
use_pth: bool,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
@ -60,27 +60,35 @@ impl Args {
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
|
||||
let cache = Cache::default().repo(repo);
|
||||
(
|
||||
cache
|
||||
.get("config.json")
|
||||
.ok_or(anyhow!("Missing config file in cache"))?,
|
||||
cache
|
||||
.get("tokenizer.json")
|
||||
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
|
||||
cache
|
||||
.get("model.safetensors")
|
||||
.ok_or(anyhow!("Missing weights file in cache"))?,
|
||||
)
|
||||
} else {
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config = api.get("config.json")?;
|
||||
let tokenizer = api.get("tokenizer.json")?;
|
||||
let weights = if self.use_pth {
|
||||
api.get("pytorch_model.bin")?
|
||||
} else {
|
||||
api.get("model.safetensors")?
|
||||
};
|
||||
(config, tokenizer, weights)
|
||||
(
|
||||
api.get("config.json")?,
|
||||
api.get("tokenizer.json")?,
|
||||
api.get("model.safetensors")?,
|
||||
)
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let vb = if self.use_pth {
|
||||
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
||||
} else {
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||
};
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||
let model = BertModel::load(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
|
@ -138,9 +138,18 @@ fn main() -> Result<()> {
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let weights = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|f| Ok(f.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
|
||||
let config = Config::starcoder_1b();
|
||||
let model = GPTBigCode::load(vb, config)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
@ -1,19 +0,0 @@
|
||||
# candle-blip
|
||||
|
||||
The
|
||||
[blip-image-captioning](https://huggingface.co/Salesforce/blip-image-captioning-base)
|
||||
model can generate captions for an input image.
|
||||
|
||||
## Running on an example
|
||||
|
||||
```bash
|
||||
cargo run --example blip --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
```
|
||||
|
||||
```
|
||||
Running on CPU, to run on GPU, build this example with `--features cuda`
|
||||
loaded image Tensor[dims 3, 384, 384; f32]
|
||||
model built
|
||||
several cyclists are riding down a road with cars behind them%
|
||||
```
|
||||

|
@ -1,154 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::blip;
|
||||
use candle_transformers::models::quantized_blip;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
M(blip::BlipForConditionalGeneration),
|
||||
Q(quantized_blip::BlipForConditionalGeneration),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn text_decoder_forward(&mut self, xs: &Tensor, img_xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::M(m) => m.text_decoder().forward(xs, img_xs),
|
||||
Self::Q(m) => m.text_decoder().forward(xs, img_xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Maybe add support for the conditional prompt.
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use the quantized version of the model.
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
}
|
||||
|
||||
const SEP_TOKEN_ID: u32 = 102;
|
||||
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 384, 384). OpenAI normalization is applied.
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
let img = image::io::Reader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?
|
||||
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
let mean =
|
||||
Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)?
|
||||
.reshape((3, 1, 1))?;
|
||||
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||
.broadcast_sub(&mean)?
|
||||
.broadcast_div(&std)
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
if args.quantized {
|
||||
let api = api.model("lmz/candle-blip".to_string());
|
||||
api.get("blip-image-captioning-large-q4k.gguf")?
|
||||
} else {
|
||||
let api = api.repo(hf_hub::Repo::with_revision(
|
||||
"Salesforce/blip-image-captioning-large".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/18".to_string(),
|
||||
));
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let tokenizer = match args.tokenizer {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("Salesforce/blip-image-captioning-large".to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
Some(file) => file.into(),
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let mut tokenizer = TokenOutputStream::new(tokenizer);
|
||||
let mut logits_processor =
|
||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||
|
||||
let config = blip::Config::image_captioning_large();
|
||||
|
||||
let (image_embeds, device, mut model) = if args.quantized {
|
||||
let device = Device::Cpu;
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
|
||||
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||
(image_embeds, device, Model::Q(model))
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||
(image_embeds, device, Model::M(model))
|
||||
};
|
||||
|
||||
let mut token_ids = vec![30522u32];
|
||||
for index in 0..1000 {
|
||||
let context_size = if index > 0 { 1 } else { token_ids.len() };
|
||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||
let logits = model.text_decoder_forward(&input_ids, &image_embeds)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
if token == SEP_TOKEN_ID {
|
||||
break;
|
||||
}
|
||||
token_ids.push(token);
|
||||
if let Some(t) = tokenizer.next_token(token)? {
|
||||
use std::io::Write;
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
println!();
|
||||
Ok(())
|
||||
}
|
@ -1,59 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::convmixer;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-convmixer".into());
|
||||
api.get("convmixer_1024_20_ks9_p14.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = convmixer::c1024_20(1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -42,7 +42,9 @@ pub fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let model = dinov2::vit_small(vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
|
@ -68,7 +68,9 @@ pub fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
let cfg = match args.which {
|
||||
Which::B0 => MBConvConfig::b0(),
|
||||
Which::B1 => MBConvConfig::b1(),
|
||||
|
@ -177,12 +177,21 @@ fn main() -> Result<()> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let weights = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|f| Ok(f.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let dtype = if args.use_f32 {
|
||||
DType::F32
|
||||
} else {
|
||||
DType::BF16
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let vb = VarBuilder::from_safetensors(weights, dtype, &device);
|
||||
let config = Config::falcon7b();
|
||||
config.validate()?;
|
||||
let model = Falcon::load(vb, config)?;
|
||||
|
@ -1,45 +0,0 @@
|
||||
# candle-jina-bert
|
||||
|
||||
Jina-Bert is a general large language model with a context size of 8192, [model
|
||||
card](https://huggingface.co/jinaai/jina-embeddings-v2-base-en). In this example
|
||||
it can be used for two different tasks:
|
||||
- Compute sentence embeddings for a prompt.
|
||||
- Compute similarities between a set of sentences.
|
||||
|
||||
|
||||
## Sentence embeddings
|
||||
|
||||
Jina-Bert is used to compute the sentence embeddings for a prompt. The model weights
|
||||
are downloaded from the hub on the first run.
|
||||
|
||||
```bash
|
||||
cargo run --example jina-bert --release -- --prompt "Here is a test sentence"
|
||||
|
||||
> [[[ 0.1595, -0.9885, 0.6494, ..., 0.3003, -0.6901, -1.2355],
|
||||
> [ 0.0374, -0.1798, 1.3359, ..., 0.6731, 0.2133, -1.6807],
|
||||
> [ 0.1700, -0.8534, 0.8924, ..., -0.1785, -0.0727, -1.5087],
|
||||
> ...
|
||||
> [-0.3113, -1.3665, 0.2027, ..., -0.2519, 0.1711, -1.5811],
|
||||
> [ 0.0907, -1.0492, 0.5382, ..., 0.0242, -0.7077, -1.0830],
|
||||
> [ 0.0369, -0.6343, 0.6105, ..., 0.0671, 0.3778, -1.1505]]]
|
||||
> Tensor[[1, 7, 768], f32]
|
||||
```
|
||||
|
||||
## Similarities
|
||||
|
||||
In this example, Jina-Bert is used to compute the sentence embeddings for a set of
|
||||
sentences (hardcoded in the examples). Then cosine similarities are computed for
|
||||
each sentence pair and they are reported by decreasing values, hence the first
|
||||
reported pair contains the two sentences that have the highest similarity score.
|
||||
The sentence embeddings are computed using average pooling through all the
|
||||
sentence tokens, including some potential padding.
|
||||
|
||||
```bash
|
||||
cargo run --example jina-bert --release
|
||||
|
||||
> score: 0.94 'The new movie is awesome' 'The new movie is so great'
|
||||
> score: 0.81 'The cat sits outside' 'The cat plays in the garden'
|
||||
> score: 0.78 'I love pasta' 'Do you like pizza?'
|
||||
> score: 0.68 'I love pasta' 'The new movie is awesome'
|
||||
> score: 0.67 'A man is playing guitar' 'A woman watches TV'
|
||||
```
|
@ -1,180 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle_transformers::models::jina_bert::{BertModel, Config};
|
||||
|
||||
use anyhow::Error as E;
|
||||
use candle::{DType, Module, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// When set, compute embeddings for this prompt.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
let model = match &self.model {
|
||||
Some(model_file) => std::path::PathBuf::from(model_file),
|
||||
None => Api::new()?
|
||||
.repo(Repo::new(
|
||||
"jinaai/jina-embeddings-v2-base-en".to_string(),
|
||||
RepoType::Model,
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let tokenizer = match &self.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => Api::new()?
|
||||
.repo(Repo::new(
|
||||
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
||||
RepoType::Model,
|
||||
))
|
||||
.get("tokenizer.json")?,
|
||||
};
|
||||
let device = candle_examples::device(self.cpu)?;
|
||||
let config = Config::v2_base();
|
||||
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let model = BertModel::new(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = &model.device;
|
||||
|
||||
if let Some(prompt) = args.prompt {
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
println!("Loaded and encoded {:?}", start.elapsed());
|
||||
for idx in 0..args.n {
|
||||
let start = std::time::Instant::now();
|
||||
let ys = model.forward(&token_ids)?;
|
||||
if idx == 0 {
|
||||
println!("{ys}");
|
||||
}
|
||||
println!("Took {:?}", start.elapsed());
|
||||
}
|
||||
} else {
|
||||
let sentences = [
|
||||
"The cat sits outside",
|
||||
"A man is playing guitar",
|
||||
"I love pasta",
|
||||
"The new movie is awesome",
|
||||
"The cat plays in the garden",
|
||||
"A woman watches TV",
|
||||
"The new movie is so great",
|
||||
"Do you like pizza?",
|
||||
];
|
||||
let n_sentences = sentences.len();
|
||||
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||
} else {
|
||||
let pp = tokenizers::PaddingParams {
|
||||
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||
..Default::default()
|
||||
};
|
||||
tokenizer.with_padding(Some(pp));
|
||||
}
|
||||
let tokens = tokenizer
|
||||
.encode_batch(sentences.to_vec(), true)
|
||||
.map_err(E::msg)?;
|
||||
let token_ids = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_ids().to_vec();
|
||||
Tensor::new(tokens.as_slice(), device)
|
||||
})
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||
println!("running inference on batch {:?}", token_ids.shape());
|
||||
let embeddings = model.forward(&token_ids)?;
|
||||
println!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
let embeddings = if args.normalize_embeddings {
|
||||
normalize_l2(&embeddings)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
println!("pooled embeddings {:?}", embeddings.shape());
|
||||
|
||||
let mut similarities = vec![];
|
||||
for i in 0..n_sentences {
|
||||
let e_i = embeddings.get(i)?;
|
||||
for j in (i + 1)..n_sentences {
|
||||
let e_j = embeddings.get(j)?;
|
||||
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
||||
similarities.push((cosine_similarity, i, j))
|
||||
}
|
||||
}
|
||||
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
|
||||
for &(score, i, j) in similarities[..5].iter() {
|
||||
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn normalize_l2(v: &Tensor) -> candle::Result<Tensor> {
|
||||
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
||||
}
|
@ -172,9 +172,17 @@ fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
println!("building the model");
|
||||
let handles = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let tensors: Vec<_> = handles
|
||||
.iter()
|
||||
.map(|h| Ok(h.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||
}
|
||||
};
|
||||
|
@ -6,10 +6,9 @@ extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use candle_transformers::models::llama2_c as model;
|
||||
use candle_transformers::models::llama2_c_weights as weights;
|
||||
use candle_transformers::models::quantized_llama2_c as qmodel;
|
||||
mod model;
|
||||
mod training;
|
||||
mod weights;
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
@ -20,7 +19,6 @@ use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use model::{Config, Llama};
|
||||
use qmodel::QLlama;
|
||||
use weights::TransformerWeights;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
@ -154,20 +152,6 @@ fn main() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
enum Model {
|
||||
Llama(Llama),
|
||||
QLlama(QLlama),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
|
||||
match self {
|
||||
Self::Llama(l) => Ok(l.forward(xs, pos)?),
|
||||
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
use std::io::BufRead;
|
||||
|
||||
@ -257,66 +241,24 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
|
||||
let device = candle_examples::device(common_args.cpu)?;
|
||||
|
||||
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
|
||||
let is_safetensors = config_path
|
||||
.extension()
|
||||
.map_or(false, |v| v == "safetensors");
|
||||
let (model, config) = if is_gguf {
|
||||
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
|
||||
let (_vocab_size, dim) = vb
|
||||
.get_no_shape("model.embed_tokens.weight")?
|
||||
.shape()
|
||||
.dims2()?;
|
||||
let config = match dim {
|
||||
64 => Config::tiny_260k(),
|
||||
288 => Config::tiny_15m(),
|
||||
512 => Config::tiny_42m(),
|
||||
768 => Config::tiny_110m(),
|
||||
_ => anyhow::bail!("no config for dim {dim}"),
|
||||
};
|
||||
let freq_cis_real = vb
|
||||
.get(
|
||||
(config.seq_len, config.head_size() / 2),
|
||||
"rot.freq_cis_real",
|
||||
)?
|
||||
.dequantize(&candle::Device::Cpu)?;
|
||||
let freq_cis_imag = vb
|
||||
.get(
|
||||
(config.seq_len, config.head_size() / 2),
|
||||
"rot.freq_cis_imag",
|
||||
)?
|
||||
.dequantize(&candle::Device::Cpu)?;
|
||||
|
||||
let fake_vb = candle_nn::VarBuilder::from_tensors(
|
||||
[
|
||||
("freq_cis_real".to_string(), freq_cis_real),
|
||||
("freq_cis_imag".to_string(), freq_cis_imag),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
candle::DType::F32,
|
||||
&candle::Device::Cpu,
|
||||
);
|
||||
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
} else if is_safetensors {
|
||||
let config = Config::tiny_15m();
|
||||
let (vb, config) = if is_safetensors {
|
||||
let config = Config::tiny();
|
||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
(vb, config)
|
||||
} else {
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
println!("{config:?}");
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
(vb, config)
|
||||
};
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
|
||||
@ -331,7 +273,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0.. {
|
||||
if tokens.len() >= config.seq_len {
|
||||
if tokens.len() >= model.config.seq_len {
|
||||
break;
|
||||
}
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
|
@ -17,20 +17,7 @@ pub struct Config {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn tiny_260k() -> Self {
|
||||
Self {
|
||||
dim: 64,
|
||||
hidden_dim: 768,
|
||||
n_layers: 5,
|
||||
n_heads: 8,
|
||||
n_kv_heads: 4,
|
||||
vocab_size: 32000,
|
||||
seq_len: 512,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_15m() -> Self {
|
||||
pub fn tiny() -> Self {
|
||||
Self {
|
||||
dim: 288,
|
||||
hidden_dim: 768,
|
||||
@ -42,32 +29,6 @@ impl Config {
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_42m() -> Self {
|
||||
Self {
|
||||
dim: 512,
|
||||
hidden_dim: 768,
|
||||
n_layers: 8,
|
||||
n_heads: 8,
|
||||
n_kv_heads: 8,
|
||||
vocab_size: 32000,
|
||||
seq_len: 1024,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_110m() -> Self {
|
||||
Self {
|
||||
dim: 768,
|
||||
hidden_dim: 768,
|
||||
n_layers: 12,
|
||||
n_heads: 12,
|
||||
n_kv_heads: 12,
|
||||
vocab_size: 32000,
|
||||
seq_len: 1024,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -75,9 +36,9 @@ pub struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
pub use_kv_cache: bool,
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
pub cos: Tensor,
|
||||
pub sin: Tensor,
|
||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
@ -114,7 +75,7 @@ impl Cache {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn mask(&self, t: usize) -> Result<Tensor> {
|
||||
fn mask(&self, t: usize) -> Result<Tensor> {
|
||||
let mut masks = self.masks.lock().unwrap();
|
||||
if let Some(mask) = masks.get(&t) {
|
||||
Ok(mask.clone())
|
@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
);
|
||||
let varmap = candle_nn::VarMap::new();
|
||||
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
||||
let config = Config::tiny_15m();
|
||||
let config = Config::tiny();
|
||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
use anyhow::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use candle::{DType, Device, IndexOp, Result, Shape, Tensor};
|
||||
use candle::{DType, Device, IndexOp, Shape, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
use super::llama2_c::Config;
|
||||
use crate::model::Config;
|
||||
|
||||
pub struct TransformerWeights {
|
||||
// token embedding table
|
@ -89,10 +89,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
@ -205,9 +201,16 @@ fn main() -> Result<()> {
|
||||
let cache = model::Cache::new(dtype, &config, &device)?;
|
||||
|
||||
println!("building the model");
|
||||
let vb = unsafe {
|
||||
candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)?
|
||||
};
|
||||
let handles = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let tensors: Vec<_> = handles
|
||||
.iter()
|
||||
.map(|h| Ok(h.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device);
|
||||
let llama = Llama::load(vb, &cache, &config, comm)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
@ -219,7 +222,7 @@ fn main() -> Result<()> {
|
||||
.to_vec();
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
let mut new_tokens = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
|
@ -1,38 +0,0 @@
|
||||
# candle-marian-mt
|
||||
|
||||
`marian-mt` is a neural machine translation model. In this example it is used to
|
||||
translate text from French to English. See the associated [model
|
||||
card](https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en) for details on
|
||||
the model itself.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
cargo run --example marian-mt --release -- \
|
||||
--text "Demain, dès l'aube, à l'heure où blanchit la campagne, Je partirai. Vois-tu, je sais que tu m'attends. J'irai par la forêt, j'irai par la montagne. Je ne puis demeurer loin de toi plus longtemps."
|
||||
```
|
||||
|
||||
```
|
||||
<NIL> Tomorrow, at dawn, at the time when the country is whitening, I will go. See,
|
||||
I know you are waiting for me. I will go through the forest, I will go through the
|
||||
mountain. I cannot stay far from you any longer.</s>
|
||||
```
|
||||
|
||||
## Generating the tokenizer.json files
|
||||
|
||||
You can use the following script to generate the `tokenizer.json` config files
|
||||
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
|
||||
packages to be install and use the `convert_slow_tokenizer.py` script from this
|
||||
directory.
|
||||
|
||||
```python
|
||||
from convert_slow_tokenizer import MarianConverter
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
|
||||
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
|
||||
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
|
||||
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
|
||||
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
|
||||
```
|
File diff suppressed because it is too large
Load Diff
@ -1,152 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::marian;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
Base,
|
||||
Big,
|
||||
}
|
||||
|
||||
// TODO: Maybe add support for the conditional prompt.
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_dec: Option<String>,
|
||||
|
||||
/// Choose the variant of the model to run.
|
||||
#[arg(long, default_value = "big")]
|
||||
which: Which,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use the quantized version of the model.
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Text to be translated
|
||||
#[arg(long)]
|
||||
text: String,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use hf_hub::api::sync::Api;
|
||||
let args = Args::parse();
|
||||
|
||||
let config = match args.which {
|
||||
Which::Base => marian::Config::opus_mt_fr_en(),
|
||||
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
|
||||
};
|
||||
let tokenizer = {
|
||||
let tokenizer = match args.tokenizer {
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::Base => "tokenizer-marian-base-fr.json",
|
||||
Which::Big => "tokenizer-marian-fr.json",
|
||||
};
|
||||
Api::new()?
|
||||
.model("lmz/candle-marian".to_string())
|
||||
.get(name)?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
};
|
||||
|
||||
let tokenizer_dec = {
|
||||
let tokenizer = match args.tokenizer_dec {
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::Base => "tokenizer-marian-base-en.json",
|
||||
Which::Big => "tokenizer-marian-en.json",
|
||||
};
|
||||
Api::new()?
|
||||
.model("lmz/candle-marian".to_string())
|
||||
.get(name)?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
};
|
||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = {
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => match args.which {
|
||||
Which::Base => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-fr-en".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/4".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
Which::Big => Api::new()?
|
||||
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
||||
.get("model.safetensors")?,
|
||||
},
|
||||
};
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
||||
};
|
||||
let mut model = marian::MTModel::new(&config, vb)?;
|
||||
|
||||
let mut logits_processor =
|
||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||
|
||||
let encoder_xs = {
|
||||
let mut tokens = tokenizer
|
||||
.encode(args.text, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
tokens.push(config.eos_token_id);
|
||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
model.encoder().forward(&tokens, 0)?
|
||||
};
|
||||
|
||||
let mut token_ids = vec![config.decoder_start_token_id];
|
||||
for index in 0..1000 {
|
||||
let context_size = if index >= 1 { 1 } else { token_ids.len() };
|
||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
token_ids.push(token);
|
||||
if let Some(t) = tokenizer_dec.next_token(token)? {
|
||||
use std::io::Write;
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
if token == config.eos_token_id || token == config.forced_eos_token_id {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
println!();
|
||||
Ok(())
|
||||
}
|
@ -1,90 +0,0 @@
|
||||
# candle-mistral: 7b LLM with Apache 2.0 licensed weights
|
||||
|
||||
Mistral-7B-v0.1 is a pretrained generative LLM with 7 billion parameters. It outperforms all the publicly available 13b models
|
||||
as of 2023-09-28. Weights (and the original Python model code) are released under the permissive Apache 2.0 license.
|
||||
|
||||
- [Blog post](https://mistral.ai/news/announcing-mistral-7b/) from Mistral announcing the model release.
|
||||
- [Model card](https://huggingface.co/mistralai/Mistral-7B-v0.1) on the
|
||||
HuggingFace Hub.
|
||||
This example supports the initial model as well as a quantized variant.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example mistral --release --features cuda -- --prompt 'Write helloworld code in Rust' --sample-len 150
|
||||
|
||||
Generated text:
|
||||
Write helloworld code in Rust
|
||||
=============================
|
||||
|
||||
This is a simple example of how to write "Hello, world!" program in Rust.
|
||||
|
||||
## Compile and run
|
||||
|
||||
``bash
|
||||
$ cargo build --release
|
||||
Compiling hello-world v0.1.0 (/home/user/rust/hello-world)
|
||||
Finished release [optimized] target(s) in 0.26s
|
||||
$ ./target/release/hello-world
|
||||
Hello, world!
|
||||
``
|
||||
|
||||
## Source code
|
||||
|
||||
``rust
|
||||
fn main() {
|
||||
println!("Hello, world!");
|
||||
}
|
||||
``
|
||||
|
||||
## License
|
||||
|
||||
This example is released under the terms
|
||||
```
|
||||
|
||||
## Running the quantized version of the model
|
||||
|
||||
```bash
|
||||
$ cargo run --example mistral --features accelerate --release -- \
|
||||
$ --prompt "Here is a sample quick sort implementation in rust " --quantized -n 400
|
||||
avx: false, neon: true, simd128: false, f16c: false
|
||||
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
retrieved the files in 562.292µs
|
||||
loaded the model in 1.100323667s
|
||||
Here is a sample quick sort implementation in rust
|
||||
|
||||
``rust
|
||||
fn quick_sort(arr: &mut [i32]) {
|
||||
if arr.len() <= 1 {
|
||||
return;
|
||||
}
|
||||
|
||||
let pivot = arr[0];
|
||||
let mut left = vec![];
|
||||
let mut right = vec![];
|
||||
|
||||
for i in 1..arr.len() {
|
||||
if arr[i] < pivot {
|
||||
left.push(arr[i]);
|
||||
} else {
|
||||
right.push(arr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
quick_sort(&mut left);
|
||||
quick_sort(&mut right);
|
||||
|
||||
let mut i = 0;
|
||||
for _ in &left {
|
||||
arr[i] = left.pop().unwrap();
|
||||
i += 1;
|
||||
}
|
||||
|
||||
for _ in &right {
|
||||
arr[i] = right.pop().unwrap();
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
``
|
||||
226 tokens generated (10.91 token/s)
|
||||
```
|
@ -1,271 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::mistral::{Config, Model as Mistral};
|
||||
use candle_transformers::models::quantized_mistral::Model as QMistral;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
Mistral(Mistral),
|
||||
Quantized(QMistral),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("</s>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::Mistral(m) => m.forward(&input, start_pos)?,
|
||||
Model::Quantized(m) => m.forward(&input, start_pos)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "lmz/candle-mistral")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
vec![repo.get("model-q4k.gguf")?]
|
||||
} else {
|
||||
vec![
|
||||
repo.get("pytorch_model-00001-of-00002.safetensors")?,
|
||||
repo.get("pytorch_model-00002-of-00002.safetensors")?,
|
||||
]
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||
let model = QMistral::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Mistral::new(&config, vb)?;
|
||||
(Model::Mistral(model), device)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum};
|
||||
use rand::prelude::*;
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
|
||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||
|
||||
const IMAGE_DIM: usize = 784;
|
||||
const LABELS: usize = 10;
|
||||
@ -95,7 +95,7 @@ impl ConvNet {
|
||||
.flatten_from(1)?
|
||||
.apply(&self.fc1)?
|
||||
.relu()?;
|
||||
self.dropout.forward_t(&xs, train)?.apply(&self.fc2)
|
||||
self.dropout.forward(&xs, train)?.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::nn::conv1d_weight_norm;
|
||||
use candle::{DType, IndexOp, Module, Result, Tensor};
|
||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, Module, VarBuilder};
|
||||
|
||||
// Encodec Model
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||
@ -199,34 +199,25 @@ impl EncodecResidualVectorQuantizer {
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
||||
#[derive(Debug)]
|
||||
struct EncodecLSTM {
|
||||
layers: Vec<candle_nn::LSTM>,
|
||||
layers: Vec<(Tensor, Tensor, Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl EncodecLSTM {
|
||||
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = &vb.pp("lstm");
|
||||
let mut layers = vec![];
|
||||
for layer_idx in 0..cfg.num_lstm_layers {
|
||||
let config = candle_nn::LSTMConfig {
|
||||
layer_idx,
|
||||
..Default::default()
|
||||
};
|
||||
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
|
||||
layers.push(lstm)
|
||||
for i in 0..cfg.num_lstm_layers {
|
||||
let w_hh = vb.get((4 * dim, dim), &format!("weight_hh_l{i}"))?;
|
||||
let w_ih = vb.get((4 * dim, dim), &format!("weight_ih_l{i}"))?;
|
||||
let b_hh = vb.get(4 * dim, &format!("bias_hh_l{i}"))?;
|
||||
let b_ih = vb.get(4 * dim, &format!("bias_ih_l{i}"))?;
|
||||
layers.push((w_hh, w_ih, b_hh, b_ih))
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecLSTM {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
use candle_nn::RNN;
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
let states = layer.seq(&xs)?;
|
||||
xs = layer.states_to_tensor(&states)?;
|
||||
}
|
||||
Ok(xs)
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
@ -256,9 +247,7 @@ impl EncodecConvTranspose1d {
|
||||
bias,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConvTranspose1d {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
@ -310,9 +299,7 @@ impl EncodecConv1d {
|
||||
conv,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: padding, depending on causal.
|
||||
let xs = self.conv.forward(xs)?;
|
||||
@ -353,9 +340,7 @@ impl EncodecResnetBlock {
|
||||
shortcut,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecResnetBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs.clone();
|
||||
let xs = xs.elu(1.)?;
|
||||
@ -454,17 +439,8 @@ impl EncodecEncoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?;
|
||||
for (resnets, conv) in self.sampling_layers.iter() {
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?;
|
||||
}
|
||||
xs = xs.elu(1.0)?.apply(conv)?;
|
||||
}
|
||||
xs.apply(&self.final_lstm)?
|
||||
.elu(1.0)?
|
||||
.apply(&self.final_conv)
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
@ -531,15 +507,8 @@ impl EncodecDecoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
|
||||
for (conv, resnets) in self.sampling_layers.iter() {
|
||||
xs = xs.elu(1.)?.apply(conv)?;
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?
|
||||
}
|
||||
}
|
||||
xs.elu(1.)?.apply(&self.final_conv)
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -73,7 +73,9 @@ fn main() -> Result<()> {
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DTYPE, &device)? };
|
||||
let model = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let model = model.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
||||
let config = GenConfig::small();
|
||||
let mut model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||
|
||||
|
@ -40,7 +40,7 @@ impl Default for Config {
|
||||
num_attention_heads: 16,
|
||||
layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
activation_function: Activation::Gelu,
|
||||
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
|
||||
hidden_size: 1024,
|
||||
dropout: 0.1,
|
||||
attention_dropout: 0.0,
|
||||
@ -66,7 +66,7 @@ impl Config {
|
||||
num_attention_heads: 16,
|
||||
layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
activation_function: Activation::Gelu,
|
||||
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
|
||||
hidden_size: 1024,
|
||||
dropout: 0.1,
|
||||
attention_dropout: 0.0,
|
||||
|
@ -1,56 +0,0 @@
|
||||
# candle-phi: 1.3b LLM with state of the art performance for <10b models.
|
||||
|
||||
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) is a language model using
|
||||
only 1.3 billion parameters but with state of the art performance compared to
|
||||
models with up to 10 billion parameters.
|
||||
|
||||
The candle implementation provides both the standard version as well as a
|
||||
quantized variant.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
$ cargo run --example phi --release -- --prompt "def print_prime(n): "
|
||||
|
||||
def print_prime(n):
|
||||
print("Printing prime numbers")
|
||||
for i in range(2, n+1):
|
||||
if is_prime(i):
|
||||
print(i)
|
||||
|
||||
def is_prime(n):
|
||||
if n <= 1:
|
||||
return False
|
||||
for i in range(2, int(math.sqrt(n))+1):
|
||||
if n % i == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
$ cargo run --example phi --release -- \
|
||||
--prompt "Explain how to find the median of an array and write the corresponding python function.\nAnswer:" \
|
||||
--quantized --sample-len 200
|
||||
|
||||
Explain how to find the median of an array and write the corresponding python function.
|
||||
Answer: The median is the middle value in an array. If the array has an even number of elements, the median is the average of the two middle values.
|
||||
|
||||
def median(arr):
|
||||
arr.sort()
|
||||
n = len(arr)
|
||||
if n % 2 == 0:
|
||||
return (arr[n//2 - 1] + arr[n//2]) / 2
|
||||
else:
|
||||
return arr[n//2]
|
||||
```
|
||||
|
||||
This also supports the [Puffin Phi v2
|
||||
model](https://huggingface.co/teknium/Puffin-Phi-v2) for human interaction.
|
||||
```
|
||||
$ cargo run --example phi --release -- \
|
||||
--prompt "USER: What would you do on a sunny day in Paris?\nASSISTANT:" \
|
||||
--sample-len 200 --model puffin-phi-v2 --quantized
|
||||
USER: What would you do on a sunny day in Paris?
|
||||
ASSISTANT: On a sunny day in Paris, you could visit the Musée du Louvre to admire the famous
|
||||
painting "Mona Lisa" by Leonardo da Vinci. You might also want to stroll along the Champs-Élysées
|
||||
and enjoy the beautiful architecture of the buildings around you. Don't forget to stop by a café
|
||||
for a cup of coffee and to soak up the sun!"
|
||||
```
|
@ -1,313 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
MixFormer(MixFormer),
|
||||
Quantized(QMixFormer),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
verbose_prompt,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||
if tokens.is_empty() {
|
||||
anyhow::bail!("Empty prompts are not supported in the phi model.")
|
||||
}
|
||||
if self.verbose_prompt {
|
||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
println!("{id:7} -> '{token}'");
|
||||
}
|
||||
}
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||
Some(token) => *token,
|
||||
None => anyhow::bail!("cannot find the endoftext token"),
|
||||
};
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::MixFormer(m) => m.forward(&input)?,
|
||||
Model::Quantized(m) => m.forward(&input)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum WhichModel {
|
||||
#[value(name = "1")]
|
||||
V1,
|
||||
#[value(name = "1.5")]
|
||||
V1_5,
|
||||
PuffinPhiV2,
|
||||
PhiHermes,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "1.5")]
|
||||
model: WhichModel,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
"lmz/candle-quantized-phi".to_string()
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"lmz/candle-quantized-phi".to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
let revision = match args.revision {
|
||||
Some(rev) => rev.to_string(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
"main".to_string()
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 => "refs/pr/2".to_string(),
|
||||
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?,
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||
}
|
||||
},
|
||||
};
|
||||
let filename = match args.weight_file {
|
||||
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
||||
None => {
|
||||
if args.quantized {
|
||||
match args.model {
|
||||
WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
|
||||
WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
|
||||
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?,
|
||||
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?,
|
||||
}
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?,
|
||||
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?,
|
||||
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = match args.model {
|
||||
WhichModel::V1 => Config::v1(),
|
||||
WhichModel::V1_5 => Config::v1_5(),
|
||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||
};
|
||||
let (model, device) = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
||||
let model = QMixFormer::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||
let model = MixFormer::new(&config, vb)?;
|
||||
(Model::MixFormer(model), device)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
args.verbose_prompt,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,42 +0,0 @@
|
||||
# candle-quantized-t5
|
||||
|
||||
This example uses a quantized version of the t5 model.
|
||||
|
||||
```bash
|
||||
$ cargo run --example quantized-t5 --release -- --prompt "translate to German: A beautiful candle."
|
||||
...
|
||||
Eine schöne Kerze.
|
||||
```
|
||||
|
||||
The weight file is automatically retrieved from the hub. It is also possible to
|
||||
generate quantized weight files from the original safetensors file by using the
|
||||
`tensor-tools` command line utility via:
|
||||
|
||||
```bash
|
||||
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||
```
|
||||
|
||||
To use a different model, specify the `model-id`. For example, you can use
|
||||
quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
|
||||
|
||||
```bash
|
||||
$ cargo run --example quantized-t5 --release -- \
|
||||
--model-id "jbochi/candle-coedit-quantized" \
|
||||
--prompt "Make this text coherent: Their flight is weak. They run quickly through the tree canopy." \
|
||||
--temperature 0
|
||||
...
|
||||
Although their flight is weak, they run quickly through the tree canopy.
|
||||
|
||||
By default, it will look for `model.gguf` and `config.json`, but you can specify
|
||||
custom local or remote `weight-file` and `config-file`s:
|
||||
|
||||
```bash
|
||||
cargo run --example quantized-t5 --release -- \
|
||||
--model-id "jbochi/candle-coedit-quantized" \
|
||||
--weight-file "model-xl.gguf" \
|
||||
--config-file "config-xl.json" \
|
||||
--prompt "Rewrite to make this easier to understand: Note that a storm surge is what forecasters consider a hurricane's most treacherous aspect." \
|
||||
--temperature 0
|
||||
...
|
||||
Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
|
||||
```
|
@ -1,228 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use candle_transformers::models::quantized_t5 as t5;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, api::sync::ApiRepo, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
T5Small,
|
||||
FlanT5Small,
|
||||
FlanT5Base,
|
||||
FlanT5Large,
|
||||
FlanT5Xl,
|
||||
FlanT5Xxl,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The model repository to use on the HuggingFace hub.
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
// Enable/disable decoding.
|
||||
#[arg(long, default_value = "false")]
|
||||
disable_cache: bool,
|
||||
|
||||
/// Use this prompt, otherwise compute sentence similarities.
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "t5-small")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
struct T5ModelBuilder {
|
||||
device: Device,
|
||||
config: t5::Config,
|
||||
weights_filename: PathBuf,
|
||||
}
|
||||
|
||||
impl T5ModelBuilder {
|
||||
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
||||
let device = Device::Cpu;
|
||||
let default_model = "lmz/candle-quantized-t5".to_string();
|
||||
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, "main".to_string()),
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config_filename = match &args.config_file {
|
||||
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
|
||||
None => match args.which {
|
||||
Which::T5Small => api.get("config.json")?,
|
||||
Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
|
||||
Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
|
||||
Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
|
||||
Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
|
||||
Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
|
||||
},
|
||||
};
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let weights_filename = match &args.weight_file {
|
||||
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
|
||||
None => match args.which {
|
||||
Which::T5Small => api.get("model.gguf")?,
|
||||
Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
|
||||
Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?,
|
||||
Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?,
|
||||
Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?,
|
||||
Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?,
|
||||
},
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||
config.use_cache = !args.disable_cache;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
Ok((
|
||||
Self {
|
||||
device,
|
||||
config,
|
||||
weights_filename,
|
||||
},
|
||||
tokenizer,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
|
||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||
}
|
||||
|
||||
fn get_local_or_remote_file(filename: &str, api: &ApiRepo) -> Result<PathBuf> {
|
||||
let local_filename = std::path::PathBuf::from(filename);
|
||||
if local_filename.exists() {
|
||||
Ok(local_filename)
|
||||
} else {
|
||||
Ok(api.get(filename)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
||||
let device = &builder.device;
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let mut model = builder.build_model()?;
|
||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||
let temperature = if args.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
Some(args.temperature)
|
||||
};
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
|
||||
let encoder_output = model.encode(&input_token_ids)?;
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for index in 0.. {
|
||||
if output_token_ids.len() > 512 {
|
||||
break;
|
||||
}
|
||||
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
|
||||
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
|
||||
} else {
|
||||
let last_token = *output_token_ids.last().unwrap();
|
||||
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
||||
};
|
||||
let logits = model
|
||||
.decode(&decoder_token_ids, &encoder_output)?
|
||||
.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&output_token_ids[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token_id = logits_processor.sample(&logits)?;
|
||||
if next_token_id as usize == builder.config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
output_token_ids.push(next_token_id);
|
||||
if let Some(text) = tokenizer.id_to_token(next_token_id) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start.elapsed();
|
||||
println!(
|
||||
"\n{} tokens generated ({:.2} token/s)\n",
|
||||
output_token_ids.len(),
|
||||
output_token_ids.len() as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
@ -9,7 +9,7 @@ use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::Tensor;
|
||||
use candle::{Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
|
||||
use candle_transformers::models::quantized_llama as model;
|
||||
@ -44,29 +44,6 @@ enum Which {
|
||||
L13bCode,
|
||||
#[value(name = "32b-code")]
|
||||
L34bCode,
|
||||
#[value(name = "7b-mistral")]
|
||||
Mistral7b,
|
||||
#[value(name = "7b-mistral-instruct")]
|
||||
Mistral7bInstruct,
|
||||
#[value(name = "7b-zephyr")]
|
||||
Zephyr7b,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn is_mistral(&self) -> bool {
|
||||
match self {
|
||||
Self::L7b
|
||||
| Self::L13b
|
||||
| Self::L70b
|
||||
| Self::L7bChat
|
||||
| Self::L13bChat
|
||||
| Self::L70bChat
|
||||
| Self::L7bCode
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode => false,
|
||||
Self::Mistral7b | Self::Mistral7bInstruct | Self::Zephyr7b => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -133,12 +110,7 @@ impl Args {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = if self.which.is_mistral() {
|
||||
"mistralai/Mistral-7B-v0.1"
|
||||
} else {
|
||||
"hf-internal-testing/llama-tokenizer"
|
||||
};
|
||||
let api = api.model(repo.to_string());
|
||||
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
@ -168,18 +140,6 @@ impl Args {
|
||||
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
|
||||
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
|
||||
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
|
||||
Which::Mistral7b => (
|
||||
"TheBloke/Mistral-7B-v0.1-GGUF",
|
||||
"mistral-7b-v0.1.Q4_K_S.gguf",
|
||||
),
|
||||
Which::Mistral7bInstruct => (
|
||||
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
||||
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
|
||||
),
|
||||
Which::Zephyr7b => (
|
||||
"TheBloke/zephyr-7B-alpha-GGUF",
|
||||
"zephyr-7b-alpha.Q4_K_M.gguf",
|
||||
),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(repo.to_string());
|
||||
@ -232,13 +192,11 @@ fn main() -> anyhow::Result<()> {
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(false)?;
|
||||
let temperature = if args.temperature == 0. {
|
||||
None
|
||||
} else {
|
||||
Some(args.temperature)
|
||||
};
|
||||
tracing_subscriber::fmt::init();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
@ -278,10 +236,10 @@ fn main() -> anyhow::Result<()> {
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
ModelWeights::from_gguf(model, &mut file, &device)?
|
||||
ModelWeights::from_gguf(model, &mut file)?
|
||||
}
|
||||
Some("ggml" | "bin") | Some(_) | None => {
|
||||
let model = ggml_file::Content::read(&mut file, &device)?;
|
||||
let model = ggml_file::Content::read(&mut file)?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensors.iter() {
|
||||
let elem_count = tensor.shape().elem_count();
|
||||
@ -303,13 +261,9 @@ fn main() -> anyhow::Result<()> {
|
||||
| Which::L7bCode
|
||||
| Which::L13bCode
|
||||
| Which::L34bCode => 1,
|
||||
Which::Mistral7b
|
||||
| Which::Mistral7bInstruct
|
||||
| Which::Zephyr7b
|
||||
| Which::L70b
|
||||
| Which::L70bChat => 8,
|
||||
Which::L70b | Which::L70bChat => 8,
|
||||
};
|
||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa), &device)?
|
||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||
}
|
||||
};
|
||||
println!("model built");
|
||||
@ -337,11 +291,7 @@ fn main() -> anyhow::Result<()> {
|
||||
prompt.pop();
|
||||
}
|
||||
}
|
||||
if args.which.is_mistral() {
|
||||
format!("[INST] {prompt} [/INST]")
|
||||
} else {
|
||||
prompt
|
||||
}
|
||||
prompt
|
||||
}
|
||||
};
|
||||
print!("{}", &prompt_str);
|
||||
@ -368,23 +318,18 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
// TODO Remove this once implementation is finished.
|
||||
let logits = logits.ones_like()?;
|
||||
// logits_processor.sample(&logits)?
|
||||
15043
|
||||
logits_processor.sample(&logits)?
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
all_tokens.push(next_token);
|
||||
print_token(next_token, &tokenizer);
|
||||
|
||||
let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
|
||||
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
@ -397,15 +342,9 @@ fn main() -> anyhow::Result<()> {
|
||||
&all_tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
// TODO Remove this once implementation is finished.
|
||||
// let logits = logits.ones_like()?;
|
||||
// next_token = logits_processor.sample(&logits)?;
|
||||
let next_token = 15043;
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
print_token(next_token, &tokenizer);
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
};
|
||||
}
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
|
@ -1,16 +0,0 @@
|
||||
# candle-reinforcement-learning
|
||||
|
||||
Reinforcement Learning examples for candle.
|
||||
|
||||
This has been tested with `gymnasium` version `0.29.1`. You can install the
|
||||
Python package with:
|
||||
```bash
|
||||
pip install "gymnasium[accept-rom-license]"
|
||||
```
|
||||
|
||||
In order to run the example, use the following command. Note the additional
|
||||
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
|
||||
crate.
|
||||
```bash
|
||||
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples
|
||||
```
|
@ -1,308 +0,0 @@
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from PIL import Image
|
||||
from multiprocessing import Process, Pipe
|
||||
|
||||
# atari_wrappers.py
|
||||
class NoopResetEnv(gym.Wrapper):
|
||||
def __init__(self, env, noop_max=30):
|
||||
"""Sample initial states by taking random number of no-ops on reset.
|
||||
No-op is assumed to be action 0.
|
||||
"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.noop_max = noop_max
|
||||
self.override_num_noops = None
|
||||
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
||||
|
||||
def reset(self):
|
||||
""" Do no-op action for a number of steps in [1, noop_max]."""
|
||||
self.env.reset()
|
||||
if self.override_num_noops is not None:
|
||||
noops = self.override_num_noops
|
||||
else:
|
||||
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) #pylint: disable=E1101
|
||||
assert noops > 0
|
||||
obs = None
|
||||
for _ in range(noops):
|
||||
obs, _, done, _ = self.env.step(0)
|
||||
if done:
|
||||
obs = self.env.reset()
|
||||
return obs
|
||||
|
||||
class FireResetEnv(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
"""Take action on reset for environments that are fixed until firing."""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
||||
assert len(env.unwrapped.get_action_meanings()) >= 3
|
||||
|
||||
def reset(self):
|
||||
self.env.reset()
|
||||
obs, _, done, _ = self.env.step(1)
|
||||
if done:
|
||||
self.env.reset()
|
||||
obs, _, done, _ = self.env.step(2)
|
||||
if done:
|
||||
self.env.reset()
|
||||
return obs
|
||||
|
||||
class ImageSaver(gym.Wrapper):
|
||||
def __init__(self, env, img_path, rank):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self._cnt = 0
|
||||
self._img_path = img_path
|
||||
self._rank = rank
|
||||
|
||||
def step(self, action):
|
||||
step_result = self.env.step(action)
|
||||
obs, _, _, _ = step_result
|
||||
img = Image.fromarray(obs, 'RGB')
|
||||
img.save('%s/out%d-%05d.png' % (self._img_path, self._rank, self._cnt))
|
||||
self._cnt += 1
|
||||
return step_result
|
||||
|
||||
class EpisodicLifeEnv(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
"""Make end-of-life == end-of-episode, but only reset on true game over.
|
||||
Done by DeepMind for the DQN and co. since it helps value estimation.
|
||||
"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.lives = 0
|
||||
self.was_real_done = True
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.was_real_done = done
|
||||
# check current lives, make loss of life terminal,
|
||||
# then update lives to handle bonus lives
|
||||
lives = self.env.unwrapped.ale.lives()
|
||||
if lives < self.lives and lives > 0:
|
||||
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
|
||||
# so its important to keep lives > 0, so that we only reset once
|
||||
# the environment advertises done.
|
||||
done = True
|
||||
self.lives = lives
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
"""Reset only when lives are exhausted.
|
||||
This way all states are still reachable even though lives are episodic,
|
||||
and the learner need not know about any of this behind-the-scenes.
|
||||
"""
|
||||
if self.was_real_done:
|
||||
obs = self.env.reset()
|
||||
else:
|
||||
# no-op step to advance from terminal/lost life state
|
||||
obs, _, _, _ = self.env.step(0)
|
||||
self.lives = self.env.unwrapped.ale.lives()
|
||||
return obs
|
||||
|
||||
class MaxAndSkipEnv(gym.Wrapper):
|
||||
def __init__(self, env, skip=4):
|
||||
"""Return only every `skip`-th frame"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
# most recent raw observations (for max pooling across time steps)
|
||||
self._obs_buffer = deque(maxlen=2)
|
||||
self._skip = skip
|
||||
|
||||
def step(self, action):
|
||||
"""Repeat action, sum reward, and max over last observations."""
|
||||
total_reward = 0.0
|
||||
done = None
|
||||
for _ in range(self._skip):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self._obs_buffer.append(obs)
|
||||
total_reward += reward
|
||||
if done:
|
||||
break
|
||||
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
|
||||
|
||||
return max_frame, total_reward, done, info
|
||||
|
||||
def reset(self):
|
||||
"""Clear past frame buffer and init. to first obs. from inner env."""
|
||||
self._obs_buffer.clear()
|
||||
obs = self.env.reset()
|
||||
self._obs_buffer.append(obs)
|
||||
return obs
|
||||
|
||||
class ClipRewardEnv(gym.RewardWrapper):
|
||||
def reward(self, reward):
|
||||
"""Bin reward to {+1, 0, -1} by its sign."""
|
||||
return np.sign(reward)
|
||||
|
||||
class WarpFrame(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
self.res = 84
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.res, self.res, 1), dtype='uint8')
|
||||
|
||||
def observation(self, obs):
|
||||
frame = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32'))
|
||||
frame = np.array(Image.fromarray(frame).resize((self.res, self.res),
|
||||
resample=Image.BILINEAR), dtype=np.uint8)
|
||||
return frame.reshape((self.res, self.res, 1))
|
||||
|
||||
class FrameStack(gym.Wrapper):
|
||||
def __init__(self, env, k):
|
||||
"""Buffer observations and stack across channels (last axis)."""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.k = k
|
||||
self.frames = deque([], maxlen=k)
|
||||
shp = env.observation_space.shape
|
||||
assert shp[2] == 1 # can only stack 1-channel frames
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k), dtype='uint8')
|
||||
|
||||
def reset(self):
|
||||
"""Clear buffer and re-fill by duplicating the first observation."""
|
||||
ob = self.env.reset()
|
||||
for _ in range(self.k): self.frames.append(ob)
|
||||
return self.observation()
|
||||
|
||||
def step(self, action):
|
||||
ob, reward, done, info = self.env.step(action)
|
||||
self.frames.append(ob)
|
||||
return self.observation(), reward, done, info
|
||||
|
||||
def observation(self):
|
||||
assert len(self.frames) == self.k
|
||||
return np.concatenate(self.frames, axis=2)
|
||||
|
||||
def wrap_deepmind(env, episode_life=True, clip_rewards=True):
|
||||
"""Configure environment for DeepMind-style Atari.
|
||||
|
||||
Note: this does not include frame stacking!"""
|
||||
assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip
|
||||
if episode_life:
|
||||
env = EpisodicLifeEnv(env)
|
||||
env = NoopResetEnv(env, noop_max=30)
|
||||
env = MaxAndSkipEnv(env, skip=4)
|
||||
if 'FIRE' in env.unwrapped.get_action_meanings():
|
||||
env = FireResetEnv(env)
|
||||
env = WarpFrame(env)
|
||||
if clip_rewards:
|
||||
env = ClipRewardEnv(env)
|
||||
return env
|
||||
|
||||
# envs.py
|
||||
def make_env(env_id, img_dir, seed, rank):
|
||||
def _thunk():
|
||||
env = gym.make(env_id)
|
||||
env.reset(seed=(seed + rank))
|
||||
if img_dir is not None:
|
||||
env = ImageSaver(env, img_dir, rank)
|
||||
env = wrap_deepmind(env)
|
||||
env = WrapPyTorch(env)
|
||||
return env
|
||||
|
||||
return _thunk
|
||||
|
||||
class WrapPyTorch(gym.ObservationWrapper):
|
||||
def __init__(self, env=None):
|
||||
super(WrapPyTorch, self).__init__(env)
|
||||
self.observation_space = gym.spaces.Box(0.0, 1.0, [1, 84, 84], dtype='float32')
|
||||
|
||||
def observation(self, observation):
|
||||
return observation.transpose(2, 0, 1)
|
||||
|
||||
# vecenv.py
|
||||
class VecEnv(object):
|
||||
"""
|
||||
Vectorized environment base class
|
||||
"""
|
||||
def step(self, vac):
|
||||
"""
|
||||
Apply sequence of actions to sequence of environments
|
||||
actions -> (observations, rewards, news)
|
||||
|
||||
where 'news' is a boolean vector indicating whether each element is new.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
def reset(self):
|
||||
"""
|
||||
Reset all environments
|
||||
"""
|
||||
raise NotImplementedError
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
# subproc_vec_env.py
|
||||
def worker(remote, env_fn_wrapper):
|
||||
env = env_fn_wrapper.x()
|
||||
while True:
|
||||
cmd, data = remote.recv()
|
||||
if cmd == 'step':
|
||||
ob, reward, done, info = env.step(data)
|
||||
if done:
|
||||
ob = env.reset()
|
||||
remote.send((ob, reward, done, info))
|
||||
elif cmd == 'reset':
|
||||
ob = env.reset()
|
||||
remote.send(ob)
|
||||
elif cmd == 'close':
|
||||
remote.close()
|
||||
break
|
||||
elif cmd == 'get_spaces':
|
||||
remote.send((env.action_space, env.observation_space))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
class CloudpickleWrapper(object):
|
||||
"""
|
||||
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
||||
"""
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
def __getstate__(self):
|
||||
import cloudpickle
|
||||
return cloudpickle.dumps(self.x)
|
||||
def __setstate__(self, ob):
|
||||
import pickle
|
||||
self.x = pickle.loads(ob)
|
||||
|
||||
class SubprocVecEnv(VecEnv):
|
||||
def __init__(self, env_fns):
|
||||
"""
|
||||
envs: list of gym environments to run in subprocesses
|
||||
"""
|
||||
nenvs = len(env_fns)
|
||||
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
|
||||
self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn)))
|
||||
for (work_remote, env_fn) in zip(self.work_remotes, env_fns)]
|
||||
for p in self.ps:
|
||||
p.start()
|
||||
|
||||
self.remotes[0].send(('get_spaces', None))
|
||||
self.action_space, self.observation_space = self.remotes[0].recv()
|
||||
|
||||
|
||||
def step(self, actions):
|
||||
for remote, action in zip(self.remotes, actions):
|
||||
remote.send(('step', action))
|
||||
results = [remote.recv() for remote in self.remotes]
|
||||
obs, rews, dones, infos = zip(*results)
|
||||
return np.stack(obs), np.stack(rews), np.stack(dones), infos
|
||||
|
||||
def reset(self):
|
||||
for remote in self.remotes:
|
||||
remote.send(('reset', None))
|
||||
return np.stack([remote.recv() for remote in self.remotes])
|
||||
|
||||
def close(self):
|
||||
for remote in self.remotes:
|
||||
remote.send(('close', None))
|
||||
for p in self.ps:
|
||||
p.join()
|
||||
|
||||
@property
|
||||
def num_envs(self):
|
||||
return len(self.remotes)
|
||||
|
||||
# Create the environment.
|
||||
def make(env_name, img_dir, num_processes):
|
||||
envs = SubprocVecEnv([
|
||||
make_env(env_name, img_dir, 1337, i) for i in range(num_processes)
|
||||
])
|
||||
return envs
|
@ -1,451 +0,0 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::fmt::Display;
|
||||
|
||||
use candle::{DType, Device, Error, Module, Result, Tensor, Var};
|
||||
use candle_nn::{
|
||||
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
|
||||
VarBuilder, VarMap,
|
||||
};
|
||||
use rand::{distributions::Uniform, thread_rng, Rng};
|
||||
|
||||
pub struct OuNoise {
|
||||
mu: f64,
|
||||
theta: f64,
|
||||
sigma: f64,
|
||||
state: Tensor,
|
||||
}
|
||||
impl OuNoise {
|
||||
pub fn new(mu: f64, theta: f64, sigma: f64, size_action: usize) -> Result<Self> {
|
||||
Ok(Self {
|
||||
mu,
|
||||
theta,
|
||||
sigma,
|
||||
state: Tensor::ones(size_action, DType::F32, &Device::Cpu)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn sample(&mut self) -> Result<Tensor> {
|
||||
let rand = Tensor::randn_like(&self.state, 0.0, 1.0)?;
|
||||
let dx = ((self.theta * (self.mu - &self.state)?)? + (self.sigma * rand)?)?;
|
||||
self.state = (&self.state + dx)?;
|
||||
Ok(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Transition {
|
||||
state: Tensor,
|
||||
action: Tensor,
|
||||
reward: Tensor,
|
||||
next_state: Tensor,
|
||||
terminated: bool,
|
||||
truncated: bool,
|
||||
}
|
||||
impl Transition {
|
||||
fn new(
|
||||
state: &Tensor,
|
||||
action: &Tensor,
|
||||
reward: &Tensor,
|
||||
next_state: &Tensor,
|
||||
terminated: bool,
|
||||
truncated: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: state.clone(),
|
||||
action: action.clone(),
|
||||
reward: reward.clone(),
|
||||
next_state: next_state.clone(),
|
||||
terminated,
|
||||
truncated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReplayBuffer {
|
||||
buffer: VecDeque<Transition>,
|
||||
capacity: usize,
|
||||
size: usize,
|
||||
}
|
||||
impl ReplayBuffer {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
buffer: VecDeque::with_capacity(capacity),
|
||||
capacity,
|
||||
size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(
|
||||
&mut self,
|
||||
state: &Tensor,
|
||||
action: &Tensor,
|
||||
reward: &Tensor,
|
||||
next_state: &Tensor,
|
||||
terminated: bool,
|
||||
truncated: bool,
|
||||
) {
|
||||
if self.size == self.capacity {
|
||||
self.buffer.pop_front();
|
||||
} else {
|
||||
self.size += 1;
|
||||
}
|
||||
self.buffer.push_back(Transition::new(
|
||||
state, action, reward, next_state, terminated, truncated,
|
||||
));
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub fn random_batch(
|
||||
&self,
|
||||
batch_size: usize,
|
||||
) -> Result<Option<(Tensor, Tensor, Tensor, Tensor, Vec<bool>, Vec<bool>)>> {
|
||||
if self.size < batch_size {
|
||||
Ok(None)
|
||||
} else {
|
||||
let transitions: Vec<&Transition> = thread_rng()
|
||||
.sample_iter(Uniform::from(0..self.size))
|
||||
.take(batch_size)
|
||||
.map(|i| self.buffer.get(i).unwrap())
|
||||
.collect();
|
||||
|
||||
let states: Vec<Tensor> = transitions
|
||||
.iter()
|
||||
.map(|t| t.state.unsqueeze(0))
|
||||
.collect::<Result<_>>()?;
|
||||
let actions: Vec<Tensor> = transitions
|
||||
.iter()
|
||||
.map(|t| t.action.unsqueeze(0))
|
||||
.collect::<Result<_>>()?;
|
||||
let rewards: Vec<Tensor> = transitions
|
||||
.iter()
|
||||
.map(|t| t.reward.unsqueeze(0))
|
||||
.collect::<Result<_>>()?;
|
||||
let next_states: Vec<Tensor> = transitions
|
||||
.iter()
|
||||
.map(|t| t.next_state.unsqueeze(0))
|
||||
.collect::<Result<_>>()?;
|
||||
let terminateds: Vec<bool> = transitions.iter().map(|t| t.terminated).collect();
|
||||
let truncateds: Vec<bool> = transitions.iter().map(|t| t.truncated).collect();
|
||||
|
||||
Ok(Some((
|
||||
Tensor::cat(&states, 0)?,
|
||||
Tensor::cat(&actions, 0)?,
|
||||
Tensor::cat(&rewards, 0)?,
|
||||
Tensor::cat(&next_states, 0)?,
|
||||
terminateds,
|
||||
truncateds,
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn track(
|
||||
varmap: &mut VarMap,
|
||||
vb: &VarBuilder,
|
||||
target_prefix: &str,
|
||||
network_prefix: &str,
|
||||
dims: &[(usize, usize)],
|
||||
tau: f64,
|
||||
) -> Result<()> {
|
||||
for (i, &(in_dim, out_dim)) in dims.iter().enumerate() {
|
||||
let target_w = vb.get((out_dim, in_dim), &format!("{target_prefix}-fc{i}.weight"))?;
|
||||
let network_w = vb.get((out_dim, in_dim), &format!("{network_prefix}-fc{i}.weight"))?;
|
||||
varmap.set_one(
|
||||
format!("{target_prefix}-fc{i}.weight"),
|
||||
((tau * network_w)? + ((1.0 - tau) * target_w)?)?,
|
||||
)?;
|
||||
|
||||
let target_b = vb.get(out_dim, &format!("{target_prefix}-fc{i}.bias"))?;
|
||||
let network_b = vb.get(out_dim, &format!("{network_prefix}-fc{i}.bias"))?;
|
||||
varmap.set_one(
|
||||
format!("{target_prefix}-fc{i}.bias"),
|
||||
((tau * network_b)? + ((1.0 - tau) * target_b)?)?,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct Actor<'a> {
|
||||
varmap: VarMap,
|
||||
vb: VarBuilder<'a>,
|
||||
network: Sequential,
|
||||
target_network: Sequential,
|
||||
size_state: usize,
|
||||
size_action: usize,
|
||||
dims: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
impl Actor<'_> {
|
||||
fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
|
||||
let mut varmap = VarMap::new();
|
||||
let vb = VarBuilder::from_varmap(&varmap, dtype, device);
|
||||
|
||||
let dims = vec![(size_state, 400), (400, 300), (300, size_action)];
|
||||
|
||||
let make_network = |prefix: &str| {
|
||||
let seq = seq()
|
||||
.add(linear(
|
||||
dims[0].0,
|
||||
dims[0].1,
|
||||
vb.pp(format!("{prefix}-fc0")),
|
||||
)?)
|
||||
.add(Activation::Relu)
|
||||
.add(linear(
|
||||
dims[1].0,
|
||||
dims[1].1,
|
||||
vb.pp(format!("{prefix}-fc1")),
|
||||
)?)
|
||||
.add(Activation::Relu)
|
||||
.add(linear(
|
||||
dims[2].0,
|
||||
dims[2].1,
|
||||
vb.pp(format!("{prefix}-fc2")),
|
||||
)?)
|
||||
.add(func(|xs| xs.tanh()));
|
||||
Ok::<Sequential, Error>(seq)
|
||||
};
|
||||
|
||||
let network = make_network("actor")?;
|
||||
let target_network = make_network("target-actor")?;
|
||||
|
||||
// this sets the two networks to be equal to each other using tau = 1.0
|
||||
track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0);
|
||||
|
||||
Ok(Self {
|
||||
varmap,
|
||||
vb,
|
||||
network,
|
||||
target_network,
|
||||
size_state,
|
||||
size_action,
|
||||
dims,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, state: &Tensor) -> Result<Tensor> {
|
||||
self.network.forward(state)
|
||||
}
|
||||
|
||||
fn target_forward(&self, state: &Tensor) -> Result<Tensor> {
|
||||
self.target_network.forward(state)
|
||||
}
|
||||
|
||||
fn track(&mut self, tau: f64) -> Result<()> {
|
||||
track(
|
||||
&mut self.varmap,
|
||||
&self.vb,
|
||||
"target-actor",
|
||||
"actor",
|
||||
&self.dims,
|
||||
tau,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
struct Critic<'a> {
|
||||
varmap: VarMap,
|
||||
vb: VarBuilder<'a>,
|
||||
network: Sequential,
|
||||
target_network: Sequential,
|
||||
size_state: usize,
|
||||
size_action: usize,
|
||||
dims: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
impl Critic<'_> {
|
||||
fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
|
||||
let mut varmap = VarMap::new();
|
||||
let vb = VarBuilder::from_varmap(&varmap, dtype, device);
|
||||
|
||||
let dims: Vec<(usize, usize)> = vec![(size_state + size_action, 400), (400, 300), (300, 1)];
|
||||
|
||||
let make_network = |prefix: &str| {
|
||||
let seq = seq()
|
||||
.add(linear(
|
||||
dims[0].0,
|
||||
dims[0].1,
|
||||
vb.pp(format!("{prefix}-fc0")),
|
||||
)?)
|
||||
.add(Activation::Relu)
|
||||
.add(linear(
|
||||
dims[1].0,
|
||||
dims[1].1,
|
||||
vb.pp(format!("{prefix}-fc1")),
|
||||
)?)
|
||||
.add(Activation::Relu)
|
||||
.add(linear(
|
||||
dims[2].0,
|
||||
dims[2].1,
|
||||
vb.pp(format!("{prefix}-fc2")),
|
||||
)?);
|
||||
Ok::<Sequential, Error>(seq)
|
||||
};
|
||||
|
||||
let network = make_network("critic")?;
|
||||
let target_network = make_network("target-critic")?;
|
||||
|
||||
// this sets the two networks to be equal to each other using tau = 1.0
|
||||
track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0);
|
||||
|
||||
Ok(Self {
|
||||
varmap,
|
||||
vb,
|
||||
network,
|
||||
target_network,
|
||||
size_state,
|
||||
size_action,
|
||||
dims,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
|
||||
let xs = Tensor::cat(&[action, state], 1)?;
|
||||
self.network.forward(&xs)
|
||||
}
|
||||
|
||||
fn target_forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
|
||||
let xs = Tensor::cat(&[action, state], 1)?;
|
||||
self.target_network.forward(&xs)
|
||||
}
|
||||
|
||||
fn track(&mut self, tau: f64) -> Result<()> {
|
||||
track(
|
||||
&mut self.varmap,
|
||||
&self.vb,
|
||||
"target-critic",
|
||||
"critic",
|
||||
&self.dims,
|
||||
tau,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
pub struct DDPG<'a> {
|
||||
actor: Actor<'a>,
|
||||
actor_optim: AdamW,
|
||||
critic: Critic<'a>,
|
||||
critic_optim: AdamW,
|
||||
gamma: f64,
|
||||
tau: f64,
|
||||
replay_buffer: ReplayBuffer,
|
||||
ou_noise: OuNoise,
|
||||
|
||||
size_state: usize,
|
||||
size_action: usize,
|
||||
pub train: bool,
|
||||
}
|
||||
|
||||
impl DDPG<'_> {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
device: &Device,
|
||||
size_state: usize,
|
||||
size_action: usize,
|
||||
train: bool,
|
||||
actor_lr: f64,
|
||||
critic_lr: f64,
|
||||
gamma: f64,
|
||||
tau: f64,
|
||||
buffer_capacity: usize,
|
||||
ou_noise: OuNoise,
|
||||
) -> Result<Self> {
|
||||
let filter_by_prefix = |varmap: &VarMap, prefix: &str| {
|
||||
varmap
|
||||
.data()
|
||||
.lock()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.filter_map(|(name, var)| name.starts_with(prefix).then_some(var.clone()))
|
||||
.collect::<Vec<Var>>()
|
||||
};
|
||||
|
||||
let actor = Actor::new(device, DType::F32, size_state, size_action)?;
|
||||
let actor_optim = AdamW::new(
|
||||
filter_by_prefix(&actor.varmap, "actor"),
|
||||
ParamsAdamW {
|
||||
lr: actor_lr,
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
|
||||
let critic = Critic::new(device, DType::F32, size_state, size_action)?;
|
||||
let critic_optim = AdamW::new(
|
||||
filter_by_prefix(&critic.varmap, "critic"),
|
||||
ParamsAdamW {
|
||||
lr: critic_lr,
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
actor,
|
||||
actor_optim,
|
||||
critic,
|
||||
critic_optim,
|
||||
gamma,
|
||||
tau,
|
||||
replay_buffer: ReplayBuffer::new(buffer_capacity),
|
||||
ou_noise,
|
||||
size_state,
|
||||
size_action,
|
||||
train,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn remember(
|
||||
&mut self,
|
||||
state: &Tensor,
|
||||
action: &Tensor,
|
||||
reward: &Tensor,
|
||||
next_state: &Tensor,
|
||||
terminated: bool,
|
||||
truncated: bool,
|
||||
) {
|
||||
self.replay_buffer
|
||||
.push(state, action, reward, next_state, terminated, truncated)
|
||||
}
|
||||
|
||||
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
|
||||
let actions = self
|
||||
.actor
|
||||
.forward(&state.detach()?.unsqueeze(0)?)?
|
||||
.squeeze(0)?;
|
||||
let actions = if self.train {
|
||||
(actions + self.ou_noise.sample()?)?
|
||||
} else {
|
||||
actions
|
||||
};
|
||||
actions.squeeze(0)?.to_scalar::<f32>()
|
||||
}
|
||||
|
||||
pub fn train(&mut self, batch_size: usize) -> Result<()> {
|
||||
let (states, actions, rewards, next_states, _, _) =
|
||||
match self.replay_buffer.random_batch(batch_size)? {
|
||||
Some(v) => v,
|
||||
_ => return Ok(()),
|
||||
};
|
||||
|
||||
let q_target = self
|
||||
.critic
|
||||
.target_forward(&next_states, &self.actor.target_forward(&next_states)?)?;
|
||||
let q_target = (rewards + (self.gamma * q_target)?.detach())?;
|
||||
let q = self.critic.forward(&states, &actions)?;
|
||||
let diff = (q_target - q)?;
|
||||
|
||||
let critic_loss = diff.sqr()?.mean_all()?;
|
||||
self.critic_optim.backward_step(&critic_loss)?;
|
||||
|
||||
let actor_loss = self
|
||||
.critic
|
||||
.forward(&states, &self.actor.forward(&states)?)?
|
||||
.mean_all()?
|
||||
.neg()?;
|
||||
self.actor_optim.backward_step(&actor_loss)?;
|
||||
|
||||
self.critic.track(self.tau)?;
|
||||
self.actor.track(self.tau)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -1,112 +0,0 @@
|
||||
#![allow(unused)]
|
||||
//! Wrappers around the Python API of Gymnasium (the new version of OpenAI gym)
|
||||
use candle::{Device, Result, Tensor};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
|
||||
/// The return value for a step.
|
||||
#[derive(Debug)]
|
||||
pub struct Step<A> {
|
||||
pub state: Tensor,
|
||||
pub action: A,
|
||||
pub reward: f64,
|
||||
pub terminated: bool,
|
||||
pub truncated: bool,
|
||||
}
|
||||
|
||||
impl<A: Copy> Step<A> {
|
||||
/// Returns a copy of this step changing the observation tensor.
|
||||
pub fn copy_with_obs(&self, state: &Tensor) -> Step<A> {
|
||||
Step {
|
||||
state: state.clone(),
|
||||
action: self.action,
|
||||
reward: self.reward,
|
||||
terminated: self.terminated,
|
||||
truncated: self.truncated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An OpenAI Gym session.
|
||||
pub struct GymEnv {
|
||||
env: PyObject,
|
||||
action_space: usize,
|
||||
observation_space: Vec<usize>,
|
||||
}
|
||||
|
||||
fn w(res: PyErr) -> candle::Error {
|
||||
candle::Error::wrap(res)
|
||||
}
|
||||
|
||||
impl GymEnv {
|
||||
/// Creates a new session of the specified OpenAI Gym environment.
|
||||
pub fn new(name: &str) -> Result<GymEnv> {
|
||||
Python::with_gil(|py| {
|
||||
let gym = py.import("gymnasium")?;
|
||||
let make = gym.getattr("make")?;
|
||||
let env = make.call1((name,))?;
|
||||
let action_space = env.getattr("action_space")?;
|
||||
let action_space = if let Ok(val) = action_space.getattr("n") {
|
||||
val.extract()?
|
||||
} else {
|
||||
let action_space: Vec<usize> = action_space.getattr("shape")?.extract()?;
|
||||
action_space[0]
|
||||
};
|
||||
let observation_space = env.getattr("observation_space")?;
|
||||
let observation_space = observation_space.getattr("shape")?.extract()?;
|
||||
Ok(GymEnv {
|
||||
env: env.into(),
|
||||
action_space,
|
||||
observation_space,
|
||||
})
|
||||
})
|
||||
.map_err(w)
|
||||
}
|
||||
|
||||
/// Resets the environment, returning the observation tensor.
|
||||
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
||||
let state: Vec<f32> = Python::with_gil(|py| {
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_item("seed", seed)?;
|
||||
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
||||
state.as_ref(py).get_item(0)?.extract()
|
||||
})
|
||||
.map_err(w)?;
|
||||
Tensor::new(state, &Device::Cpu)
|
||||
}
|
||||
|
||||
/// Applies an environment step using the specified action.
|
||||
pub fn step<A: pyo3::IntoPy<pyo3::Py<pyo3::PyAny>> + Clone>(
|
||||
&self,
|
||||
action: A,
|
||||
) -> Result<Step<A>> {
|
||||
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
|
||||
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
|
||||
let step = step.as_ref(py);
|
||||
let state: Vec<f32> = step.get_item(0)?.extract()?;
|
||||
let reward: f64 = step.get_item(1)?.extract()?;
|
||||
let terminated: bool = step.get_item(2)?.extract()?;
|
||||
let truncated: bool = step.get_item(3)?.extract()?;
|
||||
Ok((state, reward, terminated, truncated))
|
||||
})
|
||||
.map_err(w)?;
|
||||
let state = Tensor::new(state, &Device::Cpu)?;
|
||||
Ok(Step {
|
||||
state,
|
||||
action,
|
||||
reward,
|
||||
terminated,
|
||||
truncated,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the number of allowed actions for this environment.
|
||||
pub fn action_space(&self) -> usize {
|
||||
self.action_space
|
||||
}
|
||||
|
||||
/// Returns the shape of the observation tensors.
|
||||
pub fn observation_space(&self) -> &[usize] {
|
||||
&self.observation_space
|
||||
}
|
||||
}
|
@ -1,144 +0,0 @@
|
||||
#![allow(unused)]
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
mod gym_env;
|
||||
mod vec_gym_env;
|
||||
|
||||
mod ddpg;
|
||||
|
||||
use candle::{Device, Result, Tensor};
|
||||
use clap::Parser;
|
||||
use rand::Rng;
|
||||
|
||||
// The impact of the q value of the next state on the current state's q value.
|
||||
const GAMMA: f64 = 0.99;
|
||||
// The weight for updating the target networks.
|
||||
const TAU: f64 = 0.005;
|
||||
// The capacity of the replay buffer used for sampling training data.
|
||||
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
|
||||
// The training batch size for each training iteration.
|
||||
const TRAINING_BATCH_SIZE: usize = 100;
|
||||
// The total number of episodes.
|
||||
const MAX_EPISODES: usize = 100;
|
||||
// The maximum length of an episode.
|
||||
const EPISODE_LENGTH: usize = 200;
|
||||
// The number of training iterations after one episode finishes.
|
||||
const TRAINING_ITERATIONS: usize = 200;
|
||||
|
||||
// Ornstein-Uhlenbeck process parameters.
|
||||
const MU: f64 = 0.0;
|
||||
const THETA: f64 = 0.15;
|
||||
const SIGMA: f64 = 0.1;
|
||||
|
||||
const ACTOR_LEARNING_RATE: f64 = 1e-4;
|
||||
const CRITIC_LEARNING_RATE: f64 = 1e-3;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let env = gym_env::GymEnv::new("Pendulum-v1")?;
|
||||
println!("action space: {}", env.action_space());
|
||||
println!("observation space: {:?}", env.observation_space());
|
||||
|
||||
let size_state = env.observation_space().iter().product::<usize>();
|
||||
let size_action = env.action_space();
|
||||
|
||||
let mut agent = ddpg::DDPG::new(
|
||||
&Device::Cpu,
|
||||
size_state,
|
||||
size_action,
|
||||
true,
|
||||
ACTOR_LEARNING_RATE,
|
||||
CRITIC_LEARNING_RATE,
|
||||
GAMMA,
|
||||
TAU,
|
||||
REPLAY_BUFFER_CAPACITY,
|
||||
ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?,
|
||||
)?;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for episode in 0..MAX_EPISODES {
|
||||
// let mut state = env.reset(episode as u64)?;
|
||||
let mut state = env.reset(rng.gen::<u64>())?;
|
||||
|
||||
let mut total_reward = 0.0;
|
||||
for _ in 0..EPISODE_LENGTH {
|
||||
let mut action = 2.0 * agent.actions(&state)?;
|
||||
action = action.clamp(-2.0, 2.0);
|
||||
|
||||
let step = env.step(vec![action])?;
|
||||
total_reward += step.reward;
|
||||
|
||||
agent.remember(
|
||||
&state,
|
||||
&Tensor::new(vec![action], &Device::Cpu)?,
|
||||
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
|
||||
&step.state,
|
||||
step.terminated,
|
||||
step.truncated,
|
||||
);
|
||||
|
||||
if step.terminated || step.truncated {
|
||||
break;
|
||||
}
|
||||
state = step.state;
|
||||
}
|
||||
|
||||
println!("episode {episode} with total reward of {total_reward}");
|
||||
|
||||
for _ in 0..TRAINING_ITERATIONS {
|
||||
agent.train(TRAINING_BATCH_SIZE)?;
|
||||
}
|
||||
}
|
||||
|
||||
println!("Testing...");
|
||||
agent.train = false;
|
||||
for episode in 0..10 {
|
||||
// let mut state = env.reset(episode as u64)?;
|
||||
let mut state = env.reset(rng.gen::<u64>())?;
|
||||
let mut total_reward = 0.0;
|
||||
for _ in 0..EPISODE_LENGTH {
|
||||
let mut action = 2.0 * agent.actions(&state)?;
|
||||
action = action.clamp(-2.0, 2.0);
|
||||
|
||||
let step = env.step(vec![action])?;
|
||||
total_reward += step.reward;
|
||||
|
||||
if step.terminated || step.truncated {
|
||||
break;
|
||||
}
|
||||
state = step.state;
|
||||
}
|
||||
println!("episode {episode} with total reward of {total_reward}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,91 +0,0 @@
|
||||
#![allow(unused)]
|
||||
//! Vectorized version of the gym environment.
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Step {
|
||||
pub obs: Tensor,
|
||||
pub reward: Tensor,
|
||||
pub is_done: Tensor,
|
||||
}
|
||||
|
||||
pub struct VecGymEnv {
|
||||
env: PyObject,
|
||||
action_space: usize,
|
||||
observation_space: Vec<usize>,
|
||||
}
|
||||
|
||||
fn w(res: PyErr) -> candle::Error {
|
||||
candle::Error::wrap(res)
|
||||
}
|
||||
|
||||
impl VecGymEnv {
|
||||
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
|
||||
Python::with_gil(|py| {
|
||||
let sys = py.import("sys")?;
|
||||
let path = sys.getattr("path")?;
|
||||
let _ = path.call_method1(
|
||||
"append",
|
||||
("candle-examples/examples/reinforcement-learning",),
|
||||
)?;
|
||||
let gym = py.import("atari_wrappers")?;
|
||||
let make = gym.getattr("make")?;
|
||||
let env = make.call1((name, img_dir, nprocesses))?;
|
||||
let action_space = env.getattr("action_space")?;
|
||||
let action_space = action_space.getattr("n")?.extract()?;
|
||||
let observation_space = env.getattr("observation_space")?;
|
||||
let observation_space: Vec<usize> = observation_space.getattr("shape")?.extract()?;
|
||||
let observation_space =
|
||||
[vec![nprocesses].as_slice(), observation_space.as_slice()].concat();
|
||||
Ok(VecGymEnv {
|
||||
env: env.into(),
|
||||
action_space,
|
||||
observation_space,
|
||||
})
|
||||
})
|
||||
.map_err(w)
|
||||
}
|
||||
|
||||
pub fn reset(&self) -> Result<Tensor> {
|
||||
let obs = Python::with_gil(|py| {
|
||||
let obs = self.env.call_method0(py, "reset")?;
|
||||
let obs = obs.call_method0(py, "flatten")?;
|
||||
obs.extract::<Vec<f32>>(py)
|
||||
})
|
||||
.map_err(w)?;
|
||||
Tensor::new(obs, &Device::Cpu)?.reshape(self.observation_space.as_slice())
|
||||
}
|
||||
|
||||
pub fn step(&self, action: Vec<usize>) -> Result<Step> {
|
||||
let (obs, reward, is_done) = Python::with_gil(|py| {
|
||||
let step = self.env.call_method(py, "step", (action,), None)?;
|
||||
let step = step.as_ref(py);
|
||||
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
|
||||
let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?;
|
||||
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
|
||||
let reward: Vec<f32> = step.get_item(1)?.extract()?;
|
||||
let is_done: Vec<f32> = step.get_item(2)?.extract()?;
|
||||
Ok((obs, reward, is_done))
|
||||
})
|
||||
.map_err(w)?;
|
||||
let obs = Tensor::from_vec(obs, self.observation_space.as_slice(), &Device::Cpu)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let reward = Tensor::new(reward, &Device::Cpu)?;
|
||||
let is_done = Tensor::new(is_done, &Device::Cpu)?;
|
||||
Ok(Step {
|
||||
obs,
|
||||
reward,
|
||||
is_done,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn action_space(&self) -> usize {
|
||||
self.action_space
|
||||
}
|
||||
|
||||
pub fn observation_space(&self) -> &[usize] {
|
||||
&self.observation_space
|
||||
}
|
||||
}
|
@ -1,40 +0,0 @@
|
||||
# candle-replit-code: code completion specialized model.
|
||||
|
||||
[replit-code-v1_5-3b](https://huggingface.co/replit/replit-code-v1_5-3b) is a
|
||||
language model specialized for code completion. This model uses 3.3B parameters
|
||||
in `bfloat16` (so the GPU version will only work on recent nvidia cards).
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
cargo run --example replit-code --release -- --prompt 'def fibonacci(n): '
|
||||
```
|
||||
This produces the following output.
|
||||
|
||||
```
|
||||
def fibonacci(n): # write Fibonacci series up to n
|
||||
"""Print a Fibonacci series up to n."""
|
||||
a, b = 0, 1
|
||||
while a < n:
|
||||
print(a, end=' ')
|
||||
a, b = b, a+b
|
||||
print()
|
||||
|
||||
|
||||
def fibonacci_loop(n): # write Fibonacci series up to n
|
||||
"""Print a Fibonacci series up to n."""
|
||||
result = []
|
||||
a, b = 0, 1
|
||||
while a < n:
|
||||
result.append(a)
|
||||
a, b = b, a+b
|
||||
return result
|
||||
|
||||
|
||||
def fibonacci_generator(n): # write Fibonacci series up to n
|
||||
"""Print a Fibonacci series up to n."""
|
||||
a, b = 0, 1
|
||||
while a < n:
|
||||
yield a
|
||||
a, b = b, a+b
|
||||
```
|
@ -1,265 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::mpt::{Config, Model as M};
|
||||
use candle_transformers::models::quantized_mpt::Model as Q;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
M(M),
|
||||
Q(Q),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::M(model) => model.forward(xs),
|
||||
Self::Q(model) => model.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
verbose_prompt,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||
if tokens.is_empty() {
|
||||
anyhow::bail!("Empty prompts are not supported in the phi model.")
|
||||
}
|
||||
if self.verbose_prompt {
|
||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
println!("{id:7} -> '{token}'");
|
||||
}
|
||||
}
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||
Some(token) => *token,
|
||||
None => anyhow::bail!("cannot find the endoftext token"),
|
||||
};
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 1000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => "lmz/candle-replit-code".to_string(),
|
||||
};
|
||||
let revision = match args.revision {
|
||||
Some(rev) => rev.to_string(),
|
||||
None => "main".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filename = match args.weight_file {
|
||||
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
||||
None => {
|
||||
if args.quantized {
|
||||
repo.get("model-replit-code-v1_5-q4k.gguf")?
|
||||
} else {
|
||||
repo.get("model.safetensors")?
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::replit_code_v1_5_3b();
|
||||
let (model, device) = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
||||
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
|
||||
(model, Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||
let model = Model::M(M::new(&config, vb.pp("transformer"))?);
|
||||
(model, device)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
args.verbose_prompt,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,19 +0,0 @@
|
||||
# candle-resnet
|
||||
|
||||
A candle implementation of inference using a pre-trained [ResNet](https://arxiv.org/abs/1512.03385).
|
||||
This uses a classification head trained on the ImageNet dataset and returns the
|
||||
probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example resnet --release -- --image tiger.jpg
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
tiger, Panthera tigris : 90.21%
|
||||
tiger cat : 8.93%
|
||||
lion, king of beasts, Panthera leo: 0.35%
|
||||
leopard, Panthera pardus: 0.16%
|
||||
jaguar, panther, Panthera onca, Felis onca: 0.09%
|
||||
```
|
@ -1,12 +0,0 @@
|
||||
# This script exports pre-trained model weights in the safetensors format.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from safetensors import torch as stt
|
||||
|
||||
m = torchvision.models.resnet50(pretrained=True)
|
||||
stt.save_file(m.state_dict(), 'resnet50.safetensors')
|
||||
m = torchvision.models.resnet101(pretrained=True)
|
||||
stt.save_file(m.state_dict(), 'resnet101.safetensors')
|
||||
m = torchvision.models.resnet152(pretrained=True)
|
||||
stt.save_file(m.state_dict(), 'resnet152.safetensors')
|
@ -1,90 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::resnet;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "18")]
|
||||
Resnet18,
|
||||
#[value(name = "34")]
|
||||
Resnet34,
|
||||
#[value(name = "50")]
|
||||
Resnet50,
|
||||
#[value(name = "101")]
|
||||
Resnet101,
|
||||
#[value(name = "152")]
|
||||
Resnet152,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Variant of the model to use.
|
||||
#[arg(value_enum, long, default_value_t = Which::Resnet18)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-resnet".into());
|
||||
let filename = match args.which {
|
||||
Which::Resnet18 => "resnet18.safetensors",
|
||||
Which::Resnet34 => "resnet34.safetensors",
|
||||
Which::Resnet50 => "resnet50.safetensors",
|
||||
Which::Resnet101 => "resnet101.safetensors",
|
||||
Which::Resnet152 => "resnet152.safetensors",
|
||||
};
|
||||
api.get(filename)?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let class_count = candle_examples::imagenet::CLASS_COUNT as usize;
|
||||
let model = match args.which {
|
||||
Which::Resnet18 => resnet::resnet18(class_count, vb)?,
|
||||
Which::Resnet34 => resnet::resnet34(class_count, vb)?,
|
||||
Which::Resnet50 => resnet::resnet50(class_count, vb)?,
|
||||
Which::Resnet101 => resnet::resnet101(class_count, vb)?,
|
||||
Which::Resnet152 => resnet::resnet152(class_count, vb)?,
|
||||
};
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -16,29 +16,25 @@ based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
|
||||
cargo run --example segment-anything --release -- \
|
||||
--image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
--use-tiny
|
||||
--point 0.6,0.6 --point 0.6,0.55
|
||||
--point-x 0.4
|
||||
--point-y 0.3
|
||||
```
|
||||
|
||||
Running this command generates a `sam_merged.jpg` file containing the original
|
||||
image with a blue overlay of the selected mask. The red dots represent the prompt
|
||||
specified by `--point 0.6,0.6 --point 0.6,0.55`, this prompt is assumed to be part
|
||||
image with a blue overlay of the selected mask. The red dot represents the prompt
|
||||
specified by `--point-x 0.4 --point-y 0.3`, this prompt is assumed to be part
|
||||
of the target mask.
|
||||
|
||||
The values used for `--point` should be a comma delimited pair of float values.
|
||||
They are proportional to the image dimension, i.e. use 0.5 for the image center.
|
||||
The values used for `--point-x` and `--point-y` should be between 0 and 1 and
|
||||
are proportional to the image dimension, i.e. use 0.5 for the image center.
|
||||
|
||||
Original image:
|
||||

|
||||
|
||||
Segment results by prompting with a single point `--point 0.6,0.55`:
|
||||

|
||||
|
||||
Segment results by prompting with multiple points `--point 0.6,0.6 --point 0.6,0.55`:
|
||||

|
||||

|
||||
|
||||
### Command-line flags
|
||||
- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
|
||||
one.
|
||||
- `--point`: specifies the location of the target points.
|
||||
- `--point-x`, `--point-y`: specifies the location of the target point.
|
||||
- `--threshold`: sets the threshold value to be part of the mask, a negative
|
||||
value results in a larger mask and can be specified via `--threshold=-1.2`.
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 158 KiB |
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user