mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Compare commits
27 Commits
Author | SHA1 | Date | |
---|---|---|---|
7e49e0af96 | |||
181d2299b2 | |||
2801541e5f | |||
4289984d32 | |||
1471f98f0b | |||
dd4a40f1c0 | |||
79845bd93b | |||
6071797450 | |||
b58b247323 | |||
3900091e75 | |||
54355ff997 | |||
e02f1912bb | |||
a52b71686b | |||
7adfb70dff | |||
3ad02147e4 | |||
4f39695465 | |||
4cf4844c9d | |||
d840838e95 | |||
61a070fdd1 | |||
e35669647d | |||
53e8b7ee3e | |||
cc26cce23c | |||
02c2ec2c71 | |||
9a2784b8ab | |||
0f652f0e3d | |||
ddee9dc1dd | |||
fc9bb7784a |
7
.github/dependabot.yml
vendored
7
.github/dependabot.yml
vendored
@ -1,7 +0,0 @@
|
|||||||
version: 2
|
|
||||||
updates:
|
|
||||||
- package-ecosystem: "cargo"
|
|
||||||
directory: "/"
|
|
||||||
schedule:
|
|
||||||
interval: "weekly"
|
|
||||||
open-pull-requests-limit: 5
|
|
4
.github/workflows/ci_cuda.yaml
vendored
4
.github/workflows/ci_cuda.yaml
vendored
@ -8,8 +8,6 @@ jobs:
|
|||||||
start-runner:
|
start-runner:
|
||||||
name: Start self-hosted EC2 runner
|
name: Start self-hosted EC2 runner
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
# Don't run on forks, they won't have access to secrets anyway.
|
|
||||||
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
|
||||||
env:
|
env:
|
||||||
AWS_REGION: us-east-1
|
AWS_REGION: us-east-1
|
||||||
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
||||||
@ -72,7 +70,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
AWS_REGION: us-east-1
|
AWS_REGION: us-east-1
|
||||||
if: ${{ (success() || failure()) && github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }} # required to stop the runner even if the error happened in the previous jobs
|
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
|
||||||
steps:
|
steps:
|
||||||
- name: Configure AWS credentials
|
- name: Configure AWS credentials
|
||||||
uses: aws-actions/configure-aws-credentials@v1
|
uses: aws-actions/configure-aws-credentials@v1
|
||||||
|
@ -63,7 +63,7 @@ This documents the main changes to the `candle` crate.
|
|||||||
[760](https://github.com/huggingface/candle/pull/760).
|
[760](https://github.com/huggingface/candle/pull/760).
|
||||||
- Add the Segment-Anything Model (SAM) as an example
|
- Add the Segment-Anything Model (SAM) as an example
|
||||||
[773](https://github.com/huggingface/candle/pull/773).
|
[773](https://github.com/huggingface/candle/pull/773).
|
||||||
- TinyViT backbone for the segment anything example
|
- TinyViT backbone for the segemnt anything example
|
||||||
[787](https://github.com/huggingface/candle/pull/787).
|
[787](https://github.com/huggingface/candle/pull/787).
|
||||||
- Shape with holes support
|
- Shape with holes support
|
||||||
[770](https://github.com/huggingface/candle/pull/770).
|
[770](https://github.com/huggingface/candle/pull/770).
|
||||||
|
24
Cargo.toml
24
Cargo.toml
@ -19,7 +19,7 @@ exclude = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.3.3"
|
version = "0.3.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -31,18 +31,9 @@ license = "MIT OR Apache-2.0"
|
|||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle = { path = "./candle-core", package = "candle-core" }
|
|
||||||
candle-datasets = { path = "./candle-datasets" }
|
|
||||||
candle-flash-attn = { path = "./candle-flash-attn" }
|
|
||||||
candle-kernels = { path = "./candle-kernels" }
|
|
||||||
candle-metal-kernels = { path = "./candle-metal-kernels" }
|
|
||||||
candle-nn = { path = "./candle-nn" }
|
|
||||||
candle-onnx = { path = "./candle-onnx" }
|
|
||||||
candle-transformers = { path = "./candle-transformers" }
|
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
cudarc = { version = "0.9.14", features = ["f16"] }
|
||||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
|
||||||
hf-hub = "0.3.0"
|
hf-hub = "0.3.0"
|
||||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||||
@ -50,7 +41,7 @@ imageproc = { version = "0.23.0", default-features = false }
|
|||||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||||
libc = { version = "0.2.147" }
|
libc = { version = "0.2.147" }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
|
||||||
num_cpus = "1.15.0"
|
num_cpus = "1.15.0"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
parquet = { version = "45.0.0" }
|
parquet = { version = "45.0.0" }
|
||||||
@ -63,14 +54,17 @@ serde = { version = "1.0.171", features = ["derive"] }
|
|||||||
serde_plain = "1.0.2"
|
serde_plain = "1.0.2"
|
||||||
serde_json = "1.0.99"
|
serde_json = "1.0.99"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tokenizers = { version = "0.15.0", default-features = false }
|
tokenizers = { version = "0.13.4", default-features = false }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-chrome = "0.7.1"
|
tracing-chrome = "0.7.1"
|
||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
metal = { version = "0.27.0", features = ["mps"]}
|
#metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||||
|
metal = { path = "../metal-rs", features = ["mps"] }
|
||||||
|
dispatch = "0.2.0"
|
||||||
|
rustc-hash = "1.1"
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
50
README.md
50
README.md
@ -54,25 +54,19 @@ These online demos run entirely in your browser:
|
|||||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
|
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
|
||||||
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
||||||
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
||||||
- [Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||||
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
||||||
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
|
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
|
||||||
|
|
||||||
We also provide a some command line based examples using state of the art models:
|
We also provide a some command line based examples using state of the art models:
|
||||||
|
|
||||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
|
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
||||||
the SOLAR-10.7B variant.
|
|
||||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
- [Phi-v1 and Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
|
||||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||||
pre-trained on 1T tokens of English and code datasets.
|
pre-trained on 1T tokens of English and code datasets.
|
||||||
- [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal
|
|
||||||
implementation of the Mamba state space model.
|
|
||||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||||
better performance than all publicly available 13b models as of 2023-09-28.
|
performance larger than all publicly available 13b models as of 2023-09-28.
|
||||||
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
|
|
||||||
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
|
||||||
much faster inference.
|
|
||||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||||
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||||
@ -84,7 +78,7 @@ We also provide a some command line based examples using state of the art models
|
|||||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
|
||||||
|
|
||||||
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
||||||
image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions.
|
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
|
||||||
|
|
||||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
|
||||||
|
|
||||||
@ -109,9 +103,6 @@ We also provide a some command line based examples using state of the art models
|
|||||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||||
using self-supervision (can be used for imagenet classification, depth
|
using self-supervision (can be used for imagenet classification, depth
|
||||||
evaluation, segmentation).
|
evaluation, segmentation).
|
||||||
- [VGG](./candle-examples/examples/vgg/),
|
|
||||||
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
|
||||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
|
||||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||||
generate captions for an image.
|
generate captions for an image.
|
||||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||||
@ -131,7 +122,7 @@ There are also some wasm examples for whisper and
|
|||||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
|
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||||
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||||
[Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
|
[Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
|
||||||
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||||
|
|
||||||
For LLaMA2, run the following command to retrieve the weight files and start a
|
For LLaMA2, run the following command to retrieve the weight files and start a
|
||||||
@ -148,20 +139,17 @@ And then head over to
|
|||||||
<!--- ANCHOR: useful_libraries --->
|
<!--- ANCHOR: useful_libraries --->
|
||||||
|
|
||||||
## Useful External Resources
|
## Useful External Resources
|
||||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
|
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
|
||||||
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
||||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and
|
- [`optimisers`](https://github.com/KGrewal1/optimisers): a collection of optimisers
|
||||||
ergonomic LoRA implementation for Candle. `candle-lora` has
|
|
||||||
out-of-the-box LoRA support for many models from Candle, which can be found
|
|
||||||
[here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
|
|
||||||
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
|
|
||||||
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
||||||
|
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
|
||||||
|
that conforms to the official `peft` implementation.
|
||||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||||
serving local LLMs including an OpenAI compatible API server.
|
serving local LLMs including an OpenAI compatible API server.
|
||||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
- [`candle-ext`](https://github.com/mokeyish/candle-ext): an extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
- [`kalosm`](https://github.com/floneum/floneum/tree/master/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||||
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
|
||||||
|
|
||||||
If you have an addition to this list, please submit a pull request.
|
If you have an addition to this list, please submit a pull request.
|
||||||
|
|
||||||
@ -180,23 +168,15 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- WASM support, run your models in a browser.
|
- WASM support, run your models in a browser.
|
||||||
- Included models.
|
- Included models.
|
||||||
- Language Models.
|
- Language Models.
|
||||||
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
|
- LLaMA v1 and v2.
|
||||||
- Falcon.
|
- Falcon.
|
||||||
- StarCoder.
|
- StarCoder.
|
||||||
- Phi 1, 1.5, and 2.
|
- Phi v1.5.
|
||||||
- Minimal Mamba
|
|
||||||
- Mistral 7b v0.1.
|
- Mistral 7b v0.1.
|
||||||
- Mixtral 8x7b v0.1.
|
|
||||||
- StableLM-3B-4E1T.
|
- StableLM-3B-4E1T.
|
||||||
- Replit-code-v1.5-3B.
|
- Replit-code-v1.5-3B.
|
||||||
- Bert.
|
- Bert.
|
||||||
- Yi-6B and Yi-34B.
|
- Yi-6B and Yi-34B.
|
||||||
- Quantized LLMs.
|
|
||||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
|
||||||
- Mistral 7b, and 7b instruct.
|
|
||||||
- Mixtral 8x7b.
|
|
||||||
- Zephyr 7b a and b (Mistral-7b based).
|
|
||||||
- OpenChat 3.5 (Mistral-7b based).
|
|
||||||
- Text to text.
|
- Text to text.
|
||||||
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||||
- Marian MT (Machine Translation).
|
- Marian MT (Machine Translation).
|
||||||
@ -207,7 +187,7 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- Image to text.
|
- Image to text.
|
||||||
- BLIP.
|
- BLIP.
|
||||||
- Computer Vision Models.
|
- Computer Vision Models.
|
||||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG.
|
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
||||||
- yolo-v3, yolo-v8.
|
- yolo-v3, yolo-v8.
|
||||||
- Segment-Anything Model (SAM).
|
- Segment-Anything Model (SAM).
|
||||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { workspace = true }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-datasets = { workspace = true }
|
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||||
candle-nn = { workspace = true }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
candle-transformers = { workspace = true }
|
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||||
candle-flash-attn = { workspace = true, optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
|
@ -28,7 +28,6 @@ let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap()
|
|||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
#[test]
|
#[test]
|
||||||
fn book_hub_2() {
|
fn book_hub_2() {
|
||||||
{
|
|
||||||
// ANCHOR: book_hub_2
|
// ANCHOR: book_hub_2
|
||||||
use candle::Device;
|
use candle::Device;
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
@ -46,10 +45,9 @@ let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap()
|
|||||||
assert_eq!(weights.len(), 206);
|
assert_eq!(weights.len(), 206);
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
// #[test]
|
#[test]
|
||||||
// fn book_hub_3() {
|
fn book_hub_3() {
|
||||||
{
|
|
||||||
// ANCHOR: book_hub_3
|
// ANCHOR: book_hub_3
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
@ -104,7 +102,6 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
|
|||||||
assert_eq!(view.shape(), &[768, 768]);
|
assert_eq!(view.shape(), &[768, 768]);
|
||||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -12,8 +12,8 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle-kernels = { workspace = true, optional = true }
|
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||||
candle-metal-kernels = { workspace = true, optional = true }
|
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||||
metal = { workspace = true, optional = true}
|
metal = { workspace = true, optional = true}
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
gemm = { workspace = true }
|
gemm = { workspace = true }
|
||||||
@ -30,12 +30,12 @@ safetensors = { workspace = true }
|
|||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
yoke = { workspace = true }
|
yoke = { workspace = true }
|
||||||
zip = { workspace = true }
|
zip = { workspace = true }
|
||||||
|
dispatch = { workspace = true, optional = true }
|
||||||
|
rustc-hash = { workspace = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
clap = { workspace = true }
|
clap = { workspace = true }
|
||||||
criterion = { workspace = true }
|
|
||||||
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
@ -43,8 +43,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
|||||||
cudnn = ["cuda", "cudarc/cudnn"]
|
cudnn = ["cuda", "cudarc/cudnn"]
|
||||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"]
|
||||||
|
|
||||||
[[bench]]
|
|
||||||
name = "bench_main"
|
|
||||||
harness = false
|
|
||||||
|
@ -1,8 +0,0 @@
|
|||||||
mod benchmarks;
|
|
||||||
|
|
||||||
use criterion::criterion_main;
|
|
||||||
criterion_main!(
|
|
||||||
benchmarks::matmul::benches,
|
|
||||||
benchmarks::affine::benches,
|
|
||||||
benchmarks::where_cond::benches
|
|
||||||
);
|
|
@ -1,43 +0,0 @@
|
|||||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
|
||||||
use candle_core::{DType, Device, Tensor};
|
|
||||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn run(a: &Tensor) {
|
|
||||||
a.affine(12.34, 56.78).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
|
||||||
let b = 1;
|
|
||||||
let m = 1024;
|
|
||||||
let k = 1024;
|
|
||||||
|
|
||||||
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
|
||||||
|
|
||||||
let flops = b * m * k * dtype.size_in_bytes();
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name(name));
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |b| {
|
|
||||||
b.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
run(black_box(&tensor));
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
|
||||||
let handler = BenchDeviceHandler::new().unwrap();
|
|
||||||
for device in handler.devices {
|
|
||||||
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
|
|
||||||
run_affine_benchmark(c, &device, DType::F16, "affine_f16");
|
|
||||||
run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
|
@ -1,44 +0,0 @@
|
|||||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
|
||||||
use candle_core::{DType, Device, Tensor};
|
|
||||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn run(a: &Tensor, b: &Tensor) {
|
|
||||||
a.matmul(&b.t().unwrap()).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_bench(c: &mut Criterion, device: &Device) {
|
|
||||||
let b = 1;
|
|
||||||
let m = 1;
|
|
||||||
let n = 2048;
|
|
||||||
let k = 2048;
|
|
||||||
|
|
||||||
let dtype = DType::F32;
|
|
||||||
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
|
||||||
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
|
|
||||||
|
|
||||||
let flops = b * m * n * k;
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name("matmul"));
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |b| {
|
|
||||||
b.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
run(black_box(&lhs), black_box(&rhs));
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
|
||||||
let handler = BenchDeviceHandler::new().unwrap();
|
|
||||||
for device in handler.devices {
|
|
||||||
run_bench(c, &device);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
|
@ -1,65 +0,0 @@
|
|||||||
pub(crate) mod affine;
|
|
||||||
pub(crate) mod matmul;
|
|
||||||
pub(crate) mod where_cond;
|
|
||||||
|
|
||||||
use candle_core::{Device, Result};
|
|
||||||
|
|
||||||
pub(crate) trait BenchDevice {
|
|
||||||
fn sync(&self) -> Result<()>;
|
|
||||||
|
|
||||||
fn bench_name<S: Into<String>>(&self, name: S) -> String;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BenchDevice for Device {
|
|
||||||
fn sync(&self) -> Result<()> {
|
|
||||||
match self {
|
|
||||||
Device::Cpu => Ok(()),
|
|
||||||
Device::Cuda(device) => {
|
|
||||||
#[cfg(feature = "cuda")]
|
|
||||||
return Ok(device.synchronize()?);
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
|
||||||
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
|
||||||
}
|
|
||||||
Device::Metal(device) => {
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
return Ok(device.wait_until_completed()?);
|
|
||||||
#[cfg(not(feature = "metal"))]
|
|
||||||
panic!("Metal device without metal feature enabled: {:?}", device)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn bench_name<S: Into<String>>(&self, name: S) -> String {
|
|
||||||
match self {
|
|
||||||
Device::Cpu => {
|
|
||||||
let cpu_type = if cfg!(feature = "accelerate") {
|
|
||||||
"accelerate"
|
|
||||||
} else if cfg!(feature = "mkl") {
|
|
||||||
"mkl"
|
|
||||||
} else {
|
|
||||||
"cpu"
|
|
||||||
};
|
|
||||||
format!("{}_{}", cpu_type, name.into())
|
|
||||||
}
|
|
||||||
Device::Cuda(_) => format!("cuda_{}", name.into()),
|
|
||||||
Device::Metal(_) => format!("metal_{}", name.into()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct BenchDeviceHandler {
|
|
||||||
devices: Vec<Device>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BenchDeviceHandler {
|
|
||||||
pub fn new() -> Result<Self> {
|
|
||||||
let mut devices = Vec::new();
|
|
||||||
if cfg!(feature = "metal") {
|
|
||||||
devices.push(Device::new_metal(0)?);
|
|
||||||
} else if cfg!(feature = "cuda") {
|
|
||||||
devices.push(Device::new_cuda(0)?);
|
|
||||||
}
|
|
||||||
devices.push(Device::Cpu);
|
|
||||||
Ok(Self { devices })
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,64 +0,0 @@
|
|||||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
|
||||||
use candle_core::{DType, Device, Tensor};
|
|
||||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
|
|
||||||
a.where_cond(b, c).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
const fn create_cond_arr<const N: usize>() -> [u8; N] {
|
|
||||||
let mut arr = [0u8; N];
|
|
||||||
let mut i = 0;
|
|
||||||
while i < N {
|
|
||||||
arr[i] = (i % 2) as u8;
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
arr
|
|
||||||
}
|
|
||||||
|
|
||||||
const B: usize = 1;
|
|
||||||
const M: usize = 1024;
|
|
||||||
const K: usize = 1024;
|
|
||||||
const SIZE: usize = B * M * K;
|
|
||||||
|
|
||||||
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
|
||||||
|
|
||||||
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
|
||||||
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
|
|
||||||
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
|
|
||||||
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
|
|
||||||
|
|
||||||
let elements = B * M * K;
|
|
||||||
// E.g. 2 f32 tensors + 1 u8 tensor
|
|
||||||
let flops = (2 * elements * dtype.size_in_bytes()) + elements;
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name(name));
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |b| {
|
|
||||||
b.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
run(
|
|
||||||
black_box(&tensor),
|
|
||||||
black_box(&on_true),
|
|
||||||
black_box(&on_false),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
|
||||||
let device = BenchDeviceHandler::new().unwrap();
|
|
||||||
for d in device.devices {
|
|
||||||
run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
|
|
||||||
run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
|
|
||||||
run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
|
@ -8,10 +8,11 @@ use anyhow::Result;
|
|||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
|
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||||
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
|
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||||
let new_a = a.slice_scatter(&b, 1, 2)?;
|
let start = std::time::Instant::now();
|
||||||
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
|
||||||
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
println!("{:?}", start.elapsed());
|
||||||
|
println!("{res:?}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
|
use candle_core::quantized::{gguf_file, k_quants, QTensor};
|
||||||
use candle_core::{Device, Result};
|
use candle_core::{Device, Result, Tensor};
|
||||||
use clap::{Parser, Subcommand, ValueEnum};
|
use clap::{Parser, Subcommand, ValueEnum};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
@ -11,7 +11,12 @@ enum QuantizationMode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl QuantizationMode {
|
impl QuantizationMode {
|
||||||
fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result<QTensor> {
|
fn quantize(
|
||||||
|
&self,
|
||||||
|
name: &str,
|
||||||
|
tensor: QTensor,
|
||||||
|
default: fn(&Tensor) -> Result<QTensor>,
|
||||||
|
) -> Result<QTensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::Llama => {
|
Self::Llama => {
|
||||||
// Same behavior as the llama.cpp quantization.
|
// Same behavior as the llama.cpp quantization.
|
||||||
@ -19,9 +24,9 @@ impl QuantizationMode {
|
|||||||
if should_quantize {
|
if should_quantize {
|
||||||
let tensor = tensor.dequantize(&Device::Cpu)?;
|
let tensor = tensor.dequantize(&Device::Cpu)?;
|
||||||
if name == "output.weight" {
|
if name == "output.weight" {
|
||||||
QTensor::quantize(&tensor, GgmlDType::Q6K)
|
QTensor::quantize::<k_quants::BlockQ6K>(&tensor)
|
||||||
} else {
|
} else {
|
||||||
QTensor::quantize(&tensor, dtype)
|
default(&tensor)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Ok(tensor)
|
Ok(tensor)
|
||||||
@ -55,27 +60,6 @@ enum Quantization {
|
|||||||
F32,
|
F32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Quantization {
|
|
||||||
fn dtype(&self) -> GgmlDType {
|
|
||||||
match self {
|
|
||||||
Quantization::Q4_0 => GgmlDType::Q4_0,
|
|
||||||
Quantization::Q4_1 => GgmlDType::Q4_1,
|
|
||||||
Quantization::Q5_0 => GgmlDType::Q5_0,
|
|
||||||
Quantization::Q5_1 => GgmlDType::Q5_1,
|
|
||||||
Quantization::Q8_0 => GgmlDType::Q8_0,
|
|
||||||
Quantization::Q8_1 => GgmlDType::Q8_1,
|
|
||||||
Quantization::Q2k => GgmlDType::Q2K,
|
|
||||||
Quantization::Q3k => GgmlDType::Q3K,
|
|
||||||
Quantization::Q4k => GgmlDType::Q4K,
|
|
||||||
Quantization::Q5k => GgmlDType::Q5K,
|
|
||||||
Quantization::Q6k => GgmlDType::Q6K,
|
|
||||||
Quantization::Q8k => GgmlDType::Q8K,
|
|
||||||
Quantization::F16 => GgmlDType::F16,
|
|
||||||
Quantization::F32 => GgmlDType::F32,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(ValueEnum, Debug, Clone)]
|
#[derive(ValueEnum, Debug, Clone)]
|
||||||
enum Format {
|
enum Format {
|
||||||
Safetensors,
|
Safetensors,
|
||||||
@ -118,7 +102,7 @@ enum Command {
|
|||||||
},
|
},
|
||||||
|
|
||||||
Quantize {
|
Quantize {
|
||||||
/// The input file(s), in safetensors format.
|
/// The input file, in gguf format.
|
||||||
in_file: Vec<std::path::PathBuf>,
|
in_file: Vec<std::path::PathBuf>,
|
||||||
|
|
||||||
/// The output file, in gguf format.
|
/// The output file, in gguf format.
|
||||||
@ -133,15 +117,6 @@ enum Command {
|
|||||||
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
|
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
|
||||||
mode: QuantizationMode,
|
mode: QuantizationMode,
|
||||||
},
|
},
|
||||||
|
|
||||||
Dequantize {
|
|
||||||
/// The input file, in gguf format.
|
|
||||||
in_file: std::path::PathBuf,
|
|
||||||
|
|
||||||
/// The output file, in safetensors format.
|
|
||||||
#[arg(long)]
|
|
||||||
out_file: std::path::PathBuf,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
@ -150,12 +125,7 @@ struct Args {
|
|||||||
command: Command,
|
command: Command,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_ls(
|
fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
|
||||||
file: &std::path::PathBuf,
|
|
||||||
format: Option<Format>,
|
|
||||||
verbose: bool,
|
|
||||||
device: &Device,
|
|
||||||
) -> Result<()> {
|
|
||||||
let format = match format {
|
let format = match format {
|
||||||
Some(format) => format,
|
Some(format) => format,
|
||||||
None => match Format::infer(file) {
|
None => match Format::infer(file) {
|
||||||
@ -221,7 +191,7 @@ fn run_ls(
|
|||||||
}
|
}
|
||||||
Format::Ggml => {
|
Format::Ggml => {
|
||||||
let mut file = std::fs::File::open(file)?;
|
let mut file = std::fs::File::open(file)?;
|
||||||
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
|
let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
|
||||||
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
||||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||||
for (name, qtensor) in tensors.iter() {
|
for (name, qtensor) in tensors.iter() {
|
||||||
@ -262,8 +232,37 @@ fn run_quantize_safetensors(
|
|||||||
}
|
}
|
||||||
println!("tensors: {}", tensors.len());
|
println!("tensors: {}", tensors.len());
|
||||||
|
|
||||||
let dtype = q.dtype();
|
let quantize_fn = match q {
|
||||||
let block_size = dtype.block_size();
|
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||||
|
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||||
|
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||||
|
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||||
|
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||||
|
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||||
|
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||||
|
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||||
|
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||||
|
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||||
|
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||||
|
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||||
|
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||||
|
Quantization::F32 => QTensor::quantize::<f32>,
|
||||||
|
};
|
||||||
|
let block_size = match q {
|
||||||
|
Quantization::Q4_0 => k_quants::QK4_0,
|
||||||
|
Quantization::Q4_1 => k_quants::QK4_1,
|
||||||
|
Quantization::Q5_0 => k_quants::QK5_0,
|
||||||
|
Quantization::Q5_1 => k_quants::QK5_1,
|
||||||
|
Quantization::Q8_0 => k_quants::QK8_0,
|
||||||
|
Quantization::Q8_1 => k_quants::QK8_1,
|
||||||
|
Quantization::Q2k
|
||||||
|
| Quantization::Q3k
|
||||||
|
| Quantization::Q4k
|
||||||
|
| Quantization::Q5k
|
||||||
|
| Quantization::Q6k
|
||||||
|
| Quantization::Q8k => k_quants::QK_K,
|
||||||
|
Quantization::F16 | Quantization::F32 => 1,
|
||||||
|
};
|
||||||
|
|
||||||
let qtensors = tensors
|
let qtensors = tensors
|
||||||
.into_par_iter()
|
.into_par_iter()
|
||||||
@ -271,9 +270,9 @@ fn run_quantize_safetensors(
|
|||||||
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
|
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
|
||||||
println!(" quantizing {name} {tensor:?} {should_quantize}");
|
println!(" quantizing {name} {tensor:?} {should_quantize}");
|
||||||
let tensor = if should_quantize {
|
let tensor = if should_quantize {
|
||||||
QTensor::quantize(&tensor, dtype)?
|
quantize_fn(&tensor)?
|
||||||
} else {
|
} else {
|
||||||
QTensor::quantize(&tensor, GgmlDType::F32)?
|
QTensor::quantize::<f32>(&tensor)?
|
||||||
};
|
};
|
||||||
Ok((name, tensor))
|
Ok((name, tensor))
|
||||||
})
|
})
|
||||||
@ -286,29 +285,11 @@ fn run_quantize_safetensors(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_dequantize(
|
|
||||||
in_file: std::path::PathBuf,
|
|
||||||
out_file: std::path::PathBuf,
|
|
||||||
device: &Device,
|
|
||||||
) -> Result<()> {
|
|
||||||
let mut in_file = std::fs::File::open(in_file)?;
|
|
||||||
let content = gguf_file::Content::read(&mut in_file)?;
|
|
||||||
let mut tensors = std::collections::HashMap::new();
|
|
||||||
for (tensor_name, _) in content.tensor_infos.iter() {
|
|
||||||
let tensor = content.tensor(&mut in_file, tensor_name, device)?;
|
|
||||||
let tensor = tensor.dequantize(device)?;
|
|
||||||
tensors.insert(tensor_name.to_string(), tensor);
|
|
||||||
}
|
|
||||||
candle_core::safetensors::save(&tensors, out_file)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_quantize(
|
fn run_quantize(
|
||||||
in_files: &[std::path::PathBuf],
|
in_files: &[std::path::PathBuf],
|
||||||
out_file: std::path::PathBuf,
|
out_file: std::path::PathBuf,
|
||||||
q: Quantization,
|
q: Quantization,
|
||||||
qmode: QuantizationMode,
|
qmode: QuantizationMode,
|
||||||
device: &Device,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if in_files.is_empty() {
|
if in_files.is_empty() {
|
||||||
candle_core::bail!("no specified input files")
|
candle_core::bail!("no specified input files")
|
||||||
@ -334,15 +315,31 @@ fn run_quantize(
|
|||||||
let content = gguf_file::Content::read(&mut in_)?;
|
let content = gguf_file::Content::read(&mut in_)?;
|
||||||
println!("tensors: {}", content.tensor_infos.len());
|
println!("tensors: {}", content.tensor_infos.len());
|
||||||
|
|
||||||
let dtype = q.dtype();
|
let quantize_fn = match q {
|
||||||
|
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||||
|
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||||
|
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||||
|
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||||
|
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||||
|
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||||
|
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||||
|
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||||
|
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||||
|
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||||
|
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||||
|
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||||
|
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||||
|
Quantization::F32 => QTensor::quantize::<f32>,
|
||||||
|
};
|
||||||
|
|
||||||
let qtensors = content
|
let qtensors = content
|
||||||
.tensor_infos
|
.tensor_infos
|
||||||
.par_iter()
|
.par_iter()
|
||||||
.map(|(name, _)| {
|
.map(|(name, _)| {
|
||||||
println!(" quantizing {name}");
|
println!(" quantizing {name}");
|
||||||
let mut in_file = std::fs::File::open(&in_files[0])?;
|
let mut in_file = std::fs::File::open(&in_files[0])?;
|
||||||
let tensor = content.tensor(&mut in_file, name, device)?;
|
let tensor = content.tensor(&mut in_file, name)?;
|
||||||
let tensor = qmode.quantize(name, tensor, dtype)?;
|
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
|
||||||
Ok((name, tensor))
|
Ok((name, tensor))
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
@ -362,7 +359,6 @@ fn run_quantize(
|
|||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let device = Device::Cpu;
|
|
||||||
match args.command {
|
match args.command {
|
||||||
Command::Ls {
|
Command::Ls {
|
||||||
files,
|
files,
|
||||||
@ -374,7 +370,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
if multiple_files {
|
if multiple_files {
|
||||||
println!("--- {file:?} ---");
|
println!("--- {file:?} ---");
|
||||||
}
|
}
|
||||||
run_ls(file, format.clone(), verbose, &device)?
|
run_ls(file, format.clone(), verbose)?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Command::Quantize {
|
Command::Quantize {
|
||||||
@ -382,8 +378,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
out_file,
|
out_file,
|
||||||
quantization,
|
quantization,
|
||||||
mode,
|
mode,
|
||||||
} => run_quantize(&in_file, out_file, quantization, mode, &device)?,
|
} => run_quantize(&in_file, out_file, quantization, mode)?,
|
||||||
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?,
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -114,7 +114,7 @@ impl Tensor {
|
|||||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||||
Op::Reshape(node)
|
Op::Reshape(node)
|
||||||
| Op::UpsampleNearest1D(node)
|
| Op::UpsampleNearest1D(node)
|
||||||
| Op::UpsampleNearest2D { arg: node, .. }
|
| Op::UpsampleNearest2D(node)
|
||||||
| Op::AvgPool2D { arg: node, .. }
|
| Op::AvgPool2D { arg: node, .. }
|
||||||
| Op::MaxPool2D { arg: node, .. }
|
| Op::MaxPool2D { arg: node, .. }
|
||||||
| Op::Copy(node)
|
| Op::Copy(node)
|
||||||
@ -350,27 +350,9 @@ impl Tensor {
|
|||||||
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
||||||
op: "upsample-nearest1d",
|
op: "upsample-nearest1d",
|
||||||
})?,
|
})?,
|
||||||
Op::UpsampleNearest2D {
|
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||||
arg,
|
op: "upsample-nearest2d",
|
||||||
target_h,
|
})?,
|
||||||
target_w,
|
|
||||||
} => {
|
|
||||||
let (_n, c, h, w) = arg.dims4()?;
|
|
||||||
if target_h % h != 0 || target_w % w != 0 {
|
|
||||||
crate::bail!("backward not supported for non integer upscaling factors")
|
|
||||||
}
|
|
||||||
let scale_h = target_h / h;
|
|
||||||
let scale_w = target_w / w;
|
|
||||||
|
|
||||||
if scale_h != scale_w {
|
|
||||||
crate::bail!("backward not supported for non uniform upscaling factors")
|
|
||||||
};
|
|
||||||
let kernel =
|
|
||||||
Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
|
|
||||||
let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
|
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
|
||||||
*sum_grad = conv_sum;
|
|
||||||
}
|
|
||||||
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
||||||
|
@ -201,9 +201,10 @@ impl Device {
|
|||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Device::Metal(device) => {
|
Device::Metal(_device) => {
|
||||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||||
Ok(Storage::Metal(storage))
|
// Ok(Storage::Metal(storage))
|
||||||
|
crate::bail!("Metal rand_uniform not implemented")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -64,7 +64,7 @@ impl Tensor {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
/// Generic structure used to index a slice of the tensor
|
/// Generic structure used to index a slice of the tensor
|
||||||
pub enum TensorIndexer {
|
pub enum TensorIndexer {
|
||||||
/// This selects the elements for which an index has some specific value.
|
/// This selects the elemnts for which an index has some specific value.
|
||||||
Select(usize),
|
Select(usize),
|
||||||
/// This is a regular slice, purely indexing a chunk of the tensor
|
/// This is a regular slice, purely indexing a chunk of the tensor
|
||||||
Narrow(Bound<usize>, Bound<usize>),
|
Narrow(Bound<usize>, Bound<usize>),
|
||||||
@ -104,31 +104,37 @@ impl From<&Tensor> for TensorIndexer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
trait RB: RangeBounds<usize> {}
|
macro_rules! impl_from_range {
|
||||||
impl RB for Range<usize> {}
|
($range_type:ty) => {
|
||||||
impl RB for RangeFrom<usize> {}
|
impl From<$range_type> for TensorIndexer {
|
||||||
impl RB for RangeFull {}
|
fn from(range: $range_type) -> Self {
|
||||||
impl RB for RangeInclusive<usize> {}
|
|
||||||
impl RB for RangeTo<usize> {}
|
|
||||||
impl RB for RangeToInclusive<usize> {}
|
|
||||||
|
|
||||||
impl<T: RB> From<T> for TensorIndexer {
|
|
||||||
fn from(range: T) -> Self {
|
|
||||||
use std::ops::Bound::*;
|
use std::ops::Bound::*;
|
||||||
|
|
||||||
let start = match range.start_bound() {
|
let start = match range.start_bound() {
|
||||||
Included(idx) => Included(*idx),
|
Included(idx) => Included(*idx),
|
||||||
Excluded(idx) => Excluded(*idx),
|
Excluded(idx) => Excluded(*idx),
|
||||||
Unbounded => Unbounded,
|
Unbounded => Unbounded,
|
||||||
};
|
};
|
||||||
|
|
||||||
let end = match range.end_bound() {
|
let end = match range.end_bound() {
|
||||||
Included(idx) => Included(*idx),
|
Included(idx) => Included(*idx),
|
||||||
Excluded(idx) => Excluded(*idx),
|
Excluded(idx) => Excluded(*idx),
|
||||||
Unbounded => Unbounded,
|
Unbounded => Unbounded,
|
||||||
};
|
};
|
||||||
|
|
||||||
TensorIndexer::Narrow(start, end)
|
TensorIndexer::Narrow(start, end)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl_from_range!(Range<usize>);
|
||||||
|
impl_from_range!(RangeFrom<usize>);
|
||||||
|
impl_from_range!(RangeFull);
|
||||||
|
impl_from_range!(RangeInclusive<usize>);
|
||||||
|
impl_from_range!(RangeTo<usize>);
|
||||||
|
impl_from_range!(RangeToInclusive<usize>);
|
||||||
|
|
||||||
/// Trait used to implement multiple signatures for ease of use of the slicing
|
/// Trait used to implement multiple signatures for ease of use of the slicing
|
||||||
/// of a tensor
|
/// of a tensor
|
||||||
pub trait IndexOp<T> {
|
pub trait IndexOp<T> {
|
||||||
|
@ -72,7 +72,7 @@ pub mod utils;
|
|||||||
mod variable;
|
mod variable;
|
||||||
|
|
||||||
pub use cpu_backend::CpuStorage;
|
pub use cpu_backend::CpuStorage;
|
||||||
pub use device::{Device, DeviceLocation, NdArray};
|
pub use device::{Device, DeviceLocation};
|
||||||
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
pub use indexer::IndexOp;
|
pub use indexer::IndexOp;
|
||||||
@ -123,6 +123,12 @@ pub trait Module {
|
|||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Module for quantized::QMatMul {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
self.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
self(xs)
|
self(xs)
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -132,11 +132,7 @@ pub enum Op {
|
|||||||
},
|
},
|
||||||
|
|
||||||
UpsampleNearest1D(Tensor),
|
UpsampleNearest1D(Tensor),
|
||||||
UpsampleNearest2D {
|
UpsampleNearest2D(Tensor),
|
||||||
arg: Tensor,
|
|
||||||
target_h: usize,
|
|
||||||
target_w: usize,
|
|
||||||
},
|
|
||||||
|
|
||||||
Cat(Vec<Tensor>, usize),
|
Cat(Vec<Tensor>, usize),
|
||||||
|
|
||||||
|
@ -703,7 +703,6 @@ impl PthTensors {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||||
use std::io::Read;
|
|
||||||
let tensor_info = match self.tensor_infos.get(name) {
|
let tensor_info = match self.tensor_infos.get(name) {
|
||||||
None => return Ok(None),
|
None => return Ok(None),
|
||||||
Some(tensor_info) => tensor_info,
|
Some(tensor_info) => tensor_info,
|
||||||
@ -713,21 +712,14 @@ impl PthTensors {
|
|||||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||||
|
|
||||||
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
|
||||||
// case.
|
// For now only support the basic case.
|
||||||
if !tensor_info.layout.is_contiguous() {
|
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"cannot retrieve non-contiguous tensors {:?}",
|
"cannot retrieve non-contiguous tensors {:?}",
|
||||||
tensor_info.layout
|
tensor_info.layout
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
let start_offset = tensor_info.layout.start_offset();
|
|
||||||
if start_offset > 0 {
|
|
||||||
std::io::copy(
|
|
||||||
&mut reader.by_ref().take(start_offset as u64),
|
|
||||||
&mut std::io::sink(),
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
let tensor = Tensor::from_reader(
|
let tensor = Tensor::from_reader(
|
||||||
tensor_info.layout.shape().clone(),
|
tensor_info.layout.shape().clone(),
|
||||||
tensor_info.dtype,
|
tensor_info.dtype,
|
||||||
|
@ -353,7 +353,7 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
|||||||
q3 = q3.add(32);
|
q3 = q3.add(32);
|
||||||
|
|
||||||
// Prepare low and high bits
|
// Prepare low and high bits
|
||||||
// We hardcode the shifts here to avoid loading them into a separate register
|
// We hardcode the shifts here to avoid loading them into a seperate register
|
||||||
let q3l_0 = _mm256_and_si256(q3bits, m3);
|
let q3l_0 = _mm256_and_si256(q3bits, m3);
|
||||||
let q3h_0 = if j == 0 {
|
let q3h_0 = if j == 0 {
|
||||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
|
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
|
||||||
@ -586,7 +586,7 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
|||||||
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
|
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
|
||||||
q5 = q5.add(32);
|
q5 = q5.add(32);
|
||||||
|
|
||||||
//Similar to q3k we hardcode the shifts here to avoid loading them into a separate register
|
//Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register
|
||||||
let q5l_0 = _mm256_and_si256(q5bits, m4);
|
let q5l_0 = _mm256_and_si256(q5bits, m4);
|
||||||
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
|
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
|
||||||
let q5l_0_right_shift = match j {
|
let q5l_0_right_shift = match j {
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
//! Support for the GGML file format.
|
//! Support for the GGML file format.
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
use super::{k_quants, GgmlDType};
|
||||||
use super::metal::load_quantized_metal;
|
use crate::Result;
|
||||||
use super::{k_quants, GgmlDType, QStorage};
|
|
||||||
use crate::{Device, Result};
|
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -123,22 +121,11 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
|||||||
raw_data: &[u8],
|
raw_data: &[u8],
|
||||||
size_in_bytes: usize,
|
size_in_bytes: usize,
|
||||||
dims: Vec<usize>,
|
dims: Vec<usize>,
|
||||||
device: &Device,
|
|
||||||
) -> Result<super::QTensor> {
|
) -> Result<super::QTensor> {
|
||||||
let raw_data_ptr = raw_data.as_ptr();
|
let raw_data_ptr = raw_data.as_ptr();
|
||||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||||
let data: QStorage = match device {
|
super::QTensor::new(data.to_vec(), dims)
|
||||||
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
Device::Metal(metal) => load_quantized_metal(metal, data)?,
|
|
||||||
#[cfg(not(feature = "metal"))]
|
|
||||||
Device::Metal(_metal) => {
|
|
||||||
crate::bail!("Metal backend requires `metal` feature")
|
|
||||||
}
|
|
||||||
device => unimplemented!("Implement quantized tensor for device {device:?}"),
|
|
||||||
};
|
|
||||||
super::QTensor::new(data, dims)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a [Tensor] from a raw GGML tensor.
|
/// Creates a [Tensor] from a raw GGML tensor.
|
||||||
@ -146,50 +133,29 @@ pub fn qtensor_from_ggml(
|
|||||||
ggml_dtype: GgmlDType,
|
ggml_dtype: GgmlDType,
|
||||||
raw_data: &[u8],
|
raw_data: &[u8],
|
||||||
dims: Vec<usize>,
|
dims: Vec<usize>,
|
||||||
device: &Device,
|
|
||||||
) -> Result<super::QTensor> {
|
) -> Result<super::QTensor> {
|
||||||
let tensor_elems = dims.iter().product::<usize>();
|
let tensor_elems = dims.iter().product::<usize>();
|
||||||
let block_size = ggml_dtype.block_size();
|
let blck_size = ggml_dtype.blck_size();
|
||||||
if tensor_elems % block_size != 0 {
|
if tensor_elems % blck_size != 0 {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
|
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size();
|
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
|
||||||
|
|
||||||
match ggml_dtype {
|
match ggml_dtype {
|
||||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
|
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
|
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
|
||||||
GgmlDType::Q4_0 => {
|
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
|
||||||
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
|
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
|
||||||
}
|
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
|
||||||
GgmlDType::Q4_1 => {
|
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
|
||||||
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
|
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
|
||||||
}
|
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
|
||||||
GgmlDType::Q5_0 => {
|
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
|
||||||
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
|
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
|
||||||
}
|
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
|
||||||
GgmlDType::Q5_1 => {
|
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
|
||||||
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
|
|
||||||
}
|
|
||||||
GgmlDType::Q8_0 => {
|
|
||||||
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
|
|
||||||
}
|
|
||||||
GgmlDType::Q2K => {
|
|
||||||
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
|
|
||||||
}
|
|
||||||
GgmlDType::Q3K => {
|
|
||||||
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
|
|
||||||
}
|
|
||||||
GgmlDType::Q4K => {
|
|
||||||
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
|
|
||||||
}
|
|
||||||
GgmlDType::Q5K => {
|
|
||||||
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
|
|
||||||
}
|
|
||||||
GgmlDType::Q6K => {
|
|
||||||
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
|
|
||||||
}
|
|
||||||
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -197,7 +163,6 @@ pub fn qtensor_from_ggml(
|
|||||||
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
magic: VersionedMagic,
|
magic: VersionedMagic,
|
||||||
device: &Device,
|
|
||||||
) -> Result<(String, super::QTensor)> {
|
) -> Result<(String, super::QTensor)> {
|
||||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||||
@ -218,11 +183,11 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
|||||||
}
|
}
|
||||||
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
||||||
let tensor_elems = dims.iter().product::<usize>();
|
let tensor_elems = dims.iter().product::<usize>();
|
||||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size();
|
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||||
// TODO: Mmap version to avoid copying the data around?
|
// TODO: Mmap version to avoid copying the data around?
|
||||||
let mut raw_data = vec![0u8; size_in_bytes];
|
let mut raw_data = vec![0u8; size_in_bytes];
|
||||||
reader.read_exact(&mut raw_data)?;
|
reader.read_exact(&mut raw_data)?;
|
||||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
|
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
|
||||||
Ok(tensor) => Ok((name, tensor)),
|
Ok(tensor) => Ok((name, tensor)),
|
||||||
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
||||||
}
|
}
|
||||||
@ -236,10 +201,7 @@ pub struct Content {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Content {
|
impl Content {
|
||||||
pub fn read<R: std::io::Seek + std::io::Read>(
|
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
|
||||||
reader: &mut R,
|
|
||||||
device: &Device,
|
|
||||||
) -> Result<Content> {
|
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||||
@ -249,7 +211,7 @@ impl Content {
|
|||||||
let mut tensors = HashMap::new();
|
let mut tensors = HashMap::new();
|
||||||
|
|
||||||
while reader.stream_position()? != last_position {
|
while reader.stream_position()? != last_position {
|
||||||
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
let (name, tensor) = read_one_tensor(reader, magic)?;
|
||||||
tensors.insert(name, tensor);
|
tensors.insert(name, tensor);
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||||
|
|
||||||
use super::{GgmlDType, QTensor};
|
use super::{GgmlDType, QTensor};
|
||||||
use crate::{Device, Result};
|
use crate::Result;
|
||||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ impl VersionedMagic {
|
|||||||
(Magic::Gguf, 1) => Self::GgufV1,
|
(Magic::Gguf, 1) => Self::GgufV1,
|
||||||
(Magic::Gguf, 2) => Self::GgufV2,
|
(Magic::Gguf, 2) => Self::GgufV2,
|
||||||
(Magic::Gguf, 3) => Self::GgufV3,
|
(Magic::Gguf, 3) => Self::GgufV3,
|
||||||
_ => crate::bail!("gguf: unsupported magic/version {magic:?}/{version}"),
|
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
||||||
};
|
};
|
||||||
Ok(versioned_magic)
|
Ok(versioned_magic)
|
||||||
}
|
}
|
||||||
@ -59,25 +59,19 @@ impl TensorInfo {
|
|||||||
&self,
|
&self,
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
tensor_data_offset: u64,
|
tensor_data_offset: u64,
|
||||||
device: &Device,
|
|
||||||
) -> Result<QTensor> {
|
) -> Result<QTensor> {
|
||||||
let tensor_elems = self.shape.elem_count();
|
let tensor_elems = self.shape.elem_count();
|
||||||
let block_size = self.ggml_dtype.block_size();
|
let blck_size = self.ggml_dtype.blck_size();
|
||||||
if tensor_elems % block_size != 0 {
|
if tensor_elems % blck_size != 0 {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
|
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size();
|
let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
|
||||||
let mut raw_data = vec![0u8; size_in_bytes];
|
let mut raw_data = vec![0u8; size_in_bytes];
|
||||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||||
reader.read_exact(&mut raw_data)?;
|
reader.read_exact(&mut raw_data)?;
|
||||||
super::ggml_file::qtensor_from_ggml(
|
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
|
||||||
self.ggml_dtype,
|
|
||||||
&raw_data,
|
|
||||||
self.shape.dims().to_vec(),
|
|
||||||
device,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -466,13 +460,12 @@ impl Content {
|
|||||||
&self,
|
&self,
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
name: &str,
|
name: &str,
|
||||||
device: &Device,
|
|
||||||
) -> Result<QTensor> {
|
) -> Result<QTensor> {
|
||||||
let tensor_info = match self.tensor_infos.get(name) {
|
let tensor_info = match self.tensor_infos.get(name) {
|
||||||
Some(tensor_info) => tensor_info,
|
Some(tensor_info) => tensor_info,
|
||||||
None => crate::bail!("cannot find tensor info for {name}"),
|
None => crate::bail!("cannot find tensor-infor for {name}"),
|
||||||
};
|
};
|
||||||
tensor_info.read(reader, self.tensor_data_offset, device)
|
tensor_info.read(reader, self.tensor_data_offset)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -524,9 +517,10 @@ pub fn write<W: std::io::Seek + std::io::Write>(
|
|||||||
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
|
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
let data = tensor.data()?;
|
let data_ptr = tensor.as_ptr();
|
||||||
let size_in_bytes = data.len();
|
let size_in_bytes = tensor.storage_size_in_bytes();
|
||||||
w.write_all(&data)?;
|
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||||
|
w.write_all(data)?;
|
||||||
let padding = 31 - (31 + size_in_bytes) % 32;
|
let padding = 31 - (31 + size_in_bytes) % 32;
|
||||||
w.write_all(&vec![0u8; padding])?;
|
w.write_all(&vec![0u8; padding])?;
|
||||||
}
|
}
|
||||||
|
@ -1545,13 +1545,13 @@ impl GgmlType for BlockQ5K {
|
|||||||
let d2 = d * sc as f32;
|
let d2 = d * sc as f32;
|
||||||
let m2 = min * m as f32;
|
let m2 = min * m as f32;
|
||||||
for (ql, qh) in ql.iter().zip(qh) {
|
for (ql, qh) in ql.iter().zip(qh) {
|
||||||
let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 };
|
let to_add = if qh & u1 != 0 { 16 } else { 1 };
|
||||||
y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1;
|
y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1;
|
||||||
ys_index += 1;
|
ys_index += 1;
|
||||||
}
|
}
|
||||||
for (ql, qh) in ql.iter().zip(qh) {
|
for (ql, qh) in ql.iter().zip(qh) {
|
||||||
let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 };
|
let to_add = if qh & u2 != 0 { 16 } else { 1 };
|
||||||
y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2;
|
y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2;
|
||||||
ys_index += 1;
|
ys_index += 1;
|
||||||
}
|
}
|
||||||
is += 2;
|
is += 2;
|
||||||
|
@ -1,153 +0,0 @@
|
|||||||
use super::{GgmlDType, QStorage};
|
|
||||||
use crate::{DType, MetalDevice, MetalStorage, Result};
|
|
||||||
use metal::Buffer;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
pub struct QMetalStorage {
|
|
||||||
dtype: GgmlDType,
|
|
||||||
device: MetalDevice,
|
|
||||||
buffer: Arc<Buffer>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QMetalStorage {
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
|
||||||
self.dtype
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn buffer(&self) -> &Buffer {
|
|
||||||
&self.buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: GgmlDType) -> Self {
|
|
||||||
Self {
|
|
||||||
device,
|
|
||||||
buffer,
|
|
||||||
dtype,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
|
||||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
|
||||||
let command_buffer = self.device.command_buffer()?;
|
|
||||||
command_buffer.set_label("to_cpu");
|
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
|
||||||
blit.set_label("blit_to_cpu");
|
|
||||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
|
||||||
blit.end_encoding();
|
|
||||||
self.device.wait_until_completed()?;
|
|
||||||
let mut out = vec![0.0; elem_count];
|
|
||||||
match self.dtype {
|
|
||||||
GgmlDType::F32 => {
|
|
||||||
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
f32::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::F16 => {
|
|
||||||
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
half::f16::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q4_0 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q4_1 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q5_0 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q5_1 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q8_0 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q8_1 => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q2K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ2K> =
|
|
||||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q3K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ3K> =
|
|
||||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q4K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ4K> =
|
|
||||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q5K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ5K> =
|
|
||||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q6K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ6K> =
|
|
||||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
GgmlDType::Q8K => {
|
|
||||||
let vec: Vec<crate::quantized::BlockQ8K> =
|
|
||||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let buffer = self.device.new_buffer_with_data(&out)?;
|
|
||||||
Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
|
||||||
// Quantization only happens on CPU for now.
|
|
||||||
let src = src.to_cpu::<f32>()?;
|
|
||||||
let elem_count = src.len();
|
|
||||||
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
|
||||||
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;
|
|
||||||
qcpu_storage.quantize(&src)?;
|
|
||||||
let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
|
|
||||||
self.buffer = buffer;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
|
||||||
device: &MetalDevice,
|
|
||||||
data: &[T],
|
|
||||||
) -> Result<QStorage> {
|
|
||||||
let buffer = device.new_buffer_with_data(data)?;
|
|
||||||
let device = device.clone();
|
|
||||||
Ok(QStorage::Metal(QMetalStorage {
|
|
||||||
dtype: T::DTYPE,
|
|
||||||
device,
|
|
||||||
buffer,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
|
||||||
let ptr = buffer.contents() as *const T;
|
|
||||||
assert!(!ptr.is_null());
|
|
||||||
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
|
||||||
slice.to_vec()
|
|
||||||
}
|
|
@ -1,125 +1,23 @@
|
|||||||
#[cfg(feature = "metal")]
|
use crate::{Device, Result, Shape, Tensor};
|
||||||
use crate::{backend::BackendStorage, DType};
|
|
||||||
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
|
||||||
use k_quants::*;
|
|
||||||
use std::borrow::Cow;
|
|
||||||
|
|
||||||
#[cfg(target_feature = "avx")]
|
#[cfg(target_feature = "avx")]
|
||||||
pub mod avx;
|
pub mod avx;
|
||||||
pub mod ggml_file;
|
pub mod ggml_file;
|
||||||
pub mod gguf_file;
|
pub mod gguf_file;
|
||||||
pub mod k_quants;
|
pub mod k_quants;
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
pub mod metal;
|
|
||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
pub mod neon;
|
pub mod neon;
|
||||||
#[cfg(target_feature = "simd128")]
|
#[cfg(target_feature = "simd128")]
|
||||||
pub mod simd128;
|
pub mod simd128;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
use half::f16;
|
|
||||||
|
|
||||||
pub use k_quants::GgmlType;
|
pub use k_quants::GgmlType;
|
||||||
|
|
||||||
pub struct QTensor {
|
pub struct QTensor {
|
||||||
storage: QStorage,
|
data: Box<dyn QuantizedType>,
|
||||||
shape: Shape,
|
shape: Shape,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Device {
|
|
||||||
fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
|
|
||||||
match self {
|
|
||||||
Device::Cpu => {
|
|
||||||
let storage = dtype.cpu_zeros(elem_count);
|
|
||||||
Ok(QStorage::Cpu(storage))
|
|
||||||
}
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
Device::Metal(metal) => {
|
|
||||||
let size = elem_count * dtype.type_size() / dtype.block_size();
|
|
||||||
let buffer = metal.allocate_zeros(size)?;
|
|
||||||
Ok(QStorage::Metal(metal::QMetalStorage::new(
|
|
||||||
buffer,
|
|
||||||
metal.clone(),
|
|
||||||
dtype,
|
|
||||||
)))
|
|
||||||
}
|
|
||||||
#[cfg(not(feature = "metal"))]
|
|
||||||
Device::Metal(_metal) => {
|
|
||||||
crate::bail!("Metal feature not activated");
|
|
||||||
}
|
|
||||||
Device::Cuda(_cuda) => {
|
|
||||||
crate::bail!("Cuda ggml quantization not supported");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub enum QStorage {
|
|
||||||
Cpu(Box<dyn QuantizedType>),
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
Metal(metal::QMetalStorage),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QStorage {
|
|
||||||
fn block_size(&self) -> usize {
|
|
||||||
match self {
|
|
||||||
QStorage::Cpu(storage) => storage.block_size(),
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
QStorage::Metal(storage) => storage.dtype().block_size(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dtype(&self) -> GgmlDType {
|
|
||||||
match self {
|
|
||||||
QStorage::Cpu(storage) => storage.dtype(),
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
QStorage::Metal(storage) => storage.dtype(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn size_in_bytes(&self) -> usize {
|
|
||||||
match self {
|
|
||||||
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
QStorage::Metal(storage) => storage.buffer().length() as usize,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn quantize(&mut self, src: &Storage) -> Result<()> {
|
|
||||||
match (self, src) {
|
|
||||||
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
|
||||||
storage.from_float(src.as_slice::<f32>()?)?;
|
|
||||||
}
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
|
||||||
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
|
||||||
match self {
|
|
||||||
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn data(&self) -> Result<Cow<[u8]>> {
|
|
||||||
match self {
|
|
||||||
QStorage::Cpu(storage) => {
|
|
||||||
let data_ptr = storage.as_ptr();
|
|
||||||
let size_in_bytes = storage.storage_size_in_bytes();
|
|
||||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
|
||||||
Ok(Cow::from(data))
|
|
||||||
}
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
QStorage::Metal(_storage) => {
|
|
||||||
crate::bail!("not implemented");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub enum GgmlDType {
|
pub enum GgmlDType {
|
||||||
F32,
|
F32,
|
||||||
@ -179,25 +77,6 @@ impl GgmlDType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The block dtype
|
|
||||||
pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
|
|
||||||
match self {
|
|
||||||
Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
|
|
||||||
Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
|
|
||||||
Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
|
|
||||||
Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
|
|
||||||
Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
|
|
||||||
Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
|
|
||||||
Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
|
|
||||||
Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
|
|
||||||
Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
|
|
||||||
Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
|
|
||||||
Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
|
|
||||||
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
|
|
||||||
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
|
|
||||||
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/// The type size for blocks in bytes.
|
/// The type size for blocks in bytes.
|
||||||
pub fn type_size(&self) -> usize {
|
pub fn type_size(&self) -> usize {
|
||||||
use k_quants::*;
|
use k_quants::*;
|
||||||
@ -221,7 +100,7 @@ impl GgmlDType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// The block size, i.e. the number of elements stored in each block.
|
/// The block size, i.e. the number of elements stored in each block.
|
||||||
pub fn block_size(&self) -> usize {
|
pub fn blck_size(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
Self::F32 => 1,
|
Self::F32 => 1,
|
||||||
Self::F16 => 1,
|
Self::F16 => 1,
|
||||||
@ -240,13 +119,9 @@ impl GgmlDType {
|
|||||||
pub trait QuantizedType: Send + Sync {
|
pub trait QuantizedType: Send + Sync {
|
||||||
fn dtype(&self) -> GgmlDType;
|
fn dtype(&self) -> GgmlDType;
|
||||||
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
|
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
|
||||||
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
|
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
|
||||||
fn storage_size_in_bytes(&self) -> usize;
|
fn storage_size_in_bytes(&self) -> usize;
|
||||||
fn as_ptr(&self) -> *const u8;
|
fn as_ptr(&self) -> *const u8;
|
||||||
fn block_size(&self) -> usize;
|
|
||||||
#[allow(clippy::wrong_self_convention)]
|
|
||||||
fn from_float(&mut self, xs: &[f32]) -> Result<()>;
|
|
||||||
fn size(&self) -> usize;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||||
@ -254,26 +129,12 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
|||||||
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
|
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn size(&self) -> usize {
|
|
||||||
self.len() * core::mem::size_of::<T>()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn from_float(&mut self, xs: &[f32]) -> Result<()> {
|
|
||||||
T::from_float(xs, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dtype(&self) -> GgmlDType {
|
fn dtype(&self) -> GgmlDType {
|
||||||
T::DTYPE
|
T::DTYPE
|
||||||
}
|
}
|
||||||
|
|
||||||
fn block_size(&self) -> usize {
|
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
|
||||||
T::BLCK_SIZE
|
T::to_float(self.as_slice(), ys)
|
||||||
}
|
|
||||||
|
|
||||||
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
|
|
||||||
let mut ys = vec![0.0f32; elem_count];
|
|
||||||
T::to_float(self.as_slice(), &mut ys)?;
|
|
||||||
Ok(CpuStorage::F32(ys))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn storage_size_in_bytes(&self) -> usize {
|
fn storage_size_in_bytes(&self) -> usize {
|
||||||
@ -291,49 +152,56 @@ impl std::fmt::Debug for QTensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
|
fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
|
||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
if dims.is_empty() {
|
if dims.is_empty() {
|
||||||
crate::bail!("scalar tensor cannot be quantized {shape:?}")
|
crate::bail!("scalar tensor cannot be quantized {shape:?}")
|
||||||
}
|
}
|
||||||
if dims[dims.len() - 1] % block_size != 0 {
|
if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
|
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
|
||||||
block_size
|
T::BLCK_SIZE
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
impl QTensor {
|
impl QTensor {
|
||||||
pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
|
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||||
|
data: Vec<T>,
|
||||||
|
shape: S,
|
||||||
|
) -> Result<Self> {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
check_shape(&shape, storage.block_size())?;
|
check_shape::<T>(&shape)?;
|
||||||
Ok(Self { storage, shape })
|
Ok(Self {
|
||||||
|
data: Box::new(data),
|
||||||
|
shape,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
|
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
|
||||||
let shape = src.shape();
|
let shape = src.shape();
|
||||||
let block_size = dtype.block_size();
|
check_shape::<T>(shape)?;
|
||||||
check_shape(shape, block_size)?;
|
let src = src
|
||||||
let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
|
.to_dtype(crate::DType::F32)?
|
||||||
let elem_count = shape.elem_count();
|
.flatten_all()?
|
||||||
if elem_count % block_size != 0 {
|
.to_vec1::<f32>()?;
|
||||||
|
if src.len() % T::BLCK_SIZE != 0 {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"tensor size ({shape:?}) is not divisible by block size {}",
|
"tensor size ({shape:?}) is not divisible by block size {}",
|
||||||
block_size
|
T::BLCK_SIZE
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
let mut storage = src.device().qzeros(elem_count, dtype)?;
|
let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
||||||
storage.quantize(&src.storage())?;
|
T::from_float(&src, &mut data)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
storage,
|
data: Box::new(data),
|
||||||
shape: shape.clone(),
|
shape: shape.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
pub fn dtype(&self) -> GgmlDType {
|
||||||
self.storage.dtype()
|
self.data.dtype()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rank(&self) -> usize {
|
pub fn rank(&self) -> usize {
|
||||||
@ -345,19 +213,21 @@ impl QTensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||||
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
let mut f32_data = vec![0f32; self.shape.elem_count()];
|
||||||
let none = crate::op::BackpropOp::none();
|
self.data.to_float(&mut f32_data)?;
|
||||||
let is_variable = false;
|
Tensor::from_vec(f32_data, &self.shape, device)
|
||||||
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
|
}
|
||||||
.to_device(device)
|
|
||||||
|
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
|
||||||
|
self.data.matmul_t(mkn, lhs, dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn storage_size_in_bytes(&self) -> usize {
|
pub fn storage_size_in_bytes(&self) -> usize {
|
||||||
self.storage.size_in_bytes()
|
self.data.storage_size_in_bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn data(&self) -> Result<Cow<'_, [u8]>> {
|
pub fn as_ptr(&self) -> *const u8 {
|
||||||
self.storage.data()
|
self.data.as_ptr()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -424,97 +294,21 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
}
|
}
|
||||||
dst_shape.push(n);
|
dst_shape.push(n);
|
||||||
let dst_shape = Shape::from(dst_shape);
|
let dst_shape = Shape::from(dst_shape);
|
||||||
#[allow(clippy::infallible_destructuring_match)]
|
let storage = storage.as_slice::<f32>()?;
|
||||||
let self_storage = match &self.storage {
|
let storage =
|
||||||
QStorage::Cpu(storage) => storage,
|
&storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
_ => crate::bail!("Invalid storage"),
|
|
||||||
};
|
|
||||||
let slice = storage.as_slice::<f32>()?;
|
|
||||||
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
|
||||||
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
|
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||||
self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?;
|
self.matmul_t(
|
||||||
|
(dst_shape.elem_count() / n, k, n),
|
||||||
|
storage,
|
||||||
|
&mut dst_storage,
|
||||||
|
)?;
|
||||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
storage: &crate::MetalStorage,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
) -> Result<(crate::MetalStorage, Shape)> {
|
|
||||||
use crate::MetalError;
|
|
||||||
|
|
||||||
if !layout.is_contiguous() {
|
|
||||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
|
||||||
}
|
|
||||||
let src_shape = layout.shape();
|
|
||||||
// self is transposed so n is first then k.
|
|
||||||
if src_shape.rank() < 2 {
|
|
||||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
|
||||||
}
|
|
||||||
let (n, k) = self.shape.dims2()?;
|
|
||||||
let mut dst_shape = src_shape.dims().to_vec();
|
|
||||||
|
|
||||||
let (b, m) = match dst_shape.len() {
|
|
||||||
3 => (dst_shape[0], dst_shape[1]),
|
|
||||||
2 => (1, dst_shape[0]),
|
|
||||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
|
||||||
};
|
|
||||||
let last_k = dst_shape.pop().unwrap();
|
|
||||||
if last_k != k {
|
|
||||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
|
||||||
}
|
|
||||||
dst_shape.push(n);
|
|
||||||
let dst_shape = Shape::from(dst_shape);
|
|
||||||
let device = storage.device().clone();
|
|
||||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
|
||||||
let (buffer, dtype) = match &self.storage {
|
|
||||||
QStorage::Metal(metal) => (metal.buffer(), metal.dtype()),
|
|
||||||
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
|
||||||
};
|
|
||||||
let command_buffer = device.command_buffer()?;
|
|
||||||
candle_metal_kernels::call_quantized_matmul_t(
|
|
||||||
device.device(),
|
|
||||||
&command_buffer,
|
|
||||||
device.kernels(),
|
|
||||||
dtype.into(),
|
|
||||||
(b, m, n, k),
|
|
||||||
storage.buffer(),
|
|
||||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
|
||||||
buffer,
|
|
||||||
&dst,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
|
||||||
Ok((dst_storage, dst_shape))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
impl QMatMul {
|
||||||
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
fn from(value: GgmlDType) -> Self {
|
|
||||||
match value {
|
|
||||||
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
|
||||||
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
|
||||||
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
|
||||||
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
|
||||||
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
|
||||||
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
|
||||||
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
|
||||||
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
|
||||||
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
|
||||||
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
|
||||||
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
|
||||||
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
|
||||||
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
|
||||||
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl crate::Module for QMatMul {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
match self {
|
match self {
|
||||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||||
Self::Tensor(w) => {
|
Self::Tensor(w) => {
|
||||||
|
@ -12,14 +12,6 @@ use core::arch::arm::*;
|
|||||||
#[cfg(target_arch = "aarch64")]
|
#[cfg(target_arch = "aarch64")]
|
||||||
use core::arch::aarch64::*;
|
use core::arch::aarch64::*;
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
|
|
||||||
// TODO: dotprod
|
|
||||||
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
|
|
||||||
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
|
||||||
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||||
let qk = QK8_0;
|
let qk = QK8_0;
|
||||||
@ -51,8 +43,15 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
|||||||
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
||||||
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||||
|
|
||||||
let pl0 = vdotq_s32(v0_0ls, v1_0l);
|
// TODO: Support dotprod when it's available outside of nightly.
|
||||||
let ph0 = vdotq_s32(v0_0hs, v1_0h);
|
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
|
||||||
|
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
||||||
|
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
|
||||||
|
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
||||||
|
|
||||||
|
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||||
|
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||||
|
|
||||||
sumv0 = vmlaq_n_f32(
|
sumv0 = vmlaq_n_f32(
|
||||||
sumv0,
|
sumv0,
|
||||||
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
||||||
@ -83,8 +82,14 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
|||||||
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
||||||
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||||
|
|
||||||
let p0 = vdotq_s32(x0_0, y0_0);
|
// TODO dotprod once this is the intrinsics are.
|
||||||
let p1 = vdotq_s32(x0_1, y0_1);
|
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
||||||
|
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
||||||
|
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
||||||
|
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
||||||
|
|
||||||
|
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
||||||
|
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
||||||
|
|
||||||
sumv0 = vmlaq_n_f32(
|
sumv0 = vmlaq_n_f32(
|
||||||
sumv0,
|
sumv0,
|
||||||
@ -113,7 +118,10 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res
|
|||||||
for i in (0..QK_K).step_by(16) {
|
for i in (0..QK_K).step_by(16) {
|
||||||
let xs = vld1q_s8(xs.add(i));
|
let xs = vld1q_s8(xs.add(i));
|
||||||
let ys = vld1q_s8(ys.add(i));
|
let ys = vld1q_s8(ys.add(i));
|
||||||
let xy = vdotq_s32(xs, ys);
|
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
|
||||||
|
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
|
||||||
|
|
||||||
|
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
|
||||||
sum_i = vaddq_s32(sum_i, xy)
|
sum_i = vaddq_s32(sum_i, xy)
|
||||||
}
|
}
|
||||||
sumf += vaddvq_s32(sum_i) as f32 * scale
|
sumf += vaddvq_s32(sum_i) as f32 * scale
|
||||||
@ -183,16 +191,30 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
|
|||||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
|
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
|
||||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
|
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
|
||||||
|
|
||||||
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
// TODO: dotprod
|
||||||
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
|
||||||
|
let p0 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||||
|
);
|
||||||
|
let p1 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||||
|
);
|
||||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||||
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||||
scale = scale.add(2);
|
scale = scale.add(2);
|
||||||
|
|
||||||
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
let p2 = vaddq_s16(
|
||||||
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||||
|
);
|
||||||
|
let p3 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||||
|
);
|
||||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||||
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||||
scale = scale.add(2);
|
scale = scale.add(2);
|
||||||
|
|
||||||
let q8bytes = vld1q_s8_x4(q8);
|
let q8bytes = vld1q_s8_x4(q8);
|
||||||
@ -212,16 +234,29 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
|
|||||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
|
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
|
||||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
|
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
|
||||||
|
|
||||||
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
// TODO: dotprod case.
|
||||||
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
let p0 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||||
|
);
|
||||||
|
let p1 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||||
|
);
|
||||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||||
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||||
scale = scale.add(2);
|
scale = scale.add(2);
|
||||||
|
|
||||||
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
let p2 = vaddq_s16(
|
||||||
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||||
|
);
|
||||||
|
let p3 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||||
|
);
|
||||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||||
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||||
scale = scale.add(2);
|
scale = scale.add(2);
|
||||||
}
|
}
|
||||||
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
|
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
|
||||||
@ -298,14 +333,28 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
|||||||
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
|
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
|
||||||
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
|
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
|
||||||
|
|
||||||
let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
|
// TODO: dotprod
|
||||||
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
|
|
||||||
sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
|
let p0 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
|
||||||
|
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
|
||||||
|
);
|
||||||
|
let p1 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
|
||||||
|
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
|
||||||
|
);
|
||||||
|
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
|
||||||
scales = scales.add(1);
|
scales = scales.add(1);
|
||||||
|
|
||||||
let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
|
let p2 = vaddq_s16(
|
||||||
let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
|
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
|
||||||
sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
|
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
|
||||||
|
);
|
||||||
|
let p3 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
|
||||||
|
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
|
||||||
|
);
|
||||||
|
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
|
||||||
scales = scales.add(1);
|
scales = scales.add(1);
|
||||||
}
|
}
|
||||||
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
|
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
|
||||||
@ -368,15 +417,22 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
|||||||
for j in 0..QK_K / 64 {
|
for j in 0..QK_K / 64 {
|
||||||
let q4bits = vld1q_u8_x2(q4);
|
let q4bits = vld1q_u8_x2(q4);
|
||||||
q4 = q4.add(32);
|
q4 = q4.add(32);
|
||||||
|
// TODO: dotprod
|
||||||
let q8bytes = vld1q_s8_x2(q8);
|
let q8bytes = vld1q_s8_x2(q8);
|
||||||
q8 = q8.add(32);
|
q8 = q8.add(32);
|
||||||
let q4bytes = int8x16x2_t(
|
let q4bytes = int8x16x2_t(
|
||||||
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
|
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
|
||||||
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
|
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
|
||||||
);
|
);
|
||||||
let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
|
let p0 = vaddq_s16(
|
||||||
let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
|
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||||
sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
|
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||||
|
);
|
||||||
|
let p1 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||||
|
);
|
||||||
|
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
|
||||||
|
|
||||||
let q8bytes = vld1q_s8_x2(q8);
|
let q8bytes = vld1q_s8_x2(q8);
|
||||||
q8 = q8.add(32);
|
q8 = q8.add(32);
|
||||||
@ -384,9 +440,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
|||||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
|
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
|
||||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
|
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
|
||||||
);
|
);
|
||||||
let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
|
let p2 = vaddq_s16(
|
||||||
let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
|
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||||
sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
|
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||||
|
);
|
||||||
|
let p3 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||||
|
);
|
||||||
|
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
|
||||||
}
|
}
|
||||||
sumf += d * (sumi1 + sumi2) as f32;
|
sumf += d * (sumi1 + sumi2) as f32;
|
||||||
}
|
}
|
||||||
@ -464,14 +526,27 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
|||||||
vreinterpretq_s8_u8(q3h_3),
|
vreinterpretq_s8_u8(q3h_3),
|
||||||
);
|
);
|
||||||
|
|
||||||
let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
|
// TODO: dotprod
|
||||||
let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
|
let p0 = vaddq_s16(
|
||||||
let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
|
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
|
||||||
let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
|
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
|
||||||
isum += vaddvq_s32(p0) * *scale as i32
|
);
|
||||||
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
let p1 = vaddq_s16(
|
||||||
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
|
||||||
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
|
||||||
|
);
|
||||||
|
let p2 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
|
||||||
|
);
|
||||||
|
let p3 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
|
||||||
|
);
|
||||||
|
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||||
|
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||||
|
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||||
|
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||||
scale = scale.add(4);
|
scale = scale.add(4);
|
||||||
|
|
||||||
let q3h_0 = vbicq_u8(m2, qhbits.0);
|
let q3h_0 = vbicq_u8(m2, qhbits.0);
|
||||||
@ -496,14 +571,27 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
|||||||
vreinterpretq_s8_u8(q3h_3),
|
vreinterpretq_s8_u8(q3h_3),
|
||||||
);
|
);
|
||||||
|
|
||||||
let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
|
// TODO: dotprod
|
||||||
let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
|
let p0 = vaddq_s16(
|
||||||
let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
|
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
|
||||||
let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
|
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
|
||||||
isum += vaddvq_s32(p0) * *scale as i32
|
);
|
||||||
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
let p1 = vaddq_s16(
|
||||||
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
|
||||||
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
|
||||||
|
);
|
||||||
|
let p2 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
|
||||||
|
);
|
||||||
|
let p3 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
|
||||||
|
);
|
||||||
|
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||||
|
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||||
|
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||||
|
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||||
scale = scale.add(4);
|
scale = scale.add(4);
|
||||||
|
|
||||||
if j == 0 {
|
if j == 0 {
|
||||||
@ -561,6 +649,7 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
|
|||||||
let mut is = 0usize;
|
let mut is = 0usize;
|
||||||
|
|
||||||
// TODO: dotprod
|
// TODO: dotprod
|
||||||
|
|
||||||
for _j in 0..QK_K / 128 {
|
for _j in 0..QK_K / 128 {
|
||||||
let q2bits = vld1q_u8_x2(q2);
|
let q2bits = vld1q_u8_x2(q2);
|
||||||
q2 = q2.add(32);
|
q2 = q2.add(32);
|
||||||
@ -607,7 +696,14 @@ unsafe fn multiply_accum_with_scale(
|
|||||||
q2bytes: int8x16x2_t,
|
q2bytes: int8x16x2_t,
|
||||||
q8bytes: int8x16x2_t,
|
q8bytes: int8x16x2_t,
|
||||||
) -> i32 {
|
) -> i32 {
|
||||||
let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
|
let p1 = vaddq_s16(
|
||||||
let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
|
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
|
||||||
vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
|
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
|
||||||
|
);
|
||||||
|
let p2 = vaddq_s16(
|
||||||
|
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
|
||||||
|
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
|
||||||
|
);
|
||||||
|
vaddvq_s16(p1) as i32 * aux[is + index] as i32
|
||||||
|
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
|
||||||
}
|
}
|
||||||
|
@ -478,6 +478,23 @@ extract_dims!(
|
|||||||
(usize, usize, usize, usize, usize)
|
(usize, usize, usize, usize, usize)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stride() {
|
||||||
|
let shape = Shape::from(());
|
||||||
|
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
|
||||||
|
let shape = Shape::from(42);
|
||||||
|
assert_eq!(shape.stride_contiguous(), [1]);
|
||||||
|
let shape = Shape::from((42, 1337));
|
||||||
|
assert_eq!(shape.stride_contiguous(), [1337, 1]);
|
||||||
|
let shape = Shape::from((299, 792, 458));
|
||||||
|
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait ShapeWithOneHole {
|
pub trait ShapeWithOneHole {
|
||||||
fn into_shape(self, el_count: usize) -> Result<Shape>;
|
fn into_shape(self, el_count: usize) -> Result<Shape>;
|
||||||
}
|
}
|
||||||
@ -610,20 +627,3 @@ impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
|
|||||||
Ok((d1, d2, d3, d4, d).into())
|
Ok((d1, d2, d3, d4, d).into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn stride() {
|
|
||||||
let shape = Shape::from(());
|
|
||||||
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
|
|
||||||
let shape = Shape::from(42);
|
|
||||||
assert_eq!(shape.stride_contiguous(), [1]);
|
|
||||||
let shape = Shape::from((42, 1337));
|
|
||||||
assert_eq!(shape.stride_contiguous(), [1337, 1]);
|
|
||||||
let shape = Shape::from((299, 792, 458));
|
|
||||||
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
//! Tensors are N-dimensional matrixes of elements using a single data type.
|
//! Tensors are N-dimenional matrixes of elements using a single data type.
|
||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
use crate::op::{
|
use crate::op::{
|
||||||
@ -361,16 +361,6 @@ impl Tensor {
|
|||||||
Self::new_impl(array, shape, device, false)
|
Self::new_impl(array, shape, device, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a new tensor with all the elements having the same specified value. Note that
|
|
||||||
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
|
|
||||||
pub fn full<D: crate::WithDType, S: Into<Shape>>(
|
|
||||||
value: D,
|
|
||||||
shape: S,
|
|
||||||
device: &Device,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new 1D tensor from an iterator.
|
/// Creates a new 1D tensor from an iterator.
|
||||||
pub fn from_iter<D: crate::WithDType>(
|
pub fn from_iter<D: crate::WithDType>(
|
||||||
iter: impl IntoIterator<Item = D>,
|
iter: impl IntoIterator<Item = D>,
|
||||||
@ -396,7 +386,7 @@ impl Tensor {
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
if D::is_zero(&step) {
|
if D::is_zero(&step) {
|
||||||
bail!("step cannot be zero")
|
crate::bail!("step cannot be zero")
|
||||||
}
|
}
|
||||||
let mut data = vec![];
|
let mut data = vec![];
|
||||||
let mut current = start;
|
let mut current = start;
|
||||||
@ -679,7 +669,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Split a tensor into the specified number of chunks, this may return less chunks than
|
/// Split a tensor into the specified number of chunks, this may return less chunks than
|
||||||
/// specified.
|
/// specificed.
|
||||||
pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
|
pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
|
||||||
let dim = dim.to_index(self.shape(), "chunk")?;
|
let dim = dim.to_index(self.shape(), "chunk")?;
|
||||||
let size = self.dim(dim)?;
|
let size = self.dim(dim)?;
|
||||||
@ -1004,11 +994,7 @@ impl Tensor {
|
|||||||
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
|
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
|
||||||
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||||
let (n, c, _h, _w) = self.dims4()?;
|
let (n, c, _h, _w) = self.dims4()?;
|
||||||
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
|
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
|
||||||
arg,
|
|
||||||
target_h,
|
|
||||||
target_w,
|
|
||||||
});
|
|
||||||
let storage = self
|
let storage = self
|
||||||
.storage()
|
.storage()
|
||||||
.upsample_nearest2d(self.layout(), target_h, target_w)?;
|
.upsample_nearest2d(self.layout(), target_h, target_w)?;
|
||||||
@ -1041,9 +1027,6 @@ impl Tensor {
|
|||||||
let kernel_size = kernel_size.to_usize2();
|
let kernel_size = kernel_size.to_usize2();
|
||||||
let stride = stride.to_usize2();
|
let stride = stride.to_usize2();
|
||||||
let (n, c, h, w) = self.dims4()?;
|
let (n, c, h, w) = self.dims4()?;
|
||||||
if h < kernel_size.0 || w < kernel_size.1 {
|
|
||||||
bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
|
|
||||||
}
|
|
||||||
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
|
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
|
||||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||||
@ -1079,9 +1062,6 @@ impl Tensor {
|
|||||||
let kernel_size = kernel_size.to_usize2();
|
let kernel_size = kernel_size.to_usize2();
|
||||||
let stride = stride.to_usize2();
|
let stride = stride.to_usize2();
|
||||||
let (n, c, h, w) = self.dims4()?;
|
let (n, c, h, w) = self.dims4()?;
|
||||||
if h < kernel_size.0 || w < kernel_size.1 {
|
|
||||||
bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
|
|
||||||
}
|
|
||||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
||||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||||
@ -1804,7 +1784,7 @@ impl Tensor {
|
|||||||
let is_permutation =
|
let is_permutation =
|
||||||
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
|
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
|
||||||
if !is_permutation {
|
if !is_permutation {
|
||||||
bail!(
|
crate::bail!(
|
||||||
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
||||||
self.dims(),
|
self.dims(),
|
||||||
dims
|
dims
|
||||||
@ -1883,7 +1863,10 @@ impl Tensor {
|
|||||||
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
||||||
}
|
}
|
||||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||||
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
(Storage::Metal(storage), Device::Cpu) => {
|
||||||
|
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
|
||||||
|
Storage::Cpu(storage.to_cpu_storage()?)
|
||||||
|
}
|
||||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||||
// are the same.
|
// are the same.
|
||||||
@ -2299,7 +2282,7 @@ impl Tensor {
|
|||||||
if left == 0 && right == 0 {
|
if left == 0 && right == 0 {
|
||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
} else if self.elem_count() == 0 {
|
} else if self.elem_count() == 0 {
|
||||||
bail!("cannot use pad_with_same on an empty tensor")
|
crate::bail!("cannot use pad_with_same on an empty tensor")
|
||||||
} else if left == 0 {
|
} else if left == 0 {
|
||||||
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||||
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
|
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
|
||||||
@ -2463,136 +2446,17 @@ impl Tensor {
|
|||||||
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
||||||
let rank = self.rank() as i64;
|
let rank = self.rank() as i64;
|
||||||
if rank <= axis {
|
if rank <= axis {
|
||||||
bail!("axis {axis} is too large, tensor rank {rank}")
|
crate::bail!("axis {axis} is too large, tensor rank {rank}")
|
||||||
} else if 0 <= axis {
|
} else if 0 <= axis {
|
||||||
Ok(axis as usize)
|
Ok(axis as usize)
|
||||||
} else {
|
} else {
|
||||||
let naxis = rank + axis;
|
let naxis = rank + axis;
|
||||||
if naxis < 0 {
|
if naxis < 0 {
|
||||||
bail!("axis {axis} is too small, tensor rank {rank}")
|
crate::bail!("axis {axis} is too small, tensor rank {rank}")
|
||||||
}
|
}
|
||||||
Ok(naxis as usize)
|
Ok(naxis as usize)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a lower triangular matrix of ones of size n by n.
|
|
||||||
pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
|
||||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
|
||||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
|
||||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
|
||||||
t1.le(&t2)?.to_dtype(dtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns an upper triangular matrix of ones of size n by n.
|
|
||||||
pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
|
||||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
|
||||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
|
||||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
|
||||||
t1.ge(&t2)?.to_dtype(dtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a matrix with a diagonal of ones of size n by n.
|
|
||||||
pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
|
||||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
|
||||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
|
||||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
|
||||||
t1.eq(&t2)?.to_dtype(dtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the cumulative sum of elements of the input tensor summed over the specified
|
|
||||||
/// dimension.
|
|
||||||
///
|
|
||||||
/// This operation is most efficient when dim is the last dimension of the tensor.
|
|
||||||
pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
|
|
||||||
let dim = dim.to_index(self.shape(), "cumsum")?;
|
|
||||||
let rank = self.rank();
|
|
||||||
if rank == 0 {
|
|
||||||
return Ok(self.clone());
|
|
||||||
}
|
|
||||||
let n_axis = self.dim(dim)?;
|
|
||||||
let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
|
|
||||||
if rank == 1 {
|
|
||||||
self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
|
|
||||||
} else {
|
|
||||||
let last = rank - 1;
|
|
||||||
let t = self.transpose(dim, last)?;
|
|
||||||
let t = t.broadcast_matmul(&triu)?;
|
|
||||||
t.transpose(dim, last)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a copy of `self` where the values within `ranges` have been replaced with the
|
|
||||||
/// content of `src`.
|
|
||||||
pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
|
|
||||||
&self,
|
|
||||||
ranges: &[D],
|
|
||||||
src: &Tensor,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let src_dims = src.dims();
|
|
||||||
let self_dims = self.dims();
|
|
||||||
if self_dims.len() != src_dims.len() {
|
|
||||||
bail!(
|
|
||||||
"slice-assign requires input with the same rank {} <> {}",
|
|
||||||
self_dims.len(),
|
|
||||||
src_dims.len()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if self_dims.len() != ranges.len() {
|
|
||||||
bail!(
|
|
||||||
"slice-assign requires input with the same rank as there are ranges {} <> {}",
|
|
||||||
self_dims.len(),
|
|
||||||
ranges.len()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
let mut src = src.clone();
|
|
||||||
let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
|
|
||||||
for (i, range) in ranges.iter().enumerate() {
|
|
||||||
let start_included = match range.start_bound() {
|
|
||||||
std::ops::Bound::Unbounded => 0,
|
|
||||||
std::ops::Bound::Included(v) => *v,
|
|
||||||
std::ops::Bound::Excluded(v) => *v + 1,
|
|
||||||
};
|
|
||||||
let end_excluded = match range.end_bound() {
|
|
||||||
std::ops::Bound::Unbounded => self_dims[i],
|
|
||||||
std::ops::Bound::Included(v) => *v + 1,
|
|
||||||
std::ops::Bound::Excluded(v) => *v,
|
|
||||||
};
|
|
||||||
if end_excluded <= start_included {
|
|
||||||
bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
|
|
||||||
}
|
|
||||||
if self_dims[i] < end_excluded {
|
|
||||||
bail!(
|
|
||||||
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
|
|
||||||
self_dims[i]
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if end_excluded - start_included != src_dims[i] {
|
|
||||||
bail!(
|
|
||||||
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
|
|
||||||
)
|
|
||||||
}
|
|
||||||
src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
|
|
||||||
mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
|
|
||||||
}
|
|
||||||
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns log(sum(exp(tensor), dim)).
|
|
||||||
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
|
||||||
let exp = self.exp()?;
|
|
||||||
let sum = exp.sum(sum_dims)?;
|
|
||||||
sum.log()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Pointwise pow operation.
|
|
||||||
pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
|
|
||||||
rhs.mul(&self.log()?)?.exp()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Broadcasting version of `pow`.
|
|
||||||
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
|
|
||||||
rhs.broadcast_mul(&self.log()?)?.exp()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! bin_trait {
|
macro_rules! bin_trait {
|
||||||
|
@ -270,166 +270,6 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||||
);
|
);
|
||||||
|
|
||||||
// manually checked: see comments
|
|
||||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
|
||||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
let z = Tensor::new(
|
|
||||||
&[
|
|
||||||
1_f32, 02., 03., 04., 05., 06.,
|
|
||||||
07., 08., 09., 10., 11., 12.,
|
|
||||||
13., 14., 15., 16., 17., 18.,
|
|
||||||
19., 20., 21., 22., 23., 24.,
|
|
||||||
25., 26., 27., 28., 29., 30.,
|
|
||||||
31., 32., 33., 34., 35., 36.,
|
|
||||||
],
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
// gradient should be
|
|
||||||
// row 1
|
|
||||||
// 1+2+7+8 = 18
|
|
||||||
// 3+4+9+10 = 26
|
|
||||||
// 5+6+11+12 = 34
|
|
||||||
// row 2
|
|
||||||
// 13+14+19+20 = 66
|
|
||||||
// 15+16+21+22 = 74
|
|
||||||
// 17+18+23+24 = 82
|
|
||||||
// row 3
|
|
||||||
// 25+26+31+32 = 114
|
|
||||||
// 27+28+33+34 = 122
|
|
||||||
// 29+30+35+36 = 130
|
|
||||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
|
||||||
|
|
||||||
let grads = loss.backward()?;
|
|
||||||
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
|
|
||||||
[[18_f32, 26., 34.], [66., 74., 82.], [114., 122., 130.]]
|
|
||||||
);
|
|
||||||
|
|
||||||
// manually checked: see comments
|
|
||||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
|
|
||||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
let z = Tensor::new(
|
|
||||||
&[
|
|
||||||
1_f32, 02., 03., 04., 05., 06.,
|
|
||||||
07., 08., 09., 10., 11., 12.,
|
|
||||||
13., 14., 15., 16., 17., 18.,
|
|
||||||
19., 20., 21., 22., 23., 24.,
|
|
||||||
25., 26., 27., 28., 29., 30.,
|
|
||||||
31., 32., 33., 34., 35., 36.,
|
|
||||||
],
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
// gradient should be
|
|
||||||
// row 1
|
|
||||||
// 1+2+3+7+8+9+13+14+15 = 72
|
|
||||||
// 4+5+6+10+11+12+16+17+18 = 99
|
|
||||||
// row 2
|
|
||||||
// 19+20+21+25+26+27+31+32+33 = 234
|
|
||||||
// 22+23+24+28+29+30+34+35+36 = 243
|
|
||||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
|
||||||
|
|
||||||
let grads = loss.backward()?;
|
|
||||||
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
|
|
||||||
[[72_f32, 99.], [234., 261.]]
|
|
||||||
);
|
|
||||||
|
|
||||||
// manually checked: see comments
|
|
||||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]], [[6f32, 7.], [8., 9.]]]], device)?;
|
|
||||||
|
|
||||||
let y = x.interpolate2d(4, 4)?.reshape(32)?;
|
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
let z = Tensor::new(
|
|
||||||
&[
|
|
||||||
1_f32, 02., 03., 04.,
|
|
||||||
05., 06., 07., 08.,
|
|
||||||
09., 10., 11., 12.,
|
|
||||||
13., 14., 15., 16.,
|
|
||||||
17., 18., 19., 20.,
|
|
||||||
21., 22., 23., 24.,
|
|
||||||
25., 26., 27., 28.,
|
|
||||||
29., 30., 31., 32.
|
|
||||||
],
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
// gradient should be
|
|
||||||
// m1r1
|
|
||||||
// 1+2+5+6=14
|
|
||||||
// 3+4+7+8=22
|
|
||||||
// m1r2
|
|
||||||
// 9+10+13+14=46
|
|
||||||
// 11+12+15+16=54
|
|
||||||
// m2r1
|
|
||||||
// 17+18+21+22=78
|
|
||||||
// 19+20+23+24=86
|
|
||||||
// m2r2
|
|
||||||
// 25+26+29+30=110
|
|
||||||
// 27+28+31+32=118
|
|
||||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
|
||||||
|
|
||||||
let grads = loss.backward()?;
|
|
||||||
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
|
|
||||||
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
|
|
||||||
);
|
|
||||||
|
|
||||||
// manually checked: see comments
|
|
||||||
let x = Var::new(
|
|
||||||
&[[[[1f32, 2.], [4., 5.]]], [[[6f32, 7.], [8., 9.]]]],
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let y = x.interpolate2d(4, 4)?.reshape(32)?;
|
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
let z = Tensor::new(
|
|
||||||
&[
|
|
||||||
1_f32, 02., 03., 04.,
|
|
||||||
05., 06., 07., 08.,
|
|
||||||
09., 10., 11., 12.,
|
|
||||||
13., 14., 15., 16.,
|
|
||||||
17., 18., 19., 20.,
|
|
||||||
21., 22., 23., 24.,
|
|
||||||
25., 26., 27., 28.,
|
|
||||||
29., 30., 31., 32.
|
|
||||||
],
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
// gradient should be
|
|
||||||
// m1r1
|
|
||||||
// 1+2+5+6=14
|
|
||||||
// 3+4+7+8=22
|
|
||||||
// m1r2
|
|
||||||
// 9+10+13+14=46
|
|
||||||
// 11+12+15+16=54
|
|
||||||
// m2r1
|
|
||||||
// 17+18+21+22=78
|
|
||||||
// 19+20+23+24=86
|
|
||||||
// m2r2
|
|
||||||
// 25+26+29+30=110
|
|
||||||
// 27+28+31+32=118
|
|
||||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
|
||||||
|
|
||||||
let grads = loss.backward()?;
|
|
||||||
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
|
|
||||||
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
|
|
||||||
);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,32 +91,3 @@ fn index_3d() -> Result<()> {
|
|||||||
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn slice_assign() -> Result<()> {
|
|
||||||
let dev = Device::Cpu;
|
|
||||||
|
|
||||||
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
|
|
||||||
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
|
|
||||||
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
|
|
||||||
assert_eq!(
|
|
||||||
out.to_vec2::<u32>()?,
|
|
||||||
&[
|
|
||||||
[0, 1, 2, 3, 4],
|
|
||||||
[5, 6, 7, 0, 1],
|
|
||||||
[10, 11, 12, 2, 3],
|
|
||||||
[15, 16, 17, 4, 5]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
|
|
||||||
assert_eq!(
|
|
||||||
out.to_vec2::<u32>()?,
|
|
||||||
&[
|
|
||||||
[0, 1, 2, 3, 4],
|
|
||||||
[2, 3, 7, 8, 9],
|
|
||||||
[4, 5, 12, 13, 14],
|
|
||||||
[15, 16, 17, 18, 19]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
use candle_core::{
|
use candle_core::{
|
||||||
bail,
|
|
||||||
quantized::{self, GgmlDType},
|
quantized::{self, GgmlDType},
|
||||||
test_device,
|
|
||||||
test_utils::to_vec2_round,
|
test_utils::to_vec2_round,
|
||||||
Device, Module, Result, Tensor,
|
Device, Result, Tensor,
|
||||||
};
|
};
|
||||||
use quantized::{k_quants, GgmlType};
|
use quantized::{k_quants, GgmlType};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
@ -15,48 +13,16 @@ const GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS: f32 = 0.0075;
|
|||||||
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040;
|
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040;
|
||||||
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
|
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
|
||||||
|
|
||||||
fn test_matmul(
|
#[test]
|
||||||
device: &Device,
|
fn quantized_matmul() -> Result<()> {
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
let cpu = &Device::Cpu;
|
||||||
dtype: GgmlDType,
|
|
||||||
) -> Result<()> {
|
|
||||||
let lhs = (0..(m * k))
|
|
||||||
.map(|v| v as f32 / (m * k) as f32)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let rhs = (0..(k * n))
|
|
||||||
.map(|v| v as f32 / (n * k) as f32)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
|
||||||
let rhs = Tensor::from_slice(&rhs, (k, n), device)?;
|
|
||||||
let mm = lhs.matmul(&rhs)?;
|
|
||||||
let qtensor = quantized::QTensor::quantize(&rhs.t()?, dtype)?;
|
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
|
||||||
let res = matmul.forward(&lhs)?;
|
|
||||||
|
|
||||||
let error: f32 = ((&mm - &res)?.abs()? / &mm.abs()?)?
|
|
||||||
.sum_all()?
|
|
||||||
.to_scalar()?;
|
|
||||||
let error = error / (b * m * n) as f32;
|
|
||||||
assert!(
|
|
||||||
error <= 0.02,
|
|
||||||
"Error {error} is too big. \nExpected:\n {mm} \nFound:\n {res}\n for {dtype:?}"
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn quantized_matmul(device: &Device) -> Result<()> {
|
|
||||||
// TODO Enable this later when we enable cuda.
|
|
||||||
if device.is_cuda() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let (m, k, n) = (3, 64, 4);
|
let (m, k, n) = (3, 64, 4);
|
||||||
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
||||||
let mut dst = vec![42.; 3 * 4];
|
let mut dst = vec![42.; 3 * 4];
|
||||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||||
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
|
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
||||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -66,7 +32,6 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
|||||||
341876.0, 994283.0, 1655709.0, 2301518.0
|
341876.0, 994283.0, 1655709.0, 2301518.0
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
|
||||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
mm.to_vec2::<f32>()?,
|
mm.to_vec2::<f32>()?,
|
||||||
@ -77,49 +42,35 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
let res = matmul.forward(&tensor_lhs)?;
|
let res = matmul.forward(&tensor_lhs)?;
|
||||||
match device {
|
assert_eq!(
|
||||||
Device::Metal(_) => assert_eq!(
|
|
||||||
to_vec2_round(&res, 0)?,
|
|
||||||
&[
|
|
||||||
[84946.0, 214126.0, 344757.0, 473798.0],
|
|
||||||
[213458.0, 604350.0, 1000469.0, 1387990.0],
|
|
||||||
[341970.0, 994574.0, 1656181.0, 2302182.0]
|
|
||||||
]
|
|
||||||
),
|
|
||||||
_ => assert_eq!(
|
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
&[
|
&[
|
||||||
[85120.0, 214562.0, 345455.0, 474748.0],
|
[85120.0, 214562.0, 345455.0, 474748.0],
|
||||||
[213475.0, 604465.0, 1000686.0, 1388317.0],
|
[213475.0, 604465.0, 1000686.0, 1388317.0],
|
||||||
[341876.0, 994283.0, 1655709.0, 2301518.0]
|
[341876.0, 994283.0, 1655709.0, 2301518.0]
|
||||||
]
|
]
|
||||||
),
|
);
|
||||||
}
|
|
||||||
|
|
||||||
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
#[test]
|
||||||
// TODO Enable this later when we enable cuda.
|
fn quantized_matmul_neg() -> Result<()> {
|
||||||
if device.is_cuda() {
|
let cpu = &Device::Cpu;
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let (m, k, n) = (3, 64, 4);
|
let (m, k, n) = (3, 64, 4);
|
||||||
let lhs = (0..(m * k))
|
let lhs = (0..(m * k))
|
||||||
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
||||||
let mut dst = vec![42.; 3 * 4];
|
let mut dst = vec![42.; 3 * 4];
|
||||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||||
let rhs = (0..k * n)
|
let rhs = (0..k * n)
|
||||||
.map(|v| v as f32 - (k * n) as f32 / 3.0)
|
.map(|v| v as f32 - (k * n) as f32 / 3.0)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
||||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -139,56 +90,32 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
let res = matmul.forward(&tensor_lhs)?;
|
let res = matmul.forward(&tensor_lhs)?;
|
||||||
match device {
|
assert_eq!(
|
||||||
Device::Metal(_) => assert_eq!(
|
|
||||||
to_vec2_round(&res, 0)?,
|
|
||||||
&[
|
|
||||||
[243666.0, -19714.0, -285433.0, -550453.0],
|
|
||||||
[23782.0, 21654.0, 19400.0, 18369.0],
|
|
||||||
[-196102.0, 63022.0, 324233.0, 587191.0]
|
|
||||||
]
|
|
||||||
),
|
|
||||||
_ => assert_eq!(
|
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
&[
|
&[
|
||||||
[243524.0, -19596.0, -285051.0, -549815.0],
|
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||||
[23777.0, 21651.0, 19398.0, 18367.0],
|
[23777.0, 21651.0, 19398.0, 18367.0],
|
||||||
[-196472.0, 63012.0, 324585.0, 587902.0]
|
[-196472.0, 63012.0, 324585.0, 587902.0]
|
||||||
]
|
]
|
||||||
),
|
);
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(
|
#[test]
|
||||||
quantized_matmul,
|
fn quantize_q4_0() -> Result<()> {
|
||||||
quantized_matmul_cpu,
|
use k_quants::BlockQ4_0;
|
||||||
quantized_matmul_cuda,
|
|
||||||
quantized_matmul_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
quantized_matmul_neg,
|
|
||||||
quantized_matmul_neg_cpu,
|
|
||||||
quantized_matmul_neg_cuda,
|
|
||||||
quantized_matmul_neg_metal
|
|
||||||
);
|
|
||||||
|
|
||||||
fn quantize_q4_0(device: &Device) -> Result<()> {
|
|
||||||
// TODO Enable this later when we enable cuda.
|
|
||||||
if device.is_cuda() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
|
let mut dst = vec![0f32; 32 * 4];
|
||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let mut quant = vec![BlockQ4_0::zeros(); 4];
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
|
BlockQ4_0::from_float(&src, &mut quant)?;
|
||||||
let dst = quant.dequantize(device)?;
|
BlockQ4_0::to_float(&quant, dst.as_mut_slice())?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
dst.to_vec1::<f32>()?,
|
dst,
|
||||||
&[
|
&[
|
||||||
-0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625,
|
-0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625,
|
||||||
11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25,
|
11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25,
|
||||||
@ -204,21 +131,21 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
|
|||||||
127.0, 127.0
|
127.0, 127.0
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
ggml_quantization_error_test(GgmlDType::Q4_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
ggml_quantization_error_test::<BlockQ4_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q4_1(device: &Device) -> Result<()> {
|
#[test]
|
||||||
// TODO Enable this later when we enable cuda.
|
fn quantize_q4_1() -> Result<()> {
|
||||||
if device.is_cuda() {
|
use k_quants::BlockQ4_1;
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let mut dst = vec![0f32; 32 * 4];
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
let mut quant = vec![BlockQ4_1::zeros(); 4];
|
||||||
let dst = quant.dequantize(device)?;
|
BlockQ4_1::from_float(&src, &mut quant)?;
|
||||||
|
BlockQ4_1::to_float(&quant, dst.as_mut_slice())?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
round_vector(&dst.to_vec1::<f32>()?),
|
round_vector(&dst),
|
||||||
&[
|
&[
|
||||||
0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332,
|
0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332,
|
||||||
12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73,
|
12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73,
|
||||||
@ -234,21 +161,21 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
|
|||||||
118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996
|
118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
ggml_quantization_error_test(GgmlDType::Q4_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
ggml_quantization_error_test::<BlockQ4_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q5_0(device: &Device) -> Result<()> {
|
#[test]
|
||||||
// TODO Enable this later when we enable cuda.
|
fn quantize_q5_0() -> Result<()> {
|
||||||
if device.is_cuda() {
|
use k_quants::BlockQ5_0;
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let mut dst = vec![0f32; 32 * 4];
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
let mut quant = vec![BlockQ5_0::zeros(); 4];
|
||||||
let dst = quant.dequantize(device)?;
|
BlockQ5_0::from_float(&src, &mut quant)?;
|
||||||
|
BlockQ5_0::to_float(&quant, dst.as_mut_slice())?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
round_vector(&dst.to_vec1::<f32>()?),
|
round_vector(&dst),
|
||||||
&[
|
&[
|
||||||
-0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625,
|
-0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625,
|
||||||
11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313,
|
11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313,
|
||||||
@ -264,21 +191,21 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
|
|||||||
119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0
|
119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
ggml_quantization_error_test(GgmlDType::Q5_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
ggml_quantization_error_test::<BlockQ5_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q5_1(device: &Device) -> Result<()> {
|
#[test]
|
||||||
// TODO Enable this later when we enable cuda.
|
fn quantize_q5_1() -> Result<()> {
|
||||||
if device.is_cuda() {
|
use k_quants::BlockQ5_1;
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let mut dst = vec![0f32; 32 * 4];
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
let mut quant = vec![BlockQ5_1::zeros(); 4];
|
||||||
let dst = quant.dequantize(device)?;
|
BlockQ5_1::from_float(&src, &mut quant)?;
|
||||||
|
BlockQ5_1::to_float(&quant, dst.as_mut_slice())?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
round_vector(&dst.to_vec1::<f32>()?),
|
dst,
|
||||||
&[
|
&[
|
||||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
|
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
|
||||||
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
|
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
|
||||||
@ -292,11 +219,13 @@ fn quantize_q5_1(device: &Device) -> Result<()> {
|
|||||||
124.0, 125.0, 126.0, 127.0
|
124.0, 125.0, 126.0, 127.0
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
ggml_quantization_error_test(GgmlDType::Q5_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
|
||||||
|
ggml_quantization_error_test::<BlockQ5_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result<Tensor> {
|
/// Generates a small test vector ranging from -`bound` to `bound` with `size` steps
|
||||||
|
fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
|
||||||
assert!(
|
assert!(
|
||||||
size % crate::quantized::k_quants::QK_K == 0,
|
size % crate::quantized::k_quants::QK_K == 0,
|
||||||
"size must be a multiple of {}",
|
"size must be a multiple of {}",
|
||||||
@ -306,8 +235,10 @@ fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result<Tensor>
|
|||||||
let src = (0..size)
|
let src = (0..size)
|
||||||
.map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.))
|
.map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let dst = vec![0f32; size];
|
||||||
assert_eq!([src[0], src[size / 2]], [-bound, 0.0]);
|
assert_eq!([src[0], src[size / 2]], [-bound, 0.0]);
|
||||||
Tensor::from_vec(src, (size,), device)
|
(src, dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Round a vector
|
/// Round a vector
|
||||||
@ -334,8 +265,7 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a vector similar to the ones used in GGML unit tests:
|
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
|
||||||
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
|
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
|
||||||
(0..GGML_TEST_SIZE)
|
(0..GGML_TEST_SIZE)
|
||||||
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
|
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
|
||||||
@ -354,16 +284,14 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
|
|||||||
sum / a.len() as f32
|
sum / a.len() as f32
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Similar to the GGML quantization unit test:
|
/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
||||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
|
||||||
fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f32) -> Result<()> {
|
|
||||||
let src = create_ggml_like_vector(0.0);
|
let src = create_ggml_like_vector(0.0);
|
||||||
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
|
let mut dst = vec![0.0; GGML_TEST_SIZE];
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
|
||||||
let dst = quant.dequantize(device)?;
|
let error = calculate_rmse(src.as_slice(), dst.as_slice());
|
||||||
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
|
|
||||||
if error > max_error {
|
if error > max_error {
|
||||||
bail!(
|
candle_core::bail!(
|
||||||
"Quantization error {} exceeds max error {}",
|
"Quantization error {} exceeds max error {}",
|
||||||
error,
|
error,
|
||||||
max_error
|
max_error
|
||||||
@ -372,19 +300,19 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q2k(device: &Device) -> Result<()> {
|
fn quantize_roundtrip<T: GgmlType>(src: &[f32], dst: &mut [f32]) -> Result<Vec<T>> {
|
||||||
// TODO Enable this later when we enable cuda.
|
let mut quant = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
||||||
if device.is_cuda() {
|
T::from_float(src, &mut quant)?;
|
||||||
return Ok(());
|
T::to_float(&quant, dst)?;
|
||||||
}
|
Ok(quant)
|
||||||
let dtype = GgmlDType::Q2K;
|
}
|
||||||
|
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
#[test]
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
fn quantize_q2k() -> Result<()> {
|
||||||
let dst = quant.dequantize(device)?;
|
use k_quants::BlockQ2K;
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let _quant = quantize_roundtrip::<BlockQ2K>(src.as_slice(), dst.as_mut_slice())?;
|
||||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.1);
|
compare_with_error(dst.as_slice(), src.as_slice(), 0.1);
|
||||||
|
|
||||||
// Test some specific values
|
// Test some specific values
|
||||||
@ -398,30 +326,20 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
|||||||
[-0.499, -0.366, -0.249, 0.0, 0.295, 0.492]
|
[-0.499, -0.366, -0.249, 0.0, 0.295, 0.492]
|
||||||
);
|
);
|
||||||
|
|
||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let _quant_big = quantize_roundtrip::<BlockQ2K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
|
||||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0);
|
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0);
|
||||||
|
|
||||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
|
ggml_quantization_error_test::<BlockQ2K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q3k(device: &Device) -> Result<()> {
|
#[test]
|
||||||
// TODO Enable this later when we enable cuda.
|
fn quantize_q3k() -> Result<()> {
|
||||||
if device.is_cuda() {
|
use k_quants::BlockQ3K;
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let dtype = GgmlDType::Q3K;
|
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
|
||||||
let dst = quant.dequantize(device)?;
|
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let _quant = quantize_roundtrip::<BlockQ3K>(src.as_slice(), dst.as_mut_slice())?;
|
||||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.03);
|
compare_with_error(dst.as_slice(), src.as_slice(), 0.03);
|
||||||
|
|
||||||
// Test some specific values
|
// Test some specific values
|
||||||
@ -435,30 +353,20 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
|||||||
[-0.493, -0.37, -0.243, -0.0, 0.292, 0.492]
|
[-0.493, -0.37, -0.243, -0.0, 0.292, 0.492]
|
||||||
);
|
);
|
||||||
|
|
||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let _quant_big = quantize_roundtrip::<BlockQ3K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
|
||||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5);
|
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5);
|
||||||
|
|
||||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
|
ggml_quantization_error_test::<BlockQ3K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q4k(device: &Device) -> Result<()> {
|
#[test]
|
||||||
// TODO Enable this later when we enable cuda.
|
fn quantize_q4k() -> Result<()> {
|
||||||
if device.is_cuda() {
|
use k_quants::BlockQ4K;
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let dtype = GgmlDType::Q4K;
|
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
|
||||||
let dst = quant.dequantize(device)?;
|
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let _quant = quantize_roundtrip::<BlockQ4K>(src.as_slice(), dst.as_mut_slice())?;
|
||||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.017);
|
compare_with_error(dst.as_slice(), src.as_slice(), 0.017);
|
||||||
|
|
||||||
// Test some specific values
|
// Test some specific values
|
||||||
@ -472,31 +380,21 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
|||||||
[-0.5, -0.373, -0.25, 0.0, 0.288, 0.498]
|
[-0.5, -0.373, -0.25, 0.0, 0.288, 0.498]
|
||||||
);
|
);
|
||||||
|
|
||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let _quant_big = quantize_roundtrip::<BlockQ4K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
|
||||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5);
|
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5);
|
||||||
|
|
||||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
ggml_quantization_error_test::<BlockQ4K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q5k(device: &Device) -> Result<()> {
|
#[test]
|
||||||
// TODO Enable this later when we enable cuda.
|
fn quantize_q5k() -> Result<()> {
|
||||||
if device.is_cuda() {
|
use k_quants::BlockQ5K;
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let dtype = GgmlDType::Q5K;
|
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
|
||||||
let dst = quant.dequantize(device)?;
|
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let _quant = quantize_roundtrip::<BlockQ5K>(src.as_slice(), dst.as_mut_slice())?;
|
||||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.009);
|
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||||
|
|
||||||
// Test some specific values
|
// Test some specific values
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -506,33 +404,24 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
|||||||
let dst = round_vector(&dst);
|
let dst = round_vector(&dst);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||||
[-0.5, -0.373, -0.25, 0.0, 0.279, 0.499]
|
[-0.499, -0.372, -0.249, 0.001, 0.279, 0.499]
|
||||||
);
|
);
|
||||||
|
|
||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let _quant_big = quantize_roundtrip::<BlockQ5K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
|
||||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5);
|
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5);
|
||||||
|
|
||||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
ggml_quantization_error_test::<BlockQ5K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q6k(device: &Device) -> Result<()> {
|
#[test]
|
||||||
// TODO Enable this later when we enable cuda.
|
fn quantize_q6k() -> Result<()> {
|
||||||
if device.is_cuda() {
|
use k_quants::BlockQ6K;
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let dtype = GgmlDType::Q6K;
|
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
|
||||||
let dst = quant.dequantize(device)?;
|
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let _quant = quantize_roundtrip::<BlockQ6K>(src.as_slice(), dst.as_mut_slice())?;
|
||||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||||
|
|
||||||
// Test some specific values
|
// Test some specific values
|
||||||
@ -546,31 +435,22 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
|||||||
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
|
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
|
||||||
);
|
);
|
||||||
|
|
||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let _quant_big = quantize_roundtrip::<BlockQ6K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
|
||||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0);
|
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0);
|
||||||
|
|
||||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
ggml_quantization_error_test::<BlockQ6K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q8k(device: &Device) -> Result<()> {
|
#[test]
|
||||||
// TODO Enable this later when we enable cuda.
|
fn quantize_q8k() -> Result<()> {
|
||||||
if device.is_cuda() {
|
use k_quants::BlockQ8K;
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let dtype = GgmlDType::Q8K;
|
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
|
||||||
let dst = quant.dequantize(device)?;
|
|
||||||
|
|
||||||
let src = src.to_vec1::<f32>()?;
|
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||||
let dst = dst.to_vec1::<f32>()?;
|
let _quant = quantize_roundtrip::<BlockQ8K>(src.as_slice(), dst.as_mut_slice())?;
|
||||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
compare_with_error(dst.as_slice(), src.as_slice(), 0.003);
|
||||||
|
|
||||||
// Test some specific values
|
// Test some specific values
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -583,79 +463,15 @@ fn quantize_q8k(device: &Device) -> Result<()> {
|
|||||||
[-0.5, -0.375, -0.25, -0.0, 0.281, 0.499]
|
[-0.5, -0.375, -0.25, -0.0, 0.281, 0.499]
|
||||||
);
|
);
|
||||||
|
|
||||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
let _quant_big = quantize_roundtrip::<BlockQ8K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||||
let dst_big = quant_big.dequantize(device)?;
|
|
||||||
|
|
||||||
let src_big = src_big.to_vec1::<f32>()?;
|
|
||||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
|
||||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6);
|
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6);
|
||||||
|
|
||||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
ggml_quantization_error_test::<BlockQ8K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
test_device!(
|
|
||||||
quantize_q4_0,
|
|
||||||
quantize_q4_0_cpu,
|
|
||||||
quantize_q4_0_cuda,
|
|
||||||
quantize_q4_0_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
quantize_q4_1,
|
|
||||||
quantize_q4_1_cpu,
|
|
||||||
quantize_q4_1_cuda,
|
|
||||||
quantize_q4_1_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
quantize_q5_0,
|
|
||||||
quantize_q5_0_cpu,
|
|
||||||
quantize_q5_0_cuda,
|
|
||||||
quantize_q5_0_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
quantize_q5_1,
|
|
||||||
quantize_q5_1_cpu,
|
|
||||||
quantize_q5_1_cuda,
|
|
||||||
quantize_q5_1_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
quantize_q2k,
|
|
||||||
quantize_q2k_cpu,
|
|
||||||
quantize_q2k_cuda,
|
|
||||||
quantize_q2k_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
quantize_q3k,
|
|
||||||
quantize_q3k_cpu,
|
|
||||||
quantize_q3k_cuda,
|
|
||||||
quantize_q3k_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
quantize_q4k,
|
|
||||||
quantize_q4k_cpu,
|
|
||||||
quantize_q4k_cuda,
|
|
||||||
quantize_q4k_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
quantize_q5k,
|
|
||||||
quantize_q5k_cpu,
|
|
||||||
quantize_q5k_cuda,
|
|
||||||
quantize_q5k_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
quantize_q6k,
|
|
||||||
quantize_q6k_cpu,
|
|
||||||
quantize_q6k_cuda,
|
|
||||||
quantize_q6k_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
|
||||||
quantize_q8k,
|
|
||||||
quantize_q8k_cpu,
|
|
||||||
quantize_q8k_cuda,
|
|
||||||
quantize_q8k_metal
|
|
||||||
);
|
|
||||||
|
|
||||||
/// Very simple dot product implementation
|
/// Very simple dot product implementation
|
||||||
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
|
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
|
||||||
a.iter().zip(b).map(|(a, b)| a * b).sum()
|
a.iter().zip(b).map(|(a, b)| a * b).sum()
|
||||||
@ -671,66 +487,54 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
|
|||||||
GgmlDType::Q5K => 0.000740,
|
GgmlDType::Q5K => 0.000740,
|
||||||
GgmlDType::Q6K => 0.000952,
|
GgmlDType::Q6K => 0.000952,
|
||||||
GgmlDType::Q4_0 => 0.001143,
|
GgmlDType::Q4_0 => 0.001143,
|
||||||
GgmlDType::Q4_1 => 0.008,
|
GgmlDType::Q4_1 => 0.007784,
|
||||||
GgmlDType::Q5_0 => 0.001353,
|
GgmlDType::Q5_0 => 0.001353,
|
||||||
GgmlDType::Q5_1 => 0.00149,
|
GgmlDType::Q5_1 => 0.001363,
|
||||||
GgmlDType::Q8_0 => 0.000092,
|
GgmlDType::Q8_0 => 0.000092,
|
||||||
|
|
||||||
// Not from the ggml repo.
|
// Not from the ggml repo.
|
||||||
GgmlDType::Q8K => 0.00065,
|
GgmlDType::Q8K => 0.00065,
|
||||||
_ => bail!("No GGML results for quantization type {dtype:?}",),
|
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
||||||
};
|
};
|
||||||
Ok(err)
|
Ok(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Similar to the GGML matmul unit test:
|
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
||||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
|
||||||
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||||
let a = create_ggml_like_vector(0.0);
|
let a = create_ggml_like_vector(0.0);
|
||||||
let b = create_ggml_like_vector(1.0);
|
let b = create_ggml_like_vector(1.0);
|
||||||
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 1.0)?;
|
|
||||||
// Another example that is more likely to trigger the overflow reported in #1526
|
|
||||||
let a = (0..GGML_TEST_SIZE)
|
|
||||||
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let b = (0..GGML_TEST_SIZE)
|
|
||||||
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 2.0)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Result<()> {
|
|
||||||
let length = a.len();
|
let length = a.len();
|
||||||
|
|
||||||
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
|
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
|
||||||
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
|
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
|
||||||
T::from_float(a, &mut a_quant)?;
|
T::from_float(&a, &mut a_quant)?;
|
||||||
T::VecDotType::from_float(b, &mut b_quant)?;
|
T::VecDotType::from_float(&b, &mut b_quant)?;
|
||||||
|
|
||||||
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
||||||
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
|
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
|
||||||
let reference_result = vec_dot_reference(a, b);
|
let reference_result = vec_dot_reference(&a, &b);
|
||||||
|
|
||||||
if (result - result_unopt).abs() / length as f32 > 1e-6 {
|
if (result - result_unopt).abs() / length as f32 > 1e-6 {
|
||||||
bail!(
|
candle_core::bail!(
|
||||||
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
|
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
let error = (result - reference_result).abs() / length as f32;
|
let error = (result - reference_result).abs() / length as f32;
|
||||||
|
|
||||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m;
|
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
||||||
|
|
||||||
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
|
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||||
bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",);
|
candle_core::bail!(
|
||||||
|
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
|
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
|
||||||
// => we use a slightly higher error threshold
|
// => we use a slightly higher error threshold
|
||||||
const ERROR_LENIENCY: f32 = 0.00001;
|
const ERROR_LENIENCY: f32 = 0.00001;
|
||||||
if error - ERROR_LENIENCY > ggml_error {
|
if error - ERROR_LENIENCY > ggml_error {
|
||||||
bail!(
|
candle_core::bail!(
|
||||||
"Dot product error {} exceeds ggml reference error {}",
|
"Dot product error {} exceeds ggml reference error {}",
|
||||||
error,
|
error,
|
||||||
ggml_error
|
ggml_error
|
||||||
@ -739,16 +543,6 @@ fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Res
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn quantized_mm() -> Result<()> {
|
|
||||||
ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
|
|
||||||
ggml_matmul_error_test::<k_quants::BlockQ4_1>()?;
|
|
||||||
ggml_matmul_error_test::<k_quants::BlockQ5_0>()?;
|
|
||||||
ggml_matmul_error_test::<k_quants::BlockQ5_1>()?;
|
|
||||||
ggml_matmul_error_test::<k_quants::BlockQ8_0>()?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
|
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
|
||||||
fn get_random_tensors(
|
fn get_random_tensors(
|
||||||
m: usize,
|
m: usize,
|
||||||
@ -772,112 +566,6 @@ fn get_random_tensors(
|
|||||||
Ok((lhs, rhs, mm))
|
Ok((lhs, rhs, mm))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[macro_export]
|
|
||||||
macro_rules! quantized_matmul {
|
|
||||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
|
||||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
|
||||||
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
|
|
||||||
fn $fn_name(device: &Device) -> Result<()> {
|
|
||||||
if device.is_cuda() {
|
|
||||||
// TODO Enable Cuda GGML sometime maybe.
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
test_matmul(device, (1, 3, 4, 256), $dtype)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
test_device!($fn_name, $fn_name_cpu, $fn_name_cuda, $fn_name_metal);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
quantized_matmul!(
|
|
||||||
quantized_matmul_q4_0_bis,
|
|
||||||
quantized_matmul_q4_0_cpu,
|
|
||||||
quantized_matmul_q4_0_cuda,
|
|
||||||
quantized_matmul_q4_0_metal,
|
|
||||||
GgmlDType::Q4_0
|
|
||||||
);
|
|
||||||
quantized_matmul!(
|
|
||||||
quantized_matmul_q4_1_bis,
|
|
||||||
quantized_matmul_q4_1_cpu,
|
|
||||||
quantized_matmul_q4_1_cuda,
|
|
||||||
quantized_matmul_q4_1_metal,
|
|
||||||
GgmlDType::Q4_1
|
|
||||||
);
|
|
||||||
quantized_matmul!(
|
|
||||||
quantized_matmul_q5_0_bis,
|
|
||||||
quantized_matmul_q5_0_cpu,
|
|
||||||
quantized_matmul_q5_0_cuda,
|
|
||||||
quantized_matmul_q5_0_metal,
|
|
||||||
GgmlDType::Q5_0
|
|
||||||
);
|
|
||||||
quantized_matmul!(
|
|
||||||
quantized_matmul_q5_1_bis,
|
|
||||||
quantized_matmul_q5_1_cpu,
|
|
||||||
quantized_matmul_q5_1_cuda,
|
|
||||||
quantized_matmul_q5_1_metal,
|
|
||||||
GgmlDType::Q5_1
|
|
||||||
);
|
|
||||||
quantized_matmul!(
|
|
||||||
quantized_matmul_q8_0_bis,
|
|
||||||
quantized_matmul_q8_0_cpu,
|
|
||||||
quantized_matmul_q8_0_cuda,
|
|
||||||
quantized_matmul_q8_0_metal,
|
|
||||||
GgmlDType::Q8_0
|
|
||||||
);
|
|
||||||
// Not implemented in Ggml
|
|
||||||
// quantized_matmul!(
|
|
||||||
// quantized_matmul_q8_1_bis,
|
|
||||||
// quantized_matmul_q8_1_cpu,
|
|
||||||
// quantized_matmul_q8_1_cuda,
|
|
||||||
// quantized_matmul_q8_1_metal,
|
|
||||||
// GgmlDType::Q8_1
|
|
||||||
// );
|
|
||||||
// TODO This is bugged (also bugged in GGML
|
|
||||||
quantized_matmul!(
|
|
||||||
quantized_matmul_q2k_bis,
|
|
||||||
quantized_matmul_q2k_cpu,
|
|
||||||
quantized_matmul_q2k_cuda,
|
|
||||||
quantized_matmul_q2k_metal,
|
|
||||||
GgmlDType::Q2K
|
|
||||||
);
|
|
||||||
quantized_matmul!(
|
|
||||||
quantized_matmul_q3k_bis,
|
|
||||||
quantized_matmul_q3k_cpu,
|
|
||||||
quantized_matmul_q3k_cuda,
|
|
||||||
quantized_matmul_q3k_metal,
|
|
||||||
GgmlDType::Q3K
|
|
||||||
);
|
|
||||||
quantized_matmul!(
|
|
||||||
quantized_matmul_q4k_bis,
|
|
||||||
quantized_matmul_q4k_cpu,
|
|
||||||
quantized_matmul_q4k_cuda,
|
|
||||||
quantized_matmul_q4k_metal,
|
|
||||||
GgmlDType::Q4K
|
|
||||||
);
|
|
||||||
quantized_matmul!(
|
|
||||||
quantized_matmul_q5k_bis,
|
|
||||||
quantized_matmul_q5k_cpu,
|
|
||||||
quantized_matmul_q5k_cuda,
|
|
||||||
quantized_matmul_q5k_metal,
|
|
||||||
GgmlDType::Q5K
|
|
||||||
);
|
|
||||||
quantized_matmul!(
|
|
||||||
quantized_matmul_q6k_bis,
|
|
||||||
quantized_matmul_q6k_cpu,
|
|
||||||
quantized_matmul_q6k_cuda,
|
|
||||||
quantized_matmul_q6k_metal,
|
|
||||||
GgmlDType::Q6K
|
|
||||||
);
|
|
||||||
// Not implemented on metal
|
|
||||||
// quantized_matmul!(
|
|
||||||
// quantized_matmul_q8k_bis,
|
|
||||||
// quantized_matmul_q8k_cpu,
|
|
||||||
// quantized_matmul_q8k_cuda,
|
|
||||||
// quantized_matmul_q8k_metal,
|
|
||||||
// GgmlDType::Q8K
|
|
||||||
// );
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn quantized_matmul_q2k() -> Result<()> {
|
fn quantized_matmul_q2k() -> Result<()> {
|
||||||
use k_quants::BlockQ2K;
|
use k_quants::BlockQ2K;
|
||||||
@ -890,7 +578,7 @@ fn quantized_matmul_q2k() -> Result<()> {
|
|||||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
|
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
@ -916,7 +604,7 @@ fn quantized_matmul_q3k() -> Result<()> {
|
|||||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q3K)?;
|
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
@ -942,7 +630,7 @@ fn quantized_matmul_q4k() -> Result<()> {
|
|||||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q4K)?;
|
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
@ -968,7 +656,7 @@ fn quantized_matmul_q5k() -> Result<()> {
|
|||||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q5K)?;
|
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
@ -995,7 +683,7 @@ fn quantized_matmul_q6k() -> Result<()> {
|
|||||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q6K)?;
|
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
@ -1020,7 +708,7 @@ fn quantized_matmul_q8k() -> Result<()> {
|
|||||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||||
|
|
||||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q8K)?;
|
let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||||
let mm = rhs.forward(&lhs)?;
|
let mm = rhs.forward(&lhs)?;
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D};
|
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
|
||||||
|
|
||||||
fn zeros(device: &Device) -> Result<()> {
|
fn zeros(device: &Device) -> Result<()> {
|
||||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||||
@ -32,14 +32,6 @@ fn ones(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn full(device: &Device) -> Result<()> {
|
|
||||||
assert_eq!(
|
|
||||||
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
|
|
||||||
[[42, 42, 42], [42, 42, 42]],
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn arange(device: &Device) -> Result<()> {
|
fn arange(device: &Device) -> Result<()> {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
|
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
|
||||||
@ -1080,7 +1072,6 @@ fn randn(device: &Device) -> Result<()> {
|
|||||||
|
|
||||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
|
||||||
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
||||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||||
@ -1168,100 +1159,3 @@ fn i64_abs() -> Result<()> {
|
|||||||
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn tril_triu_eye() -> Result<()> {
|
|
||||||
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
|
|
||||||
assert_eq!(
|
|
||||||
t.to_vec2::<f32>()?,
|
|
||||||
[
|
|
||||||
[1.0, 0.0, 0.0, 0.0],
|
|
||||||
[1.0, 1.0, 0.0, 0.0],
|
|
||||||
[1.0, 1.0, 1.0, 0.0],
|
|
||||||
[1.0, 1.0, 1.0, 1.0]
|
|
||||||
],
|
|
||||||
);
|
|
||||||
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
|
|
||||||
assert_eq!(
|
|
||||||
t.to_vec2::<f32>()?,
|
|
||||||
[
|
|
||||||
[1.0, 1.0, 1.0, 1.0],
|
|
||||||
[0.0, 1.0, 1.0, 1.0],
|
|
||||||
[0.0, 0.0, 1.0, 1.0],
|
|
||||||
[0.0, 0.0, 0.0, 1.0]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
|
|
||||||
assert_eq!(
|
|
||||||
t.to_vec2::<f32>()?,
|
|
||||||
[
|
|
||||||
[1.0, 0.0, 0.0, 0.0],
|
|
||||||
[0.0, 1.0, 0.0, 0.0],
|
|
||||||
[0.0, 0.0, 1.0, 0.0],
|
|
||||||
[0.0, 0.0, 0.0, 1.0]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn cumsum() -> Result<()> {
|
|
||||||
let t = &[3f32, 1., 4., 1., 5.];
|
|
||||||
let t = Tensor::new(t, &Device::Cpu)?;
|
|
||||||
assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]);
|
|
||||||
let t = t.unsqueeze(1)?;
|
|
||||||
assert_eq!(
|
|
||||||
t.cumsum(0)?.to_vec2::<f32>()?,
|
|
||||||
[[3.0], [4.0], [8.0], [9.0], [14.0]]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
t.cumsum(1)?.to_vec2::<f32>()?,
|
|
||||||
[[3.0], [1.0], [4.0], [1.0], [5.0]]
|
|
||||||
);
|
|
||||||
let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
|
||||||
let t = Tensor::new(t, &Device::Cpu)?;
|
|
||||||
assert_eq!(
|
|
||||||
t.cumsum(1)?.to_vec2::<f32>()?,
|
|
||||||
[[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]],
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
t.cumsum(0)?.to_vec2::<f32>()?,
|
|
||||||
[[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data.
|
|
||||||
/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon.
|
|
||||||
fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
|
|
||||||
let a_vec: Vec<f64> = a.to_vec1()?;
|
|
||||||
let b_vec: Vec<f64> = b.to_vec1()?;
|
|
||||||
|
|
||||||
assert_eq!(a_vec.len(), b_vec.len());
|
|
||||||
for (a, b) in a_vec.iter().zip(b_vec.iter()) {
|
|
||||||
assert!((a - b).abs() < epsilon);
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn log_sum_exp() -> Result<()> {
|
|
||||||
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
|
||||||
let output = input.log_sum_exp(D::Minus1)?;
|
|
||||||
// The expectations obtained from pytorch.
|
|
||||||
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
|
||||||
assert_close(&output, &expected, 0.00001)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn pow() -> Result<()> {
|
|
||||||
let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
|
||||||
let rhs = (&lhs - 2.)?;
|
|
||||||
let res = lhs.pow(&rhs)?;
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec2_round(&res, 4)?,
|
|
||||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
@ -11,8 +11,8 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle = { workspace = true }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { workspace = true }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
hf-hub = { workspace = true}
|
hf-hub = { workspace = true}
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
memmap2 = { workspace = true }
|
memmap2 = { workspace = true }
|
||||||
|
@ -11,17 +11,14 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { workspace = true }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-datasets = { workspace = true }
|
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||||
candle-nn = { workspace = true }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
candle-transformers = { workspace = true }
|
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||||
candle-flash-attn = { workspace = true, optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||||
candle-onnx = { workspace = true, optional = true }
|
candle-onnx = { path = "../candle-onnx", version = "0.3.0", optional = true }
|
||||||
|
|
||||||
csv = "1.3.0"
|
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
half = { workspace = true, optional = true }
|
half = { workspace = true, optional = true }
|
||||||
hf-hub = { workspace = true, features=["tokio"]}
|
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
@ -36,6 +33,7 @@ tokenizers = { workspace = true, features = ["onig"] }
|
|||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
clap = { workspace = true }
|
clap = { workspace = true }
|
||||||
|
hf-hub = { workspace = true, features=["tokio"]}
|
||||||
imageproc = { workspace = true }
|
imageproc = { workspace = true }
|
||||||
memmap2 = { workspace = true }
|
memmap2 = { workspace = true }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
@ -49,18 +47,16 @@ tokio = "1.29.1"
|
|||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
bindgen_cuda = { version = "0.1.1", optional = true }
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||||
cudnn = ["candle/cudnn"]
|
cudnn = ["candle/cudnn"]
|
||||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
onnx = ["candle-onnx"]
|
onnx = ["candle-onnx"]
|
||||||
metal = ["candle/metal", "candle-nn/metal"]
|
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
|
@ -4,28 +4,235 @@ use std::io::Write;
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
struct KernelDirectories {
|
struct KernelDirectories {
|
||||||
kernel_glob: &'static str,
|
kernel_dir: &'static str,
|
||||||
rust_target: &'static str,
|
rust_target: &'static str,
|
||||||
include_dirs: &'static [&'static str],
|
include_dirs: &'static [&'static str],
|
||||||
}
|
}
|
||||||
|
|
||||||
const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
const DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
||||||
kernel_glob: "examples/custom-ops/kernels/*.cu",
|
kernel_dir: "examples/custom-ops/kernels/",
|
||||||
rust_target: "examples/custom-ops/cuda_kernels.rs",
|
rust_target: "examples/custom-ops/cuda_kernels.rs",
|
||||||
include_dirs: &[],
|
include_dirs: &[],
|
||||||
}];
|
}];
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
impl KernelDirectories {
|
||||||
println!("cargo:rerun-if-changed=build.rs");
|
fn maybe_build_ptx(
|
||||||
|
&self,
|
||||||
|
cu_file: &std::path::Path,
|
||||||
|
ptx_file: &std::path::Path,
|
||||||
|
compute_cap: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
let should_compile = if ptx_file.exists() {
|
||||||
|
let ptx_modified = ptx_file.metadata()?.modified()?;
|
||||||
|
let cu_modified = cu_file.metadata()?.modified()?;
|
||||||
|
cu_modified.duration_since(ptx_modified).is_ok()
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
};
|
||||||
|
if should_compile {
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
{
|
{
|
||||||
for kdir in KERNEL_DIRS.iter() {
|
let mut command = std::process::Command::new("nvcc");
|
||||||
let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob);
|
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
|
||||||
println!("cargo:info={builder:?}");
|
let include_dirs: Vec<String> =
|
||||||
let bindings = builder.build_ptx().unwrap();
|
self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
|
||||||
bindings.write(kdir.rust_target).unwrap()
|
command
|
||||||
|
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||||
|
.arg("--ptx")
|
||||||
|
.args(["--default-stream", "per-thread"])
|
||||||
|
.args(["--output-directory", out_dir.to_str().unwrap()])
|
||||||
|
.arg(format!("-I/{}", self.kernel_dir))
|
||||||
|
.args(include_dirs)
|
||||||
|
.arg(cu_file);
|
||||||
|
let output = command
|
||||||
|
.spawn()
|
||||||
|
.context("failed spawning nvcc")?
|
||||||
|
.wait_with_output()?;
|
||||||
|
if !output.status.success() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||||
|
String::from_utf8_lossy(&output.stdout),
|
||||||
|
String::from_utf8_lossy(&output.stderr)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
std::fs::OpenOptions::new()
|
||||||
|
.create(true)
|
||||||
|
.write(true)
|
||||||
|
.open(ptx_file)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> {
|
||||||
|
println!("cargo:rerun-if-changed={}", self.kernel_dir);
|
||||||
|
let kernel_dir = PathBuf::from(self.kernel_dir);
|
||||||
|
let out_dir = out_dir.join(self.kernel_dir);
|
||||||
|
if !out_dir.exists() {
|
||||||
|
std::fs::create_dir_all(&out_dir)?;
|
||||||
|
}
|
||||||
|
let mut cu_files = vec![];
|
||||||
|
let mut cuh_files = vec![];
|
||||||
|
for file in std::fs::read_dir(kernel_dir)?.flatten() {
|
||||||
|
let file = file.path();
|
||||||
|
match file.extension().and_then(|v| v.to_str()) {
|
||||||
|
Some("cu") => cu_files.push(file),
|
||||||
|
Some("cuh") => cuh_files.push(file),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut ptx_paths = vec![];
|
||||||
|
for cu_file in cu_files.iter() {
|
||||||
|
let file_stem = cu_file
|
||||||
|
.file_stem()
|
||||||
|
.with_context(|| format!("no stem {cu_file:?}"))?;
|
||||||
|
let file_stem = file_stem.to_string_lossy().into_owned();
|
||||||
|
let ptx_file = out_dir.join(&format!("{file_stem}.ptx"));
|
||||||
|
self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?;
|
||||||
|
ptx_paths.push(ptx_file);
|
||||||
|
}
|
||||||
|
|
||||||
|
let regenerate_rs_file = true;
|
||||||
|
if regenerate_rs_file {
|
||||||
|
let mut file = std::fs::File::create(self.rust_target)?;
|
||||||
|
for ptx_path in ptx_paths {
|
||||||
|
let name = ptx_path
|
||||||
|
.file_stem()
|
||||||
|
.context("empty stem")?
|
||||||
|
.to_string_lossy();
|
||||||
|
file.write_all(b"#[rustfmt::skip]\n")?;
|
||||||
|
let const_definition = format!(
|
||||||
|
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
|
||||||
|
name.to_uppercase().replace('.', "_"),
|
||||||
|
self.kernel_dir,
|
||||||
|
);
|
||||||
|
file.write_all(const_definition.as_bytes())?;
|
||||||
|
file.write_all(b"\n")?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
println!("cargo:rerun-if-changed=build.rs");
|
||||||
|
|
||||||
|
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
|
||||||
|
let out_dir = PathBuf::from(out_dir);
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
set_cuda_include_dir()?;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
let compute_cap = compute_cap()?;
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
let compute_cap = 0;
|
||||||
|
for d in DIRS {
|
||||||
|
d.process(&out_dir, compute_cap)?
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_cuda_include_dir() -> Result<()> {
|
||||||
|
// NOTE: copied from cudarc build.rs.
|
||||||
|
let env_vars = [
|
||||||
|
"CUDA_PATH",
|
||||||
|
"CUDA_ROOT",
|
||||||
|
"CUDA_TOOLKIT_ROOT_DIR",
|
||||||
|
"CUDNN_LIB",
|
||||||
|
];
|
||||||
|
let env_vars = env_vars
|
||||||
|
.into_iter()
|
||||||
|
.map(std::env::var)
|
||||||
|
.filter_map(Result::ok)
|
||||||
|
.map(Into::<PathBuf>::into);
|
||||||
|
|
||||||
|
let roots = [
|
||||||
|
"/usr",
|
||||||
|
"/usr/local/cuda",
|
||||||
|
"/opt/cuda",
|
||||||
|
"/usr/lib/cuda",
|
||||||
|
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||||
|
"C:/CUDA",
|
||||||
|
];
|
||||||
|
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||||
|
let root = env_vars
|
||||||
|
.chain(roots)
|
||||||
|
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||||
|
.context("cannot find include/cuda.h")?;
|
||||||
|
println!(
|
||||||
|
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||||
|
root.join("include").display()
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
fn compute_cap() -> Result<usize> {
|
||||||
|
// Grab compute code from nvidia-smi
|
||||||
|
let mut compute_cap = {
|
||||||
|
let out = std::process::Command::new("nvidia-smi")
|
||||||
|
.arg("--query-gpu=compute_cap")
|
||||||
|
.arg("--format=csv")
|
||||||
|
.output()
|
||||||
|
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||||
|
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||||
|
let mut lines = out.lines();
|
||||||
|
assert_eq!(
|
||||||
|
lines.next().context("missing line in stdout")?,
|
||||||
|
"compute_cap"
|
||||||
|
);
|
||||||
|
let cap = lines
|
||||||
|
.next()
|
||||||
|
.context("missing line in stdout")?
|
||||||
|
.replace('.', "");
|
||||||
|
cap.parse::<usize>()
|
||||||
|
.with_context(|| format!("cannot parse as int {cap}"))?
|
||||||
|
};
|
||||||
|
|
||||||
|
// Grab available GPU codes from nvcc and select the highest one
|
||||||
|
let max_nvcc_code = {
|
||||||
|
let out = std::process::Command::new("nvcc")
|
||||||
|
.arg("--list-gpu-code")
|
||||||
|
.output()
|
||||||
|
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||||
|
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||||
|
|
||||||
|
let out = out.lines().collect::<Vec<&str>>();
|
||||||
|
let mut codes = Vec::with_capacity(out.len());
|
||||||
|
for code in out {
|
||||||
|
let code = code.split('_').collect::<Vec<&str>>();
|
||||||
|
if !code.is_empty() && code.contains(&"sm") {
|
||||||
|
if let Ok(num) = code[1].parse::<usize>() {
|
||||||
|
codes.push(num);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
codes.sort();
|
||||||
|
if !codes.contains(&compute_cap) {
|
||||||
|
anyhow::bail!(
|
||||||
|
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
*codes.last().unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
|
||||||
|
// then choose the highest gpu code in nvcc
|
||||||
|
if compute_cap > max_nvcc_code {
|
||||||
|
println!(
|
||||||
|
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
|
||||||
|
);
|
||||||
|
compute_cap = max_nvcc_code;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||||
|
|
||||||
|
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||||
|
compute_cap = compute_cap_str
|
||||||
|
.parse::<usize>()
|
||||||
|
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
|
||||||
|
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
|
||||||
|
}
|
||||||
|
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
||||||
|
Ok(compute_cap)
|
||||||
}
|
}
|
||||||
|
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
Bert is a general large language model. In this example it can be used for two
|
Bert is a general large language model. In this example it can be used for two
|
||||||
different tasks:
|
different tasks:
|
||||||
|
|
||||||
- Compute sentence embeddings for a prompt.
|
- Compute sentence embeddings for a prompt.
|
||||||
- Compute similarities between a set of sentences.
|
- Compute similarities between a set of sentences.
|
||||||
|
|
||||||
|
|
||||||
## Sentence embeddings
|
## Sentence embeddings
|
||||||
|
|
||||||
Bert is used to compute the sentence embeddings for a prompt. The model weights
|
Bert is used to compute the sentence embeddings for a prompt. The model weights
|
||||||
@ -24,48 +24,6 @@ cargo run --example bert --release -- --prompt "Here is a test sentence"
|
|||||||
> Tensor[[1, 7, 384], f32]
|
> Tensor[[1, 7, 384], f32]
|
||||||
```
|
```
|
||||||
|
|
||||||
### Custom models
|
|
||||||
|
|
||||||
You can specify different models, such as BGE, with the `--model-id` flag:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run --example bert --release -- \
|
|
||||||
--model-id BAAI/bge-large-zh-v1.5 \
|
|
||||||
--prompt "Here is a test sentence"
|
|
||||||
Loaded and encoded 435.70775ms
|
|
||||||
[[[ 3.0944e-1, -7.8455e-5, -1.2768e0, ..., 1.3755e-2, -3.2371e-1, 2.3819e-1],
|
|
||||||
[-2.8506e-1, 1.9953e-1, -1.3076e0, ..., 6.9819e-2, 1.0833e-2, -1.1512e0],
|
|
||||||
[ 3.9892e-1, 2.0000e-1, -9.3178e-1, ..., -4.1393e-1, -4.9644e-2, -3.3786e-1],
|
|
||||||
...
|
|
||||||
[ 6.0345e-1, 3.5744e-1, -1.2672e0, ..., -6.9165e-1, -3.4973e-3, -8.4214e-1],
|
|
||||||
[ 3.9218e-1, -3.2735e-1, -1.3123e0, ..., -4.9318e-1, -5.1334e-1, -3.6391e-1],
|
|
||||||
[ 3.0978e-1, 2.5662e-4, -1.2773e0, ..., 1.3357e-2, -3.2390e-1, 2.3858e-1]]]
|
|
||||||
Tensor[[1, 9, 1024], f32]
|
|
||||||
Took 176.744667ms
|
|
||||||
```
|
|
||||||
|
|
||||||
### Gelu approximation
|
|
||||||
|
|
||||||
You can get a speedup by using an approximation of the gelu activation, with a
|
|
||||||
small loss of precision, by passing the `--approximate-gelu` flag:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ cargo run --example bert --release -- \
|
|
||||||
--model-id BAAI/bge-large-zh-v1.5 \
|
|
||||||
--prompt "Here is a test sentence" \
|
|
||||||
--approximate-gelu
|
|
||||||
Loaded and encoded 244.388042ms
|
|
||||||
[[[ 3.1048e-1, -6.0339e-4, -1.2758e0, ..., 1.3718e-2, -3.2362e-1, 2.3775e-1],
|
|
||||||
[-2.8354e-1, 1.9984e-1, -1.3077e0, ..., 6.9390e-2, 9.9681e-3, -1.1531e0],
|
|
||||||
[ 3.9947e-1, 1.9917e-1, -9.3178e-1, ..., -4.1301e-1, -5.0719e-2, -3.3955e-1],
|
|
||||||
...
|
|
||||||
[ 6.0499e-1, 3.5664e-1, -1.2642e0, ..., -6.9134e-1, -3.4581e-3, -8.4471e-1],
|
|
||||||
[ 3.9311e-1, -3.2812e-1, -1.3105e0, ..., -4.9291e-1, -5.1270e-1, -3.6543e-1],
|
|
||||||
[ 3.1082e-1, -2.6737e-4, -1.2762e0, ..., 1.3319e-2, -3.2381e-1, 2.3815e-1]]]
|
|
||||||
Tensor[[1, 9, 1024], f32]
|
|
||||||
Took 116.840791ms
|
|
||||||
```
|
|
||||||
|
|
||||||
## Similarities
|
## Similarities
|
||||||
|
|
||||||
In this example, Bert is used to compute the sentence embeddings for a set of
|
In this example, Bert is used to compute the sentence embeddings for a set of
|
||||||
|
@ -3,7 +3,7 @@ extern crate intel_mkl_src;
|
|||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
|
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::Tensor;
|
use candle::Tensor;
|
||||||
@ -45,10 +45,6 @@ struct Args {
|
|||||||
/// L2 normalization for embeddings.
|
/// L2 normalization for embeddings.
|
||||||
#[arg(long, default_value = "true")]
|
#[arg(long, default_value = "true")]
|
||||||
normalize_embeddings: bool,
|
normalize_embeddings: bool,
|
||||||
|
|
||||||
/// Use tanh based approximation for Gelu instead of erf implementation.
|
|
||||||
#[arg(long, default_value = "false")]
|
|
||||||
approximate_gelu: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
@ -77,7 +73,7 @@ impl Args {
|
|||||||
(config, tokenizer, weights)
|
(config, tokenizer, weights)
|
||||||
};
|
};
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
let mut config: Config = serde_json::from_str(&config)?;
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let vb = if self.use_pth {
|
let vb = if self.use_pth {
|
||||||
@ -85,9 +81,6 @@ impl Args {
|
|||||||
} else {
|
} else {
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||||
};
|
};
|
||||||
if self.approximate_gelu {
|
|
||||||
config.hidden_act = HiddenAct::GeluApproximate;
|
|
||||||
}
|
|
||||||
let model = BertModel::load(vb, &config)?;
|
let model = BertModel::load(vb, &config)?;
|
||||||
Ok((model, tokenizer))
|
Ok((model, tokenizer))
|
||||||
}
|
}
|
||||||
|
@ -106,17 +106,17 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let config = blip::Config::image_captioning_large();
|
let config = blip::Config::image_captioning_large();
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let (image_embeds, device, mut model) = if args.quantized {
|
let (image_embeds, device, mut model) = if args.quantized {
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let image = load_image(args.image)?.to_device(&device)?;
|
let image = load_image(args.image)?.to_device(&device)?;
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?;
|
let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
|
||||||
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||||
(image_embeds, device, Model::Q(model))
|
(image_embeds, device, Model::Q(model))
|
||||||
} else {
|
} else {
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let image = load_image(args.image)?.to_device(&device)?;
|
let image = load_image(args.image)?.to_device(&device)?;
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
@ -0,0 +1,2 @@
|
|||||||
|
#[rustfmt::skip]
|
||||||
|
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));
|
||||||
|
@ -6,8 +6,7 @@
|
|||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
#[rustfmt::skip]
|
#[allow(unused)]
|
||||||
#[cfg(feature = "cuda")]
|
|
||||||
mod cuda_kernels;
|
mod cuda_kernels;
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
@ -1,22 +0,0 @@
|
|||||||
# candle-distilbert
|
|
||||||
|
|
||||||
DistilBert is a distiled version of the Bert model.
|
|
||||||
|
|
||||||
## Sentence embeddings
|
|
||||||
|
|
||||||
DistilBert is used to compute the sentence embeddings for a prompt. The model weights
|
|
||||||
are downloaded from the hub on the first run.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
|
||||||
|
|
||||||
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
|
||||||
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
|
||||||
> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829],
|
|
||||||
> ...
|
|
||||||
> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042],
|
|
||||||
> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190],
|
|
||||||
> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]]
|
|
||||||
> Tensor[[1, 7, 768], f32]
|
|
||||||
|
|
||||||
```
|
|
@ -1,135 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use candle::{Device, Tensor};
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use clap::Parser;
|
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
|
||||||
#[arg(long)]
|
|
||||||
model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
revision: Option<String>,
|
|
||||||
|
|
||||||
/// When set, compute embeddings for this prompt.
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: String,
|
|
||||||
|
|
||||||
/// Use the pytorch weights rather than the safetensors ones
|
|
||||||
#[arg(long)]
|
|
||||||
use_pth: bool,
|
|
||||||
|
|
||||||
/// The number of times to run the prompt.
|
|
||||||
#[arg(long, default_value = "1")]
|
|
||||||
n: usize,
|
|
||||||
|
|
||||||
/// L2 normalization for embeddings.
|
|
||||||
#[arg(long, default_value = "true")]
|
|
||||||
normalize_embeddings: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Args {
|
|
||||||
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
|
|
||||||
let device = candle_examples::device(self.cpu)?;
|
|
||||||
let default_model = "distilbert-base-uncased".to_string();
|
|
||||||
let default_revision = "main".to_string();
|
|
||||||
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
|
||||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
|
||||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
|
||||||
(None, Some(revision)) => (default_model, revision),
|
|
||||||
(None, None) => (default_model, default_revision),
|
|
||||||
};
|
|
||||||
|
|
||||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
|
||||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
|
||||||
let api = Api::new()?;
|
|
||||||
let api = api.repo(repo);
|
|
||||||
let config = api.get("config.json")?;
|
|
||||||
let tokenizer = api.get("tokenizer.json")?;
|
|
||||||
let weights = if self.use_pth {
|
|
||||||
api.get("pytorch_model.bin")?
|
|
||||||
} else {
|
|
||||||
api.get("model.safetensors")?
|
|
||||||
};
|
|
||||||
(config, tokenizer, weights)
|
|
||||||
};
|
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
|
||||||
let config: Config = serde_json::from_str(&config)?;
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let vb = if self.use_pth {
|
|
||||||
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
|
||||||
} else {
|
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
|
||||||
};
|
|
||||||
let model = DistilBertModel::load(vb, &config)?;
|
|
||||||
Ok((model, tokenizer))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_mask(size: usize, device: &Device) -> Tensor {
|
|
||||||
let mask: Vec<_> = (0..size)
|
|
||||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
|
||||||
.collect();
|
|
||||||
Tensor::from_slice(&mask, (size, size), device).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
println!("tracing...");
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
|
||||||
let device = &model.device;
|
|
||||||
|
|
||||||
let tokenizer = tokenizer
|
|
||||||
.with_padding(None)
|
|
||||||
.with_truncation(None)
|
|
||||||
.map_err(E::msg)?;
|
|
||||||
let tokens = tokenizer
|
|
||||||
.encode(args.prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
|
||||||
let mask = get_mask(tokens.len(), device);
|
|
||||||
|
|
||||||
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
|
|
||||||
println!("mask: {:?}", mask.to_vec2::<u8>());
|
|
||||||
|
|
||||||
let ys = model.forward(&token_ids, &mask)?;
|
|
||||||
println!("{ys}");
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
|
||||||
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
|
||||||
}
|
|
@ -165,7 +165,14 @@ fn main() -> Result<()> {
|
|||||||
args.revision,
|
args.revision,
|
||||||
));
|
));
|
||||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||||
let filenames = candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?;
|
let mut filenames = vec![];
|
||||||
|
for rfilename in [
|
||||||
|
"model-00001-of-00002.safetensors",
|
||||||
|
"model-00002-of-00002.safetensors",
|
||||||
|
] {
|
||||||
|
let filename = repo.get(rfilename)?;
|
||||||
|
filenames.push(filename);
|
||||||
|
}
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ extern crate accelerate_src;
|
|||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use anyhow::{bail, Error as E, Result};
|
use anyhow::{bail, Error as E, Result};
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::Parser;
|
||||||
|
|
||||||
use candle::{DType, Tensor};
|
use candle::{DType, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
@ -22,21 +22,11 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
|
|||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
use candle_transformers::models::llama as model;
|
use candle_transformers::models::llama as model;
|
||||||
use model::{Llama, LlamaConfig};
|
use model::{Config, Llama, LlamaConfig};
|
||||||
|
|
||||||
const EOS_TOKEN: &str = "</s>";
|
const EOS_TOKEN: &str = "</s>";
|
||||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
V1,
|
|
||||||
V2,
|
|
||||||
#[value(name = "solar-10.7b")]
|
|
||||||
Solar10_7B,
|
|
||||||
#[value(name = "tiny-llama-1.1b-chat")]
|
|
||||||
TinyLlama1_1BChat,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -44,6 +34,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Use npy instead of safetensors
|
||||||
|
#[arg(long)]
|
||||||
|
npy: Option<String>,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
temperature: Option<f64>,
|
temperature: Option<f64>,
|
||||||
@ -82,13 +76,17 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
/// The model size to use.
|
#[arg(long)]
|
||||||
#[arg(long, default_value = "v2")]
|
v1: bool,
|
||||||
which: Which,
|
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
|
|
||||||
|
/// The folder name that contains safetensor weights and json files
|
||||||
|
/// (same structure as huggingface online)
|
||||||
|
#[arg(long)]
|
||||||
|
local_weights: Option<String>,
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
#[arg(long, default_value_t = 1.0)]
|
#[arg(long, default_value_t = 1.0)]
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
@ -120,34 +118,65 @@ fn main() -> Result<()> {
|
|||||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||||
None => DType::F16,
|
None => DType::F16,
|
||||||
};
|
};
|
||||||
let (llama, tokenizer_filename, cache) = {
|
let (llama, tokenizer_filename, cache) = match args.npy {
|
||||||
|
Some(filename) => {
|
||||||
|
let config = if args.v1 {
|
||||||
|
Config::config_7b_v1(args.use_flash_attn)
|
||||||
|
} else {
|
||||||
|
Config::config_7b_v2(args.use_flash_attn)
|
||||||
|
};
|
||||||
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
|
let vb = VarBuilder::from_npz(filename, dtype, &device)?;
|
||||||
|
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
|
||||||
|
(Llama::load(vb, &cache, &config)?, tokenizer, cache)
|
||||||
|
}
|
||||||
|
None => {
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
let model_id = args.model_id.unwrap_or_else(|| {
|
||||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
if args.v1 {
|
||||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
"Narsil/amall-7b".to_string()
|
||||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
} else {
|
||||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
"meta-llama/Llama-2-7b-hf".to_string()
|
||||||
|
}
|
||||||
});
|
});
|
||||||
println!("loading the model weights from {model_id}");
|
println!("loading the model weights from {model_id}");
|
||||||
let revision = args.revision.unwrap_or("main".to_string());
|
let revision = args.revision.unwrap_or("main".to_string());
|
||||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
|
|
||||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
let tokenizer_filename = match &args.local_weights {
|
||||||
let config_filename = api.get("config.json")?;
|
Some(path) => (path.to_owned() + "tokenizer.json").into(),
|
||||||
|
_ => api.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
|
||||||
|
let config_filename = match &args.local_weights {
|
||||||
|
Some(path) => (path.to_owned() + "config.json").into(),
|
||||||
|
_ => api.get("config.json")?,
|
||||||
|
};
|
||||||
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
let config = config.into_config(args.use_flash_attn);
|
let config = config.into_config(args.use_flash_attn);
|
||||||
|
|
||||||
let filenames = match args.which {
|
let mut filenames = vec![];
|
||||||
Which::V1 | Which::V2 | Which::Solar10_7B => {
|
for rfilename in [
|
||||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
"model-00001-of-00002.safetensors",
|
||||||
|
"model-00002-of-00002.safetensors",
|
||||||
|
] {
|
||||||
|
match &args.local_weights {
|
||||||
|
Some(path) => {
|
||||||
|
filenames.push((path.to_owned() + rfilename).into());
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let filename = api.get(rfilename)?;
|
||||||
|
filenames.push(filename);
|
||||||
}
|
}
|
||||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
|
||||||
};
|
};
|
||||||
|
}
|
||||||
|
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||||
@ -165,14 +194,14 @@ fn main() -> Result<()> {
|
|||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
let mut token_generated = 0;
|
let mut token_generated = 0;
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
|
let context_size = if cache.use_kv_cache && index > 0 {
|
||||||
(1, index_pos)
|
1
|
||||||
} else {
|
} else {
|
||||||
(tokens.len(), 0)
|
tokens.len()
|
||||||
};
|
};
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
let logits = llama.forward(&input, context_index)?;
|
let logits = llama.forward(&input, index_pos)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
let logits = if args.repeat_penalty == 1. {
|
let logits = if args.repeat_penalty == 1. {
|
||||||
logits
|
logits
|
||||||
|
@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
.extension()
|
.extension()
|
||||||
.map_or(false, |v| v == "safetensors");
|
.map_or(false, |v| v == "safetensors");
|
||||||
let (model, config) = if is_gguf {
|
let (model, config) = if is_gguf {
|
||||||
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
|
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
|
||||||
let (_vocab_size, dim) = vb
|
let (_vocab_size, dim) = vb
|
||||||
.get_no_shape("model.embed_tokens.weight")?
|
.get_no_shape("model.embed_tokens.weight")?
|
||||||
.shape()
|
.shape()
|
||||||
@ -279,13 +279,13 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
(config.seq_len, config.head_size() / 2),
|
(config.seq_len, config.head_size() / 2),
|
||||||
"rot.freq_cis_real",
|
"rot.freq_cis_real",
|
||||||
)?
|
)?
|
||||||
.dequantize(&device)?;
|
.dequantize(&candle::Device::Cpu)?;
|
||||||
let freq_cis_imag = vb
|
let freq_cis_imag = vb
|
||||||
.get(
|
.get(
|
||||||
(config.seq_len, config.head_size() / 2),
|
(config.seq_len, config.head_size() / 2),
|
||||||
"rot.freq_cis_imag",
|
"rot.freq_cis_imag",
|
||||||
)?
|
)?
|
||||||
.dequantize(&device)?;
|
.dequantize(&candle::Device::Cpu)?;
|
||||||
|
|
||||||
let fake_vb = candle_nn::VarBuilder::from_tensors(
|
let fake_vb = candle_nn::VarBuilder::from_tensors(
|
||||||
[
|
[
|
||||||
@ -295,7 +295,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.collect(),
|
.collect(),
|
||||||
candle::DType::F32,
|
candle::DType::F32,
|
||||||
&device,
|
&candle::Device::Cpu,
|
||||||
);
|
);
|
||||||
let cache = model::Cache::new(true, &config, fake_vb)?;
|
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||||
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||||
|
@ -143,7 +143,14 @@ fn main() -> Result<()> {
|
|||||||
let config_filename = api.get("config.json")?;
|
let config_filename = api.get("config.json")?;
|
||||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
let mut filenames = vec![];
|
||||||
|
for rfilename in [
|
||||||
|
"model-00001-of-00002.safetensors",
|
||||||
|
"model-00002-of-00002.safetensors",
|
||||||
|
] {
|
||||||
|
let filename = api.get(rfilename)?;
|
||||||
|
filenames.push(filename);
|
||||||
|
}
|
||||||
|
|
||||||
if args.rank.is_none() {
|
if args.rank.is_none() {
|
||||||
let children: Vec<_> = (0..args.num_shards)
|
let children: Vec<_> = (0..args.num_shards)
|
||||||
|
@ -1,12 +0,0 @@
|
|||||||
# candle-mamba-minimal: minimal implementation of Mamba
|
|
||||||
|
|
||||||
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
|
|
||||||
|
|
||||||
## Running the example
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
|
|
||||||
Mamba is the most popular and best-selling game in the world. It has been downloaded more than 1,000 times by over 1 million people worldwide since its release on March 18th 2016.
|
|
||||||
|
|
||||||
The Mamba series of games are a collection that combines elements from all genres including action, adventure, strategy & puzzle games with some unique gameplay features such as stealth and survival. The game is also known for its innovative graphics and the ability to play in a variety of different modes like single player or multiplayer.
|
|
||||||
```
|
|
@ -1,287 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
|
|
||||||
mod model;
|
|
||||||
use model::{Config, Model};
|
|
||||||
|
|
||||||
use candle::{DType, Device, Module, Tensor};
|
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
struct TextGeneration {
|
|
||||||
model: Model,
|
|
||||||
device: Device,
|
|
||||||
tokenizer: TokenOutputStream,
|
|
||||||
logits_processor: LogitsProcessor,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextGeneration {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn new(
|
|
||||||
model: Model,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
temp: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
device: &Device,
|
|
||||||
) -> Self {
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
tokenizer: TokenOutputStream::new(tokenizer),
|
|
||||||
logits_processor,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n,
|
|
||||||
device: device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
self.tokenizer.clear();
|
|
||||||
let mut tokens = self
|
|
||||||
.tokenizer
|
|
||||||
.tokenizer()
|
|
||||||
.encode(prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
for &t in tokens.iter() {
|
|
||||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
|
||||||
print!("{t}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => anyhow::bail!("cannot find the </s> token"),
|
|
||||||
};
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for _ in 0..sample_len {
|
|
||||||
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
|
||||||
let logits = self.model.forward(&input)?;
|
|
||||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
|
||||||
&logits,
|
|
||||||
self.repeat_penalty,
|
|
||||||
&tokens[start_at..],
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let dt = start_gen.elapsed();
|
|
||||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
print!("{rest}");
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
println!(
|
|
||||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
||||||
generated_tokens as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
|
|
||||||
enum Which {
|
|
||||||
Mamba130m,
|
|
||||||
Mamba370m,
|
|
||||||
Mamba790m,
|
|
||||||
Mamba1_4b,
|
|
||||||
Mamba2_8b,
|
|
||||||
Mamba2_8bSlimPj,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for Which {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "{:?}", self)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Which {
|
|
||||||
fn model_id(&self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
Self::Mamba130m => "state-spaces/mamba-130m",
|
|
||||||
Self::Mamba370m => "state-spaces/mamba-370m",
|
|
||||||
Self::Mamba790m => "state-spaces/mamba-790m",
|
|
||||||
Self::Mamba1_4b => "state-spaces/mamba-1.4b",
|
|
||||||
Self::Mamba2_8b => "state-spaces/mamba-2.8b",
|
|
||||||
Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn revision(&self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
Self::Mamba130m
|
|
||||||
| Self::Mamba370m
|
|
||||||
| Self::Mamba790m
|
|
||||||
| Self::Mamba1_4b
|
|
||||||
| Self::Mamba2_8bSlimPj => "refs/pr/1",
|
|
||||||
Self::Mamba2_8b => "refs/pr/4",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: String,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
temperature: Option<f64>,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
|
||||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
|
||||||
sample_len: usize,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "mamba130m")]
|
|
||||||
which: Which,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
revision: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
weight_files: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
config_file: Option<String>,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
||||||
#[arg(long, default_value_t = 1.1)]
|
|
||||||
repeat_penalty: f32,
|
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
|
||||||
#[arg(long, default_value_t = 64)]
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle::utils::with_avx(),
|
|
||||||
candle::utils::with_neon(),
|
|
||||||
candle::utils::with_simd128(),
|
|
||||||
candle::utils::with_f16c()
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
||||||
args.temperature.unwrap_or(0.),
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n
|
|
||||||
);
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let api = Api::new()?;
|
|
||||||
let repo = api.repo(Repo::with_revision(
|
|
||||||
args.model_id
|
|
||||||
.unwrap_or_else(|| args.which.model_id().to_string()),
|
|
||||||
RepoType::Model,
|
|
||||||
args.revision
|
|
||||||
.unwrap_or_else(|| args.which.revision().to_string()),
|
|
||||||
));
|
|
||||||
let tokenizer_filename = match args.tokenizer_file {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => api
|
|
||||||
.model("EleutherAI/gpt-neox-20b".to_string())
|
|
||||||
.get("tokenizer.json")?,
|
|
||||||
};
|
|
||||||
let config_filename = match args.config_file {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => repo.get("config.json")?,
|
|
||||||
};
|
|
||||||
let filenames = match args.weight_files {
|
|
||||||
Some(files) => files
|
|
||||||
.split(',')
|
|
||||||
.map(std::path::PathBuf::from)
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
None => {
|
|
||||||
vec![repo.get("model.safetensors")?]
|
|
||||||
}
|
|
||||||
};
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
|
||||||
let model = Model::new(&config, vb.pp("backbone"))?;
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
args.seed,
|
|
||||||
args.temperature,
|
|
||||||
args.top_p,
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n,
|
|
||||||
&device,
|
|
||||||
);
|
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,204 +0,0 @@
|
|||||||
/// This follows the lines of:
|
|
||||||
/// https://github.com/johnma2006/mamba-minimal/blob/master/model.py
|
|
||||||
/// Simple, minimal implementation of Mamba in one file of PyTorch.
|
|
||||||
use candle::{IndexOp, Module, Result, Tensor, D};
|
|
||||||
use candle_nn::{RmsNorm, VarBuilder};
|
|
||||||
|
|
||||||
use candle_transformers::models::with_tracing::{linear, linear_no_bias, Linear};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Deserialize)]
|
|
||||||
pub struct Config {
|
|
||||||
d_model: usize,
|
|
||||||
n_layer: usize,
|
|
||||||
vocab_size: usize,
|
|
||||||
pad_vocab_size_multiple: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Config {
|
|
||||||
fn vocab_size(&self) -> usize {
|
|
||||||
let pad = self.pad_vocab_size_multiple;
|
|
||||||
(self.vocab_size + pad - 1) / pad * pad
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dt_rank(&self) -> usize {
|
|
||||||
(self.d_model + 15) / 16
|
|
||||||
}
|
|
||||||
|
|
||||||
fn d_conv(&self) -> usize {
|
|
||||||
4
|
|
||||||
}
|
|
||||||
|
|
||||||
fn d_state(&self) -> usize {
|
|
||||||
16
|
|
||||||
}
|
|
||||||
|
|
||||||
fn d_inner(&self) -> usize {
|
|
||||||
self.d_model * 2
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L177
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct MambaBlock {
|
|
||||||
in_proj: Linear,
|
|
||||||
conv1d: candle_nn::Conv1d,
|
|
||||||
x_proj: Linear,
|
|
||||||
dt_proj: Linear,
|
|
||||||
a_log: Tensor,
|
|
||||||
d: Tensor,
|
|
||||||
out_proj: Linear,
|
|
||||||
dt_rank: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MambaBlock {
|
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let d_inner = cfg.d_inner();
|
|
||||||
let d_conv = cfg.d_conv();
|
|
||||||
let d_state = cfg.d_state();
|
|
||||||
let dt_rank = cfg.dt_rank();
|
|
||||||
let in_proj = linear_no_bias(cfg.d_model, d_inner * 2, vb.pp("in_proj"))?;
|
|
||||||
let conv_cfg = candle_nn::Conv1dConfig {
|
|
||||||
groups: d_inner,
|
|
||||||
padding: d_conv - 1,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let conv1d = candle_nn::conv1d(d_inner, d_inner, d_conv, conv_cfg, vb.pp("conv1d"))?;
|
|
||||||
let x_proj = linear_no_bias(d_inner, dt_rank + d_state * 2, vb.pp("x_proj"))?;
|
|
||||||
let dt_proj = linear(dt_rank, d_inner, vb.pp("dt_proj"))?;
|
|
||||||
let a_log = vb.get((d_inner, d_state), "A_log")?;
|
|
||||||
let d = vb.get(d_inner, "D")?;
|
|
||||||
let out_proj = linear_no_bias(d_inner, cfg.d_model, vb.pp("out_proj"))?;
|
|
||||||
Ok(Self {
|
|
||||||
in_proj,
|
|
||||||
conv1d,
|
|
||||||
x_proj,
|
|
||||||
dt_proj,
|
|
||||||
a_log,
|
|
||||||
d,
|
|
||||||
out_proj,
|
|
||||||
dt_rank,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ssm(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let (_d_in, n) = self.a_log.dims2()?;
|
|
||||||
let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
|
|
||||||
let d = self.d.to_dtype(candle::DType::F32)?;
|
|
||||||
let x_dbl = xs.apply(&self.x_proj)?;
|
|
||||||
let delta = x_dbl.narrow(D::Minus1, 0, self.dt_rank)?;
|
|
||||||
let b = x_dbl.narrow(D::Minus1, self.dt_rank, n)?;
|
|
||||||
let c = x_dbl.narrow(D::Minus1, self.dt_rank + n, n)?;
|
|
||||||
let delta = delta.contiguous()?.apply(&self.dt_proj)?;
|
|
||||||
// softplus without threshold
|
|
||||||
let delta = (delta.exp()? + 1.)?.log()?;
|
|
||||||
let ss = selective_scan(xs, &delta, &a, &b, &c, &d)?;
|
|
||||||
Ok(ss)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L275
|
|
||||||
fn selective_scan(
|
|
||||||
u: &Tensor,
|
|
||||||
delta: &Tensor,
|
|
||||||
a: &Tensor,
|
|
||||||
b: &Tensor,
|
|
||||||
c: &Tensor,
|
|
||||||
d: &Tensor,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let (b_sz, l, d_in) = u.dims3()?;
|
|
||||||
let n = a.dim(1)?;
|
|
||||||
let delta = delta.t()?.reshape((b_sz, d_in, l, 1))?; // b d_in l 1
|
|
||||||
let delta_a = delta.broadcast_mul(&a.reshape((1, d_in, 1, n))?)?.exp()?;
|
|
||||||
let delta_b_u = delta
|
|
||||||
.broadcast_mul(&b.reshape((b_sz, 1, l, n))?)?
|
|
||||||
.broadcast_mul(&u.t()?.reshape((b_sz, d_in, l, 1))?)?;
|
|
||||||
let mut xs = Tensor::zeros((b_sz, d_in, n), delta_a.dtype(), delta_a.device())?;
|
|
||||||
let mut ys = Vec::with_capacity(l);
|
|
||||||
for i in 0..l {
|
|
||||||
xs = ((delta_a.i((.., .., i))? * xs)? + delta_b_u.i((.., .., i))?)?;
|
|
||||||
let y = xs.matmul(&c.i((.., i, ..))?.unsqueeze(2)?)?.squeeze(2)?;
|
|
||||||
ys.push(y)
|
|
||||||
}
|
|
||||||
let ys = Tensor::stack(ys.as_slice(), 1)?;
|
|
||||||
ys + u.broadcast_mul(d)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for MambaBlock {
|
|
||||||
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L206
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let (_b_sz, seq_len, _dim) = xs.dims3()?;
|
|
||||||
let xs_and_res = xs.apply(&self.in_proj)?.chunk(2, D::Minus1)?;
|
|
||||||
let (xs, res) = (&xs_and_res[0], &xs_and_res[1]);
|
|
||||||
let xs = xs
|
|
||||||
.t()?
|
|
||||||
.apply(&self.conv1d)?
|
|
||||||
.narrow(D::Minus1, 0, seq_len)?
|
|
||||||
.t()?;
|
|
||||||
let xs = candle_nn::ops::silu(&xs)?;
|
|
||||||
let ys = (self.ssm(&xs)? * candle_nn::ops::silu(res))?;
|
|
||||||
ys.apply(&self.out_proj)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L143
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct ResidualBlock {
|
|
||||||
mixer: MambaBlock,
|
|
||||||
norm: RmsNorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ResidualBlock {
|
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?;
|
|
||||||
let mixer = MambaBlock::new(cfg, vb.pp("mixer"))?;
|
|
||||||
Ok(Self { mixer, norm })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for ResidualBlock {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
xs.apply(&self.norm)?.apply(&self.mixer)? + xs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L56
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct Model {
|
|
||||||
embedding: candle_nn::Embedding,
|
|
||||||
layers: Vec<ResidualBlock>,
|
|
||||||
norm_f: RmsNorm,
|
|
||||||
lm_head: Linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Model {
|
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let embedding = candle_nn::embedding(cfg.vocab_size(), cfg.d_model, vb.pp("embedding"))?;
|
|
||||||
let mut layers = Vec::with_capacity(cfg.n_layer);
|
|
||||||
let vb_l = vb.pp("layers");
|
|
||||||
for layer_idx in 0..cfg.n_layer {
|
|
||||||
let layer = ResidualBlock::new(cfg, vb_l.pp(layer_idx))?;
|
|
||||||
layers.push(layer)
|
|
||||||
}
|
|
||||||
let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
|
|
||||||
let lm_head = Linear::from_weights(embedding.embeddings().clone(), None);
|
|
||||||
Ok(Self {
|
|
||||||
embedding,
|
|
||||||
layers,
|
|
||||||
norm_f,
|
|
||||||
lm_head,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Model {
|
|
||||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
|
||||||
let (_b_size, seq_len) = input_ids.dims2()?;
|
|
||||||
let mut xs = self.embedding.forward(input_ids)?;
|
|
||||||
for layer in self.layers.iter() {
|
|
||||||
xs = layer.forward(&xs)?
|
|
||||||
}
|
|
||||||
xs.narrow(1, seq_len - 1, 1)?
|
|
||||||
.apply(&self.norm_f)?
|
|
||||||
.apply(&self.lm_head)
|
|
||||||
}
|
|
||||||
}
|
|
@ -155,8 +155,8 @@ struct Args {
|
|||||||
#[arg(long, short = 'n', default_value_t = 100)]
|
#[arg(long, short = 'n', default_value_t = 100)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long, default_value = "lmz/candle-mistral")]
|
||||||
model_id: Option<String>,
|
model_id: String,
|
||||||
|
|
||||||
#[arg(long, default_value = "main")]
|
#[arg(long, default_value = "main")]
|
||||||
revision: String,
|
revision: String,
|
||||||
@ -207,18 +207,8 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let model_id = match args.model_id {
|
|
||||||
Some(model_id) => model_id,
|
|
||||||
None => {
|
|
||||||
if args.quantized {
|
|
||||||
"lmz/candle-mistral".to_string()
|
|
||||||
} else {
|
|
||||||
"mistralai/Mistral-7B-v0.1".to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
model_id,
|
args.model_id,
|
||||||
RepoType::Model,
|
RepoType::Model,
|
||||||
args.revision,
|
args.revision,
|
||||||
));
|
));
|
||||||
@ -235,7 +225,10 @@ fn main() -> Result<()> {
|
|||||||
if args.quantized {
|
if args.quantized {
|
||||||
vec![repo.get("model-q4k.gguf")?]
|
vec![repo.get("model-q4k.gguf")?]
|
||||||
} else {
|
} else {
|
||||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
vec![
|
||||||
|
repo.get("pytorch_model-00001-of-00002.safetensors")?,
|
||||||
|
repo.get("pytorch_model-00002-of-00002.safetensors")?,
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -244,14 +237,13 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let (model, device) = if args.quantized {
|
let (model, device) = if args.quantized {
|
||||||
let filename = &filenames[0];
|
let filename = &filenames[0];
|
||||||
let vb =
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
|
||||||
let model = QMistral::new(&config, vb)?;
|
let model = QMistral::new(&config, vb)?;
|
||||||
(Model::Quantized(model), device)
|
(Model::Quantized(model), Device::Cpu)
|
||||||
} else {
|
} else {
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let dtype = if device.is_cuda() {
|
let dtype = if device.is_cuda() {
|
||||||
DType::BF16
|
DType::BF16
|
||||||
} else {
|
} else {
|
||||||
|
@ -1,25 +0,0 @@
|
|||||||
# candle-mixtral: 8x7b LLM using a sparse mixture of experts.
|
|
||||||
|
|
||||||
Mixtral-8x7B-v0.1 is a pretrained generative LLM with 56 billion parameters.
|
|
||||||
|
|
||||||
- [Blog post](https://mistral.ai/news/mixtral-of-experts/) from Mistral announcing the model release.
|
|
||||||
- [Model card](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) on the HuggingFace Hub.
|
|
||||||
|
|
||||||
## Running the example
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ cargo run --example mixtral --release -- --prompt "def print_prime(n): "
|
|
||||||
def print_prime(n): # n is the number of prime numbers to be printed
|
|
||||||
i = 2
|
|
||||||
count = 0
|
|
||||||
while (count < n):
|
|
||||||
if (isPrime(i)):
|
|
||||||
print(i)
|
|
||||||
count += 1
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
def isPrime(n):
|
|
||||||
for x in range(2, int(n**0.5)+1):
|
|
||||||
if (n % x == 0):
|
|
||||||
...
|
|
||||||
```
|
|
@ -1,241 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use clap::Parser;
|
|
||||||
|
|
||||||
use candle_transformers::models::mixtral::{Config, Model};
|
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
struct TextGeneration {
|
|
||||||
model: Model,
|
|
||||||
device: Device,
|
|
||||||
tokenizer: TokenOutputStream,
|
|
||||||
logits_processor: LogitsProcessor,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextGeneration {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn new(
|
|
||||||
model: Model,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
temp: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
device: &Device,
|
|
||||||
) -> Self {
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
tokenizer: TokenOutputStream::new(tokenizer),
|
|
||||||
logits_processor,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n,
|
|
||||||
device: device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
self.tokenizer.clear();
|
|
||||||
let mut tokens = self
|
|
||||||
.tokenizer
|
|
||||||
.tokenizer()
|
|
||||||
.encode(prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
for &t in tokens.iter() {
|
|
||||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
|
||||||
print!("{t}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_token("</s>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => anyhow::bail!("cannot find the </s> token"),
|
|
||||||
};
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for index in 0..sample_len {
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
|
||||||
let ctxt = &tokens[start_pos..];
|
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
||||||
let logits = self.model.forward(&input, start_pos)?;
|
|
||||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
|
||||||
&logits,
|
|
||||||
self.repeat_penalty,
|
|
||||||
&tokens[start_at..],
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let dt = start_gen.elapsed();
|
|
||||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
print!("{rest}");
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
println!(
|
|
||||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
||||||
generated_tokens as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
use_flash_attn: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: String,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
temperature: Option<f64>,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
|
||||||
#[arg(long, short = 'n', default_value_t = 100)]
|
|
||||||
sample_len: usize,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]
|
|
||||||
model_id: String,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "main")]
|
|
||||||
revision: String,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
weight_files: Option<String>,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
||||||
#[arg(long, default_value_t = 1.1)]
|
|
||||||
repeat_penalty: f32,
|
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
|
||||||
#[arg(long, default_value_t = 64)]
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle::utils::with_avx(),
|
|
||||||
candle::utils::with_neon(),
|
|
||||||
candle::utils::with_simd128(),
|
|
||||||
candle::utils::with_f16c()
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
||||||
args.temperature.unwrap_or(0.),
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n
|
|
||||||
);
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let api = Api::new()?;
|
|
||||||
let repo = api.repo(Repo::with_revision(
|
|
||||||
args.model_id,
|
|
||||||
RepoType::Model,
|
|
||||||
args.revision,
|
|
||||||
));
|
|
||||||
let tokenizer_filename = match args.tokenizer_file {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => repo.get("tokenizer.json")?,
|
|
||||||
};
|
|
||||||
let filenames = match args.weight_files {
|
|
||||||
Some(files) => files
|
|
||||||
.split(',')
|
|
||||||
.map(std::path::PathBuf::from)
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
|
||||||
};
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let config = Config::v0_1_8x7b(args.use_flash_attn);
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let dtype = if device.is_cuda() {
|
|
||||||
DType::BF16
|
|
||||||
} else {
|
|
||||||
DType::F32
|
|
||||||
};
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
|
||||||
let model = Model::new(&config, vb)?;
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
args.seed,
|
|
||||||
args.temperature,
|
|
||||||
args.top_p,
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n,
|
|
||||||
&device,
|
|
||||||
);
|
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -321,7 +321,7 @@ impl MusicgenDecoder {
|
|||||||
let positions = self.embed_positions.forward(&input)?.to_device(dev)?;
|
let positions = self.embed_positions.forward(&input)?.to_device(dev)?;
|
||||||
let mut xs = inputs_embeds.broadcast_add(&positions)?;
|
let mut xs = inputs_embeds.broadcast_add(&positions)?;
|
||||||
let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?;
|
let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?;
|
||||||
for decoder_layer in self.layers.iter_mut() {
|
for (_layer_idx, decoder_layer) in self.layers.iter_mut().enumerate() {
|
||||||
xs = decoder_layer.forward(&xs, &attention_mask, None)?;
|
xs = decoder_layer.forward(&xs, &attention_mask, None)?;
|
||||||
}
|
}
|
||||||
let xs = self.layer_norm.forward(&xs)?;
|
let xs = self.layer_norm.forward(&xs)?;
|
||||||
|
@ -1,33 +1,14 @@
|
|||||||
# candle-phi: 1.3b and 2.7b LLM with state of the art performance for <10b models.
|
# candle-phi: 1.3b LLM with state of the art performance for <10b models.
|
||||||
|
|
||||||
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) and
|
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) is a language model using
|
||||||
[Phi-2](https://huggingface.co/microsoft/phi-2) are language models using
|
only 1.3 billion parameters but with state of the art performance compared to
|
||||||
only 1.3 and 2.7 billion parameters but with state of the art performance compared to
|
|
||||||
models with up to 10 billion parameters.
|
models with up to 10 billion parameters.
|
||||||
|
|
||||||
The candle implementation provides both the standard version as well as a
|
The candle implementation provides both the standard version as well as a
|
||||||
quantized variant.
|
quantized variant.
|
||||||
|
|
||||||
## Running some examples
|
## Running some example
|
||||||
|
|
||||||
For the v2 version.
|
|
||||||
```bash
|
|
||||||
$ cargo run --example phi --release -- --model 2 \
|
|
||||||
--prompt "A skier slides down a frictionless slope of height 40m and length 80m. What's the skier speed at the bottom?"
|
|
||||||
|
|
||||||
A skier slides down a frictionless slope of height 40m and length 80m. What's the skier speed at the bottom?
|
|
||||||
|
|
||||||
Solution:
|
|
||||||
The potential energy of the skier is converted into kinetic energy as it slides down the slope. The formula for potential energy is mgh, where m is mass, g is acceleration due to gravity (9.8 m/s^2), and h is height. Since there's no friction, all the potential energy is converted into kinetic energy at the bottom of the slope. The formula for kinetic energy is 1/2mv^2, where v is velocity. We can equate these two formulas:
|
|
||||||
mgh = 1/2mv^2
|
|
||||||
Solving for v, we get:
|
|
||||||
v = sqrt(2gh)
|
|
||||||
Substituting the given values, we get:
|
|
||||||
v = sqrt(2*9.8*40) = 28 m/s
|
|
||||||
Therefore, the skier speed at the bottom of the slope is 28 m/s.
|
|
||||||
```
|
|
||||||
|
|
||||||
For the v1.5 version.
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example phi --release -- --prompt "def print_prime(n): "
|
$ cargo run --example phi --release -- --prompt "def print_prime(n): "
|
||||||
|
|
||||||
|
@ -8,7 +8,6 @@ use anyhow::{Error as E, Result};
|
|||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||||
use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
|
|
||||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
@ -19,7 +18,6 @@ use tokenizers::Tokenizer;
|
|||||||
|
|
||||||
enum Model {
|
enum Model {
|
||||||
MixFormer(MixFormer),
|
MixFormer(MixFormer),
|
||||||
Phi(Phi),
|
|
||||||
Quantized(QMixFormer),
|
Quantized(QMixFormer),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,7 +84,6 @@ impl TextGeneration {
|
|||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
let logits = match &mut self.model {
|
let logits = match &mut self.model {
|
||||||
Model::MixFormer(m) => m.forward(&input)?,
|
Model::MixFormer(m) => m.forward(&input)?,
|
||||||
Model::Phi(m) => m.forward(&input)?,
|
|
||||||
Model::Quantized(m) => m.forward(&input)?,
|
Model::Quantized(m) => m.forward(&input)?,
|
||||||
};
|
};
|
||||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
@ -120,16 +117,12 @@ impl TextGeneration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum WhichModel {
|
enum WhichModel {
|
||||||
#[value(name = "1")]
|
#[value(name = "1")]
|
||||||
V1,
|
V1,
|
||||||
#[value(name = "1.5")]
|
#[value(name = "1.5")]
|
||||||
V1_5,
|
V1_5,
|
||||||
#[value(name = "2")]
|
|
||||||
V2,
|
|
||||||
#[value(name = "2-old")]
|
|
||||||
V2Old,
|
|
||||||
PuffinPhiV2,
|
PuffinPhiV2,
|
||||||
PhiHermes,
|
PhiHermes,
|
||||||
}
|
}
|
||||||
@ -150,10 +143,7 @@ struct Args {
|
|||||||
verbose_prompt: bool,
|
verbose_prompt: bool,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: Option<String>,
|
prompt: String,
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
mmlu_dir: Option<String>,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -168,13 +158,13 @@ struct Args {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
#[arg(long, short = 'n', default_value_t = 100)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
|
|
||||||
#[arg(long, default_value = "2")]
|
#[arg(long, default_value = "1.5")]
|
||||||
model: WhichModel,
|
model: WhichModel,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -235,7 +225,6 @@ fn main() -> Result<()> {
|
|||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||||
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
|
|
||||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
"lmz/candle-quantized-phi".to_string()
|
"lmz/candle-quantized-phi".to_string()
|
||||||
}
|
}
|
||||||
@ -250,12 +239,9 @@ fn main() -> Result<()> {
|
|||||||
"main".to_string()
|
"main".to_string()
|
||||||
} else {
|
} else {
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 => "refs/pr/8".to_string(),
|
WhichModel::V1 => "refs/pr/2".to_string(),
|
||||||
WhichModel::V1_5 => "refs/pr/73".to_string(),
|
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
||||||
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(),
|
||||||
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
|
||||||
"main".to_string()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -264,34 +250,27 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer_filename = match args.tokenizer {
|
let tokenizer_filename = match args.tokenizer {
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
None => match args.model {
|
None => match args.model {
|
||||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old => {
|
WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?,
|
||||||
repo.get("tokenizer.json")?
|
|
||||||
}
|
|
||||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
repo.get("tokenizer-puffin-phi-v2.json")?
|
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
let filenames = match args.weight_file {
|
let filename = match args.weight_file {
|
||||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
||||||
None => {
|
None => {
|
||||||
if args.quantized {
|
if args.quantized {
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
|
WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
|
||||||
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
|
WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
|
||||||
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
|
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?,
|
||||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?,
|
||||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?,
|
||||||
WhichModel::V2 | WhichModel::V2Old => candle_examples::hub_load_safetensors(
|
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?,
|
||||||
&repo,
|
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?,
|
||||||
"model.safetensors.index.json",
|
|
||||||
)?,
|
|
||||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
|
|
||||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -300,52 +279,24 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config = || match args.model {
|
let config = match args.model {
|
||||||
WhichModel::V1 => Config::v1(),
|
WhichModel::V1 => Config::v1(),
|
||||||
WhichModel::V1_5 => Config::v1_5(),
|
WhichModel::V1_5 => Config::v1_5(),
|
||||||
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
|
|
||||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||||
};
|
};
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let (model, device) = if args.quantized {
|
||||||
let model = if args.quantized {
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
||||||
let config = config();
|
let model = QMixFormer::new(&config, vb)?;
|
||||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
(Model::Quantized(model), Device::Cpu)
|
||||||
&filenames[0],
|
|
||||||
&device,
|
|
||||||
)?;
|
|
||||||
let model = match args.model {
|
|
||||||
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
|
|
||||||
_ => QMixFormer::new(&config, vb)?,
|
|
||||||
};
|
|
||||||
Model::Quantized(model)
|
|
||||||
} else {
|
} else {
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
let device = candle_examples::device(args.cpu)?;
|
||||||
match args.model {
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
|
let model = MixFormer::new(&config, vb)?;
|
||||||
let config_filename = repo.get("config.json")?;
|
(Model::MixFormer(model), device)
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
|
||||||
let config: PhiConfig = serde_json::from_str(&config)?;
|
|
||||||
let phi = Phi::new(&config, vb)?;
|
|
||||||
Model::Phi(phi)
|
|
||||||
}
|
|
||||||
WhichModel::V2Old => {
|
|
||||||
let config = config();
|
|
||||||
Model::MixFormer(MixFormer::new_v2(&config, vb)?)
|
|
||||||
}
|
|
||||||
WhichModel::PhiHermes | WhichModel::PuffinPhiV2 => {
|
|
||||||
let config = config();
|
|
||||||
Model::MixFormer(MixFormer::new(&config, vb)?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
match (args.prompt, args.mmlu_dir) {
|
|
||||||
(None, None) | (Some(_), Some(_)) => {
|
|
||||||
anyhow::bail!("exactly one of --prompt and --mmlu-dir must be specified")
|
|
||||||
}
|
|
||||||
(Some(prompt), None) => {
|
|
||||||
let mut pipeline = TextGeneration::new(
|
let mut pipeline = TextGeneration::new(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -357,93 +308,6 @@ fn main() -> Result<()> {
|
|||||||
args.verbose_prompt,
|
args.verbose_prompt,
|
||||||
&device,
|
&device,
|
||||||
);
|
);
|
||||||
pipeline.run(&prompt, args.sample_len)?;
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
}
|
|
||||||
(None, Some(mmlu_dir)) => mmlu(model, tokenizer, &device, mmlu_dir)?,
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mmlu<P: AsRef<std::path::Path>>(
|
|
||||||
mut model: Model,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
device: &Device,
|
|
||||||
mmlu_dir: P,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
for dir_entry in mmlu_dir.as_ref().read_dir()?.flatten() {
|
|
||||||
let dir_entry = dir_entry.path();
|
|
||||||
let theme = match dir_entry.file_stem().and_then(|v| v.to_str()) {
|
|
||||||
None => "".to_string(),
|
|
||||||
Some(v) => match v.strip_suffix("_test") {
|
|
||||||
None => v.replace('_', " "),
|
|
||||||
Some(v) => v.replace('_', " "),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
if dir_entry.extension().as_ref().and_then(|v| v.to_str()) != Some("csv") {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
println!("reading {dir_entry:?}");
|
|
||||||
let dir_entry = std::fs::File::open(dir_entry)?;
|
|
||||||
let mut reader = csv::ReaderBuilder::new()
|
|
||||||
.has_headers(false)
|
|
||||||
.from_reader(dir_entry);
|
|
||||||
let token_a = tokenizer.token_to_id("A").unwrap();
|
|
||||||
let token_b = tokenizer.token_to_id("B").unwrap();
|
|
||||||
let token_c = tokenizer.token_to_id("C").unwrap();
|
|
||||||
let token_d = tokenizer.token_to_id("D").unwrap();
|
|
||||||
for row in reader.records() {
|
|
||||||
let row = match row {
|
|
||||||
Err(_) => continue,
|
|
||||||
Ok(row) => row,
|
|
||||||
};
|
|
||||||
if row.len() < 5 {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let question = row.get(0).unwrap();
|
|
||||||
let answer_a = row.get(1).unwrap();
|
|
||||||
let answer_b = row.get(2).unwrap();
|
|
||||||
let answer_c = row.get(3).unwrap();
|
|
||||||
let answer_d = row.get(4).unwrap();
|
|
||||||
let answer = row.get(5).unwrap();
|
|
||||||
let prompt = format!(
|
|
||||||
"{} {theme}.\n{question}\nA. {answer_a}\nB. {answer_b}\nC. {answer_c}\nD. {answer_d}\nAnswer:\n",
|
|
||||||
"The following are multiple choice questions (with answers) about"
|
|
||||||
);
|
|
||||||
let tokens = tokenizer.encode(prompt.as_str(), true).map_err(E::msg)?;
|
|
||||||
let tokens = tokens.get_ids().to_vec();
|
|
||||||
let input = Tensor::new(tokens, device)?.unsqueeze(0)?;
|
|
||||||
let logits = match &mut model {
|
|
||||||
Model::MixFormer(m) => {
|
|
||||||
m.clear_kv_cache();
|
|
||||||
m.forward(&input)?
|
|
||||||
}
|
|
||||||
Model::Phi(m) => {
|
|
||||||
m.clear_kv_cache();
|
|
||||||
m.forward(&input)?
|
|
||||||
}
|
|
||||||
Model::Quantized(m) => {
|
|
||||||
m.clear_kv_cache();
|
|
||||||
m.forward(&input)?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
|
||||||
let pr_a = logits_v[token_a as usize];
|
|
||||||
let pr_b = logits_v[token_b as usize];
|
|
||||||
let pr_c = logits_v[token_c as usize];
|
|
||||||
let pr_d = logits_v[token_d as usize];
|
|
||||||
let model_answer = if pr_a > pr_b && pr_a > pr_c && pr_a > pr_d {
|
|
||||||
"A"
|
|
||||||
} else if pr_b > pr_c && pr_b > pr_d {
|
|
||||||
"B"
|
|
||||||
} else if pr_c > pr_d {
|
|
||||||
"C"
|
|
||||||
} else {
|
|
||||||
"D"
|
|
||||||
};
|
|
||||||
|
|
||||||
println!("{prompt}\n -> {model_answer} vs {answer}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -132,8 +132,7 @@ impl T5ModelBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
|
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||||
let device = Device::Cpu;
|
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
|
||||||
let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &device)?;
|
|
||||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -26,19 +26,6 @@ cargo run --example quantized --release -- --prompt "The best thing about coding
|
|||||||
> The best thing about coding in rust is 1.) that I don’t need to worry about memory leaks, 2.) speed and 3.) my program will compile even on old machines.
|
> The best thing about coding in rust is 1.) that I don’t need to worry about memory leaks, 2.) speed and 3.) my program will compile even on old machines.
|
||||||
```
|
```
|
||||||
|
|
||||||
Using the mixtral sparse mixture of expert model:
|
|
||||||
```bash
|
|
||||||
|
|
||||||
$ cargo run --example quantized --release -- --which mixtral --prompt "Lebesgue's integral is superior to Riemann's because "
|
|
||||||
> avx: true, neon: false, simd128: false, f16c: true
|
|
||||||
> temp: 0.80 repeat-penalty: 1.10 repeat-last-n: 64
|
|
||||||
> loaded 995 tensors (26.44GB) in 0.03s
|
|
||||||
Lebesgue's integral is superior to Riemann's because 1. it is defined for a wider class of functions, those which are absolutely integrable; 2. the definition does not involve limits in two variables---one being computed before the other (which makes some computations more difficult); and 3. interchange of order of integration is easier to establish than with Riemann's integral. On the other hand, Lebesgue's integral applies only for bounded functions defined on finite intervals; it does not provide numerical values for improper integrals. The latter are best evaluated using Cauchy's limit definition.
|
|
||||||
|
|
||||||
The reason $f(x) = x^2$ is discontinuous at the ends of its interval of definition, and Riemann's integral requires continuity on the whole of an open interval containing it (see our earlier post), sine no such function exists with this property, is that the endpoints are infinite in measure for Lebesgue's integral.
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## Command-line flags
|
## Command-line flags
|
||||||
|
|
||||||
Run with `--help` to see all options.
|
Run with `--help` to see all options.
|
||||||
|
@ -9,7 +9,7 @@ use std::io::Write;
|
|||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
use candle::quantized::{ggml_file, gguf_file};
|
use candle::quantized::{ggml_file, gguf_file};
|
||||||
use candle::Tensor;
|
use candle::{Device, Tensor};
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
@ -45,28 +45,14 @@ enum Which {
|
|||||||
L13bCode,
|
L13bCode,
|
||||||
#[value(name = "32b-code")]
|
#[value(name = "32b-code")]
|
||||||
L34bCode,
|
L34bCode,
|
||||||
#[value(name = "7b-leo")]
|
|
||||||
Leo7b,
|
|
||||||
#[value(name = "13b-leo")]
|
|
||||||
Leo13b,
|
|
||||||
#[value(name = "7b-mistral")]
|
#[value(name = "7b-mistral")]
|
||||||
Mistral7b,
|
Mistral7b,
|
||||||
#[value(name = "7b-mistral-instruct")]
|
#[value(name = "7b-mistral-instruct")]
|
||||||
Mistral7bInstruct,
|
Mistral7bInstruct,
|
||||||
#[value(name = "7b-mistral-instruct-v0.2")]
|
|
||||||
Mistral7bInstructV02,
|
|
||||||
#[value(name = "7b-zephyr-a")]
|
#[value(name = "7b-zephyr-a")]
|
||||||
Zephyr7bAlpha,
|
Zephyr7bAlpha,
|
||||||
#[value(name = "7b-zephyr-b")]
|
#[value(name = "7b-zephyr-b")]
|
||||||
Zephyr7bBeta,
|
Zephyr7bBeta,
|
||||||
#[value(name = "7b-open-chat-3.5")]
|
|
||||||
OpenChat35,
|
|
||||||
#[value(name = "7b-starling-a")]
|
|
||||||
Starling7bAlpha,
|
|
||||||
#[value(name = "mixtral")]
|
|
||||||
Mixtral,
|
|
||||||
#[value(name = "mixtral-instruct")]
|
|
||||||
MixtralInstruct,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Which {
|
impl Which {
|
||||||
@ -80,20 +66,12 @@ impl Which {
|
|||||||
| Self::L70bChat
|
| Self::L70bChat
|
||||||
| Self::L7bCode
|
| Self::L7bCode
|
||||||
| Self::L13bCode
|
| Self::L13bCode
|
||||||
| Self::L34bCode
|
| Self::L34bCode => false,
|
||||||
| Self::Leo7b
|
// Zephyr is a fine tuned version of mistral and should be treated in the same way.
|
||||||
| Self::Leo13b => false,
|
Self::Zephyr7bAlpha
|
||||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
|
||||||
// same way. Starling is a fine tuned version of OpenChat.
|
|
||||||
Self::OpenChat35
|
|
||||||
| Self::Starling7bAlpha
|
|
||||||
| Self::Zephyr7bAlpha
|
|
||||||
| Self::Zephyr7bBeta
|
| Self::Zephyr7bBeta
|
||||||
| Self::Mixtral
|
|
||||||
| Self::MixtralInstruct
|
|
||||||
| Self::Mistral7b
|
| Self::Mistral7b
|
||||||
| Self::Mistral7bInstruct
|
| Self::Mistral7bInstruct => true,
|
||||||
| Self::Mistral7bInstructV02 => true,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -108,73 +86,17 @@ impl Which {
|
|||||||
| Self::L7bCode
|
| Self::L7bCode
|
||||||
| Self::L13bCode
|
| Self::L13bCode
|
||||||
| Self::L34bCode
|
| Self::L34bCode
|
||||||
| Self::Leo7b
|
|
||||||
| Self::Leo13b
|
|
||||||
| Self::Mixtral
|
|
||||||
| Self::MixtralInstruct
|
|
||||||
| Self::Mistral7b
|
| Self::Mistral7b
|
||||||
| Self::Mistral7bInstruct
|
| Self::Mistral7bInstruct => false,
|
||||||
| Self::Mistral7bInstructV02
|
|
||||||
| Self::OpenChat35
|
|
||||||
| Self::Starling7bAlpha => false,
|
|
||||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_open_chat(&self) -> bool {
|
|
||||||
match self {
|
|
||||||
Self::L7b
|
|
||||||
| Self::L13b
|
|
||||||
| Self::L70b
|
|
||||||
| Self::L7bChat
|
|
||||||
| Self::L13bChat
|
|
||||||
| Self::L70bChat
|
|
||||||
| Self::L7bCode
|
|
||||||
| Self::L13bCode
|
|
||||||
| Self::L34bCode
|
|
||||||
| Self::Leo7b
|
|
||||||
| Self::Leo13b
|
|
||||||
| Self::Mixtral
|
|
||||||
| Self::MixtralInstruct
|
|
||||||
| Self::Mistral7b
|
|
||||||
| Self::Mistral7bInstruct
|
|
||||||
| Self::Mistral7bInstructV02
|
|
||||||
| Self::Zephyr7bAlpha
|
|
||||||
| Self::Zephyr7bBeta => false,
|
|
||||||
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn tokenizer_repo(&self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
Which::L7b
|
|
||||||
| Which::L13b
|
|
||||||
| Which::L70b
|
|
||||||
| Which::L7bChat
|
|
||||||
| Which::L13bChat
|
|
||||||
| Which::L70bChat
|
|
||||||
| Which::L7bCode
|
|
||||||
| Which::L13bCode
|
|
||||||
| Which::L34bCode => "hf-internal-testing/llama-tokenizer",
|
|
||||||
Which::Leo7b => "LeoLM/leo-hessianai-7b",
|
|
||||||
Which::Leo13b => "LeoLM/leo-hessianai-13b",
|
|
||||||
Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
|
|
||||||
Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
||||||
Which::Mistral7b
|
|
||||||
| Which::Mistral7bInstruct
|
|
||||||
| Which::Mistral7bInstructV02
|
|
||||||
| Which::Zephyr7bAlpha
|
|
||||||
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
|
||||||
Which::OpenChat35 => "openchat/openchat_3.5",
|
|
||||||
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
/// GGML/GGUF file to load, typically a .bin/.gguf file generated by the quantize command from llama.cpp
|
/// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
|
|
||||||
@ -235,7 +157,11 @@ impl Args {
|
|||||||
Some(config) => std::path::PathBuf::from(config),
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
None => {
|
None => {
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let repo = self.which.tokenizer_repo();
|
let repo = if self.which.is_mistral() {
|
||||||
|
"mistralai/Mistral-7B-v0.1"
|
||||||
|
} else {
|
||||||
|
"hf-internal-testing/llama-tokenizer"
|
||||||
|
};
|
||||||
let api = api.model(repo.to_string());
|
let api = api.model(repo.to_string());
|
||||||
api.get("tokenizer.json")?
|
api.get("tokenizer.json")?
|
||||||
}
|
}
|
||||||
@ -266,22 +192,6 @@ impl Args {
|
|||||||
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
|
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
|
||||||
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
|
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
|
||||||
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
|
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
|
||||||
Which::Leo7b => (
|
|
||||||
"TheBloke/leo-hessianai-7B-GGUF",
|
|
||||||
"leo-hessianai-7b.Q4_K_M.gguf",
|
|
||||||
),
|
|
||||||
Which::Leo13b => (
|
|
||||||
"TheBloke/leo-hessianai-13B-GGUF",
|
|
||||||
"leo-hessianai-13b.Q4_K_M.gguf",
|
|
||||||
),
|
|
||||||
Which::Mixtral => (
|
|
||||||
"TheBloke/Mixtral-8x7B-v0.1-GGUF",
|
|
||||||
"mixtral-8x7b-v0.1.Q4_K_M.gguf",
|
|
||||||
),
|
|
||||||
Which::MixtralInstruct => (
|
|
||||||
"TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF",
|
|
||||||
"mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf",
|
|
||||||
),
|
|
||||||
Which::Mistral7b => (
|
Which::Mistral7b => (
|
||||||
"TheBloke/Mistral-7B-v0.1-GGUF",
|
"TheBloke/Mistral-7B-v0.1-GGUF",
|
||||||
"mistral-7b-v0.1.Q4_K_S.gguf",
|
"mistral-7b-v0.1.Q4_K_S.gguf",
|
||||||
@ -290,10 +200,6 @@ impl Args {
|
|||||||
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
||||||
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
|
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
|
||||||
),
|
),
|
||||||
Which::Mistral7bInstructV02 => (
|
|
||||||
"TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
|
|
||||||
"mistral-7b-instruct-v0.2.Q4_K_S.gguf",
|
|
||||||
),
|
|
||||||
Which::Zephyr7bAlpha => (
|
Which::Zephyr7bAlpha => (
|
||||||
"TheBloke/zephyr-7B-alpha-GGUF",
|
"TheBloke/zephyr-7B-alpha-GGUF",
|
||||||
"zephyr-7b-alpha.Q4_K_M.gguf",
|
"zephyr-7b-alpha.Q4_K_M.gguf",
|
||||||
@ -301,11 +207,6 @@ impl Args {
|
|||||||
Which::Zephyr7bBeta => {
|
Which::Zephyr7bBeta => {
|
||||||
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
|
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
|
||||||
}
|
}
|
||||||
Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"),
|
|
||||||
Which::Starling7bAlpha => (
|
|
||||||
"TheBloke/Starling-LM-7B-alpha-GGUF",
|
|
||||||
"starling-lm-7b-alpha.Q4_K_M.gguf",
|
|
||||||
),
|
|
||||||
};
|
};
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model(repo.to_string());
|
let api = api.model(repo.to_string());
|
||||||
@ -361,16 +262,15 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let model_path = args.model()?;
|
let model_path = args.model()?;
|
||||||
let mut file = std::fs::File::open(&model_path)?;
|
let mut file = std::fs::File::open(&model_path)?;
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let device = candle_examples::device(false)?;
|
|
||||||
|
|
||||||
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
|
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
|
||||||
Some("gguf") => {
|
Some("gguf") => {
|
||||||
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
|
let model = gguf_file::Content::read(&mut file)?;
|
||||||
let mut total_size_in_bytes = 0;
|
let mut total_size_in_bytes = 0;
|
||||||
for (_, tensor) in model.tensor_infos.iter() {
|
for (_, tensor) in model.tensor_infos.iter() {
|
||||||
let elem_count = tensor.shape.elem_count();
|
let elem_count = tensor.shape.elem_count();
|
||||||
total_size_in_bytes +=
|
total_size_in_bytes +=
|
||||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
|
||||||
}
|
}
|
||||||
println!(
|
println!(
|
||||||
"loaded {:?} tensors ({}) in {:.2}s",
|
"loaded {:?} tensors ({}) in {:.2}s",
|
||||||
@ -378,16 +278,15 @@ fn main() -> anyhow::Result<()> {
|
|||||||
&format_size(total_size_in_bytes),
|
&format_size(total_size_in_bytes),
|
||||||
start.elapsed().as_secs_f32(),
|
start.elapsed().as_secs_f32(),
|
||||||
);
|
);
|
||||||
ModelWeights::from_gguf(model, &mut file, &device)?
|
ModelWeights::from_gguf(model, &mut file)?
|
||||||
}
|
}
|
||||||
Some("ggml" | "bin") | Some(_) | None => {
|
Some("ggml" | "bin") | Some(_) | None => {
|
||||||
let model = ggml_file::Content::read(&mut file, &device)
|
let model = ggml_file::Content::read(&mut file)?;
|
||||||
.map_err(|e| e.with_path(model_path))?;
|
|
||||||
let mut total_size_in_bytes = 0;
|
let mut total_size_in_bytes = 0;
|
||||||
for (_, tensor) in model.tensors.iter() {
|
for (_, tensor) in model.tensors.iter() {
|
||||||
let elem_count = tensor.shape().elem_count();
|
let elem_count = tensor.shape().elem_count();
|
||||||
total_size_in_bytes +=
|
total_size_in_bytes +=
|
||||||
elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
|
elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
|
||||||
}
|
}
|
||||||
println!(
|
println!(
|
||||||
"loaded {:?} tensors ({}) in {:.2}s",
|
"loaded {:?} tensors ({}) in {:.2}s",
|
||||||
@ -403,20 +302,13 @@ fn main() -> anyhow::Result<()> {
|
|||||||
| Which::L13bChat
|
| Which::L13bChat
|
||||||
| Which::L7bCode
|
| Which::L7bCode
|
||||||
| Which::L13bCode
|
| Which::L13bCode
|
||||||
| Which::L34bCode
|
| Which::L34bCode => 1,
|
||||||
| Which::Leo7b
|
Which::Mistral7b
|
||||||
| Which::Leo13b => 1,
|
|
||||||
Which::Mixtral
|
|
||||||
| Which::MixtralInstruct
|
|
||||||
| Which::Mistral7b
|
|
||||||
| Which::Mistral7bInstruct
|
| Which::Mistral7bInstruct
|
||||||
| Which::Mistral7bInstructV02
|
|
||||||
| Which::Zephyr7bAlpha
|
| Which::Zephyr7bAlpha
|
||||||
| Which::Zephyr7bBeta
|
| Which::Zephyr7bBeta
|
||||||
| Which::L70b
|
| Which::L70b
|
||||||
| Which::L70bChat
|
| Which::L70bChat => 8,
|
||||||
| Which::OpenChat35
|
|
||||||
| Which::Starling7bAlpha => 8,
|
|
||||||
};
|
};
|
||||||
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
|
||||||
}
|
}
|
||||||
@ -448,9 +340,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
prompt.pop();
|
prompt.pop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if args.which.is_open_chat() {
|
if args.which.is_zephyr() {
|
||||||
format!("GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:")
|
|
||||||
} else if args.which.is_zephyr() {
|
|
||||||
if prompt_index == 0 || is_interactive {
|
if prompt_index == 0 || is_interactive {
|
||||||
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
|
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
|
||||||
} else {
|
} else {
|
||||||
@ -488,7 +378,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let start_prompt_processing = std::time::Instant::now();
|
let start_prompt_processing = std::time::Instant::now();
|
||||||
let mut next_token = {
|
let mut next_token = {
|
||||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, 0)?;
|
let logits = model.forward(&input, 0)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
logits_processor.sample(&logits)?
|
logits_processor.sample(&logits)?
|
||||||
@ -500,16 +390,12 @@ fn main() -> anyhow::Result<()> {
|
|||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let eos_token = if args.which.is_open_chat() {
|
let eos_token = *tos.tokenizer().get_vocab(true).get("</s>").unwrap();
|
||||||
"<|end_of_turn|>"
|
|
||||||
} else {
|
|
||||||
"</s>"
|
|
||||||
};
|
|
||||||
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
|
||||||
let start_post_prompt = std::time::Instant::now();
|
let start_post_prompt = std::time::Instant::now();
|
||||||
let mut sampled = 0;
|
let mut sampled = 0;
|
||||||
for index in 0..to_sample {
|
for index in 0..to_sample {
|
||||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
let logits = if args.repeat_penalty == 1. {
|
let logits = if args.repeat_penalty == 1. {
|
||||||
|
@ -8,16 +8,9 @@ Python package with:
|
|||||||
pip install "gymnasium[accept-rom-license]"
|
pip install "gymnasium[accept-rom-license]"
|
||||||
```
|
```
|
||||||
|
|
||||||
In order to run the examples, use the following commands. Note the additional
|
In order to run the example, use the following command. Note the additional
|
||||||
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
|
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
|
||||||
crate.
|
crate.
|
||||||
|
|
||||||
For the Policy Gradient example:
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- pg
|
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples
|
||||||
```
|
|
||||||
|
|
||||||
For the Deep Deterministic Policy Gradient example:
|
|
||||||
```bash
|
|
||||||
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- ddpg
|
|
||||||
```
|
```
|
||||||
|
@ -78,7 +78,7 @@ class EpisodicLifeEnv(gym.Wrapper):
|
|||||||
# then update lives to handle bonus lives
|
# then update lives to handle bonus lives
|
||||||
lives = self.env.unwrapped.ale.lives()
|
lives = self.env.unwrapped.ale.lives()
|
||||||
if lives < self.lives and lives > 0:
|
if lives < self.lives and lives > 0:
|
||||||
# for Qbert sometimes we stay in lives == 0 condition for a few frames
|
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
|
||||||
# so its important to keep lives > 0, so that we only reset once
|
# so its important to keep lives > 0, so that we only reset once
|
||||||
# the environment advertises done.
|
# the environment advertises done.
|
||||||
done = True
|
done = True
|
||||||
|
@ -8,8 +8,6 @@ use candle_nn::{
|
|||||||
};
|
};
|
||||||
use rand::{distributions::Uniform, thread_rng, Rng};
|
use rand::{distributions::Uniform, thread_rng, Rng};
|
||||||
|
|
||||||
use super::gym_env::GymEnv;
|
|
||||||
|
|
||||||
pub struct OuNoise {
|
pub struct OuNoise {
|
||||||
mu: f64,
|
mu: f64,
|
||||||
theta: f64,
|
theta: f64,
|
||||||
@ -451,106 +449,3 @@ impl DDPG<'_> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// The impact of the q value of the next state on the current state's q value.
|
|
||||||
const GAMMA: f64 = 0.99;
|
|
||||||
// The weight for updating the target networks.
|
|
||||||
const TAU: f64 = 0.005;
|
|
||||||
// The capacity of the replay buffer used for sampling training data.
|
|
||||||
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
|
|
||||||
// The training batch size for each training iteration.
|
|
||||||
const TRAINING_BATCH_SIZE: usize = 100;
|
|
||||||
// The total number of episodes.
|
|
||||||
const MAX_EPISODES: usize = 100;
|
|
||||||
// The maximum length of an episode.
|
|
||||||
const EPISODE_LENGTH: usize = 200;
|
|
||||||
// The number of training iterations after one episode finishes.
|
|
||||||
const TRAINING_ITERATIONS: usize = 200;
|
|
||||||
|
|
||||||
// Ornstein-Uhlenbeck process parameters.
|
|
||||||
const MU: f64 = 0.0;
|
|
||||||
const THETA: f64 = 0.15;
|
|
||||||
const SIGMA: f64 = 0.1;
|
|
||||||
|
|
||||||
const ACTOR_LEARNING_RATE: f64 = 1e-4;
|
|
||||||
const CRITIC_LEARNING_RATE: f64 = 1e-3;
|
|
||||||
|
|
||||||
pub fn run() -> Result<()> {
|
|
||||||
let env = GymEnv::new("Pendulum-v1")?;
|
|
||||||
println!("action space: {}", env.action_space());
|
|
||||||
println!("observation space: {:?}", env.observation_space());
|
|
||||||
|
|
||||||
let size_state = env.observation_space().iter().product::<usize>();
|
|
||||||
let size_action = env.action_space();
|
|
||||||
|
|
||||||
let mut agent = DDPG::new(
|
|
||||||
&Device::Cpu,
|
|
||||||
size_state,
|
|
||||||
size_action,
|
|
||||||
true,
|
|
||||||
ACTOR_LEARNING_RATE,
|
|
||||||
CRITIC_LEARNING_RATE,
|
|
||||||
GAMMA,
|
|
||||||
TAU,
|
|
||||||
REPLAY_BUFFER_CAPACITY,
|
|
||||||
OuNoise::new(MU, THETA, SIGMA, size_action)?,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let mut rng = rand::thread_rng();
|
|
||||||
|
|
||||||
for episode in 0..MAX_EPISODES {
|
|
||||||
// let mut state = env.reset(episode as u64)?;
|
|
||||||
let mut state = env.reset(rng.gen::<u64>())?;
|
|
||||||
|
|
||||||
let mut total_reward = 0.0;
|
|
||||||
for _ in 0..EPISODE_LENGTH {
|
|
||||||
let mut action = 2.0 * agent.actions(&state)?;
|
|
||||||
action = action.clamp(-2.0, 2.0);
|
|
||||||
|
|
||||||
let step = env.step(vec![action])?;
|
|
||||||
total_reward += step.reward;
|
|
||||||
|
|
||||||
agent.remember(
|
|
||||||
&state,
|
|
||||||
&Tensor::new(vec![action], &Device::Cpu)?,
|
|
||||||
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
|
|
||||||
&step.state,
|
|
||||||
step.terminated,
|
|
||||||
step.truncated,
|
|
||||||
);
|
|
||||||
|
|
||||||
if step.terminated || step.truncated {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
state = step.state;
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("episode {episode} with total reward of {total_reward}");
|
|
||||||
|
|
||||||
for _ in 0..TRAINING_ITERATIONS {
|
|
||||||
agent.train(TRAINING_BATCH_SIZE)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("Testing...");
|
|
||||||
agent.train = false;
|
|
||||||
for episode in 0..10 {
|
|
||||||
// let mut state = env.reset(episode as u64)?;
|
|
||||||
let mut state = env.reset(rng.gen::<u64>())?;
|
|
||||||
let mut total_reward = 0.0;
|
|
||||||
for _ in 0..EPISODE_LENGTH {
|
|
||||||
let mut action = 2.0 * agent.actions(&state)?;
|
|
||||||
action = action.clamp(-2.0, 2.0);
|
|
||||||
|
|
||||||
let step = env.step(vec![action])?;
|
|
||||||
total_reward += step.reward;
|
|
||||||
|
|
||||||
if step.terminated || step.truncated {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
state = step.state;
|
|
||||||
}
|
|
||||||
println!("episode {episode} with total reward of {total_reward}");
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
@ -6,32 +6,139 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle::Result;
|
|
||||||
use clap::{Parser, Subcommand};
|
|
||||||
|
|
||||||
mod gym_env;
|
mod gym_env;
|
||||||
mod vec_gym_env;
|
mod vec_gym_env;
|
||||||
|
|
||||||
mod ddpg;
|
mod ddpg;
|
||||||
mod policy_gradient;
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
use candle::{Device, Result, Tensor};
|
||||||
|
use clap::Parser;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
// The impact of the q value of the next state on the current state's q value.
|
||||||
|
const GAMMA: f64 = 0.99;
|
||||||
|
// The weight for updating the target networks.
|
||||||
|
const TAU: f64 = 0.005;
|
||||||
|
// The capacity of the replay buffer used for sampling training data.
|
||||||
|
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
|
||||||
|
// The training batch size for each training iteration.
|
||||||
|
const TRAINING_BATCH_SIZE: usize = 100;
|
||||||
|
// The total number of episodes.
|
||||||
|
const MAX_EPISODES: usize = 100;
|
||||||
|
// The maximum length of an episode.
|
||||||
|
const EPISODE_LENGTH: usize = 200;
|
||||||
|
// The number of training iterations after one episode finishes.
|
||||||
|
const TRAINING_ITERATIONS: usize = 200;
|
||||||
|
|
||||||
|
// Ornstein-Uhlenbeck process parameters.
|
||||||
|
const MU: f64 = 0.0;
|
||||||
|
const THETA: f64 = 0.15;
|
||||||
|
const SIGMA: f64 = 0.1;
|
||||||
|
|
||||||
|
const ACTOR_LEARNING_RATE: f64 = 1e-4;
|
||||||
|
const CRITIC_LEARNING_RATE: f64 = 1e-3;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug, Clone)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
#[command(subcommand)]
|
/// Run on CPU rather than on GPU.
|
||||||
command: Command,
|
#[arg(long)]
|
||||||
}
|
cpu: bool,
|
||||||
|
|
||||||
#[derive(Subcommand)]
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
enum Command {
|
#[arg(long)]
|
||||||
Pg,
|
tracing: bool,
|
||||||
Ddpg,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
match args.command {
|
|
||||||
Command::Pg => policy_gradient::run()?,
|
let _guard = if args.tracing {
|
||||||
Command::Ddpg => ddpg::run()?,
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let env = gym_env::GymEnv::new("Pendulum-v1")?;
|
||||||
|
println!("action space: {}", env.action_space());
|
||||||
|
println!("observation space: {:?}", env.observation_space());
|
||||||
|
|
||||||
|
let size_state = env.observation_space().iter().product::<usize>();
|
||||||
|
let size_action = env.action_space();
|
||||||
|
|
||||||
|
let mut agent = ddpg::DDPG::new(
|
||||||
|
&Device::Cpu,
|
||||||
|
size_state,
|
||||||
|
size_action,
|
||||||
|
true,
|
||||||
|
ACTOR_LEARNING_RATE,
|
||||||
|
CRITIC_LEARNING_RATE,
|
||||||
|
GAMMA,
|
||||||
|
TAU,
|
||||||
|
REPLAY_BUFFER_CAPACITY,
|
||||||
|
ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
|
||||||
|
for episode in 0..MAX_EPISODES {
|
||||||
|
// let mut state = env.reset(episode as u64)?;
|
||||||
|
let mut state = env.reset(rng.gen::<u64>())?;
|
||||||
|
|
||||||
|
let mut total_reward = 0.0;
|
||||||
|
for _ in 0..EPISODE_LENGTH {
|
||||||
|
let mut action = 2.0 * agent.actions(&state)?;
|
||||||
|
action = action.clamp(-2.0, 2.0);
|
||||||
|
|
||||||
|
let step = env.step(vec![action])?;
|
||||||
|
total_reward += step.reward;
|
||||||
|
|
||||||
|
agent.remember(
|
||||||
|
&state,
|
||||||
|
&Tensor::new(vec![action], &Device::Cpu)?,
|
||||||
|
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
|
||||||
|
&step.state,
|
||||||
|
step.terminated,
|
||||||
|
step.truncated,
|
||||||
|
);
|
||||||
|
|
||||||
|
if step.terminated || step.truncated {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
state = step.state;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("episode {episode} with total reward of {total_reward}");
|
||||||
|
|
||||||
|
for _ in 0..TRAINING_ITERATIONS {
|
||||||
|
agent.train(TRAINING_BATCH_SIZE)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("Testing...");
|
||||||
|
agent.train = false;
|
||||||
|
for episode in 0..10 {
|
||||||
|
// let mut state = env.reset(episode as u64)?;
|
||||||
|
let mut state = env.reset(rng.gen::<u64>())?;
|
||||||
|
let mut total_reward = 0.0;
|
||||||
|
for _ in 0..EPISODE_LENGTH {
|
||||||
|
let mut action = 2.0 * agent.actions(&state)?;
|
||||||
|
action = action.clamp(-2.0, 2.0);
|
||||||
|
|
||||||
|
let step = env.step(vec![action])?;
|
||||||
|
total_reward += step.reward;
|
||||||
|
|
||||||
|
if step.terminated || step.truncated {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
state = step.state;
|
||||||
|
}
|
||||||
|
println!("episode {episode} with total reward of {total_reward}");
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,146 +0,0 @@
|
|||||||
use super::gym_env::{GymEnv, Step};
|
|
||||||
use candle::{DType, Device, Error, Module, Result, Tensor};
|
|
||||||
use candle_nn::{
|
|
||||||
linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer,
|
|
||||||
ParamsAdamW, VarBuilder, VarMap,
|
|
||||||
};
|
|
||||||
use rand::{distributions::Distribution, rngs::ThreadRng, Rng};
|
|
||||||
|
|
||||||
fn new_model(
|
|
||||||
input_shape: &[usize],
|
|
||||||
num_actions: usize,
|
|
||||||
dtype: DType,
|
|
||||||
device: &Device,
|
|
||||||
) -> Result<(impl Module, VarMap)> {
|
|
||||||
let input_size = input_shape.iter().product();
|
|
||||||
|
|
||||||
let mut varmap = VarMap::new();
|
|
||||||
let var_builder = VarBuilder::from_varmap(&varmap, dtype, device);
|
|
||||||
|
|
||||||
let model = seq()
|
|
||||||
.add(linear(input_size, 32, var_builder.pp("lin1"))?)
|
|
||||||
.add(Activation::Relu)
|
|
||||||
.add(linear(32, num_actions, var_builder.pp("lin2"))?);
|
|
||||||
|
|
||||||
Ok((model, varmap))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn accumulate_rewards(steps: &[Step<i64>]) -> Vec<f64> {
|
|
||||||
let mut rewards: Vec<f64> = steps.iter().map(|s| s.reward).collect();
|
|
||||||
let mut acc_reward = 0f64;
|
|
||||||
for (i, reward) in rewards.iter_mut().enumerate().rev() {
|
|
||||||
if steps[i].terminated {
|
|
||||||
acc_reward = 0.0;
|
|
||||||
}
|
|
||||||
acc_reward += *reward;
|
|
||||||
*reward = acc_reward;
|
|
||||||
}
|
|
||||||
rewards
|
|
||||||
}
|
|
||||||
|
|
||||||
fn weighted_sample(probs: Vec<f32>, rng: &mut ThreadRng) -> Result<usize> {
|
|
||||||
let distribution = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?;
|
|
||||||
let mut rng = rng;
|
|
||||||
Ok(distribution.sample(&mut rng))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn run() -> Result<()> {
|
|
||||||
let env = GymEnv::new("CartPole-v1")?;
|
|
||||||
|
|
||||||
println!("action space: {:?}", env.action_space());
|
|
||||||
println!("observation space: {:?}", env.observation_space());
|
|
||||||
|
|
||||||
let (model, varmap) = new_model(
|
|
||||||
env.observation_space(),
|
|
||||||
env.action_space(),
|
|
||||||
DType::F32,
|
|
||||||
&Device::Cpu,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let optimizer_params = ParamsAdamW {
|
|
||||||
lr: 0.01,
|
|
||||||
weight_decay: 0.01,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
|
|
||||||
|
|
||||||
let mut rng = rand::thread_rng();
|
|
||||||
|
|
||||||
for epoch_idx in 0..100 {
|
|
||||||
let mut state = env.reset(rng.gen::<u64>())?;
|
|
||||||
let mut steps: Vec<Step<i64>> = vec![];
|
|
||||||
|
|
||||||
loop {
|
|
||||||
let action = {
|
|
||||||
let action_probs: Vec<f32> =
|
|
||||||
softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)?
|
|
||||||
.squeeze(0)?
|
|
||||||
.to_vec1()?;
|
|
||||||
weighted_sample(action_probs, &mut rng)? as i64
|
|
||||||
};
|
|
||||||
|
|
||||||
let step = env.step(action)?;
|
|
||||||
steps.push(step.copy_with_obs(&state));
|
|
||||||
|
|
||||||
if step.terminated || step.truncated {
|
|
||||||
state = env.reset(rng.gen::<u64>())?;
|
|
||||||
if steps.len() > 5000 {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
state = step.state;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let total_reward: f64 = steps.iter().map(|s| s.reward).sum();
|
|
||||||
let episodes: i64 = steps
|
|
||||||
.iter()
|
|
||||||
.map(|s| (s.terminated || s.truncated) as i64)
|
|
||||||
.sum();
|
|
||||||
println!(
|
|
||||||
"epoch: {:<3} episodes: {:<5} avg reward per episode: {:.2}",
|
|
||||||
epoch_idx,
|
|
||||||
episodes,
|
|
||||||
total_reward / episodes as f64
|
|
||||||
);
|
|
||||||
|
|
||||||
let batch_size = steps.len();
|
|
||||||
|
|
||||||
let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
|
|
||||||
.to_dtype(DType::F32)?
|
|
||||||
.detach()?;
|
|
||||||
|
|
||||||
let actions_mask = {
|
|
||||||
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
|
|
||||||
let actions_mask: Vec<Tensor> = actions
|
|
||||||
.iter()
|
|
||||||
.map(|&action| {
|
|
||||||
// One-hot encoding
|
|
||||||
let mut action_mask = vec![0.0; env.action_space()];
|
|
||||||
action_mask[action as usize] = 1.0;
|
|
||||||
|
|
||||||
Tensor::from_vec(action_mask, env.action_space(), &Device::Cpu)
|
|
||||||
.unwrap()
|
|
||||||
.to_dtype(DType::F32)
|
|
||||||
.unwrap()
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
Tensor::stack(&actions_mask, 0)?.detach()?
|
|
||||||
};
|
|
||||||
|
|
||||||
let states = {
|
|
||||||
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
|
|
||||||
Tensor::stack(&states, 0)?.detach()?
|
|
||||||
};
|
|
||||||
|
|
||||||
let log_probs = actions_mask
|
|
||||||
.mul(&log_softmax(&model.forward(&states)?, 1)?)?
|
|
||||||
.sum(1)?;
|
|
||||||
|
|
||||||
let loss = rewards.mul(&log_probs)?.neg()?.mean_all()?;
|
|
||||||
optimizer.backward_step(&loss)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -236,15 +236,16 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let config = Config::replit_code_v1_5_3b();
|
let config = Config::replit_code_v1_5_3b();
|
||||||
let model = if args.quantized {
|
let (model, device) = if args.quantized {
|
||||||
let vb =
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
||||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
|
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
|
||||||
Model::Q(Q::new(&config, vb.pp("transformer"))?)
|
(model, Device::Cpu)
|
||||||
} else {
|
} else {
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||||
Model::M(M::new(&config, vb.pp("transformer"))?)
|
let model = Model::M(M::new(&config, vb.pp("transformer"))?);
|
||||||
|
(model, device)
|
||||||
};
|
};
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
@ -1,22 +0,0 @@
|
|||||||
# candle-repvgg
|
|
||||||
|
|
||||||
[RepVGG: Making VGG-style ConvNets Great Again](https://arxiv.org/abs/2101.03697).
|
|
||||||
|
|
||||||
This candle implementation uses a pre-trained RepVGG network for inference. The
|
|
||||||
classification head has been trained on the ImageNet dataset and returns the
|
|
||||||
probabilities for the top-5 classes.
|
|
||||||
|
|
||||||
## Running an example
|
|
||||||
|
|
||||||
```
|
|
||||||
$ cargo run --example repvgg --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
|
||||||
|
|
||||||
loaded image Tensor[dims 3, 224, 224; f32]
|
|
||||||
model built
|
|
||||||
mountain bike, all-terrain bike, off-roader: 61.70%
|
|
||||||
bicycle-built-for-two, tandem bicycle, tandem: 33.14%
|
|
||||||
unicycle, monocycle : 4.88%
|
|
||||||
crash helmet : 0.15%
|
|
||||||
moped : 0.04%
|
|
||||||
|
|
||||||
```
|
|
@ -1,111 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
|
|
||||||
use candle::{DType, IndexOp, D};
|
|
||||||
use candle_nn::{Module, VarBuilder};
|
|
||||||
use candle_transformers::models::repvgg;
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
A0,
|
|
||||||
A1,
|
|
||||||
A2,
|
|
||||||
B0,
|
|
||||||
B1,
|
|
||||||
B2,
|
|
||||||
B3,
|
|
||||||
B1G4,
|
|
||||||
B2G4,
|
|
||||||
B3G4,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Which {
|
|
||||||
fn model_filename(&self) -> String {
|
|
||||||
let name = match self {
|
|
||||||
Self::A0 => "a0",
|
|
||||||
Self::A1 => "a1",
|
|
||||||
Self::A2 => "a2",
|
|
||||||
Self::B0 => "b0",
|
|
||||||
Self::B1 => "b1",
|
|
||||||
Self::B2 => "b2",
|
|
||||||
Self::B3 => "b3",
|
|
||||||
Self::B1G4 => "b1g4",
|
|
||||||
Self::B2G4 => "b2g4",
|
|
||||||
Self::B3G4 => "b3g4",
|
|
||||||
};
|
|
||||||
format!("timm/repvgg_{}.rvgg_in1k", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn config(&self) -> repvgg::Config {
|
|
||||||
match self {
|
|
||||||
Self::A0 => repvgg::Config::a0(),
|
|
||||||
Self::A1 => repvgg::Config::a1(),
|
|
||||||
Self::A2 => repvgg::Config::a2(),
|
|
||||||
Self::B0 => repvgg::Config::b0(),
|
|
||||||
Self::B1 => repvgg::Config::b1(),
|
|
||||||
Self::B2 => repvgg::Config::b2(),
|
|
||||||
Self::B3 => repvgg::Config::b3(),
|
|
||||||
Self::B1G4 => repvgg::Config::b1g4(),
|
|
||||||
Self::B2G4 => repvgg::Config::b2g4(),
|
|
||||||
Self::B3G4 => repvgg::Config::b3g4(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
|
||||||
struct Args {
|
|
||||||
#[arg(long)]
|
|
||||||
model: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
image: String,
|
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
#[arg(value_enum, long, default_value_t=Which::A0)]
|
|
||||||
which: Which,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
|
||||||
let args = Args::parse();
|
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
|
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
|
||||||
println!("loaded image {image:?}");
|
|
||||||
|
|
||||||
let model_file = match args.model {
|
|
||||||
None => {
|
|
||||||
let model_name = args.which.model_filename();
|
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
|
||||||
let api = api.model(model_name);
|
|
||||||
api.get("model.safetensors")?
|
|
||||||
}
|
|
||||||
Some(model) => model.into(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
|
||||||
let model = repvgg::repvgg(&args.which.config(), 1000, vb)?;
|
|
||||||
println!("model built");
|
|
||||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
|
||||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
|
||||||
.i(0)?
|
|
||||||
.to_vec1::<f32>()?;
|
|
||||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
|
||||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
|
||||||
for &(category_idx, pr) in prs.iter().take(5) {
|
|
||||||
println!(
|
|
||||||
"{:24}: {:.2}%",
|
|
||||||
candle_examples::imagenet::CLASSES[category_idx],
|
|
||||||
100. * pr
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -8,7 +8,7 @@ XL using Rust and [candle](https://github.com/huggingface/candle).
|
|||||||
The `stable-diffusion` example is a conversion of
|
The `stable-diffusion` example is a conversion of
|
||||||
[diffusers-rs](https://github.com/LaurentMazare/diffusers-rs) using candle
|
[diffusers-rs](https://github.com/LaurentMazare/diffusers-rs) using candle
|
||||||
rather than libtorch. This implementation supports Stable Diffusion v1.5, v2.1,
|
rather than libtorch. This implementation supports Stable Diffusion v1.5, v2.1,
|
||||||
as well as Stable Diffusion XL 1.0, and Turbo.
|
as well as Stable Diffusion XL 1.0.
|
||||||
|
|
||||||
## Getting the weights
|
## Getting the weights
|
||||||
|
|
||||||
@ -23,26 +23,16 @@ cargo run --example stable-diffusion --release --features=cuda,cudnn \
|
|||||||
-- --prompt "a cosmonaut on a horse (hd, realistic, high-def)"
|
-- --prompt "a cosmonaut on a horse (hd, realistic, high-def)"
|
||||||
```
|
```
|
||||||
|
|
||||||
The final image is named `sd_final.png` by default. The Turbo version is much
|
The final image is named `sd_final.png` by default.
|
||||||
faster than previous versions, to give it a try add a `--sd-version turbo` flag,
|
The default scheduler is the Denoising Diffusion Implicit Model scheduler (DDIM). The
|
||||||
e.g.:
|
original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim).
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run --example stable-diffusion --release --features=cuda,cudnn \
|
|
||||||
-- --prompt "a cosmonaut on a horse (hd, realistic, high-def)" --sd-version turbo
|
|
||||||
```
|
|
||||||
|
|
||||||
The default scheduler for the v1.5, v2.1 and XL 1.0 version is the Denoising
|
|
||||||
Diffusion Implicit Model scheduler (DDIM). The original paper and some code can
|
|
||||||
be found in the [associated repo](https://github.com/ermongroup/ddim).
|
|
||||||
The default scheduler for the XL Turbo version is the Euler Ancestral scheduler.
|
|
||||||
|
|
||||||
### Command-line flags
|
### Command-line flags
|
||||||
|
|
||||||
- `--prompt`: the prompt to be used to generate the image.
|
- `--prompt`: the prompt to be used to generate the image.
|
||||||
- `--uncond-prompt`: the optional unconditional prompt.
|
- `--uncond-prompt`: the optional unconditional prompt.
|
||||||
- `--sd-version`: the Stable Diffusion version to use, can be `v1-5`, `v2-1`,
|
- `--sd-version`: the Stable Diffusion version to use, can be `v1-5`, `v2-1`, or
|
||||||
`xl`, or `turbo`.
|
`xl`.
|
||||||
- `--cpu`: use the cpu rather than the gpu (much slower).
|
- `--cpu`: use the cpu rather than the gpu (much slower).
|
||||||
- `--height`, `--width`: set the height and width for the generated image.
|
- `--height`, `--width`: set the height and width for the generated image.
|
||||||
- `--n-steps`: the number of steps to be used in the diffusion process.
|
- `--n-steps`: the number of steps to be used in the diffusion process.
|
||||||
|
@ -11,6 +11,8 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D};
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
const GUIDANCE_SCALE: f64 = 7.5;
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -61,8 +63,8 @@ struct Args {
|
|||||||
sliced_attention_size: Option<usize>,
|
sliced_attention_size: Option<usize>,
|
||||||
|
|
||||||
/// The number of steps to run the diffusion for.
|
/// The number of steps to run the diffusion for.
|
||||||
#[arg(long)]
|
#[arg(long, default_value_t = 30)]
|
||||||
n_steps: Option<usize>,
|
n_steps: usize,
|
||||||
|
|
||||||
/// The number of samples to generate.
|
/// The number of samples to generate.
|
||||||
#[arg(long, default_value_t = 1)]
|
#[arg(long, default_value_t = 1)]
|
||||||
@ -85,9 +87,6 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
use_f16: bool,
|
use_f16: bool,
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
guidance_scale: Option<f64>,
|
|
||||||
|
|
||||||
#[arg(long, value_name = "FILE")]
|
#[arg(long, value_name = "FILE")]
|
||||||
img2img: Option<String>,
|
img2img: Option<String>,
|
||||||
|
|
||||||
@ -103,7 +102,6 @@ enum StableDiffusionVersion {
|
|||||||
V1_5,
|
V1_5,
|
||||||
V2_1,
|
V2_1,
|
||||||
Xl,
|
Xl,
|
||||||
Turbo,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
@ -122,13 +120,12 @@ impl StableDiffusionVersion {
|
|||||||
Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
|
Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
|
||||||
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
|
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
|
||||||
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
|
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
|
||||||
Self::Turbo => "stabilityai/sdxl-turbo",
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unet_file(&self, use_f16: bool) -> &'static str {
|
fn unet_file(&self, use_f16: bool) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
|
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||||
if use_f16 {
|
if use_f16 {
|
||||||
"unet/diffusion_pytorch_model.fp16.safetensors"
|
"unet/diffusion_pytorch_model.fp16.safetensors"
|
||||||
} else {
|
} else {
|
||||||
@ -140,7 +137,7 @@ impl StableDiffusionVersion {
|
|||||||
|
|
||||||
fn vae_file(&self, use_f16: bool) -> &'static str {
|
fn vae_file(&self, use_f16: bool) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
|
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||||
if use_f16 {
|
if use_f16 {
|
||||||
"vae/diffusion_pytorch_model.fp16.safetensors"
|
"vae/diffusion_pytorch_model.fp16.safetensors"
|
||||||
} else {
|
} else {
|
||||||
@ -152,7 +149,7 @@ impl StableDiffusionVersion {
|
|||||||
|
|
||||||
fn clip_file(&self, use_f16: bool) -> &'static str {
|
fn clip_file(&self, use_f16: bool) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
|
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||||
if use_f16 {
|
if use_f16 {
|
||||||
"text_encoder/model.fp16.safetensors"
|
"text_encoder/model.fp16.safetensors"
|
||||||
} else {
|
} else {
|
||||||
@ -164,7 +161,7 @@ impl StableDiffusionVersion {
|
|||||||
|
|
||||||
fn clip2_file(&self, use_f16: bool) -> &'static str {
|
fn clip2_file(&self, use_f16: bool) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
|
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||||
if use_f16 {
|
if use_f16 {
|
||||||
"text_encoder_2/model.fp16.safetensors"
|
"text_encoder_2/model.fp16.safetensors"
|
||||||
} else {
|
} else {
|
||||||
@ -192,7 +189,7 @@ impl ModelFile {
|
|||||||
StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
|
StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
|
||||||
"openai/clip-vit-base-patch32"
|
"openai/clip-vit-base-patch32"
|
||||||
}
|
}
|
||||||
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => {
|
StableDiffusionVersion::Xl => {
|
||||||
// This seems similar to the patch32 version except some very small
|
// This seems similar to the patch32 version except some very small
|
||||||
// difference in the split regex.
|
// difference in the split regex.
|
||||||
"openai/clip-vit-large-patch14"
|
"openai/clip-vit-large-patch14"
|
||||||
@ -209,11 +206,7 @@ impl ModelFile {
|
|||||||
Self::Vae => {
|
Self::Vae => {
|
||||||
// Override for SDXL when using f16 weights.
|
// Override for SDXL when using f16 weights.
|
||||||
// See https://github.com/huggingface/candle/issues/1060
|
// See https://github.com/huggingface/candle/issues/1060
|
||||||
if matches!(
|
if version == StableDiffusionVersion::Xl && use_f16 {
|
||||||
version,
|
|
||||||
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo,
|
|
||||||
) && use_f16
|
|
||||||
{
|
|
||||||
(
|
(
|
||||||
"madebyollin/sdxl-vae-fp16-fix",
|
"madebyollin/sdxl-vae-fp16-fix",
|
||||||
"diffusion_pytorch_model.safetensors",
|
"diffusion_pytorch_model.safetensors",
|
||||||
@ -268,7 +261,6 @@ fn text_embeddings(
|
|||||||
use_f16: bool,
|
use_f16: bool,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
use_guide_scale: bool,
|
|
||||||
first: bool,
|
first: bool,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let tokenizer_file = if first {
|
let tokenizer_file = if first {
|
||||||
@ -293,6 +285,16 @@ fn text_embeddings(
|
|||||||
}
|
}
|
||||||
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
|
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
|
let mut uncond_tokens = tokenizer
|
||||||
|
.encode(uncond_prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
|
||||||
|
uncond_tokens.push(pad_id)
|
||||||
|
}
|
||||||
|
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
println!("Building the Clip transformer.");
|
println!("Building the Clip transformer.");
|
||||||
let clip_weights_file = if first {
|
let clip_weights_file = if first {
|
||||||
ModelFile::Clip
|
ModelFile::Clip
|
||||||
@ -308,24 +310,8 @@ fn text_embeddings(
|
|||||||
let text_model =
|
let text_model =
|
||||||
stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
|
stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
|
||||||
let text_embeddings = text_model.forward(&tokens)?;
|
let text_embeddings = text_model.forward(&tokens)?;
|
||||||
|
|
||||||
let text_embeddings = if use_guide_scale {
|
|
||||||
let mut uncond_tokens = tokenizer
|
|
||||||
.encode(uncond_prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
|
|
||||||
uncond_tokens.push(pad_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
|
|
||||||
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
||||||
|
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?;
|
||||||
Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
|
|
||||||
} else {
|
|
||||||
text_embeddings.to_dtype(dtype)?
|
|
||||||
};
|
|
||||||
Ok(text_embeddings)
|
Ok(text_embeddings)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -370,7 +356,6 @@ fn run(args: Args) -> Result<()> {
|
|||||||
unet_weights,
|
unet_weights,
|
||||||
tracing,
|
tracing,
|
||||||
use_f16,
|
use_f16,
|
||||||
guidance_scale,
|
|
||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
img2img,
|
img2img,
|
||||||
img2img_strength,
|
img2img_strength,
|
||||||
@ -389,24 +374,6 @@ fn run(args: Args) -> Result<()> {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
let guidance_scale = match guidance_scale {
|
|
||||||
Some(guidance_scale) => guidance_scale,
|
|
||||||
None => match sd_version {
|
|
||||||
StableDiffusionVersion::V1_5
|
|
||||||
| StableDiffusionVersion::V2_1
|
|
||||||
| StableDiffusionVersion::Xl => 7.5,
|
|
||||||
StableDiffusionVersion::Turbo => 0.,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
let n_steps = match n_steps {
|
|
||||||
Some(n_steps) => n_steps,
|
|
||||||
None => match sd_version {
|
|
||||||
StableDiffusionVersion::V1_5
|
|
||||||
| StableDiffusionVersion::V2_1
|
|
||||||
| StableDiffusionVersion::Xl => 30,
|
|
||||||
StableDiffusionVersion::Turbo => 1,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
let dtype = if use_f16 { DType::F16 } else { DType::F32 };
|
let dtype = if use_f16 { DType::F16 } else { DType::F32 };
|
||||||
let sd_config = match sd_version {
|
let sd_config = match sd_version {
|
||||||
StableDiffusionVersion::V1_5 => {
|
StableDiffusionVersion::V1_5 => {
|
||||||
@ -418,19 +385,13 @@ fn run(args: Args) -> Result<()> {
|
|||||||
StableDiffusionVersion::Xl => {
|
StableDiffusionVersion::Xl => {
|
||||||
stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
|
stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
|
||||||
}
|
}
|
||||||
StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo(
|
|
||||||
sliced_attention_size,
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let scheduler = sd_config.build_scheduler(n_steps)?;
|
let scheduler = sd_config.build_scheduler(n_steps)?;
|
||||||
let device = candle_examples::device(cpu)?;
|
let device = candle_examples::device(cpu)?;
|
||||||
let use_guide_scale = guidance_scale > 1.0;
|
|
||||||
|
|
||||||
let which = match sd_version {
|
let which = match sd_version {
|
||||||
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false],
|
StableDiffusionVersion::Xl => vec![true, false],
|
||||||
_ => vec![true],
|
_ => vec![true],
|
||||||
};
|
};
|
||||||
let text_embeddings = which
|
let text_embeddings = which
|
||||||
@ -446,18 +407,16 @@ fn run(args: Args) -> Result<()> {
|
|||||||
use_f16,
|
use_f16,
|
||||||
&device,
|
&device,
|
||||||
dtype,
|
dtype,
|
||||||
use_guide_scale,
|
|
||||||
*first,
|
*first,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
|
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
|
||||||
println!("{text_embeddings:?}");
|
println!("{text_embeddings:?}");
|
||||||
|
|
||||||
println!("Building the autoencoder.");
|
println!("Building the autoencoder.");
|
||||||
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
||||||
let vae = sd_config.build_vae(vae_weights, &device, dtype)?;
|
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
|
||||||
let init_latent_dist = match &img2img {
|
let init_latent_dist = match &img2img {
|
||||||
None => None,
|
None => None,
|
||||||
Some(image) => {
|
Some(image) => {
|
||||||
@ -467,7 +426,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
};
|
};
|
||||||
println!("Building the unet.");
|
println!("Building the unet.");
|
||||||
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
|
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
|
||||||
let unet = sd_config.build_unet(unet_weights, &device, 4, use_flash_attn, dtype)?;
|
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
|
||||||
|
|
||||||
let t_start = if img2img.is_some() {
|
let t_start = if img2img.is_some() {
|
||||||
n_steps - (n_steps as f64 * img2img_strength) as usize
|
n_steps - (n_steps as f64 * img2img_strength) as usize
|
||||||
@ -475,19 +434,11 @@ fn run(args: Args) -> Result<()> {
|
|||||||
0
|
0
|
||||||
};
|
};
|
||||||
let bsize = 1;
|
let bsize = 1;
|
||||||
|
|
||||||
let vae_scale = match sd_version {
|
|
||||||
StableDiffusionVersion::V1_5
|
|
||||||
| StableDiffusionVersion::V2_1
|
|
||||||
| StableDiffusionVersion::Xl => 0.18215,
|
|
||||||
StableDiffusionVersion::Turbo => 0.13025,
|
|
||||||
};
|
|
||||||
|
|
||||||
for idx in 0..num_samples {
|
for idx in 0..num_samples {
|
||||||
let timesteps = scheduler.timesteps();
|
let timesteps = scheduler.timesteps();
|
||||||
let latents = match &init_latent_dist {
|
let latents = match &init_latent_dist {
|
||||||
Some(init_latent_dist) => {
|
Some(init_latent_dist) => {
|
||||||
let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?;
|
let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?;
|
||||||
if t_start < timesteps.len() {
|
if t_start < timesteps.len() {
|
||||||
let noise = latents.randn_like(0f64, 1f64)?;
|
let noise = latents.randn_like(0f64, 1f64)?;
|
||||||
scheduler.add_noise(&latents, noise, timesteps[t_start])?
|
scheduler.add_noise(&latents, noise, timesteps[t_start])?
|
||||||
@ -514,31 +465,21 @@ fn run(args: Args) -> Result<()> {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
let latent_model_input = if use_guide_scale {
|
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||||
Tensor::cat(&[&latents, &latents], 0)?
|
|
||||||
} else {
|
|
||||||
latents.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?;
|
let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?;
|
||||||
let noise_pred =
|
let noise_pred =
|
||||||
unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?;
|
unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?;
|
||||||
|
|
||||||
let noise_pred = if use_guide_scale {
|
|
||||||
let noise_pred = noise_pred.chunk(2, 0)?;
|
let noise_pred = noise_pred.chunk(2, 0)?;
|
||||||
let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
|
let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
|
||||||
|
let noise_pred =
|
||||||
(noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * guidance_scale)?)?
|
(noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?;
|
||||||
} else {
|
|
||||||
noise_pred
|
|
||||||
};
|
|
||||||
|
|
||||||
latents = scheduler.step(&noise_pred, timestep, &latents)?;
|
latents = scheduler.step(&noise_pred, timestep, &latents)?;
|
||||||
let dt = start_time.elapsed().as_secs_f32();
|
let dt = start_time.elapsed().as_secs_f32();
|
||||||
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
|
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
|
||||||
|
|
||||||
if args.intermediary_images {
|
if args.intermediary_images {
|
||||||
let image = vae.decode(&(&latents / vae_scale)?)?;
|
let image = vae.decode(&(&latents / 0.18215)?)?;
|
||||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||||
let image_filename =
|
let image_filename =
|
||||||
@ -552,7 +493,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
idx + 1,
|
idx + 1,
|
||||||
num_samples
|
num_samples
|
||||||
);
|
);
|
||||||
let image = vae.decode(&(&latents / vae_scale)?)?;
|
let image = vae.decode(&(&latents / 0.18215)?)?;
|
||||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||||
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
||||||
|
@ -234,14 +234,13 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
|
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let (model, device) = if args.quantized {
|
let (model, device) = if args.quantized {
|
||||||
let filename = &filenames[0];
|
let filename = &filenames[0];
|
||||||
let vb =
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
|
||||||
let model = QStableLM::new(&config, vb)?;
|
let model = QStableLM::new(&config, vb)?;
|
||||||
(Model::Quantized(model), Device::Cpu)
|
(Model::Quantized(model), Device::Cpu)
|
||||||
} else {
|
} else {
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let dtype = if device.is_cuda() {
|
let dtype = if device.is_cuda() {
|
||||||
DType::BF16
|
DType::BF16
|
||||||
} else {
|
} else {
|
||||||
|
@ -96,9 +96,25 @@ impl T5ModelBuilder {
|
|||||||
let api = api.repo(repo);
|
let api = api.repo(repo);
|
||||||
let config_filename = api.get("config.json")?;
|
let config_filename = api.get("config.json")?;
|
||||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2"
|
let weights_filename = if model_id == "google/flan-t5-xxl" {
|
||||||
{
|
vec![
|
||||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
api.get("model-00001-of-00005.safetensors")?,
|
||||||
|
api.get("model-00002-of-00005.safetensors")?,
|
||||||
|
api.get("model-00003-of-00005.safetensors")?,
|
||||||
|
api.get("model-00004-of-00005.safetensors")?,
|
||||||
|
api.get("model-00005-of-00005.safetensors")?,
|
||||||
|
]
|
||||||
|
} else if model_id == "google/flan-ul2" {
|
||||||
|
vec![
|
||||||
|
api.get("model-00001-of-00008.safetensors")?,
|
||||||
|
api.get("model-00002-of-00008.safetensors")?,
|
||||||
|
api.get("model-00003-of-00008.safetensors")?,
|
||||||
|
api.get("model-00004-of-00008.safetensors")?,
|
||||||
|
api.get("model-00005-of-00008.safetensors")?,
|
||||||
|
api.get("model-00006-of-00008.safetensors")?,
|
||||||
|
api.get("model-00007-of-00008.safetensors")?,
|
||||||
|
api.get("model-00008-of-00008.safetensors")?,
|
||||||
|
]
|
||||||
} else {
|
} else {
|
||||||
vec![api.get("model.safetensors")?]
|
vec![api.get("model.safetensors")?]
|
||||||
};
|
};
|
||||||
|
@ -8,7 +8,7 @@ the model itself.
|
|||||||
## Running an example
|
## Running an example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --example trocr --release -- --which base --cpu --image candle-examples/examples/trocr/assets/trocr.png
|
cargo run --example trocr --release -- --which base --cpu --image assets/trocr.png
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -128,13 +128,7 @@ impl Decoder {
|
|||||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||||
let no_speech_token = m::NO_SPEECH_TOKENS
|
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
|
||||||
.iter()
|
|
||||||
.find_map(|token| token_id(&tokenizer, token).ok());
|
|
||||||
let no_speech_token = match no_speech_token {
|
|
||||||
None => anyhow::bail!("unable to find any non-speech token"),
|
|
||||||
Some(n) => n,
|
|
||||||
};
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model,
|
model,
|
||||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||||
@ -518,7 +512,11 @@ fn main() -> Result<()> {
|
|||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
let config = repo.get("config.json")?;
|
let config = repo.get("config.json")?;
|
||||||
let tokenizer = repo.get("tokenizer.json")?;
|
let tokenizer = if args.model == WhichModel::LargeV3 {
|
||||||
|
panic!("openai/whisper-large-v3 does not provide a compatible tokenizer.json config at the moment")
|
||||||
|
} else {
|
||||||
|
repo.get("tokenizer.json")?
|
||||||
|
};
|
||||||
let model = repo.get("model.safetensors")?;
|
let model = repo.get("model.safetensors")?;
|
||||||
(config, tokenizer, model)
|
(config, tokenizer, model)
|
||||||
};
|
};
|
||||||
@ -557,10 +555,8 @@ fn main() -> Result<()> {
|
|||||||
println!("loaded mel: {:?}", mel.dims());
|
println!("loaded mel: {:?}", mel.dims());
|
||||||
|
|
||||||
let mut model = if args.quantized {
|
let mut model = if args.quantized {
|
||||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
let vb =
|
||||||
&weights_filename,
|
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
|
||||||
&device,
|
|
||||||
)?;
|
|
||||||
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
|
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
|
||||||
} else {
|
} else {
|
||||||
let vb =
|
let vb =
|
||||||
|
@ -74,9 +74,9 @@ impl TextGeneration {
|
|||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
let mut generated_tokens = 0usize;
|
let mut generated_tokens = 0usize;
|
||||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
let eos_token = match self.tokenizer.get_token("</s>") {
|
||||||
Some(token) => token,
|
Some(token) => token,
|
||||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
None => anyhow::bail!("cannot find the </s> token"),
|
||||||
};
|
};
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..sample_len {
|
for index in 0..sample_len {
|
||||||
@ -218,7 +218,21 @@ fn main() -> Result<()> {
|
|||||||
.split(',')
|
.split(',')
|
||||||
.map(std::path::PathBuf::from)
|
.map(std::path::PathBuf::from)
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
None => match args.which {
|
||||||
|
Which::L6b => vec![
|
||||||
|
repo.get("model-00001-of-00002.safetensors")?,
|
||||||
|
repo.get("model-00002-of-00002.safetensors")?,
|
||||||
|
],
|
||||||
|
Which::L34b => vec![
|
||||||
|
repo.get("model-00001-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00002-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00003-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00004-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00005-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00006-of-00007.safetensors")?,
|
||||||
|
repo.get("model-00007-of-00007.safetensors")?,
|
||||||
|
],
|
||||||
|
},
|
||||||
};
|
};
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
@ -147,7 +147,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
|||||||
let func = candle_nn::func(move |xs| {
|
let func = candle_nn::func(move |xs| {
|
||||||
let xs = conv.forward(xs)?;
|
let xs = conv.forward(xs)?;
|
||||||
let xs = match &bn {
|
let xs = match &bn {
|
||||||
Some(bn) => xs.apply_t(bn, false)?,
|
Some(bn) => bn.forward(&xs)?,
|
||||||
None => xs,
|
None => xs,
|
||||||
};
|
};
|
||||||
let xs = if leaky {
|
let xs = if leaky {
|
||||||
|
@ -43,7 +43,6 @@ pub fn report(
|
|||||||
confidence_threshold: f32,
|
confidence_threshold: f32,
|
||||||
nms_threshold: f32,
|
nms_threshold: f32,
|
||||||
) -> Result<DynamicImage> {
|
) -> Result<DynamicImage> {
|
||||||
let pred = pred.to_device(&Device::Cpu)?;
|
|
||||||
let (npreds, pred_size) = pred.dims2()?;
|
let (npreds, pred_size) = pred.dims2()?;
|
||||||
let nclasses = pred_size - 5;
|
let nclasses = pred_size - 5;
|
||||||
// The bounding boxes grouped by (maximum) class index.
|
// The bounding boxes grouped by (maximum) class index.
|
||||||
|
@ -32,7 +32,7 @@ Image source:
|
|||||||
### Pose Estimation
|
### Pose Estimation
|
||||||
```bash
|
```bash
|
||||||
cargo run --example yolo-v8 --release -- \
|
cargo run --example yolo-v8 --release -- \
|
||||||
candle-examples/examples/yolo-v8/assets/bike.jpg --task pose
|
candle-examples/examples/yolo-v8/assets/peoples.jpeg --task pose
|
||||||
```
|
```
|
||||||
|
|
||||||

|

|
||||||
|
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
|||||||
mod model;
|
mod model;
|
||||||
use model::{Multiples, YoloV8, YoloV8Pose};
|
use model::{Multiples, YoloV8, YoloV8Pose};
|
||||||
|
|
||||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
use candle::{DType, IndexOp, Result, Tensor};
|
||||||
use candle_nn::{Module, VarBuilder};
|
use candle_nn::{Module, VarBuilder};
|
||||||
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
@ -61,7 +61,6 @@ pub fn report_detect(
|
|||||||
nms_threshold: f32,
|
nms_threshold: f32,
|
||||||
legend_size: u32,
|
legend_size: u32,
|
||||||
) -> Result<DynamicImage> {
|
) -> Result<DynamicImage> {
|
||||||
let pred = pred.to_device(&Device::Cpu)?;
|
|
||||||
let (pred_size, npreds) = pred.dims2()?;
|
let (pred_size, npreds) = pred.dims2()?;
|
||||||
let nclasses = pred_size - 4;
|
let nclasses = pred_size - 4;
|
||||||
// The bounding boxes grouped by (maximum) class index.
|
// The bounding boxes grouped by (maximum) class index.
|
||||||
@ -154,7 +153,6 @@ pub fn report_pose(
|
|||||||
confidence_threshold: f32,
|
confidence_threshold: f32,
|
||||||
nms_threshold: f32,
|
nms_threshold: f32,
|
||||||
) -> Result<DynamicImage> {
|
) -> Result<DynamicImage> {
|
||||||
let pred = pred.to_device(&Device::Cpu)?;
|
|
||||||
let (pred_size, npreds) = pred.dims2()?;
|
let (pred_size, npreds) = pred.dims2()?;
|
||||||
if pred_size != 17 * 3 + 4 + 1 {
|
if pred_size != 17 * 3 + 4 + 1 {
|
||||||
candle::bail!("unexpected pred-size {pred_size}");
|
candle::bail!("unexpected pred-size {pred_size}");
|
||||||
|
@ -117,30 +117,3 @@ pub fn save_image_resize<P: AsRef<std::path::Path>>(
|
|||||||
image.save(p).map_err(candle::Error::wrap)?;
|
image.save(p).map_err(candle::Error::wrap)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Loads the safetensors files for a model from the hub based on a json index file.
|
|
||||||
pub fn hub_load_safetensors(
|
|
||||||
repo: &hf_hub::api::sync::ApiRepo,
|
|
||||||
json_file: &str,
|
|
||||||
) -> Result<Vec<std::path::PathBuf>> {
|
|
||||||
let json_file = repo.get(json_file).map_err(candle::Error::wrap)?;
|
|
||||||
let json_file = std::fs::File::open(json_file)?;
|
|
||||||
let json: serde_json::Value =
|
|
||||||
serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
|
|
||||||
let weight_map = match json.get("weight_map") {
|
|
||||||
None => candle::bail!("no weight map in {json_file:?}"),
|
|
||||||
Some(serde_json::Value::Object(map)) => map,
|
|
||||||
Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
|
|
||||||
};
|
|
||||||
let mut safetensors_files = std::collections::HashSet::new();
|
|
||||||
for value in weight_map.values() {
|
|
||||||
if let Some(file) = value.as_str() {
|
|
||||||
safetensors_files.insert(file.to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let safetensors_files = safetensors_files
|
|
||||||
.iter()
|
|
||||||
.map(|v| repo.get(v).map_err(candle::Error::wrap))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
Ok(safetensors_files)
|
|
||||||
}
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-flash-attn"
|
name = "candle-flash-attn"
|
||||||
version = "0.3.3"
|
version = "0.3.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Flash attention layer for the candle ML framework."
|
description = "Flash attention layer for the candle ML framework."
|
||||||
@ -11,14 +11,14 @@ license = "MIT OR Apache-2.0"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" }
|
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.0", package = "candle-core" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
bindgen_cuda = "0.1.1"
|
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
|
num_cpus = "1.15.0"
|
||||||
|
rayon = "1.7.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
candle-nn = { path = "../candle-nn", features = ["cuda"] }
|
candle-nn = { path = "../candle-nn", version = "0.3.0", features = ["cuda"] }
|
||||||
|
@ -2,32 +2,44 @@
|
|||||||
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
|
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
|
||||||
// variable in order to cache the compiled artifacts and avoid recompiling too often.
|
// variable in order to cache the compiled artifacts and avoid recompiling too often.
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
|
use rayon::prelude::*;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
const KERNEL_FILES: [&str; 17] = [
|
const KERNEL_FILES: [&str; 17] = [
|
||||||
"kernels/flash_api.cu",
|
"flash_api.cu",
|
||||||
"kernels/flash_fwd_hdim128_fp16_sm80.cu",
|
"flash_fwd_hdim128_fp16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim160_fp16_sm80.cu",
|
"flash_fwd_hdim160_fp16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim192_fp16_sm80.cu",
|
"flash_fwd_hdim192_fp16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim224_fp16_sm80.cu",
|
"flash_fwd_hdim224_fp16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim256_fp16_sm80.cu",
|
"flash_fwd_hdim256_fp16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim32_fp16_sm80.cu",
|
"flash_fwd_hdim32_fp16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim64_fp16_sm80.cu",
|
"flash_fwd_hdim64_fp16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim96_fp16_sm80.cu",
|
"flash_fwd_hdim96_fp16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim128_bf16_sm80.cu",
|
"flash_fwd_hdim128_bf16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim160_bf16_sm80.cu",
|
"flash_fwd_hdim160_bf16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim192_bf16_sm80.cu",
|
"flash_fwd_hdim192_bf16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim224_bf16_sm80.cu",
|
"flash_fwd_hdim224_bf16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim256_bf16_sm80.cu",
|
"flash_fwd_hdim256_bf16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim32_bf16_sm80.cu",
|
"flash_fwd_hdim32_bf16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim64_bf16_sm80.cu",
|
"flash_fwd_hdim64_bf16_sm80.cu",
|
||||||
"kernels/flash_fwd_hdim96_bf16_sm80.cu",
|
"flash_fwd_hdim96_bf16_sm80.cu",
|
||||||
];
|
];
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
|
let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
|
||||||
|
|_| num_cpus::get_physical(),
|
||||||
|
|s| usize::from_str(&s).unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
rayon::ThreadPoolBuilder::new()
|
||||||
|
.num_threads(num_cpus)
|
||||||
|
.build_global()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
println!("cargo:rerun-if-changed=build.rs");
|
println!("cargo:rerun-if-changed=build.rs");
|
||||||
for kernel_file in KERNEL_FILES.iter() {
|
for kernel_file in KERNEL_FILES.iter() {
|
||||||
println!("cargo:rerun-if-changed={kernel_file}");
|
println!("cargo:rerun-if-changed=kernels/{kernel_file}");
|
||||||
}
|
}
|
||||||
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
|
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
|
||||||
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
|
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
|
||||||
@ -54,30 +66,223 @@ fn main() -> Result<()> {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
set_cuda_include_dir()?;
|
||||||
|
|
||||||
let kernels = KERNEL_FILES.iter().collect();
|
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
|
||||||
let builder = bindgen_cuda::Builder::default()
|
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
|
||||||
.kernel_paths(kernels)
|
|
||||||
.out_dir(build_dir.clone())
|
let compute_cap = compute_cap()?;
|
||||||
|
|
||||||
|
let out_file = build_dir.join("libflashattention.a");
|
||||||
|
|
||||||
|
let kernel_dir = PathBuf::from("kernels");
|
||||||
|
let cu_files: Vec<_> = KERNEL_FILES
|
||||||
|
.iter()
|
||||||
|
.map(|f| {
|
||||||
|
let mut obj_file = out_dir.join(f);
|
||||||
|
obj_file.set_extension("o");
|
||||||
|
(kernel_dir.join(f), obj_file)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
|
||||||
|
let should_compile = if out_file.exists() {
|
||||||
|
kernel_dir
|
||||||
|
.read_dir()
|
||||||
|
.expect("kernels folder should exist")
|
||||||
|
.any(|entry| {
|
||||||
|
if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
|
||||||
|
let in_modified = entry.metadata().unwrap().modified().unwrap();
|
||||||
|
in_modified.duration_since(*out_modified).is_ok()
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
};
|
||||||
|
if should_compile {
|
||||||
|
cu_files
|
||||||
|
.par_iter()
|
||||||
|
.map(|(cu_file, obj_file)| {
|
||||||
|
let mut command = std::process::Command::new("nvcc");
|
||||||
|
command
|
||||||
.arg("-std=c++17")
|
.arg("-std=c++17")
|
||||||
.arg("-O3")
|
.arg("-O3")
|
||||||
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
||||||
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
||||||
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
||||||
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
||||||
|
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||||
|
.arg("-c")
|
||||||
|
.args(["-o", obj_file.to_str().unwrap()])
|
||||||
|
.args(["--default-stream", "per-thread"])
|
||||||
.arg("-Icutlass/include")
|
.arg("-Icutlass/include")
|
||||||
.arg("--expt-relaxed-constexpr")
|
.arg("--expt-relaxed-constexpr")
|
||||||
.arg("--expt-extended-lambda")
|
.arg("--expt-extended-lambda")
|
||||||
.arg("--use_fast_math")
|
.arg("--use_fast_math")
|
||||||
.arg("--verbose");
|
.arg("--verbose");
|
||||||
|
if let Ok(ccbin_path) = &ccbin_env {
|
||||||
let out_file = build_dir.join("libflashattention.a");
|
command
|
||||||
builder.build_lib(out_file);
|
.arg("-allow-unsupported-compiler")
|
||||||
|
.args(["-ccbin", ccbin_path]);
|
||||||
|
}
|
||||||
|
command.arg(cu_file);
|
||||||
|
let output = command
|
||||||
|
.spawn()
|
||||||
|
.context("failed spawning nvcc")?
|
||||||
|
.wait_with_output()?;
|
||||||
|
if !output.status.success() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||||
|
&command,
|
||||||
|
String::from_utf8_lossy(&output.stdout),
|
||||||
|
String::from_utf8_lossy(&output.stderr)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.collect::<Result<()>>()?;
|
||||||
|
let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
|
||||||
|
let mut command = std::process::Command::new("nvcc");
|
||||||
|
command
|
||||||
|
.arg("--lib")
|
||||||
|
.args(["-o", out_file.to_str().unwrap()])
|
||||||
|
.args(obj_files);
|
||||||
|
let output = command
|
||||||
|
.spawn()
|
||||||
|
.context("failed spawning nvcc")?
|
||||||
|
.wait_with_output()?;
|
||||||
|
if !output.status.success() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||||
|
&command,
|
||||||
|
String::from_utf8_lossy(&output.stdout),
|
||||||
|
String::from_utf8_lossy(&output.stderr)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
println!("cargo:rustc-link-search={}", build_dir.display());
|
println!("cargo:rustc-link-search={}", build_dir.display());
|
||||||
println!("cargo:rustc-link-lib=flashattention");
|
println!("cargo:rustc-link-lib=flashattention");
|
||||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||||
println!("cargo:rustc-link-lib=dylib=stdc++");
|
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||||
|
|
||||||
|
/* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
|
||||||
|
finishing to run for some reason. Calling nvcc manually worked fine.
|
||||||
|
cc::Build::new()
|
||||||
|
.cuda(true)
|
||||||
|
.include("cutlass/include")
|
||||||
|
.flag("--expt-relaxed-constexpr")
|
||||||
|
.flag("--default-stream")
|
||||||
|
.flag("per-thread")
|
||||||
|
.flag(&format!("--gpu-architecture=sm_{compute_cap}"))
|
||||||
|
.file("kernels/flash_fwd_hdim32_fp16_sm80.cu")
|
||||||
|
.compile("flashattn");
|
||||||
|
*/
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn set_cuda_include_dir() -> Result<()> {
|
||||||
|
// NOTE: copied from cudarc build.rs.
|
||||||
|
let env_vars = [
|
||||||
|
"CUDA_PATH",
|
||||||
|
"CUDA_ROOT",
|
||||||
|
"CUDA_TOOLKIT_ROOT_DIR",
|
||||||
|
"CUDNN_LIB",
|
||||||
|
];
|
||||||
|
let env_vars = env_vars
|
||||||
|
.into_iter()
|
||||||
|
.map(std::env::var)
|
||||||
|
.filter_map(Result::ok)
|
||||||
|
.map(Into::<PathBuf>::into);
|
||||||
|
|
||||||
|
let roots = [
|
||||||
|
"/usr",
|
||||||
|
"/usr/local/cuda",
|
||||||
|
"/opt/cuda",
|
||||||
|
"/usr/lib/cuda",
|
||||||
|
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||||
|
"C:/CUDA",
|
||||||
|
];
|
||||||
|
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||||
|
let root = env_vars
|
||||||
|
.chain(roots)
|
||||||
|
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||||
|
.context("cannot find include/cuda.h")?;
|
||||||
|
println!(
|
||||||
|
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||||
|
root.join("include").display()
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
fn compute_cap() -> Result<usize> {
|
||||||
|
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||||
|
|
||||||
|
// Try to parse compute caps from env
|
||||||
|
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||||
|
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
|
||||||
|
compute_cap_str
|
||||||
|
.parse::<usize>()
|
||||||
|
.context("Could not parse compute cap")?
|
||||||
|
} else {
|
||||||
|
// Use nvidia-smi to get the current compute cap
|
||||||
|
let out = std::process::Command::new("nvidia-smi")
|
||||||
|
.arg("--query-gpu=compute_cap")
|
||||||
|
.arg("--format=csv")
|
||||||
|
.output()
|
||||||
|
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||||
|
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||||
|
let mut lines = out.lines();
|
||||||
|
assert_eq!(
|
||||||
|
lines.next().context("missing line in stdout")?,
|
||||||
|
"compute_cap"
|
||||||
|
);
|
||||||
|
let cap = lines
|
||||||
|
.next()
|
||||||
|
.context("missing line in stdout")?
|
||||||
|
.replace('.', "");
|
||||||
|
let cap = cap
|
||||||
|
.parse::<usize>()
|
||||||
|
.with_context(|| format!("cannot parse as int {cap}"))?;
|
||||||
|
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
|
||||||
|
cap
|
||||||
|
};
|
||||||
|
|
||||||
|
// Grab available GPU codes from nvcc and select the highest one
|
||||||
|
let (supported_nvcc_codes, max_nvcc_code) = {
|
||||||
|
let out = std::process::Command::new("nvcc")
|
||||||
|
.arg("--list-gpu-code")
|
||||||
|
.output()
|
||||||
|
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||||
|
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||||
|
|
||||||
|
let out = out.lines().collect::<Vec<&str>>();
|
||||||
|
let mut codes = Vec::with_capacity(out.len());
|
||||||
|
for code in out {
|
||||||
|
let code = code.split('_').collect::<Vec<&str>>();
|
||||||
|
if !code.is_empty() && code.contains(&"sm") {
|
||||||
|
if let Ok(num) = code[1].parse::<usize>() {
|
||||||
|
codes.push(num);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
codes.sort();
|
||||||
|
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
|
||||||
|
(codes, max_nvcc_code)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check that nvcc supports the asked compute caps
|
||||||
|
if !supported_nvcc_codes.contains(&compute_cap) {
|
||||||
|
anyhow::bail!(
|
||||||
|
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if compute_cap > max_nvcc_code {
|
||||||
|
anyhow::bail!(
|
||||||
|
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(compute_cap)
|
||||||
|
}
|
||||||
|
@ -1,62 +0,0 @@
|
|||||||
#include <cmath>
|
|
||||||
|
|
||||||
#include <cute/tensor.hpp>
|
|
||||||
|
|
||||||
#include <cutlass/cutlass.h>
|
|
||||||
#include <cutlass/array.h>
|
|
||||||
|
|
||||||
#include "utils.h"
|
|
||||||
|
|
||||||
namespace flash {
|
|
||||||
|
|
||||||
using namespace cute;
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <bool Is_causal, typename Engine, typename Layout>
|
|
||||||
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
|
|
||||||
const int col_idx_offset_,
|
|
||||||
const int max_seqlen_k,
|
|
||||||
const int row_idx_offset,
|
|
||||||
const int max_seqlen_q,
|
|
||||||
const int warp_row_stride,
|
|
||||||
const float alibi_slope) {
|
|
||||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
|
||||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
|
||||||
const int lane_id = threadIdx.x % 32;
|
|
||||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
|
||||||
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
|
|
||||||
#pragma unroll
|
|
||||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
|
||||||
const int col_idx_base = col_idx_offset + nj * 8;
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
|
||||||
const int col_idx = col_idx_base + j;
|
|
||||||
#pragma unroll
|
|
||||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
|
||||||
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else { // Bias depends on both row_idx and col_idx
|
|
||||||
#pragma unroll
|
|
||||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
|
||||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
|
||||||
const int row_idx = row_idx_base + i * 8;
|
|
||||||
#pragma unroll
|
|
||||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
|
||||||
const int col_idx_base = col_idx_offset + nj * 8;
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
|
||||||
const int col_idx = col_idx_base + j;
|
|
||||||
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace flash
|
|
@ -14,12 +14,9 @@ struct BlockInfo {
|
|||||||
template<typename Params>
|
template<typename Params>
|
||||||
__device__ BlockInfo(const Params ¶ms, const int bidb)
|
__device__ BlockInfo(const Params ¶ms, const int bidb)
|
||||||
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
|
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
|
||||||
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
|
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
|
||||||
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
||||||
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
|
||||||
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
|
||||||
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
|
|
||||||
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
|
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -35,10 +32,8 @@ struct BlockInfo {
|
|||||||
|
|
||||||
const int sum_s_q;
|
const int sum_s_q;
|
||||||
const int sum_s_k;
|
const int sum_s_k;
|
||||||
const int actual_seqlen_q;
|
const uint32_t actual_seqlen_q;
|
||||||
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
|
const uint32_t actual_seqlen_k;
|
||||||
const int seqlen_k_cache;
|
|
||||||
const int actual_seqlen_k;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -7,6 +7,15 @@
|
|||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
// #ifdef OLD_GENERATOR_PATH
|
||||||
|
// #include <ATen/CUDAGeneratorImpl.h>
|
||||||
|
// #else
|
||||||
|
// #include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||||
|
// #endif
|
||||||
|
//
|
||||||
|
// #include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||||
|
|
||||||
|
|
||||||
constexpr int TOTAL_DIM = 0;
|
constexpr int TOTAL_DIM = 0;
|
||||||
constexpr int H_DIM = 1;
|
constexpr int H_DIM = 1;
|
||||||
constexpr int D_DIM = 2;
|
constexpr int D_DIM = 2;
|
||||||
@ -44,7 +53,6 @@ struct Flash_fwd_params : public Qkv_params {
|
|||||||
|
|
||||||
// The O matrix (output).
|
// The O matrix (output).
|
||||||
void * __restrict__ o_ptr;
|
void * __restrict__ o_ptr;
|
||||||
void * __restrict__ oaccum_ptr;
|
|
||||||
|
|
||||||
// The stride between rows of O.
|
// The stride between rows of O.
|
||||||
index_t o_batch_stride;
|
index_t o_batch_stride;
|
||||||
@ -56,10 +64,9 @@ struct Flash_fwd_params : public Qkv_params {
|
|||||||
|
|
||||||
// The pointer to the softmax sum.
|
// The pointer to the softmax sum.
|
||||||
void * __restrict__ softmax_lse_ptr;
|
void * __restrict__ softmax_lse_ptr;
|
||||||
void * __restrict__ softmax_lseaccum_ptr;
|
|
||||||
|
|
||||||
// The dimensions.
|
// The dimensions.
|
||||||
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
|
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
|
||||||
|
|
||||||
// The scaling factors for the kernel.
|
// The scaling factors for the kernel.
|
||||||
float scale_softmax;
|
float scale_softmax;
|
||||||
@ -69,30 +76,8 @@ struct Flash_fwd_params : public Qkv_params {
|
|||||||
int * __restrict__ cu_seqlens_q;
|
int * __restrict__ cu_seqlens_q;
|
||||||
int * __restrict__ cu_seqlens_k;
|
int * __restrict__ cu_seqlens_k;
|
||||||
|
|
||||||
// If provided, the actual length of each k sequence.
|
|
||||||
int * __restrict__ seqused_k;
|
|
||||||
|
|
||||||
int *__restrict__ blockmask;
|
int *__restrict__ blockmask;
|
||||||
|
|
||||||
// The K_new and V_new matrices.
|
|
||||||
void * __restrict__ knew_ptr;
|
|
||||||
void * __restrict__ vnew_ptr;
|
|
||||||
|
|
||||||
// The stride between rows of the Q, K and V matrices.
|
|
||||||
index_t knew_batch_stride;
|
|
||||||
index_t vnew_batch_stride;
|
|
||||||
index_t knew_row_stride;
|
|
||||||
index_t vnew_row_stride;
|
|
||||||
index_t knew_head_stride;
|
|
||||||
index_t vnew_head_stride;
|
|
||||||
|
|
||||||
// The cos and sin matrices for rotary embedding.
|
|
||||||
void * __restrict__ rotary_cos_ptr;
|
|
||||||
void * __restrict__ rotary_sin_ptr;
|
|
||||||
|
|
||||||
// The indices to index into the KV cache.
|
|
||||||
int *__restrict__ cache_batch_idx;
|
|
||||||
|
|
||||||
// The dropout probability (probability of keeping an activation).
|
// The dropout probability (probability of keeping an activation).
|
||||||
float p_dropout;
|
float p_dropout;
|
||||||
// uint32_t p_dropout_in_uint;
|
// uint32_t p_dropout_in_uint;
|
||||||
@ -103,22 +88,11 @@ struct Flash_fwd_params : public Qkv_params {
|
|||||||
float rp_dropout;
|
float rp_dropout;
|
||||||
float scale_softmax_rp_dropout;
|
float scale_softmax_rp_dropout;
|
||||||
|
|
||||||
// Local window size
|
// Random state.
|
||||||
int window_size_left, window_size_right;
|
// at::PhiloxCudaState philox_args;
|
||||||
|
|
||||||
bool is_bf16;
|
bool is_bf16;
|
||||||
bool is_causal;
|
bool is_causal;
|
||||||
|
|
||||||
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
|
||||||
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
|
||||||
bool is_seqlens_k_cumulative;
|
|
||||||
|
|
||||||
bool is_rotary_interleaved;
|
|
||||||
|
|
||||||
int num_splits; // For split-KV version
|
|
||||||
|
|
||||||
void * __restrict__ alibi_slopes_ptr;
|
|
||||||
index_t alibi_slopes_batch_stride;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -158,14 +132,10 @@ struct Flash_bwd_params : public Flash_fwd_params {
|
|||||||
|
|
||||||
// The pointer to the softmax d sum.
|
// The pointer to the softmax d sum.
|
||||||
void *__restrict__ dsoftmax_sum;
|
void *__restrict__ dsoftmax_sum;
|
||||||
|
|
||||||
bool deterministic;
|
|
||||||
index_t dq_accum_split_stride;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||||
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
|
||||||
|
|
||||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure);
|
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure);
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) {
|
// void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
// FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||||
|
// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
|
||||||
|
// });
|
||||||
|
// }
|
||||||
|
|
||||||
|
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
FP16_SWITCH(!params.is_bf16, [&] {
|
FP16_SWITCH(!params.is_bf16, [&] {
|
||||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||||
// if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
|
|
||||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||||
// } else {
|
|
||||||
// run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
|
|
||||||
// }
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -18,7 +20,6 @@ extern "C" void run_mha(
|
|||||||
void *v_ptr,
|
void *v_ptr,
|
||||||
void *o_ptr,
|
void *o_ptr,
|
||||||
void *softmax_lse_ptr,
|
void *softmax_lse_ptr,
|
||||||
void *alibi_slopes_ptr,
|
|
||||||
|
|
||||||
int32_t *cu_seqlens_q_ptr,
|
int32_t *cu_seqlens_q_ptr,
|
||||||
int32_t *cu_seqlens_k_ptr,
|
int32_t *cu_seqlens_k_ptr,
|
||||||
@ -27,7 +28,6 @@ extern "C" void run_mha(
|
|||||||
uint32_t k_batch_stride,
|
uint32_t k_batch_stride,
|
||||||
uint32_t v_batch_stride,
|
uint32_t v_batch_stride,
|
||||||
uint32_t o_batch_stride,
|
uint32_t o_batch_stride,
|
||||||
uint32_t alibi_slopes_batch_stride,
|
|
||||||
|
|
||||||
uint32_t q_row_stride,
|
uint32_t q_row_stride,
|
||||||
uint32_t k_row_stride,
|
uint32_t k_row_stride,
|
||||||
@ -51,11 +51,8 @@ extern "C" void run_mha(
|
|||||||
uint32_t seqlen_q_rounded,
|
uint32_t seqlen_q_rounded,
|
||||||
uint32_t seqlen_k_rounded,
|
uint32_t seqlen_k_rounded,
|
||||||
|
|
||||||
int is_bf16,
|
|
||||||
int is_causal,
|
int is_causal,
|
||||||
|
int is_bf16
|
||||||
int window_size_left,
|
|
||||||
int window_size_right
|
|
||||||
) {
|
) {
|
||||||
Flash_fwd_params params;
|
Flash_fwd_params params;
|
||||||
// Reset the parameters
|
// Reset the parameters
|
||||||
@ -68,14 +65,12 @@ extern "C" void run_mha(
|
|||||||
params.o_ptr = o_ptr;
|
params.o_ptr = o_ptr;
|
||||||
|
|
||||||
params.softmax_lse_ptr = softmax_lse_ptr;
|
params.softmax_lse_ptr = softmax_lse_ptr;
|
||||||
params.alibi_slopes_ptr = alibi_slopes_ptr;
|
|
||||||
|
|
||||||
// All stride are in elements, not bytes.
|
// All stride are in elements, not bytes.
|
||||||
params.q_batch_stride = q_batch_stride;
|
params.q_batch_stride = q_batch_stride;
|
||||||
params.k_batch_stride = k_batch_stride;
|
params.k_batch_stride = k_batch_stride;
|
||||||
params.v_batch_stride = v_batch_stride;
|
params.v_batch_stride = v_batch_stride;
|
||||||
params.o_batch_stride = o_batch_stride;
|
params.o_batch_stride = o_batch_stride;
|
||||||
params.alibi_slopes_batch_stride = alibi_slopes_batch_stride;
|
|
||||||
|
|
||||||
params.q_row_stride = q_row_stride;
|
params.q_row_stride = q_row_stride;
|
||||||
params.k_row_stride = k_row_stride;
|
params.k_row_stride = k_row_stride;
|
||||||
@ -97,6 +92,7 @@ extern "C" void run_mha(
|
|||||||
params.seqlen_k_rounded = seqlen_k_rounded;
|
params.seqlen_k_rounded = seqlen_k_rounded;
|
||||||
params.d = d;
|
params.d = d;
|
||||||
params.d_rounded = d_rounded;
|
params.d_rounded = d_rounded;
|
||||||
|
params.is_causal = is_causal;
|
||||||
|
|
||||||
// Set the different scale values.
|
// Set the different scale values.
|
||||||
params.scale_softmax = softmax_scale;
|
params.scale_softmax = softmax_scale;
|
||||||
@ -110,14 +106,6 @@ extern "C" void run_mha(
|
|||||||
params.cu_seqlens_q = cu_seqlens_q_ptr;
|
params.cu_seqlens_q = cu_seqlens_q_ptr;
|
||||||
params.cu_seqlens_k = cu_seqlens_k_ptr;
|
params.cu_seqlens_k = cu_seqlens_k_ptr;
|
||||||
params.p_ptr = nullptr; // used for `return_softmax`.
|
params.p_ptr = nullptr; // used for `return_softmax`.
|
||||||
params.seqused_k = nullptr;
|
|
||||||
|
|
||||||
params.is_causal = is_causal;
|
|
||||||
params.window_size_left = window_size_left;
|
|
||||||
params.window_size_right = window_size_right;
|
|
||||||
|
|
||||||
params.is_seqlens_k_cumulative = true;
|
|
||||||
params.num_splits = 1;
|
|
||||||
|
|
||||||
cudaStream_t stream = 0; // Use the default stream.
|
cudaStream_t stream = 0; // Use the default stream.
|
||||||
run_mha_fwd(params, stream);
|
run_mha_fwd(params, stream);
|
||||||
|
@ -1,9 +1,18 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
// This file is auto-generated. See "generate_kernels.py"
|
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
|
// template<>
|
||||||
|
// void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
// using elem_type = cutlass::bfloat16_t;
|
||||||
|
// if (params.p_dropout == 1.f) {
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
|
||||||
|
// } else {
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
template<>
|
template<>
|
||||||
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
|
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
|
||||||
|
@ -1,9 +1,31 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
// This file is auto-generated. See "generate_kernels.py"
|
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
|
// template<>
|
||||||
|
// void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
// using elem_type = cutlass::half_t;
|
||||||
|
// if (params.p_dropout == 1.f) {
|
||||||
|
// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, false, elem_type>, false>(params, stream);
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, true, elem_type>, false>(params, stream);
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, true, elem_type>, false>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, false>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 64, 4, false, false, elem_type>, false>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 128, 4, false, false, elem_type>, false>(params, stream);
|
||||||
|
// // 1st ones are good for H100, A100
|
||||||
|
// // 2nd one is good for A6000 bc we get slightly better occupancy
|
||||||
|
// } else {
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, false, elem_type>, true>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, true, elem_type>, true>(params, stream);
|
||||||
|
// // 1st one is good for H100, A100, A6000
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
|
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
|
||||||
|
@ -1,9 +1,16 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
// This file is auto-generated. See "generate_kernels.py"
|
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
|
// template<>
|
||||||
|
// void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
// using elem_type = cutlass::bfloat16_t;
|
||||||
|
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||||
|
// });
|
||||||
|
// }
|
||||||
template<>
|
template<>
|
||||||
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
|
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
|
||||||
|
@ -1,9 +1,26 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
// This file is auto-generated. See "generate_kernels.py"
|
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
|
// template<>
|
||||||
|
// void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
// using elem_type = cutlass::half_t;
|
||||||
|
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, true, elem_type>, Is_dropout>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, elem_type>>(params, stream);
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 128, 4, false, elem_type>>(params, stream);
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, elem_type>>(params, stream);
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 8, false, elem_type>>(params, stream);
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 128, 8, false, elem_type>>(params, stream);
|
||||||
|
// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest.
|
||||||
|
// // For A100, H100, 1st is fastest.
|
||||||
|
// });
|
||||||
|
// }
|
||||||
template<>
|
template<>
|
||||||
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
|
run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
|
||||||
|
@ -1,10 +1,16 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
// This file is auto-generated. See "generate_kernels.py"
|
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
template<>
|
// template<>
|
||||||
void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
// void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
// using elem_type = cutlass::bfloat16_t;
|
||||||
|
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||||
|
// });
|
||||||
|
// }
|
||||||
|
template<> void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
|
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,26 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
// This file is auto-generated. See "generate_kernels.py"
|
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
|
// template<>
|
||||||
|
// void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
// using elem_type = cutlass::half_t;
|
||||||
|
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||||
|
// // This one is slightly faster for causal?
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 8, false, elem_type>>(params, stream);
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, elem_type>>(params, stream);
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 4, false, elem_type>>(params, stream);
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 128, 4, false, elem_type>>(params, stream);
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 128, 8, false, elem_type>>(params, stream);
|
||||||
|
// });
|
||||||
|
// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout
|
||||||
|
// // For A6000, 1st is faster when causal, 3rd is faster when not causal
|
||||||
|
// }
|
||||||
template<>
|
template<>
|
||||||
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
|
run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
// This file is auto-generated. See "generate_kernels.py"
|
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
template<>
|
template<> void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
|
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
// This file is auto-generated. See "generate_kernels.py"
|
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
template<>
|
template<> void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
|
run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
// This file is auto-generated. See "generate_kernels.py"
|
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
template<>
|
template<> void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
|
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
// This file is auto-generated. See "generate_kernels.py"
|
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
template<>
|
template<> void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
|
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
// This file is auto-generated. See "generate_kernels.py"
|
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
|
@ -1,9 +1,22 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
// This file is auto-generated. See "generate_kernels.py"
|
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
|
// template<>
|
||||||
|
// void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
|
// using elem_type = cutlass::half_t;
|
||||||
|
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
|
// run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 128, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||||
|
// // For dropout there might be a lot of register spilling?
|
||||||
|
// // These two are very slow due to register spilling
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 128, 4, false, elem_type>>(params, stream);
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 256, 4, false, elem_type>>(params, stream);
|
||||||
|
// // This one is slightly slower
|
||||||
|
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 64, 4, false, elem_type>>(params, stream);
|
||||||
|
// });
|
||||||
|
// }
|
||||||
template<>
|
template<>
|
||||||
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
|
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user