mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Compare commits
1 Commits
bf16_metal
...
meshgrid-f
Author | SHA1 | Date | |
---|---|---|---|
20da4f44ef |
7
.github/dependabot.yml
vendored
7
.github/dependabot.yml
vendored
@ -1,7 +0,0 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "cargo"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 5
|
72
.github/workflows/ci_cuda.yaml
vendored
72
.github/workflows/ci_cuda.yaml
vendored
@ -5,15 +5,47 @@ on:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
start-runner:
|
||||
name: Start self-hosted EC2 runner
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
||||
EC2_INSTANCE_TYPE: g5.xlarge
|
||||
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
|
||||
EC2_SECURITY_GROUP: sg-030175c435ac141d6
|
||||
outputs:
|
||||
label: ${{ steps.start-ec2-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
- name: Start EC2 runner
|
||||
id: start-ec2-runner
|
||||
uses: philschmid/philschmid-ec2-github-runner@main
|
||||
with:
|
||||
mode: start
|
||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||
ec2-image-id: ${{ env.EC2_AMI_ID }}
|
||||
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
|
||||
subnet-id: ${{ env.EC2_SUBNET_ID }}
|
||||
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
|
||||
aws-resource-tags: > # optional, requires additional permissions
|
||||
[
|
||||
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
|
||||
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
|
||||
]
|
||||
|
||||
test-cuda:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
runs-on: [single-gpu, nvidia-gpu, t4, ci]
|
||||
container:
|
||||
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
|
||||
options: --gpus 0
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
||||
needs: start-runner # required to start the main job when the runner is ready
|
||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
@ -24,10 +56,32 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Install dependencies
|
||||
run: apt-get update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
|
||||
- name: Install Rust Stable
|
||||
uses: actions-rust-lang/setup-rust-toolchain@v1
|
||||
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: apt-get update -y && apt-get install libssl-dev -y
|
||||
- name: Test (cuda)
|
||||
run: cargo test --features cuda
|
||||
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||
stop-runner:
|
||||
name: Stop self-hosted EC2 runner
|
||||
needs:
|
||||
- start-runner
|
||||
- test-cuda
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
- name: Stop EC2 runner
|
||||
uses: philschmid/philschmid-ec2-github-runner@main
|
||||
with:
|
||||
mode: stop
|
||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||
label: ${{ needs.start-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
||||
|
BIN
.github/workflows/maturin.yml
vendored
BIN
.github/workflows/maturin.yml
vendored
Binary file not shown.
8
.github/workflows/python.yml
vendored
8
.github/workflows/python.yml
vendored
@ -39,12 +39,6 @@ jobs:
|
||||
path: ~/.cargo/registry
|
||||
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- name: Install Protoc
|
||||
uses: arduino/setup-protoc@v2
|
||||
with:
|
||||
version: "25.0"
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Install
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
@ -52,7 +46,7 @@ jobs:
|
||||
source .env/bin/activate
|
||||
pip install -U pip
|
||||
pip install pytest maturin black
|
||||
python -m maturin develop -r --features onnx
|
||||
python -m maturin develop -r
|
||||
|
||||
- name: Check style
|
||||
working-directory: ./candle-pyo3
|
||||
|
@ -63,7 +63,7 @@ This documents the main changes to the `candle` crate.
|
||||
[760](https://github.com/huggingface/candle/pull/760).
|
||||
- Add the Segment-Anything Model (SAM) as an example
|
||||
[773](https://github.com/huggingface/candle/pull/773).
|
||||
- TinyViT backbone for the segment anything example
|
||||
- TinyViT backbone for the segemnt anything example
|
||||
[787](https://github.com/huggingface/candle/pull/787).
|
||||
- Shape with holes support
|
||||
[770](https://github.com/huggingface/candle/pull/770).
|
||||
|
40
Cargo.toml
40
Cargo.toml
@ -7,19 +7,20 @@ members = [
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
"candle-transformers",
|
||||
"candle-wasm-examples/*",
|
||||
"candle-wasm-examples/llama2-c",
|
||||
"candle-wasm-examples/segment-anything",
|
||||
"candle-wasm-examples/whisper",
|
||||
"candle-wasm-examples/yolo",
|
||||
"candle-wasm-examples/bert",
|
||||
"candle-wasm-examples/phi",
|
||||
"candle-wasm-examples/t5",
|
||||
"candle-wasm-tests",
|
||||
]
|
||||
exclude = [
|
||||
"candle-flash-attn",
|
||||
"candle-kernels",
|
||||
"candle-metal-kernels",
|
||||
"candle-onnx",
|
||||
]
|
||||
exclude = ["candle-flash-attn", "candle-kernels"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.4.0"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -31,18 +32,9 @@ license = "MIT OR Apache-2.0"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.4.0" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.4.0" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.0" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.4.0" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.0" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.4.0" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.4.0" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.4.0" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
cudarc = { version = "0.9.14", features = ["f16"] }
|
||||
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.3.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||
@ -50,27 +42,25 @@ imageproc = { version = "0.23.0", default-features = false }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
parquet = { version = "50.0.0" }
|
||||
parquet = { version = "45.0.0" }
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
rayon = "1.7.0"
|
||||
rusttype = { version = "0.9", default-features = false }
|
||||
safetensors = "0.4.1"
|
||||
safetensors = "0.3.1"
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_plain = "1.0.2"
|
||||
serde_json = "1.0.99"
|
||||
thiserror = "1"
|
||||
tokenizers = { version = "0.15.0", default-features = false }
|
||||
tokenizers = { version = "0.13.4", default-features = false }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
74
README.md
74
README.md
@ -51,33 +51,23 @@ For more advanced examples, please have a look at the following section.
|
||||
These online demos run entirely in your browser:
|
||||
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
|
||||
object recognition.
|
||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
|
||||
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
|
||||
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
|
||||
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
||||
- [Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
||||
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
|
||||
|
||||
We also provide a some command line based examples using state of the art models:
|
||||
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
|
||||
the SOLAR-10.7B variant.
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||
- [Phi-v1 and Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets. Also supports
|
||||
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
|
||||
- [Mamba](./candle-examples/examples/mamba/): an inference only
|
||||
implementation of the Mamba state space model.
|
||||
pre-trained on 1T tokens of English and code datasets.
|
||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||
better performance than all publicly available 13b models as of 2023-09-28.
|
||||
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
|
||||
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
||||
much faster inference.
|
||||
performance larger than all publicly available 13b models as of 2023-09-28.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||
(English/Chinese) general LLMs with 6b and 34b parameters.
|
||||
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
|
||||
the LLaMA model using the same quantization techniques as
|
||||
[llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||
@ -85,7 +75,7 @@ We also provide a some command line based examples using state of the art models
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
|
||||
|
||||
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
|
||||
image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions.
|
||||
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
|
||||
|
||||
@ -105,19 +95,12 @@ We also provide a some command line based examples using state of the art models
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||
|
||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
||||
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
using self-supervision (can be used for imagenet classification, depth
|
||||
evaluation, segmentation).
|
||||
- [VGG](./candle-examples/examples/vgg/),
|
||||
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
generate captions for an image.
|
||||
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
|
||||
dedicated submodels for hand-writing and printed recognition.
|
||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||
model, generates the translated text from the input text.
|
||||
|
||||
Run them using commands like:
|
||||
```
|
||||
@ -133,7 +116,7 @@ There are also some wasm examples for whisper and
|
||||
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
|
||||
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
|
||||
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
|
||||
[Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
|
||||
[Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
|
||||
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
|
||||
|
||||
For LLaMA2, run the following command to retrieve the weight files and start a
|
||||
@ -150,20 +133,10 @@ And then head over to
|
||||
<!--- ANCHOR: useful_libraries --->
|
||||
|
||||
## Useful External Resources
|
||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
|
||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
|
||||
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and
|
||||
ergonomic LoRA implementation for Candle. `candle-lora` has
|
||||
out-of-the-box LoRA support for many models from Candle, which can be found
|
||||
[here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
|
||||
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
|
||||
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||
serving local LLMs including an OpenAI compatible API server.
|
||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
|
||||
that conforms to the official `peft` implementation.
|
||||
|
||||
If you have an addition to this list, please submit a pull request.
|
||||
|
||||
@ -182,35 +155,23 @@ If you have an addition to this list, please submit a pull request.
|
||||
- WASM support, run your models in a browser.
|
||||
- Included models.
|
||||
- Language Models.
|
||||
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
|
||||
- LLaMA v1 and v2.
|
||||
- Falcon.
|
||||
- StarCoder.
|
||||
- Phi 1, 1.5, and 2.
|
||||
- Mamba, Minimal Mamba
|
||||
- Phi v1.5.
|
||||
- Mistral 7b v0.1.
|
||||
- Mixtral 8x7b v0.1.
|
||||
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
|
||||
- StableLM-3B-4E1T.
|
||||
- Replit-code-v1.5-3B.
|
||||
- T5.
|
||||
- Bert.
|
||||
- Yi-6B and Yi-34B.
|
||||
- Quantized LLMs.
|
||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||
- Mistral 7b, and 7b instruct.
|
||||
- Mixtral 8x7b.
|
||||
- Zephyr 7b a and b (Mistral-7b based).
|
||||
- OpenChat 3.5 (Mistral-7b based).
|
||||
- Text to text.
|
||||
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||
- Marian MT (Machine Translation).
|
||||
- Whisper (multi-lingual support).
|
||||
- Text to image.
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Wurstchen v2.
|
||||
- Image to text.
|
||||
- BLIP.
|
||||
- TrOCR.
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
||||
- yolo-v3, yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||
@ -249,7 +210,6 @@ Cheatsheet:
|
||||
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
|
||||
- [candle-transformers](./candle-transformers): transformers-related utilities.
|
||||
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
|
||||
- [candle-onnx](./candle-onnx/): ONNX model evaluation.
|
||||
|
||||
## FAQ
|
||||
|
||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { workspace = true }
|
||||
candle-datasets = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
@ -28,7 +28,6 @@ let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap()
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_2() {
|
||||
{
|
||||
// ANCHOR: book_hub_2
|
||||
use candle::Device;
|
||||
use hf_hub::api::sync::Api;
|
||||
@ -46,10 +45,9 @@ let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap()
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
// #[rustfmt::skip]
|
||||
// #[test]
|
||||
// fn book_hub_3() {
|
||||
{
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_3() {
|
||||
// ANCHOR: book_hub_3
|
||||
use candle::{DType, Device, Tensor};
|
||||
use hf_hub::api::sync::Api;
|
||||
@ -104,7 +102,6 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
|
||||
assert_eq!(view.shape(), &[768, 768]);
|
||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||
}
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
|
@ -12,9 +12,7 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { workspace = true, optional = true }
|
||||
candle-metal-kernels = { workspace = true, optional = true }
|
||||
metal = { workspace = true, optional = true}
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
half = { workspace = true }
|
||||
@ -34,8 +32,6 @@ zip = { workspace = true }
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
criterion = { workspace = true }
|
||||
|
||||
|
||||
[features]
|
||||
default = []
|
||||
@ -43,8 +39,3 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
||||
cudnn = ["cuda", "cudarc/cudnn"]
|
||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||
|
||||
[[bench]]
|
||||
name = "bench_main"
|
||||
harness = false
|
||||
|
@ -1,9 +0,0 @@
|
||||
mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
criterion_main!(
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::where_cond::benches
|
||||
);
|
@ -1,43 +0,0 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor) {
|
||||
a.affine(12.34, 56.78).unwrap();
|
||||
}
|
||||
|
||||
fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
||||
|
||||
let flops = b * m * k * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
|
||||
run_affine_benchmark(c, &device, DType::F16, "affine_f16");
|
||||
run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,44 +0,0 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor, b: &Tensor) {
|
||||
a.matmul(&b.t().unwrap()).unwrap();
|
||||
}
|
||||
|
||||
fn run_bench(c: &mut Criterion, device: &Device) {
|
||||
let b = 1;
|
||||
let m = 1;
|
||||
let n = 2048;
|
||||
let k = 2048;
|
||||
|
||||
let dtype = DType::F32;
|
||||
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
||||
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
|
||||
|
||||
let flops = b * m * n * k;
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("matmul"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&lhs), black_box(&rhs));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_bench(c, &device);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,66 +0,0 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod where_cond;
|
||||
|
||||
use candle_core::{Device, Result};
|
||||
|
||||
pub(crate) trait BenchDevice {
|
||||
fn sync(&self) -> Result<()>;
|
||||
|
||||
fn bench_name<S: Into<String>>(&self, name: S) -> String;
|
||||
}
|
||||
|
||||
impl BenchDevice for Device {
|
||||
fn sync(&self) -> Result<()> {
|
||||
match self {
|
||||
Device::Cpu => Ok(()),
|
||||
Device::Cuda(device) => {
|
||||
#[cfg(feature = "cuda")]
|
||||
return Ok(device.synchronize()?);
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
#[cfg(feature = "metal")]
|
||||
return Ok(device.wait_until_completed()?);
|
||||
#[cfg(not(feature = "metal"))]
|
||||
panic!("Metal device without metal feature enabled: {:?}", device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn bench_name<S: Into<String>>(&self, name: S) -> String {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let cpu_type = if cfg!(feature = "accelerate") {
|
||||
"accelerate"
|
||||
} else if cfg!(feature = "mkl") {
|
||||
"mkl"
|
||||
} else {
|
||||
"cpu"
|
||||
};
|
||||
format!("{}_{}", cpu_type, name.into())
|
||||
}
|
||||
Device::Cuda(_) => format!("cuda_{}", name.into()),
|
||||
Device::Metal(_) => format!("metal_{}", name.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct BenchDeviceHandler {
|
||||
devices: Vec<Device>,
|
||||
}
|
||||
|
||||
impl BenchDeviceHandler {
|
||||
pub fn new() -> Result<Self> {
|
||||
let mut devices = Vec::new();
|
||||
if cfg!(feature = "metal") {
|
||||
devices.push(Device::new_metal(0)?);
|
||||
} else if cfg!(feature = "cuda") {
|
||||
devices.push(Device::new_cuda(0)?);
|
||||
}
|
||||
devices.push(Device::Cpu);
|
||||
Ok(Self { devices })
|
||||
}
|
||||
}
|
@ -1,63 +0,0 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn rand_uniform(a: &Tensor) {
|
||||
a.rand_like(-1.0, 123.0).unwrap();
|
||||
}
|
||||
|
||||
fn rand_normal(a: &Tensor) {
|
||||
a.randn_like(100.0, 15.0).unwrap();
|
||||
}
|
||||
|
||||
fn run_random_bench(c: &mut Criterion, device: &Device) {
|
||||
let b = 1;
|
||||
|
||||
let rows = 2048;
|
||||
let cols = 2048;
|
||||
|
||||
let dtype = DType::F32;
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||
|
||||
let flops = b * rows * cols * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |benches| {
|
||||
benches.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
rand_uniform(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("random_normal"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |benches| {
|
||||
benches.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
rand_normal(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_random_bench(c, &device);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,64 +0,0 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
|
||||
a.where_cond(b, c).unwrap();
|
||||
}
|
||||
|
||||
const fn create_cond_arr<const N: usize>() -> [u8; N] {
|
||||
let mut arr = [0u8; N];
|
||||
let mut i = 0;
|
||||
while i < N {
|
||||
arr[i] = (i % 2) as u8;
|
||||
i += 1;
|
||||
}
|
||||
arr
|
||||
}
|
||||
|
||||
const B: usize = 1;
|
||||
const M: usize = 1024;
|
||||
const K: usize = 1024;
|
||||
const SIZE: usize = B * M * K;
|
||||
|
||||
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
||||
|
||||
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
|
||||
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
|
||||
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
|
||||
|
||||
let elements = B * M * K;
|
||||
// E.g. 2 f32 tensors + 1 u8 tensor
|
||||
let flops = (2 * elements * dtype.size_in_bytes()) + elements;
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(
|
||||
black_box(&tensor),
|
||||
black_box(&on_true),
|
||||
black_box(&on_false),
|
||||
);
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let device = BenchDeviceHandler::new().unwrap();
|
||||
for d in device.devices {
|
||||
run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
|
||||
run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
|
||||
run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -8,10 +8,11 @@ use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
|
||||
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
|
||||
let new_a = a.slice_scatter(&b, 1, 2)?;
|
||||
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||
let start = std::time::Instant::now();
|
||||
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
|
||||
println!("{:?}", start.elapsed());
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
|
||||
use candle_core::{Device, Result};
|
||||
use candle_core::quantized::{gguf_file, k_quants, QTensor};
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
use rayon::prelude::*;
|
||||
|
||||
@ -11,7 +11,12 @@ enum QuantizationMode {
|
||||
}
|
||||
|
||||
impl QuantizationMode {
|
||||
fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result<QTensor> {
|
||||
fn quantize(
|
||||
&self,
|
||||
name: &str,
|
||||
tensor: QTensor,
|
||||
default: fn(&Tensor) -> Result<QTensor>,
|
||||
) -> Result<QTensor> {
|
||||
match self {
|
||||
Self::Llama => {
|
||||
// Same behavior as the llama.cpp quantization.
|
||||
@ -19,9 +24,9 @@ impl QuantizationMode {
|
||||
if should_quantize {
|
||||
let tensor = tensor.dequantize(&Device::Cpu)?;
|
||||
if name == "output.weight" {
|
||||
QTensor::quantize(&tensor, GgmlDType::Q6K)
|
||||
QTensor::quantize::<k_quants::BlockQ6K>(&tensor)
|
||||
} else {
|
||||
QTensor::quantize(&tensor, dtype)
|
||||
default(&tensor)
|
||||
}
|
||||
} else {
|
||||
Ok(tensor)
|
||||
@ -55,27 +60,6 @@ enum Quantization {
|
||||
F32,
|
||||
}
|
||||
|
||||
impl Quantization {
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
match self {
|
||||
Quantization::Q4_0 => GgmlDType::Q4_0,
|
||||
Quantization::Q4_1 => GgmlDType::Q4_1,
|
||||
Quantization::Q5_0 => GgmlDType::Q5_0,
|
||||
Quantization::Q5_1 => GgmlDType::Q5_1,
|
||||
Quantization::Q8_0 => GgmlDType::Q8_0,
|
||||
Quantization::Q8_1 => GgmlDType::Q8_1,
|
||||
Quantization::Q2k => GgmlDType::Q2K,
|
||||
Quantization::Q3k => GgmlDType::Q3K,
|
||||
Quantization::Q4k => GgmlDType::Q4K,
|
||||
Quantization::Q5k => GgmlDType::Q5K,
|
||||
Quantization::Q6k => GgmlDType::Q6K,
|
||||
Quantization::Q8k => GgmlDType::Q8K,
|
||||
Quantization::F16 => GgmlDType::F16,
|
||||
Quantization::F32 => GgmlDType::F32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
enum Format {
|
||||
Safetensors,
|
||||
@ -118,7 +102,7 @@ enum Command {
|
||||
},
|
||||
|
||||
Quantize {
|
||||
/// The input file(s), in safetensors format.
|
||||
/// The input file, in gguf format.
|
||||
in_file: Vec<std::path::PathBuf>,
|
||||
|
||||
/// The output file, in gguf format.
|
||||
@ -133,15 +117,6 @@ enum Command {
|
||||
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
|
||||
mode: QuantizationMode,
|
||||
},
|
||||
|
||||
Dequantize {
|
||||
/// The input file, in gguf format.
|
||||
in_file: std::path::PathBuf,
|
||||
|
||||
/// The output file, in safetensors format.
|
||||
#[arg(long)]
|
||||
out_file: std::path::PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
@ -150,12 +125,7 @@ struct Args {
|
||||
command: Command,
|
||||
}
|
||||
|
||||
fn run_ls(
|
||||
file: &std::path::PathBuf,
|
||||
format: Option<Format>,
|
||||
verbose: bool,
|
||||
device: &Device,
|
||||
) -> Result<()> {
|
||||
fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
|
||||
let format = match format {
|
||||
Some(format) => format,
|
||||
None => match Format::infer(file) {
|
||||
@ -196,7 +166,7 @@ fn run_ls(
|
||||
}
|
||||
}
|
||||
Format::Pth => {
|
||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?;
|
||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?;
|
||||
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
for tensor_info in tensors.iter() {
|
||||
println!(
|
||||
@ -221,7 +191,7 @@ fn run_ls(
|
||||
}
|
||||
Format::Ggml => {
|
||||
let mut file = std::fs::File::open(file)?;
|
||||
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
|
||||
let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
|
||||
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, qtensor) in tensors.iter() {
|
||||
@ -262,8 +232,37 @@ fn run_quantize_safetensors(
|
||||
}
|
||||
println!("tensors: {}", tensors.len());
|
||||
|
||||
let dtype = q.dtype();
|
||||
let block_size = dtype.block_size();
|
||||
let quantize_fn = match q {
|
||||
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||
Quantization::F32 => QTensor::quantize::<f32>,
|
||||
};
|
||||
let block_size = match q {
|
||||
Quantization::Q4_0 => k_quants::QK4_0,
|
||||
Quantization::Q4_1 => k_quants::QK4_1,
|
||||
Quantization::Q5_0 => k_quants::QK5_0,
|
||||
Quantization::Q5_1 => k_quants::QK5_1,
|
||||
Quantization::Q8_0 => k_quants::QK8_0,
|
||||
Quantization::Q8_1 => k_quants::QK8_1,
|
||||
Quantization::Q2k
|
||||
| Quantization::Q3k
|
||||
| Quantization::Q4k
|
||||
| Quantization::Q5k
|
||||
| Quantization::Q6k
|
||||
| Quantization::Q8k => k_quants::QK_K,
|
||||
Quantization::F16 | Quantization::F32 => 1,
|
||||
};
|
||||
|
||||
let qtensors = tensors
|
||||
.into_par_iter()
|
||||
@ -271,9 +270,9 @@ fn run_quantize_safetensors(
|
||||
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
|
||||
println!(" quantizing {name} {tensor:?} {should_quantize}");
|
||||
let tensor = if should_quantize {
|
||||
QTensor::quantize(&tensor, dtype)?
|
||||
quantize_fn(&tensor)?
|
||||
} else {
|
||||
QTensor::quantize(&tensor, GgmlDType::F32)?
|
||||
QTensor::quantize::<f32>(&tensor)?
|
||||
};
|
||||
Ok((name, tensor))
|
||||
})
|
||||
@ -286,29 +285,11 @@ fn run_quantize_safetensors(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_dequantize(
|
||||
in_file: std::path::PathBuf,
|
||||
out_file: std::path::PathBuf,
|
||||
device: &Device,
|
||||
) -> Result<()> {
|
||||
let mut in_file = std::fs::File::open(in_file)?;
|
||||
let content = gguf_file::Content::read(&mut in_file)?;
|
||||
let mut tensors = std::collections::HashMap::new();
|
||||
for (tensor_name, _) in content.tensor_infos.iter() {
|
||||
let tensor = content.tensor(&mut in_file, tensor_name, device)?;
|
||||
let tensor = tensor.dequantize(device)?;
|
||||
tensors.insert(tensor_name.to_string(), tensor);
|
||||
}
|
||||
candle_core::safetensors::save(&tensors, out_file)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize(
|
||||
in_files: &[std::path::PathBuf],
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
qmode: QuantizationMode,
|
||||
device: &Device,
|
||||
) -> Result<()> {
|
||||
if in_files.is_empty() {
|
||||
candle_core::bail!("no specified input files")
|
||||
@ -334,15 +315,31 @@ fn run_quantize(
|
||||
let content = gguf_file::Content::read(&mut in_)?;
|
||||
println!("tensors: {}", content.tensor_infos.len());
|
||||
|
||||
let dtype = q.dtype();
|
||||
let quantize_fn = match q {
|
||||
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||
Quantization::F32 => QTensor::quantize::<f32>,
|
||||
};
|
||||
|
||||
let qtensors = content
|
||||
.tensor_infos
|
||||
.par_iter()
|
||||
.map(|(name, _)| {
|
||||
println!(" quantizing {name}");
|
||||
let mut in_file = std::fs::File::open(&in_files[0])?;
|
||||
let tensor = content.tensor(&mut in_file, name, device)?;
|
||||
let tensor = qmode.quantize(name, tensor, dtype)?;
|
||||
let tensor = content.tensor(&mut in_file, name)?;
|
||||
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
@ -362,7 +359,6 @@ fn run_quantize(
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = Device::Cpu;
|
||||
match args.command {
|
||||
Command::Ls {
|
||||
files,
|
||||
@ -374,7 +370,7 @@ fn main() -> anyhow::Result<()> {
|
||||
if multiple_files {
|
||||
println!("--- {file:?} ---");
|
||||
}
|
||||
run_ls(file, format.clone(), verbose, &device)?
|
||||
run_ls(file, format.clone(), verbose)?
|
||||
}
|
||||
}
|
||||
Command::Quantize {
|
||||
@ -382,8 +378,7 @@ fn main() -> anyhow::Result<()> {
|
||||
out_file,
|
||||
quantization,
|
||||
mode,
|
||||
} => run_quantize(&in_file, out_file, quantization, mode, &device)?,
|
||||
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?,
|
||||
} => run_quantize(&in_file, out_file, quantization, mode)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -39,14 +39,6 @@ pub trait BackendStorage: Sized {
|
||||
_params: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self>;
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
|
@ -15,17 +15,6 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
|
||||
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
|
||||
Ok(s) => {
|
||||
!s.is_empty() && s != "0"
|
||||
},
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||
@ -68,11 +57,6 @@ impl Tensor {
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::ConvTranspose1D {
|
||||
arg: lhs,
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::Conv2D {
|
||||
arg: lhs,
|
||||
kernel: rhs,
|
||||
@ -114,7 +98,7 @@ impl Tensor {
|
||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D(node)
|
||||
| Op::UpsampleNearest2D { arg: node, .. }
|
||||
| Op::UpsampleNearest2D(node)
|
||||
| Op::AvgPool2D { arg: node, .. }
|
||||
| Op::MaxPool2D { arg: node, .. }
|
||||
| Op::Copy(node)
|
||||
@ -166,16 +150,10 @@ impl Tensor {
|
||||
if node.is_variable() {
|
||||
continue;
|
||||
}
|
||||
let grad = grads
|
||||
.remove(node)
|
||||
.expect("candle internal error - grad not populated");
|
||||
// https://github.com/huggingface/candle/issues/1241
|
||||
// Ideally, we would make these operations in place where possible to ensure that we
|
||||
// do not have to allocate too often. Here we just call `.detach` to avoid computing
|
||||
// the backprop graph of the backprop itself. This would be an issue for second order
|
||||
// derivatives but these are out of scope at the moment.
|
||||
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
||||
let grad = if do_not_detach { grad } else { grad.detach() };
|
||||
let grad = grads.remove(node).unwrap();
|
||||
// TODO: We should perform all these operations in place (or at least not track the
|
||||
// whole graph). The only drawback would be if we wanted to support grad of grad but
|
||||
// this is out of scope.
|
||||
if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||
@ -230,44 +208,7 @@ impl Tensor {
|
||||
let f_grad = pred.where_cond(&zeros, &grad)?;
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
} => {
|
||||
// The output height for conv_transpose1d is:
|
||||
// (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
|
||||
let grad_l_in = grad.dim(2)?;
|
||||
let k_size = kernel.dim(2)?;
|
||||
let out_size =
|
||||
(grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
|
||||
let out_padding = arg.dim(2)? - out_size;
|
||||
let grad_arg = grad.conv_transpose1d(
|
||||
kernel,
|
||||
*padding,
|
||||
out_padding,
|
||||
*stride,
|
||||
*dilation,
|
||||
)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
|
||||
let grad_kernel = arg
|
||||
.transpose(0, 1)?
|
||||
.conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||
.transpose(0, 1)?;
|
||||
let sum_grad = grads.or_insert(kernel)?;
|
||||
let (_, _, k0) = kernel.dims3()?;
|
||||
let (_, _, g_k0) = grad_kernel.dims3()?;
|
||||
let grad_kernel = if g_k0 != k0 {
|
||||
grad_kernel.narrow(2, 0, k0)?
|
||||
} else {
|
||||
grad_kernel
|
||||
};
|
||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||
}
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::Conv2D {
|
||||
arg,
|
||||
kernel,
|
||||
@ -297,18 +238,8 @@ impl Tensor {
|
||||
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||
.transpose(0, 1)?;
|
||||
let sum_grad = grads.or_insert(kernel)?;
|
||||
let (_, _, k0, k1) = kernel.dims4()?;
|
||||
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
|
||||
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
|
||||
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
|
||||
} else {
|
||||
grad_kernel
|
||||
};
|
||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||
}
|
||||
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "conv-transpose1d",
|
||||
})?,
|
||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "conv-transpose2d",
|
||||
})?,
|
||||
@ -350,27 +281,9 @@ impl Tensor {
|
||||
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest1d",
|
||||
})?,
|
||||
Op::UpsampleNearest2D {
|
||||
arg,
|
||||
target_h,
|
||||
target_w,
|
||||
} => {
|
||||
let (_n, c, h, w) = arg.dims4()?;
|
||||
if target_h % h != 0 || target_w % w != 0 {
|
||||
crate::bail!("backward not supported for non integer upscaling factors")
|
||||
}
|
||||
let scale_h = target_h / h;
|
||||
let scale_w = target_w / w;
|
||||
|
||||
if scale_h != scale_w {
|
||||
crate::bail!("backward not supported for non uniform upscaling factors")
|
||||
};
|
||||
let kernel =
|
||||
Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
|
||||
let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = conv_sum;
|
||||
}
|
||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest2d",
|
||||
})?,
|
||||
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
||||
@ -567,38 +480,16 @@ impl Tensor {
|
||||
+ 0.5)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Erf) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
|
||||
let erf_grad =
|
||||
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
|
||||
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::GeluErf) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
|
||||
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
|
||||
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
|
||||
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
|
||||
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
|
||||
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
|
||||
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
||||
Op::Unary(_, UnaryOp::GeluErf) => {
|
||||
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Relu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||
}
|
||||
Op::Elu(arg, alpha) => {
|
||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let zeros = arg.zeros_like()?;
|
||||
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
||||
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
||||
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
|
||||
let combined_mask = (positive_mask + negative_exp_mask)?;
|
||||
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
||||
}
|
||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||
Op::Powf(arg, e) => {
|
||||
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -25,33 +25,6 @@ impl ParamsConv1D {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParamsConvTranspose1D {
|
||||
pub(crate) b_size: usize,
|
||||
pub(crate) l_in: usize,
|
||||
pub(crate) c_out: usize,
|
||||
pub(crate) c_in: usize,
|
||||
pub(crate) k_size: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) output_padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
}
|
||||
|
||||
impl ParamsConvTranspose1D {
|
||||
pub(crate) fn l_out(&self) -> usize {
|
||||
(self.l_in - 1) * self.stride - 2 * self.padding
|
||||
+ self.dilation * (self.k_size - 1)
|
||||
+ self.output_padding
|
||||
+ 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
let l_out = self.l_out();
|
||||
vec![self.b_size, self.c_out, l_out]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum CudnnFwdAlgo {
|
||||
ImplicitGemm,
|
||||
@ -187,49 +160,6 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies a 1D transposed convolution over the input tensor.
|
||||
pub fn conv_transpose1d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
let params = ParamsConvTranspose1D {
|
||||
b_size,
|
||||
l_in,
|
||||
k_size,
|
||||
c_out,
|
||||
c_in,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
let storage = self.storage().conv_transpose1d(
|
||||
self.layout(),
|
||||
&kernel.storage(),
|
||||
kernel.layout(),
|
||||
¶ms,
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
||||
arg,
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
output_padding: params.output_padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||
let storage =
|
||||
self.storage()
|
||||
|
@ -804,11 +804,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
};
|
||||
let dim = self.dim;
|
||||
let ids_dims = self.ids_l.dims();
|
||||
@ -857,7 +857,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-select" })?,
|
||||
};
|
||||
let dim = self.dim;
|
||||
let n_ids = match self.ids_l.dims() {
|
||||
@ -913,7 +913,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
|
||||
@ -929,7 +929,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
||||
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
};
|
||||
for left_i in 0..ids_left_len {
|
||||
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
|
||||
@ -971,7 +971,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
let dim = self.dim;
|
||||
@ -1256,74 +1256,6 @@ impl Map1 for Im2Col {
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||
|
||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
const OP: &'static str = "conv_transpose1d";
|
||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
let inp = &inp[inp_l.start_offset()..];
|
||||
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
||||
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||
let l_out = p.l_out();
|
||||
|
||||
// Output shape: [b_size, c_out, l_out].
|
||||
let dst_elems = p.c_out * l_out * p.b_size;
|
||||
let dst = vec![T::zero(); dst_elems];
|
||||
let dst_s0 = p.c_out * l_out;
|
||||
let dst_s1 = l_out;
|
||||
let dst_s2 = 1;
|
||||
|
||||
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
|
||||
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
|
||||
let cont_s0 = p.l_in * p.c_in;
|
||||
let cont_s1 = p.c_in;
|
||||
for b_idx in 0..p.b_size {
|
||||
for l_idx in 0..p.l_in {
|
||||
for c_idx in 0..p.c_in {
|
||||
let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
|
||||
let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
|
||||
inp_cont[dst_idx] = inp[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k_idx in 0..p.k_size {
|
||||
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
|
||||
let k_cont = (0..p.c_in)
|
||||
.map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
|
||||
.collect::<Vec<_>>();
|
||||
for b_idx in 0..p.b_size {
|
||||
for l_idx in 0..p.l_in {
|
||||
let out_idx = l_idx * p.stride + k_idx * p.dilation;
|
||||
if out_idx < p.padding {
|
||||
continue;
|
||||
}
|
||||
let out_idx = out_idx - p.padding;
|
||||
if out_idx < l_out {
|
||||
let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
|
||||
let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
|
||||
let mut d = T::zero();
|
||||
unsafe {
|
||||
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
|
||||
}
|
||||
let dst_p = dst.as_ptr();
|
||||
// Safety: dst_idx are uniques per dst_c_idx which is used to
|
||||
// parallelise the different tasks so no two threads can try to
|
||||
// write at the same location.
|
||||
unsafe {
|
||||
let ptr = dst_p.add(dst_idx) as *mut T;
|
||||
*ptr += d
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||
|
||||
impl<'a> Map2 for Conv2D<'a> {
|
||||
@ -2503,16 +2435,6 @@ impl BackendStorage for CpuStorage {
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
@ -2617,25 +2539,25 @@ impl BackendStorage for CpuStorage {
|
||||
Self::U8(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
Self::U32(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
Self::I64(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1149,55 +1149,6 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
inp_l: &Layout,
|
||||
k: &CudaSlice<T>,
|
||||
k_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_in_k, c_out, l_k)
|
||||
// Input shape: (b_size, c_in, l_in)
|
||||
let p = &self.0;
|
||||
let l_out = p.l_out();
|
||||
let dst_el = p.c_out * l_out * p.b_size;
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(k_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), kernels::CONV)?;
|
||||
let ds = if dims.len() == 3 {
|
||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (
|
||||
el,
|
||||
l_out,
|
||||
p.stride,
|
||||
p.padding,
|
||||
p.output_padding,
|
||||
p.dilation,
|
||||
&ds,
|
||||
inp,
|
||||
k,
|
||||
&out,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
|
||||
impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
@ -1857,19 +1808,6 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(res_t)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice =
|
||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cudnn"))]
|
||||
fn conv2d(
|
||||
&self,
|
||||
|
@ -8,14 +8,12 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
pub enum DeviceLocation {
|
||||
Cpu,
|
||||
Cuda { gpu_id: usize },
|
||||
Metal { gpu_id: usize },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Device {
|
||||
Cpu,
|
||||
Cuda(crate::CudaDevice),
|
||||
Metal(crate::MetalDevice),
|
||||
}
|
||||
|
||||
pub trait NdArray {
|
||||
@ -130,15 +128,10 @@ impl Device {
|
||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||
}
|
||||
|
||||
pub fn new_metal(ordinal: usize) -> Result<Self> {
|
||||
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
|
||||
}
|
||||
|
||||
pub fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
match self {
|
||||
Self::Cpu => CpuDevice.set_seed(seed),
|
||||
Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed),
|
||||
Self::Cuda(c) => c.set_seed(seed),
|
||||
Self::Metal(m) => m.set_seed(seed),
|
||||
}
|
||||
}
|
||||
|
||||
@ -146,7 +139,6 @@ impl Device {
|
||||
match (self, rhs) {
|
||||
(Self::Cpu, Self::Cpu) => true,
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
@ -155,20 +147,21 @@ impl Device {
|
||||
match self {
|
||||
Self::Cpu => DeviceLocation::Cpu,
|
||||
Self::Cuda(device) => device.location(),
|
||||
Device::Metal(device) => device.location(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cpu(&self) -> bool {
|
||||
matches!(self, Self::Cpu)
|
||||
match self {
|
||||
Self::Cpu => true,
|
||||
Self::Cuda(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cuda(&self) -> bool {
|
||||
matches!(self, Self::Cuda(_))
|
||||
}
|
||||
|
||||
pub fn is_metal(&self) -> bool {
|
||||
matches!(self, Self::Metal(_))
|
||||
match self {
|
||||
Self::Cpu => false,
|
||||
Self::Cuda(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||
@ -192,18 +185,8 @@ impl Device {
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
||||
if dtype == DType::F16 || dtype == DType::BF16 {
|
||||
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
|
||||
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
||||
} else {
|
||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -230,18 +213,8 @@ impl Device {
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
||||
if dtype == DType::F16 || dtype == DType::BF16 {
|
||||
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
|
||||
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
||||
} else {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -265,10 +238,6 @@ impl Device {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -282,10 +251,6 @@ impl Device {
|
||||
let storage = device.zeros_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.zeros_impl(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -297,11 +262,6 @@ impl Device {
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = array.to_cpu_storage();
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -313,11 +273,6 @@ impl Device {
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = S::to_cpu_storage_owned(data);
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14,9 +14,6 @@ impl Tensor {
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
crate::DeviceLocation::Metal { gpu_id } => {
|
||||
format!(", metal:{}", gpu_id)
|
||||
}
|
||||
};
|
||||
|
||||
write!(f, "Tensor[")?;
|
||||
@ -479,9 +476,6 @@ impl std::fmt::Display for Tensor {
|
||||
crate::DeviceLocation::Cuda { gpu_id } => {
|
||||
format!(", cuda:{}", gpu_id)
|
||||
}
|
||||
crate::DeviceLocation::Metal { gpu_id } => {
|
||||
format!(", metal:{}", gpu_id)
|
||||
}
|
||||
};
|
||||
|
||||
write!(
|
||||
|
@ -79,16 +79,6 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
|
@ -1,223 +0,0 @@
|
||||
#![allow(dead_code)]
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalDevice;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MetalStorage;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum MetalError {
|
||||
#[error("{0}")]
|
||||
Message(String),
|
||||
}
|
||||
|
||||
impl From<String> for MetalError {
|
||||
fn from(e: String) -> Self {
|
||||
MetalError::Message(e)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! fail {
|
||||
() => {
|
||||
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
|
||||
};
|
||||
}
|
||||
|
||||
impl crate::backend::BackendStorage for MetalStorage {
|
||||
type Device = MetalDevice;
|
||||
|
||||
fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn dtype(&self) -> DType {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn device(&self) -> &Self::Device {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConv1D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConvTranspose2D,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: (usize, usize, usize, usize),
|
||||
_: &Layout,
|
||||
_: &Layout,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::backend::BackendDevice for MetalDevice {
|
||||
type Storage = MetalStorage;
|
||||
fn new(_: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn same_device(&self, _: &Self) -> bool {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
|
||||
use crate::{DType, DeviceLocation, Layout, Shape};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MatMulUnexpectedStriding {
|
||||
@ -152,9 +152,6 @@ pub enum Error {
|
||||
#[error("the candle crate has not been built with cuda support")]
|
||||
NotCompiledWithCudaSupport,
|
||||
|
||||
#[error("the candle crate has not been built with metal support")]
|
||||
NotCompiledWithMetalSupport,
|
||||
|
||||
#[error("cannot find tensor {path}")]
|
||||
CannotFindTensor { path: String },
|
||||
|
||||
@ -162,9 +159,6 @@ pub enum Error {
|
||||
#[error(transparent)]
|
||||
Cuda(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
#[error("Metal error {0}")]
|
||||
Metal(#[from] MetalError),
|
||||
|
||||
#[error(transparent)]
|
||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||
|
||||
|
@ -64,7 +64,7 @@ impl Tensor {
|
||||
#[derive(Debug)]
|
||||
/// Generic structure used to index a slice of the tensor
|
||||
pub enum TensorIndexer {
|
||||
/// This selects the elements for which an index has some specific value.
|
||||
/// This selects the elemnts for which an index has some specific value.
|
||||
Select(usize),
|
||||
/// This is a regular slice, purely indexing a chunk of the tensor
|
||||
Narrow(Bound<usize>, Bound<usize>),
|
||||
@ -104,31 +104,37 @@ impl From<&Tensor> for TensorIndexer {
|
||||
}
|
||||
}
|
||||
|
||||
trait RB: RangeBounds<usize> {}
|
||||
impl RB for Range<usize> {}
|
||||
impl RB for RangeFrom<usize> {}
|
||||
impl RB for RangeFull {}
|
||||
impl RB for RangeInclusive<usize> {}
|
||||
impl RB for RangeTo<usize> {}
|
||||
impl RB for RangeToInclusive<usize> {}
|
||||
macro_rules! impl_from_range {
|
||||
($range_type:ty) => {
|
||||
impl From<$range_type> for TensorIndexer {
|
||||
fn from(range: $range_type) -> Self {
|
||||
use std::ops::Bound::*;
|
||||
|
||||
impl<T: RB> From<T> for TensorIndexer {
|
||||
fn from(range: T) -> Self {
|
||||
use std::ops::Bound::*;
|
||||
let start = match range.start_bound() {
|
||||
Included(idx) => Included(*idx),
|
||||
Excluded(idx) => Excluded(*idx),
|
||||
Unbounded => Unbounded,
|
||||
};
|
||||
let end = match range.end_bound() {
|
||||
Included(idx) => Included(*idx),
|
||||
Excluded(idx) => Excluded(*idx),
|
||||
Unbounded => Unbounded,
|
||||
};
|
||||
TensorIndexer::Narrow(start, end)
|
||||
}
|
||||
let start = match range.start_bound() {
|
||||
Included(idx) => Included(*idx),
|
||||
Excluded(idx) => Excluded(*idx),
|
||||
Unbounded => Unbounded,
|
||||
};
|
||||
|
||||
let end = match range.end_bound() {
|
||||
Included(idx) => Included(*idx),
|
||||
Excluded(idx) => Excluded(*idx),
|
||||
Unbounded => Unbounded,
|
||||
};
|
||||
|
||||
TensorIndexer::Narrow(start, end)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_from_range!(Range<usize>);
|
||||
impl_from_range!(RangeFrom<usize>);
|
||||
impl_from_range!(RangeFull);
|
||||
impl_from_range!(RangeInclusive<usize>);
|
||||
impl_from_range!(RangeTo<usize>);
|
||||
impl_from_range!(RangeToInclusive<usize>);
|
||||
|
||||
/// Trait used to implement multiple signatures for ease of use of the slicing
|
||||
/// of a tensor
|
||||
pub trait IndexOp<T> {
|
||||
|
@ -49,12 +49,9 @@ mod device;
|
||||
pub mod display;
|
||||
mod dtype;
|
||||
mod dummy_cuda_backend;
|
||||
mod dummy_metal_backend;
|
||||
pub mod error;
|
||||
mod indexer;
|
||||
pub mod layout;
|
||||
#[cfg(feature = "metal")]
|
||||
pub mod metal_backend;
|
||||
#[cfg(feature = "mkl")]
|
||||
mod mkl;
|
||||
pub mod npy;
|
||||
@ -72,7 +69,7 @@ pub mod utils;
|
||||
mod variable;
|
||||
|
||||
pub use cpu_backend::CpuStorage;
|
||||
pub use device::{Device, DeviceLocation, NdArray};
|
||||
pub use device::{Device, DeviceLocation};
|
||||
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use indexer::IndexOp;
|
||||
@ -90,12 +87,6 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
@ -123,20 +114,14 @@ pub trait Module {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
impl Module for quantized::QMatMul {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self(xs)
|
||||
}
|
||||
}
|
||||
|
||||
// A trait defining a module with forward method using a single tensor argument and a flag to
|
||||
// separate the training and evaluation behaviors.
|
||||
pub trait ModuleT {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
impl<M: Module> ModuleT for M {
|
||||
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
|
||||
self.forward(xs)
|
||||
}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
|
||||
use half::{bf16, f16};
|
||||
use num_traits::float::Float;
|
||||
|
||||
@ -90,16 +90,6 @@ pub enum Op {
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
#[allow(dead_code)]
|
||||
ConvTranspose1D {
|
||||
arg: Tensor,
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
#[allow(dead_code)]
|
||||
Conv2D {
|
||||
arg: Tensor,
|
||||
@ -132,11 +122,7 @@ pub enum Op {
|
||||
},
|
||||
|
||||
UpsampleNearest1D(Tensor),
|
||||
UpsampleNearest2D {
|
||||
arg: Tensor,
|
||||
target_h: usize,
|
||||
target_w: usize,
|
||||
},
|
||||
UpsampleNearest2D(Tensor),
|
||||
|
||||
Cat(Vec<Tensor>, usize),
|
||||
|
||||
@ -188,18 +174,6 @@ pub trait CustomOp1 {
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_storage: &MetalStorage,
|
||||
_layout: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||
/// The function should return the gradient of the argument.
|
||||
@ -235,20 +209,6 @@ pub trait CustomOp2 {
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
@ -291,22 +251,6 @@ pub trait CustomOp3 {
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
@ -592,13 +536,13 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
|
||||
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
||||
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
||||
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
|
||||
unary_op!(Abs, "abs", v, v.abs());
|
||||
unary_op!(Neg, "neg", v, -v);
|
||||
unary_op!(Recip, "recip", v, v.recip());
|
||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||
|
||||
/// Tanh based approximation of the `gelu` operation
|
||||
/// GeluErf is the more precise one.
|
||||
/// `gelu` operation
|
||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||
impl UnaryOpT for Gelu {
|
||||
const NAME: &'static str = "gelu";
|
||||
@ -688,8 +632,6 @@ impl UnaryOpT for Gelu {
|
||||
}
|
||||
}
|
||||
|
||||
/// `erf` operation
|
||||
/// <https://en.wikipedia.org/wiki/Error_function>
|
||||
impl UnaryOpT for Erf {
|
||||
const NAME: &'static str = "erf";
|
||||
const KERNEL: &'static str = "uerf";
|
||||
@ -724,40 +666,6 @@ impl UnaryOpT for Erf {
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Abs {
|
||||
const NAME: &'static str = "abs";
|
||||
const KERNEL: &'static str = "uabs";
|
||||
const V: Self = Abs;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v.abs()
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Ceil {
|
||||
const NAME: &'static str = "ceil";
|
||||
const KERNEL: &'static str = "uceil";
|
||||
@ -979,10 +887,6 @@ impl BackpropOp {
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
|
||||
pub(crate) fn is_none(&self) -> bool {
|
||||
self.0.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for BackpropOp {
|
||||
|
@ -217,13 +217,6 @@ impl Object {
|
||||
let args = args.remove(1);
|
||||
(callable, args)
|
||||
}
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
|
||||
let mut args = args.tuple()?;
|
||||
args.remove(0).reduce()?
|
||||
}
|
||||
_ => (callable, args),
|
||||
};
|
||||
match callable {
|
||||
@ -234,11 +227,13 @@ impl Object {
|
||||
_ => return Ok(None),
|
||||
};
|
||||
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
||||
let mut path = dir_name.to_path_buf();
|
||||
path.push(file_path);
|
||||
Ok(Some(TensorInfo {
|
||||
name,
|
||||
dtype,
|
||||
layout,
|
||||
path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
|
||||
path: path.to_string_lossy().into_owned(),
|
||||
storage_size,
|
||||
}))
|
||||
}
|
||||
@ -350,10 +345,8 @@ impl Stack {
|
||||
module_name,
|
||||
class_name,
|
||||
} => {
|
||||
if module_name == "collections"
|
||||
&& (class_name == "OrderedDict" || class_name == "defaultdict")
|
||||
{
|
||||
// TODO: have a separate ordered dict and a separate default dict.
|
||||
if module_name == "collections" && class_name == "OrderedDict" {
|
||||
// TODO: have a separate ordered dict.
|
||||
Some(Object::Dict(vec![]))
|
||||
} else {
|
||||
None
|
||||
@ -634,16 +627,9 @@ pub struct TensorInfo {
|
||||
pub storage_size: usize,
|
||||
}
|
||||
|
||||
/// Read the tensor info from a .pth file.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `file` - The path to the .pth file.
|
||||
/// * `verbose` - Whether to print debug information.
|
||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file.
|
||||
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
file: P,
|
||||
verbose: bool,
|
||||
key: Option<&str>,
|
||||
) -> Result<Vec<TensorInfo>> {
|
||||
let file = std::fs::File::open(file)?;
|
||||
let zip_reader = std::io::BufReader::new(file);
|
||||
@ -665,9 +651,8 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
stack.read_loop(&mut reader)?;
|
||||
let obj = stack.finalize()?;
|
||||
if VERBOSE || verbose {
|
||||
println!("{obj:#?}");
|
||||
println!("{obj:?}");
|
||||
}
|
||||
|
||||
let obj = match obj {
|
||||
Object::Build { callable, args } => match *callable {
|
||||
Object::Reduce { callable, args: _ } => match *callable {
|
||||
@ -681,24 +666,6 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
},
|
||||
obj => obj,
|
||||
};
|
||||
|
||||
// If key is provided, then we need to extract the state_dict from the object.
|
||||
let obj = if let Some(key) = key {
|
||||
if let Object::Dict(key_values) = obj {
|
||||
key_values
|
||||
.into_iter()
|
||||
.find(|(k, _)| *k == Object::Unicode(key.to_owned()))
|
||||
.map(|(_, v)| v)
|
||||
.ok_or_else(|| E::Msg(format!("key {key} not found")))?
|
||||
} else {
|
||||
obj
|
||||
}
|
||||
} else {
|
||||
obj
|
||||
};
|
||||
|
||||
// If the object is a dict, then we can extract the tensor info from it.
|
||||
// NOTE: We are assuming that the `obj` is state_dict by this stage.
|
||||
if let Object::Dict(key_values) = obj {
|
||||
for (name, value) in key_values.into_iter() {
|
||||
match value.into_tensor_info(name, &dir_name) {
|
||||
@ -721,8 +688,8 @@ pub struct PthTensors {
|
||||
}
|
||||
|
||||
impl PthTensors {
|
||||
pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
|
||||
let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
|
||||
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
|
||||
let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
|
||||
let tensor_infos = tensor_infos
|
||||
.into_iter()
|
||||
.map(|ti| (ti.name.to_string(), ti))
|
||||
@ -736,7 +703,6 @@ impl PthTensors {
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||
use std::io::Read;
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
None => return Ok(None),
|
||||
Some(tensor_info) => tensor_info,
|
||||
@ -745,56 +711,27 @@ impl PthTensors {
|
||||
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
|
||||
let rank = tensor_info.layout.shape().rank();
|
||||
|
||||
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
||||
// case and when the tensor is fortran contiguous.
|
||||
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
|
||||
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
|
||||
// For now only support the basic case.
|
||||
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
|
||||
crate::bail!(
|
||||
"cannot retrieve non-contiguous tensors {:?}",
|
||||
tensor_info.layout
|
||||
)
|
||||
}
|
||||
let start_offset = tensor_info.layout.start_offset();
|
||||
if start_offset > 0 {
|
||||
std::io::copy(
|
||||
&mut reader.by_ref().take(start_offset as u64),
|
||||
&mut std::io::sink(),
|
||||
)?;
|
||||
}
|
||||
let tensor = Tensor::from_reader(
|
||||
tensor_info.layout.shape().clone(),
|
||||
tensor_info.dtype,
|
||||
&mut reader,
|
||||
)?;
|
||||
|
||||
if rank > 1 && is_fortran_contiguous {
|
||||
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
|
||||
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
|
||||
let tensor = tensor.reshape(shape_reversed)?;
|
||||
|
||||
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
|
||||
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
|
||||
let tensor = tensor.permute(dim_indeces_reversed)?;
|
||||
Ok(Some(tensor))
|
||||
} else {
|
||||
Ok(Some(tensor))
|
||||
}
|
||||
Ok(Some(tensor))
|
||||
}
|
||||
}
|
||||
|
||||
/// Read all the tensors from a PyTorch pth file with a given key.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the pth file.
|
||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
||||
/// contains multiple objects and the state_dict is the one we are interested in.
|
||||
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
||||
path: P,
|
||||
key: Option<&str>,
|
||||
) -> Result<Vec<(String, Tensor)>> {
|
||||
let pth = PthTensors::new(path, key)?;
|
||||
/// Read all the tensors from a PyTorch pth file.
|
||||
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
||||
let pth = PthTensors::new(path)?;
|
||||
let tensor_names = pth.tensor_infos.keys();
|
||||
let mut tensors = Vec::with_capacity(tensor_names.len());
|
||||
for name in tensor_names {
|
||||
@ -804,11 +741,3 @@ pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
||||
}
|
||||
Ok(tensors)
|
||||
}
|
||||
|
||||
/// Read all the tensors from a PyTorch pth file.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the pth file.
|
||||
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
||||
read_all_with_key(path, None)
|
||||
}
|
||||
|
@ -50,9 +50,14 @@ pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
@ -353,7 +358,7 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
||||
q3 = q3.add(32);
|
||||
|
||||
// Prepare low and high bits
|
||||
// We hardcode the shifts here to avoid loading them into a separate register
|
||||
// We hardcode the shifts here to avoid loading them into a seperate register
|
||||
let q3l_0 = _mm256_and_si256(q3bits, m3);
|
||||
let q3h_0 = if j == 0 {
|
||||
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
|
||||
@ -586,7 +591,7 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
||||
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
|
||||
q5 = q5.add(32);
|
||||
|
||||
//Similar to q3k we hardcode the shifts here to avoid loading them into a separate register
|
||||
//Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register
|
||||
let q5l_0 = _mm256_and_si256(q5bits, m4);
|
||||
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
|
||||
let q5l_0_right_shift = match j {
|
||||
|
@ -1,43 +0,0 @@
|
||||
#![allow(unused)]
|
||||
use super::GgmlDType;
|
||||
use crate::{Error, MetalDevice, MetalStorage, Result};
|
||||
|
||||
pub struct QMetalStorage {
|
||||
dtype: GgmlDType,
|
||||
device: MetalDevice,
|
||||
}
|
||||
|
||||
impl QMetalStorage {
|
||||
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &MetalDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
_self_shape: &crate::Shape,
|
||||
_storage: &MetalStorage,
|
||||
_layout: &crate::Layout,
|
||||
) -> Result<(MetalStorage, crate::Shape)> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
@ -1,9 +1,7 @@
|
||||
//! Support for the GGML file format.
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use super::metal::load_quantized_metal;
|
||||
use super::{k_quants, GgmlDType, QStorage};
|
||||
use crate::{Device, Result};
|
||||
use super::{k_quants, GgmlDType};
|
||||
use crate::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -123,22 +121,11 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
||||
raw_data: &[u8],
|
||||
size_in_bytes: usize,
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<super::QTensor> {
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
let data: QStorage = match device {
|
||||
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
|
||||
#[cfg(feature = "metal")]
|
||||
Device::Metal(metal) => load_quantized_metal(metal, data)?,
|
||||
#[cfg(not(feature = "metal"))]
|
||||
Device::Metal(_metal) => {
|
||||
crate::bail!("Metal backend requires `metal` feature")
|
||||
}
|
||||
device => unimplemented!("Implement quantized tensor for device {device:?}"),
|
||||
};
|
||||
super::QTensor::new(data, dims)
|
||||
super::QTensor::new(data.to_vec(), dims)
|
||||
}
|
||||
|
||||
/// Creates a [Tensor] from a raw GGML tensor.
|
||||
@ -146,50 +133,29 @@ pub fn qtensor_from_ggml(
|
||||
ggml_dtype: GgmlDType,
|
||||
raw_data: &[u8],
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<super::QTensor> {
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let block_size = ggml_dtype.block_size();
|
||||
if tensor_elems % block_size != 0 {
|
||||
let blck_size = ggml_dtype.blck_size();
|
||||
if tensor_elems % blck_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size();
|
||||
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
|
||||
|
||||
match ggml_dtype {
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
|
||||
GgmlDType::Q4_0 => {
|
||||
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
|
||||
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
||||
}
|
||||
}
|
||||
@ -197,7 +163,6 @@ pub fn qtensor_from_ggml(
|
||||
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
magic: VersionedMagic,
|
||||
device: &Device,
|
||||
) -> Result<(String, super::QTensor)> {
|
||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||
@ -218,11 +183,11 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
}
|
||||
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||
// TODO: Mmap version to avoid copying the data around?
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
|
||||
Ok(tensor) => Ok((name, tensor)),
|
||||
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
||||
}
|
||||
@ -233,14 +198,10 @@ pub struct Content {
|
||||
pub hparams: HParams,
|
||||
pub vocab: Vocab,
|
||||
pub tensors: HashMap<String, super::QTensor>,
|
||||
pub device: Device,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Content> {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||
@ -250,16 +211,14 @@ impl Content {
|
||||
let mut tensors = HashMap::new();
|
||||
|
||||
while reader.stream_position()? != last_position {
|
||||
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||
let (name, tensor) = read_one_tensor(reader, magic)?;
|
||||
tensors.insert(name, tensor);
|
||||
}
|
||||
let device = device.clone();
|
||||
Ok(Self {
|
||||
magic,
|
||||
hparams,
|
||||
vocab,
|
||||
tensors,
|
||||
device,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||
|
||||
use super::{GgmlDType, QTensor};
|
||||
use crate::{Device, Result};
|
||||
use crate::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -29,7 +29,6 @@ impl TryFrom<u32> for Magic {
|
||||
pub enum VersionedMagic {
|
||||
GgufV1,
|
||||
GgufV2,
|
||||
GgufV3,
|
||||
}
|
||||
|
||||
impl VersionedMagic {
|
||||
@ -40,8 +39,7 @@ impl VersionedMagic {
|
||||
let versioned_magic = match (magic, version) {
|
||||
(Magic::Gguf, 1) => Self::GgufV1,
|
||||
(Magic::Gguf, 2) => Self::GgufV2,
|
||||
(Magic::Gguf, 3) => Self::GgufV3,
|
||||
_ => crate::bail!("gguf: unsupported magic/version {magic:?}/{version}"),
|
||||
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
|
||||
};
|
||||
Ok(versioned_magic)
|
||||
}
|
||||
@ -59,25 +57,19 @@ impl TensorInfo {
|
||||
&self,
|
||||
reader: &mut R,
|
||||
tensor_data_offset: u64,
|
||||
device: &Device,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_elems = self.shape.elem_count();
|
||||
let block_size = self.ggml_dtype.block_size();
|
||||
if tensor_elems % block_size != 0 {
|
||||
let blck_size = self.ggml_dtype.blck_size();
|
||||
if tensor_elems % blck_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size();
|
||||
let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
super::ggml_file::qtensor_from_ggml(
|
||||
self.ggml_dtype,
|
||||
&raw_data,
|
||||
self.shape.dims().to_vec(),
|
||||
device,
|
||||
)
|
||||
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
@ -92,9 +84,7 @@ pub struct Content {
|
||||
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
|
||||
let len = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>()? as usize
|
||||
}
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
};
|
||||
let mut v = vec![0u8; len];
|
||||
reader.read_exact(&mut v)?;
|
||||
@ -294,9 +284,7 @@ impl Value {
|
||||
let value_type = ValueType::from_u32(value_type)?;
|
||||
let len = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>()? as usize
|
||||
}
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
};
|
||||
let mut vs = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
@ -393,15 +381,11 @@ impl Content {
|
||||
|
||||
let tensor_count = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>()? as usize
|
||||
}
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
};
|
||||
let metadata_kv_count = match magic {
|
||||
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
reader.read_u64::<LittleEndian>()? as usize
|
||||
}
|
||||
VersionedMagic::GgufV2 => reader.read_u64::<LittleEndian>()? as usize,
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
@ -423,7 +407,7 @@ impl Content {
|
||||
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
|
||||
dimensions.into_iter().map(|c| c as usize).collect()
|
||||
}
|
||||
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
|
||||
VersionedMagic::GgufV2 => {
|
||||
let mut dimensions = vec![0; n_dimensions as usize];
|
||||
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
|
||||
dimensions.into_iter().map(|c| c as usize).collect()
|
||||
@ -466,13 +450,12 @@ impl Content {
|
||||
&self,
|
||||
reader: &mut R,
|
||||
name: &str,
|
||||
device: &Device,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
Some(tensor_info) => tensor_info,
|
||||
None => crate::bail!("cannot find tensor info for {name}"),
|
||||
None => crate::bail!("cannot find tensor-infor for {name}"),
|
||||
};
|
||||
tensor_info.read(reader, self.tensor_data_offset, device)
|
||||
tensor_info.read(reader, self.tensor_data_offset)
|
||||
}
|
||||
}
|
||||
|
||||
@ -524,9 +507,10 @@ pub fn write<W: std::io::Seek + std::io::Write>(
|
||||
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
|
||||
)
|
||||
}
|
||||
let data = tensor.data()?;
|
||||
let size_in_bytes = data.len();
|
||||
w.write_all(&data)?;
|
||||
let data_ptr = tensor.as_ptr();
|
||||
let size_in_bytes = tensor.storage_size_in_bytes();
|
||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||
w.write_all(data)?;
|
||||
let padding = 31 - (31 + size_in_bytes) % 32;
|
||||
w.write_all(&vec![0u8; padding])?;
|
||||
}
|
||||
|
@ -236,9 +236,14 @@ impl GgmlType for BlockQ4_0 {
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
// Generic implementation.
|
||||
let mut sumf = 0f32;
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
@ -1545,13 +1550,13 @@ impl GgmlType for BlockQ5K {
|
||||
let d2 = d * sc as f32;
|
||||
let m2 = min * m as f32;
|
||||
for (ql, qh) in ql.iter().zip(qh) {
|
||||
let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 };
|
||||
y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1;
|
||||
let to_add = if qh & u1 != 0 { 16 } else { 1 };
|
||||
y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1;
|
||||
ys_index += 1;
|
||||
}
|
||||
for (ql, qh) in ql.iter().zip(qh) {
|
||||
let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 };
|
||||
y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2;
|
||||
let to_add = if qh & u2 != 0 { 16 } else { 1 };
|
||||
y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2;
|
||||
ys_index += 1;
|
||||
}
|
||||
is += 2;
|
||||
|
@ -1,234 +0,0 @@
|
||||
use super::{GgmlDType, QStorage};
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
|
||||
use metal::Buffer;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct QMetalStorage {
|
||||
dtype: GgmlDType,
|
||||
device: MetalDevice,
|
||||
buffer: Arc<Buffer>,
|
||||
}
|
||||
|
||||
impl QMetalStorage {
|
||||
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||
let size = elem_count * dtype.type_size() / dtype.block_size();
|
||||
let buffer = device.allocate_zeros(size)?;
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &MetalDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn buffer(&self) -> &Buffer {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
blit.end_encoding();
|
||||
self.device.wait_until_completed()?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
match self.dtype {
|
||||
GgmlDType::F32 => {
|
||||
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
f32::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::F16 => {
|
||||
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
half::f16::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
let vec: Vec<crate::quantized::BlockQ2K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
let vec: Vec<crate::quantized::BlockQ3K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let vec: Vec<crate::quantized::BlockQ4K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
let vec: Vec<crate::quantized::BlockQ5K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let vec: Vec<crate::quantized::BlockQ6K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8K => {
|
||||
let vec: Vec<crate::quantized::BlockQ8K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
}
|
||||
|
||||
let buffer = self.device.new_buffer_with_data(&out)?;
|
||||
Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
||||
// Quantization only happens on CPU for now.
|
||||
let src = src.to_cpu::<f32>()?;
|
||||
let elem_count = src.len();
|
||||
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
||||
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;
|
||||
qcpu_storage.quantize(&src)?;
|
||||
let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
|
||||
self.buffer = buffer;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.buffer.length() as usize
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
self_shape: &Shape,
|
||||
storage: &MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
use crate::MetalError;
|
||||
|
||||
if !layout.is_contiguous() {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
// self is transposed so n is first then k.
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let (n, k) = self_shape.dims2()?;
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
|
||||
let (b, m) = match dst_shape.len() {
|
||||
3 => (dst_shape[0], dst_shape[1]),
|
||||
2 => (1, dst_shape[0]),
|
||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||
};
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
let device = storage.device().clone();
|
||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
candle_metal_kernels::call_quantized_matmul_t(
|
||||
device.device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
self.dtype.into(),
|
||||
(b, m, n, k),
|
||||
storage.buffer(),
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
&self.buffer,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
||||
Ok((dst_storage, dst_shape))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
||||
device: &MetalDevice,
|
||||
data: &[T],
|
||||
) -> Result<QStorage> {
|
||||
let buffer = device.new_buffer_with_data(data)?;
|
||||
let device = device.clone();
|
||||
Ok(QStorage::Metal(QMetalStorage {
|
||||
dtype: T::DTYPE,
|
||||
device,
|
||||
buffer,
|
||||
}))
|
||||
}
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let ptr = buffer.contents() as *const T;
|
||||
assert!(!ptr.is_null());
|
||||
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
||||
slice.to_vec()
|
||||
}
|
||||
|
||||
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
||||
fn from(value: GgmlDType) -> Self {
|
||||
match value {
|
||||
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
||||
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
||||
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
||||
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
||||
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
||||
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
||||
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
||||
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
||||
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
||||
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
||||
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
||||
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
||||
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
||||
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
||||
}
|
||||
}
|
||||
}
|
@ -1,118 +1,23 @@
|
||||
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
||||
use k_quants::*;
|
||||
use std::borrow::Cow;
|
||||
use crate::{Device, Result, Shape, Tensor};
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
mod dummy_metal;
|
||||
pub mod ggml_file;
|
||||
pub mod gguf_file;
|
||||
pub mod k_quants;
|
||||
#[cfg(feature = "metal")]
|
||||
pub mod metal;
|
||||
#[cfg(not(feature = "metal"))]
|
||||
mod metal {
|
||||
pub use super::dummy_metal::*;
|
||||
}
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub mod simd128;
|
||||
pub mod utils;
|
||||
use half::f16;
|
||||
|
||||
pub use k_quants::GgmlType;
|
||||
|
||||
pub struct QTensor {
|
||||
storage: QStorage,
|
||||
data: Box<dyn QuantizedType>,
|
||||
shape: Shape,
|
||||
}
|
||||
|
||||
impl Device {
|
||||
fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = dtype.cpu_zeros(elem_count);
|
||||
Ok(QStorage::Cpu(storage))
|
||||
}
|
||||
Device::Metal(metal) => {
|
||||
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
|
||||
Ok(QStorage::Metal(storage))
|
||||
}
|
||||
Device::Cuda(_cuda) => {
|
||||
crate::bail!("Cuda ggml quantization not supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum QStorage {
|
||||
Cpu(Box<dyn QuantizedType>),
|
||||
Metal(metal::QMetalStorage),
|
||||
}
|
||||
|
||||
impl QStorage {
|
||||
fn block_size(&self) -> usize {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.block_size(),
|
||||
QStorage::Metal(storage) => storage.dtype().block_size(),
|
||||
}
|
||||
}
|
||||
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.dtype(),
|
||||
QStorage::Metal(storage) => storage.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
fn device(&self) -> Device {
|
||||
match self {
|
||||
QStorage::Cpu(_storage) => Device::Cpu,
|
||||
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
||||
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
fn quantize(&mut self, src: &Storage) -> Result<()> {
|
||||
match (self, src) {
|
||||
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
||||
storage.from_float(src.as_slice::<f32>()?)?;
|
||||
}
|
||||
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
||||
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
||||
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
||||
}
|
||||
}
|
||||
|
||||
fn data(&self) -> Result<Cow<[u8]>> {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => {
|
||||
let data_ptr = storage.as_ptr();
|
||||
let size_in_bytes = storage.storage_size_in_bytes();
|
||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||
Ok(Cow::from(data))
|
||||
}
|
||||
QStorage::Metal(_storage) => {
|
||||
crate::bail!("not implemented");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum GgmlDType {
|
||||
F32,
|
||||
@ -172,25 +77,6 @@ impl GgmlDType {
|
||||
}
|
||||
}
|
||||
|
||||
/// The block dtype
|
||||
pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
|
||||
match self {
|
||||
Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
|
||||
Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
|
||||
Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
|
||||
Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
|
||||
Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
|
||||
Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
|
||||
Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
|
||||
Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
|
||||
Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
|
||||
Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
|
||||
Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
|
||||
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
|
||||
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
|
||||
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
|
||||
}
|
||||
}
|
||||
/// The type size for blocks in bytes.
|
||||
pub fn type_size(&self) -> usize {
|
||||
use k_quants::*;
|
||||
@ -214,7 +100,7 @@ impl GgmlDType {
|
||||
}
|
||||
|
||||
/// The block size, i.e. the number of elements stored in each block.
|
||||
pub fn block_size(&self) -> usize {
|
||||
pub fn blck_size(&self) -> usize {
|
||||
match self {
|
||||
Self::F32 => 1,
|
||||
Self::F16 => 1,
|
||||
@ -233,13 +119,9 @@ impl GgmlDType {
|
||||
pub trait QuantizedType: Send + Sync {
|
||||
fn dtype(&self) -> GgmlDType;
|
||||
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
|
||||
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
|
||||
fn storage_size_in_bytes(&self) -> usize;
|
||||
fn as_ptr(&self) -> *const u8;
|
||||
fn block_size(&self) -> usize;
|
||||
#[allow(clippy::wrong_self_convention)]
|
||||
fn from_float(&mut self, xs: &[f32]) -> Result<()>;
|
||||
fn size(&self) -> usize;
|
||||
}
|
||||
|
||||
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
@ -247,26 +129,12 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.len() * core::mem::size_of::<T>()
|
||||
}
|
||||
|
||||
fn from_float(&mut self, xs: &[f32]) -> Result<()> {
|
||||
T::from_float(xs, self)
|
||||
}
|
||||
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
T::DTYPE
|
||||
}
|
||||
|
||||
fn block_size(&self) -> usize {
|
||||
T::BLCK_SIZE
|
||||
}
|
||||
|
||||
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
|
||||
let mut ys = vec![0.0f32; elem_count];
|
||||
T::to_float(self.as_slice(), &mut ys)?;
|
||||
Ok(CpuStorage::F32(ys))
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
|
||||
T::to_float(self.as_slice(), ys)
|
||||
}
|
||||
|
||||
fn storage_size_in_bytes(&self) -> usize {
|
||||
@ -284,53 +152,56 @@ impl std::fmt::Debug for QTensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
|
||||
fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
|
||||
let dims = shape.dims();
|
||||
if dims.is_empty() {
|
||||
crate::bail!("scalar tensor cannot be quantized {shape:?}")
|
||||
}
|
||||
if dims[dims.len() - 1] % block_size != 0 {
|
||||
if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
|
||||
crate::bail!(
|
||||
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
|
||||
block_size
|
||||
T::BLCK_SIZE
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl QTensor {
|
||||
pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
|
||||
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||
data: Vec<T>,
|
||||
shape: S,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
check_shape(&shape, storage.block_size())?;
|
||||
Ok(Self { storage, shape })
|
||||
check_shape::<T>(&shape)?;
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
shape,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
|
||||
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
|
||||
let shape = src.shape();
|
||||
let block_size = dtype.block_size();
|
||||
check_shape(shape, block_size)?;
|
||||
let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
|
||||
let elem_count = shape.elem_count();
|
||||
if elem_count % block_size != 0 {
|
||||
check_shape::<T>(shape)?;
|
||||
let src = src
|
||||
.to_dtype(crate::DType::F32)?
|
||||
.flatten_all()?
|
||||
.to_vec1::<f32>()?;
|
||||
if src.len() % T::BLCK_SIZE != 0 {
|
||||
crate::bail!(
|
||||
"tensor size ({shape:?}) is not divisible by block size {}",
|
||||
block_size
|
||||
T::BLCK_SIZE
|
||||
)
|
||||
}
|
||||
let mut storage = src.device().qzeros(elem_count, dtype)?;
|
||||
storage.quantize(&src.storage())?;
|
||||
let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
||||
T::from_float(&src, &mut data)?;
|
||||
Ok(Self {
|
||||
storage,
|
||||
data: Box::new(data),
|
||||
shape: shape.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.storage.dtype()
|
||||
}
|
||||
|
||||
pub fn device(&self) -> Device {
|
||||
self.storage.device()
|
||||
self.data.dtype()
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
@ -342,19 +213,21 @@ impl QTensor {
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
||||
let none = crate::op::BackpropOp::none();
|
||||
let is_variable = false;
|
||||
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
|
||||
.to_device(device)
|
||||
let mut f32_data = vec![0f32; self.shape.elem_count()];
|
||||
self.data.to_float(&mut f32_data)?;
|
||||
Tensor::from_vec(f32_data, &self.shape, device)
|
||||
}
|
||||
|
||||
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
|
||||
self.data.matmul_t(mkn, lhs, dst)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.storage.size_in_bytes()
|
||||
self.data.storage_size_in_bytes()
|
||||
}
|
||||
|
||||
pub fn data(&self) -> Result<Cow<'_, [u8]>> {
|
||||
self.storage.data()
|
||||
pub fn as_ptr(&self) -> *const u8 {
|
||||
self.data.as_ptr()
|
||||
}
|
||||
}
|
||||
|
||||
@ -421,33 +294,21 @@ impl crate::CustomOp1 for QTensor {
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
#[allow(clippy::infallible_destructuring_match)]
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Cpu(storage) => storage,
|
||||
QStorage::Metal(_) => crate::bail!("Invalid storage"),
|
||||
};
|
||||
let slice = storage.as_slice::<f32>()?;
|
||||
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
let storage = storage.as_slice::<f32>()?;
|
||||
let storage =
|
||||
&storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||
self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?;
|
||||
self.matmul_t(
|
||||
(dst_shape.elem_count() / n, k, n),
|
||||
storage,
|
||||
&mut dst_storage,
|
||||
)?;
|
||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||
}
|
||||
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &crate::MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::MetalStorage, Shape)> {
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Metal(metal) => metal,
|
||||
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
||||
};
|
||||
self_storage.fwd(&self.shape, storage, layout)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for QMatMul {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
impl QMatMul {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||
Self::Tensor(w) => {
|
||||
|
@ -12,14 +12,6 @@ use core::arch::arm::*;
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
use core::arch::aarch64::*;
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
|
||||
// TODO: dotprod
|
||||
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
|
||||
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
||||
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
@ -27,39 +19,71 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||
for i in 0..nb {
|
||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
||||
for i in (0..nb).step_by(2) {
|
||||
let x0 = &xs[i];
|
||||
let x1 = &xs[i + 1];
|
||||
let y0 = &ys[i];
|
||||
let y1 = &ys[i + 1];
|
||||
|
||||
let m4b = vdupq_n_u8(0x0F);
|
||||
let s8b = vdupq_n_s8(0x8);
|
||||
|
||||
let v0_0 = vld1q_u8(x0.qs.as_ptr());
|
||||
let v0_1 = vld1q_u8(x1.qs.as_ptr());
|
||||
|
||||
// 4-bit -> 8-bit
|
||||
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
||||
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
||||
let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
||||
let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
||||
|
||||
// sub 8
|
||||
let v0_0ls = vsubq_s8(v0_0l, s8b);
|
||||
let v0_0hs = vsubq_s8(v0_0h, s8b);
|
||||
let v0_1ls = vsubq_s8(v0_1l, s8b);
|
||||
let v0_1hs = vsubq_s8(v0_1h, s8b);
|
||||
|
||||
// load y
|
||||
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
||||
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
let v1_1l = vld1q_s8(y1.qs.as_ptr());
|
||||
let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16));
|
||||
|
||||
// TODO: Support dotprod when it's available outside of nightly.
|
||||
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
|
||||
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
||||
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
|
||||
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
||||
|
||||
let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l));
|
||||
let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
|
||||
let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h));
|
||||
let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
|
||||
|
||||
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||
let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
||||
let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
||||
|
||||
let pl0 = vdotq_s32(v0_0ls, v1_0l);
|
||||
let ph0 = vdotq_s32(v0_0hs, v1_0h);
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
||||
x0.d.to_f32() * y0.d.to_f32(),
|
||||
);
|
||||
sumv1 = vmlaq_n_f32(
|
||||
sumv1,
|
||||
vcvtq_f32_s32(vaddq_s32(pl1, ph1)),
|
||||
x1.d.to_f32() * y1.d.to_f32(),
|
||||
);
|
||||
}
|
||||
Ok(vaddvq_f32(sumv0))
|
||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
||||
}
|
||||
}
|
||||
|
||||
@ -70,29 +94,57 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||
for i in 0..nb {
|
||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
||||
for i in (0..nb).step_by(2) {
|
||||
let x0 = &xs[i];
|
||||
let x1 = &xs[i + 1];
|
||||
let y0 = &ys[i];
|
||||
let y1 = &ys[i + 1];
|
||||
|
||||
let x0_0 = vld1q_s8(x0.qs.as_ptr());
|
||||
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
|
||||
let x1_0 = vld1q_s8(x1.qs.as_ptr());
|
||||
let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));
|
||||
|
||||
// load y
|
||||
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
||||
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
let y1_0 = vld1q_s8(y1.qs.as_ptr());
|
||||
let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));
|
||||
|
||||
let p0 = vdotq_s32(x0_0, y0_0);
|
||||
let p1 = vdotq_s32(x0_1, y0_1);
|
||||
// TODO dotprod once this is the intrinsics are.
|
||||
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
||||
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
||||
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
||||
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
||||
|
||||
let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
|
||||
let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
||||
let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
|
||||
let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
||||
|
||||
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
||||
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
||||
let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
||||
let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(p0, p1)),
|
||||
x0.d.to_f32() * y0.d.to_f32(),
|
||||
);
|
||||
sumv1 = vmlaq_n_f32(
|
||||
sumv1,
|
||||
vcvtq_f32_s32(vaddq_s32(p2, p3)),
|
||||
x1.d.to_f32() * y1.d.to_f32(),
|
||||
);
|
||||
}
|
||||
Ok(vaddvq_f32(sumv0))
|
||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
||||
}
|
||||
}
|
||||
|
||||
@ -113,7 +165,10 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res
|
||||
for i in (0..QK_K).step_by(16) {
|
||||
let xs = vld1q_s8(xs.add(i));
|
||||
let ys = vld1q_s8(ys.add(i));
|
||||
let xy = vdotq_s32(xs, ys);
|
||||
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
|
||||
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
|
||||
|
||||
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
|
||||
sum_i = vaddq_s32(sum_i, xy)
|
||||
}
|
||||
sumf += vaddvq_s32(sum_i) as f32 * scale
|
||||
@ -183,16 +238,30 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
|
||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
|
||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
|
||||
|
||||
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
||||
// TODO: dotprod
|
||||
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
||||
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
||||
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let q8bytes = vld1q_s8_x4(q8);
|
||||
@ -212,16 +281,29 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
|
||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
|
||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
|
||||
|
||||
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
||||
// TODO: dotprod case.
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
||||
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
||||
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||
scale = scale.add(2);
|
||||
}
|
||||
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
|
||||
@ -298,14 +380,28 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
||||
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
|
||||
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
|
||||
|
||||
let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
|
||||
sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
|
||||
// TODO: dotprod
|
||||
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
|
||||
scales = scales.add(1);
|
||||
|
||||
let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
|
||||
sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
|
||||
scales = scales.add(1);
|
||||
}
|
||||
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
|
||||
@ -368,15 +464,22 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
||||
for j in 0..QK_K / 64 {
|
||||
let q4bits = vld1q_u8_x2(q4);
|
||||
q4 = q4.add(32);
|
||||
// TODO: dotprod
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
let q4bytes = int8x16x2_t(
|
||||
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
|
||||
);
|
||||
let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
|
||||
sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
@ -384,9 +487,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
|
||||
);
|
||||
let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
|
||||
let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
|
||||
sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
|
||||
}
|
||||
sumf += d * (sumi1 + sumi2) as f32;
|
||||
}
|
||||
@ -464,14 +573,27 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
||||
vreinterpretq_s8_u8(q3h_3),
|
||||
);
|
||||
|
||||
let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
|
||||
let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
|
||||
let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
|
||||
let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
|
||||
isum += vaddvq_s32(p0) * *scale as i32
|
||||
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
||||
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
||||
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
||||
// TODO: dotprod
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
|
||||
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
|
||||
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
|
||||
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
|
||||
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
|
||||
);
|
||||
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||
scale = scale.add(4);
|
||||
|
||||
let q3h_0 = vbicq_u8(m2, qhbits.0);
|
||||
@ -496,14 +618,27 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
||||
vreinterpretq_s8_u8(q3h_3),
|
||||
);
|
||||
|
||||
let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
|
||||
let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
|
||||
let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
|
||||
let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
|
||||
isum += vaddvq_s32(p0) * *scale as i32
|
||||
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
||||
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
||||
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
||||
// TODO: dotprod
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
|
||||
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
|
||||
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
|
||||
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
|
||||
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
|
||||
);
|
||||
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||
scale = scale.add(4);
|
||||
|
||||
if j == 0 {
|
||||
@ -561,6 +696,7 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
|
||||
let mut is = 0usize;
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
for _j in 0..QK_K / 128 {
|
||||
let q2bits = vld1q_u8_x2(q2);
|
||||
q2 = q2.add(32);
|
||||
@ -607,7 +743,14 @@ unsafe fn multiply_accum_with_scale(
|
||||
q2bytes: int8x16x2_t,
|
||||
q8bytes: int8x16x2_t,
|
||||
) -> i32 {
|
||||
let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
|
||||
let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
|
||||
vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
vaddvq_s16(p1) as i32 * aux[is + index] as i32
|
||||
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
|
||||
}
|
||||
|
@ -11,6 +11,10 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
@ -57,6 +61,10 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
|
@ -203,7 +203,7 @@ impl Shape {
|
||||
|
||||
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
|
||||
/// broadcasted shape. This is to be used for binary pointwise ops.
|
||||
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||
let lhs = self;
|
||||
let lhs_dims = lhs.dims();
|
||||
let rhs_dims = rhs.dims();
|
||||
@ -478,6 +478,23 @@ extract_dims!(
|
||||
(usize, usize, usize, usize, usize)
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn stride() {
|
||||
let shape = Shape::from(());
|
||||
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
|
||||
let shape = Shape::from(42);
|
||||
assert_eq!(shape.stride_contiguous(), [1]);
|
||||
let shape = Shape::from((42, 1337));
|
||||
assert_eq!(shape.stride_contiguous(), [1337, 1]);
|
||||
let shape = Shape::from((299, 792, 458));
|
||||
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ShapeWithOneHole {
|
||||
fn into_shape(self, el_count: usize) -> Result<Shape>;
|
||||
}
|
||||
@ -610,20 +627,3 @@ impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
|
||||
Ok((d1, d2, d3, d4, d).into())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn stride() {
|
||||
let shape = Shape::from(());
|
||||
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
|
||||
let shape = Shape::from(42);
|
||||
assert_eq!(shape.stride_contiguous(), [1]);
|
||||
let shape = Shape::from((42, 1337));
|
||||
assert_eq!(shape.stride_contiguous(), [1337, 1]);
|
||||
let shape = Shape::from((299, 792, 458));
|
||||
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
// out of memory. Instead try_clone should be used.
|
||||
@ -8,7 +8,6 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage,
|
||||
pub enum Storage {
|
||||
Cpu(CpuStorage),
|
||||
Cuda(CudaStorage),
|
||||
Metal(MetalStorage),
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
@ -19,10 +18,6 @@ impl Storage {
|
||||
let storage = storage.try_clone(layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.try_clone(layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -30,7 +25,6 @@ impl Storage {
|
||||
match self {
|
||||
Self::Cpu(_) => Device::Cpu,
|
||||
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||
Self::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,7 +32,6 @@ impl Storage {
|
||||
match self {
|
||||
Self::Cpu(storage) => storage.dtype(),
|
||||
Self::Cuda(storage) => storage.dtype(),
|
||||
Self::Metal(storage) => storage.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,10 +65,6 @@ impl Storage {
|
||||
let storage = storage.affine(layout, mul, add)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.affine(layout, mul, add)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -89,10 +78,6 @@ impl Storage {
|
||||
let storage = storage.powf(layout, alpha)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.powf(layout, alpha)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,10 +91,6 @@ impl Storage {
|
||||
let storage = storage.elu(layout, alpha)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.elu(layout, alpha)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -131,10 +112,6 @@ impl Storage {
|
||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
@ -158,10 +135,6 @@ impl Storage {
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -175,10 +148,6 @@ impl Storage {
|
||||
let storage = storage.to_dtype(layout, dtype)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.to_dtype(layout, dtype)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -192,10 +161,6 @@ impl Storage {
|
||||
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
||||
Ok((Self::Cuda(storage), shape))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let (storage, shape) = c.metal_fwd(storage, l)?;
|
||||
Ok((Self::Metal(storage), shape))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -216,10 +181,6 @@ impl Storage {
|
||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
||||
Ok((Self::Cuda(s), shape))
|
||||
}
|
||||
(Self::Metal(s1), Self::Metal(s2)) => {
|
||||
let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
|
||||
Ok((Self::Metal(s), shape))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -244,10 +205,6 @@ impl Storage {
|
||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||
Ok((Self::Cuda(s), shape))
|
||||
}
|
||||
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
|
||||
let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||
Ok((Self::Metal(s), shape))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -262,10 +219,6 @@ impl Storage {
|
||||
let storage = storage.unary_impl::<B>(layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.unary_impl::<B>(layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -286,10 +239,6 @@ impl Storage {
|
||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
@ -321,10 +270,6 @@ impl Storage {
|
||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -334,33 +279,6 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv_transpose1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
self.same_device(kernel, "conv-transpose1d")?;
|
||||
self.same_dtype(kernel, "conv-transpose1d")?;
|
||||
match (self, &kernel) {
|
||||
(Storage::Cpu(inp), Storage::Cpu(kernel)) => {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cpu(s))
|
||||
}
|
||||
(Storage::Cuda(inp), Storage::Cuda(kernel)) => {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "conv-transpose1d",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv2d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
@ -379,10 +297,6 @@ impl Storage {
|
||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -410,10 +324,6 @@ impl Storage {
|
||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -438,10 +348,6 @@ impl Storage {
|
||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -460,10 +366,6 @@ impl Storage {
|
||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -477,10 +379,6 @@ impl Storage {
|
||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.upsample_nearest1d(layout, sz)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -494,10 +392,6 @@ impl Storage {
|
||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
Self::Metal(storage) => {
|
||||
let storage = storage.upsample_nearest2d(layout, h, w)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -521,10 +415,6 @@ impl Storage {
|
||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
|
||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -551,10 +441,6 @@ impl Storage {
|
||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes)) => {
|
||||
let storage = s.gather(l, indexes, indexes_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -579,10 +465,6 @@ impl Storage {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -607,10 +489,6 @@ impl Storage {
|
||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -632,10 +510,6 @@ impl Storage {
|
||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -663,10 +537,6 @@ impl Storage {
|
||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(Self::Metal(lhs), Self::Metal(rhs)) => {
|
||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Metal(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -686,9 +556,6 @@ impl Storage {
|
||||
match (self, dst) {
|
||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
||||
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
||||
(Self::Metal(src), Self::Metal(dst)) => {
|
||||
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
|
@ -1,4 +1,4 @@
|
||||
//! Tensors are N-dimensional matrixes of elements using a single data type.
|
||||
//! Tensors are N-dimenional matrixes of elements using a single data type.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{
|
||||
@ -6,7 +6,7 @@ use crate::op::{
|
||||
};
|
||||
use crate::scalar::TensorOrScalar;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// Unique identifier for tensors.
|
||||
@ -361,16 +361,6 @@ impl Tensor {
|
||||
Self::new_impl(array, shape, device, false)
|
||||
}
|
||||
|
||||
/// Returns a new tensor with all the elements having the same specified value. Note that
|
||||
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
|
||||
pub fn full<D: crate::WithDType, S: Into<Shape>>(
|
||||
value: D,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
|
||||
}
|
||||
|
||||
/// Creates a new 1D tensor from an iterator.
|
||||
pub fn from_iter<D: crate::WithDType>(
|
||||
iter: impl IntoIterator<Item = D>,
|
||||
@ -395,21 +385,11 @@ impl Tensor {
|
||||
step: D,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
if D::is_zero(&step) {
|
||||
bail!("step cannot be zero")
|
||||
}
|
||||
let mut data = vec![];
|
||||
let mut current = start;
|
||||
if step >= D::zero() {
|
||||
while current < end {
|
||||
data.push(current);
|
||||
current += step;
|
||||
}
|
||||
} else {
|
||||
while current > end {
|
||||
data.push(current);
|
||||
current += step;
|
||||
}
|
||||
while current < end {
|
||||
data.push(current);
|
||||
current += step;
|
||||
}
|
||||
let len = data.len();
|
||||
Self::from_vec_impl(data, len, device, false)
|
||||
@ -487,12 +467,6 @@ impl Tensor {
|
||||
broadcast_binary_op!(broadcast_div, div);
|
||||
broadcast_binary_op!(broadcast_maximum, maximum);
|
||||
broadcast_binary_op!(broadcast_minimum, minimum);
|
||||
broadcast_binary_op!(broadcast_eq, eq);
|
||||
broadcast_binary_op!(broadcast_ne, ne);
|
||||
broadcast_binary_op!(broadcast_lt, lt);
|
||||
broadcast_binary_op!(broadcast_le, le);
|
||||
broadcast_binary_op!(broadcast_gt, gt);
|
||||
broadcast_binary_op!(broadcast_ge, ge);
|
||||
|
||||
unary_op!(recip, Recip);
|
||||
unary_op!(neg, Neg);
|
||||
@ -539,7 +513,6 @@ impl Tensor {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -679,7 +652,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
/// Split a tensor into the specified number of chunks, this may return less chunks than
|
||||
/// specified.
|
||||
/// specificed.
|
||||
pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
|
||||
let dim = dim.to_index(self.shape(), "chunk")?;
|
||||
let size = self.dim(dim)?;
|
||||
@ -804,35 +777,6 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Roll the tensor input along the given dimension.
|
||||
/// Elements that are shifted beyond the last position are re-introduced at the first position.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.roll(1, 0)?;
|
||||
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.roll(-1, 0)?;
|
||||
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
|
||||
where
|
||||
D: Dim + Clone,
|
||||
{
|
||||
let dim = dim.to_index(self.shape(), "roll")?;
|
||||
let dim_size = self.dim(dim)?;
|
||||
let shift = shift.rem_euclid(dim_size as i32) as usize;
|
||||
if shift == 0 {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let a = self.narrow(dim, 0, dim_size - shift)?;
|
||||
let b = self.narrow(dim, dim_size - shift, shift)?;
|
||||
Tensor::cat(&[&b, &a], dim)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||
/// input dimensions.
|
||||
///
|
||||
@ -895,20 +839,6 @@ impl Tensor {
|
||||
self.sum_impl(mean_dims, false)? * scale
|
||||
}
|
||||
|
||||
/// Returns the unbiased variance over the selected dimension.
|
||||
pub fn var_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "var")?;
|
||||
let mean = self.mean_keepdim(dim)?;
|
||||
let squares = self.broadcast_sub(&mean)?.sqr()?;
|
||||
squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64
|
||||
}
|
||||
|
||||
/// Returns the unbiased variance over the selected dimension.
|
||||
pub fn var<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "var")?;
|
||||
self.var_keepdim(dim)?.squeeze(dim)
|
||||
}
|
||||
|
||||
/// Gathers the maximum value across the selected dimension. The resulting shape has the same
|
||||
/// number of dimensions as the original tensor and the select dimension has a single element.
|
||||
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
@ -1033,11 +963,7 @@ impl Tensor {
|
||||
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
|
||||
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||
let (n, c, _h, _w) = self.dims4()?;
|
||||
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
|
||||
arg,
|
||||
target_h,
|
||||
target_w,
|
||||
});
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
|
||||
let storage = self
|
||||
.storage()
|
||||
.upsample_nearest2d(self.layout(), target_h, target_w)?;
|
||||
@ -1070,9 +996,6 @@ impl Tensor {
|
||||
let kernel_size = kernel_size.to_usize2();
|
||||
let stride = stride.to_usize2();
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
if h < kernel_size.0 || w < kernel_size.1 {
|
||||
bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
|
||||
}
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||
@ -1108,9 +1031,6 @@ impl Tensor {
|
||||
let kernel_size = kernel_size.to_usize2();
|
||||
let stride = stride.to_usize2();
|
||||
let (n, c, h, w) = self.dims4()?;
|
||||
if h < kernel_size.0 || w < kernel_size.1 {
|
||||
bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
|
||||
}
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
||||
let h_out = (h - kernel_size.0) / stride.0 + 1;
|
||||
let w_out = (w - kernel_size.1) / stride.1 + 1;
|
||||
@ -1266,16 +1186,14 @@ impl Tensor {
|
||||
op: "scatter-add (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
})?
|
||||
}
|
||||
if indexes.dims() != source.dims() {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (indexes, src)",
|
||||
lhs: indexes.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
})?
|
||||
}
|
||||
let storage = self.storage().scatter_add(
|
||||
self.layout(),
|
||||
@ -1347,8 +1265,7 @@ impl Tensor {
|
||||
op: "slice-scatter (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: src.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
})?
|
||||
}
|
||||
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
||||
self.storage()
|
||||
@ -1382,8 +1299,7 @@ impl Tensor {
|
||||
op: "index-add (self, source)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
})?
|
||||
}
|
||||
// The number of element in indexes must match the dimension on which the add is
|
||||
// performed on the source tensor (and the index values from `indexes` are taken from
|
||||
@ -1394,8 +1310,7 @@ impl Tensor {
|
||||
op: "index-add (ids, source))",
|
||||
lhs: indexes.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
})?
|
||||
}
|
||||
let storage = self.storage().index_add(
|
||||
self.layout(),
|
||||
@ -1443,8 +1358,7 @@ impl Tensor {
|
||||
op: "gather",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: indexes.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
})?
|
||||
}
|
||||
let storage =
|
||||
self.storage()
|
||||
@ -1518,7 +1432,6 @@ impl Tensor {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1549,7 +1462,6 @@ impl Tensor {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1590,7 +1502,6 @@ impl Tensor {
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1833,7 +1744,7 @@ impl Tensor {
|
||||
let is_permutation =
|
||||
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
|
||||
if !is_permutation {
|
||||
bail!(
|
||||
crate::bail!(
|
||||
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
||||
self.dims(),
|
||||
dims
|
||||
@ -1880,23 +1791,17 @@ impl Tensor {
|
||||
|
||||
/// Returns a new tensor detached from the current graph, gradient are not propagated through
|
||||
/// this new node. The storage of this tensor is shared with the initial tensor.
|
||||
///
|
||||
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
||||
pub fn detach(&self) -> Tensor {
|
||||
if self.op.is_none() && !self.is_variable {
|
||||
self.clone()
|
||||
} else {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: self.layout.clone(),
|
||||
op: BackpropOp::none(),
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Tensor(Arc::new(tensor_))
|
||||
}
|
||||
pub fn detach(&self) -> Result<Tensor> {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: self.layout.clone(),
|
||||
op: BackpropOp::none(),
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
/// If the target device is the same as the tensor device, only a shallow copy is performed.
|
||||
@ -1908,11 +1813,7 @@ impl Tensor {
|
||||
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
|
||||
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
(Storage::Cpu(storage), Device::Metal(metal)) => {
|
||||
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||
// are the same.
|
||||
@ -1920,9 +1821,6 @@ impl Tensor {
|
||||
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
|
||||
}
|
||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||
_ => {
|
||||
bail!("not implemented yet")
|
||||
}
|
||||
};
|
||||
let op = BackpropOp::new1(self, Op::ToDevice);
|
||||
let tensor_ = Tensor_ {
|
||||
@ -2328,7 +2226,7 @@ impl Tensor {
|
||||
if left == 0 && right == 0 {
|
||||
Ok(self.clone())
|
||||
} else if self.elem_count() == 0 {
|
||||
bail!("cannot use pad_with_same on an empty tensor")
|
||||
crate::bail!("cannot use pad_with_same on an empty tensor")
|
||||
} else if left == 0 {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_same")?;
|
||||
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
|
||||
@ -2367,11 +2265,6 @@ impl Tensor {
|
||||
m.forward(self)
|
||||
}
|
||||
|
||||
/// Run the `forward` method of `m` on `self`.
|
||||
pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
|
||||
m.forward_t(self, train)
|
||||
}
|
||||
|
||||
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
self.storage.read().unwrap()
|
||||
}
|
||||
@ -2486,142 +2379,6 @@ impl Tensor {
|
||||
) -> Result<Self> {
|
||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Normalize a 'relative' axis value: positive values are kept, negative
|
||||
/// values means counting the dimensions from the back.
|
||||
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
||||
let rank = self.rank() as i64;
|
||||
if rank <= axis {
|
||||
bail!("axis {axis} is too large, tensor rank {rank}")
|
||||
} else if 0 <= axis {
|
||||
Ok(axis as usize)
|
||||
} else {
|
||||
let naxis = rank + axis;
|
||||
if naxis < 0 {
|
||||
bail!("axis {axis} is too small, tensor rank {rank}")
|
||||
}
|
||||
Ok(naxis as usize)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a lower triangular matrix of ones of size n by n.
|
||||
pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||
t1.le(&t2)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
/// Returns an upper triangular matrix of ones of size n by n.
|
||||
pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||
t1.ge(&t2)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
/// Returns a matrix with a diagonal of ones of size n by n.
|
||||
pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let t = Tensor::arange(0u32, n as u32, device)?;
|
||||
let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
|
||||
let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
|
||||
t1.eq(&t2)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
/// Returns the cumulative sum of elements of the input tensor summed over the specified
|
||||
/// dimension.
|
||||
///
|
||||
/// This operation is most efficient when dim is the last dimension of the tensor.
|
||||
pub fn cumsum<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "cumsum")?;
|
||||
let rank = self.rank();
|
||||
if rank == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let n_axis = self.dim(dim)?;
|
||||
let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?;
|
||||
if rank == 1 {
|
||||
self.unsqueeze(0)?.matmul(&triu)?.squeeze(0)
|
||||
} else {
|
||||
let last = rank - 1;
|
||||
let t = self.transpose(dim, last)?;
|
||||
let t = t.broadcast_matmul(&triu)?;
|
||||
t.transpose(dim, last)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a copy of `self` where the values within `ranges` have been replaced with the
|
||||
/// content of `src`.
|
||||
pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
|
||||
&self,
|
||||
ranges: &[D],
|
||||
src: &Tensor,
|
||||
) -> Result<Self> {
|
||||
let src_dims = src.dims();
|
||||
let self_dims = self.dims();
|
||||
if self_dims.len() != src_dims.len() {
|
||||
bail!(
|
||||
"slice-assign requires input with the same rank {} <> {}",
|
||||
self_dims.len(),
|
||||
src_dims.len()
|
||||
)
|
||||
}
|
||||
if self_dims.len() != ranges.len() {
|
||||
bail!(
|
||||
"slice-assign requires input with the same rank as there are ranges {} <> {}",
|
||||
self_dims.len(),
|
||||
ranges.len()
|
||||
)
|
||||
}
|
||||
let mut src = src.clone();
|
||||
let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
|
||||
for (i, range) in ranges.iter().enumerate() {
|
||||
let start_included = match range.start_bound() {
|
||||
std::ops::Bound::Unbounded => 0,
|
||||
std::ops::Bound::Included(v) => *v,
|
||||
std::ops::Bound::Excluded(v) => *v + 1,
|
||||
};
|
||||
let end_excluded = match range.end_bound() {
|
||||
std::ops::Bound::Unbounded => self_dims[i],
|
||||
std::ops::Bound::Included(v) => *v + 1,
|
||||
std::ops::Bound::Excluded(v) => *v,
|
||||
};
|
||||
if end_excluded <= start_included {
|
||||
bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
|
||||
}
|
||||
if self_dims[i] < end_excluded {
|
||||
bail!(
|
||||
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
|
||||
self_dims[i]
|
||||
)
|
||||
}
|
||||
if end_excluded - start_included != src_dims[i] {
|
||||
bail!(
|
||||
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
|
||||
)
|
||||
}
|
||||
src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
|
||||
mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
|
||||
}
|
||||
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
|
||||
}
|
||||
|
||||
/// Returns log(sum(exp(tensor), dim)).
|
||||
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||
let exp = self.exp()?;
|
||||
let sum = exp.sum(sum_dims)?;
|
||||
sum.log()
|
||||
}
|
||||
|
||||
/// Pointwise pow operation.
|
||||
pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||
rhs.mul(&self.log()?)?.exp()
|
||||
}
|
||||
|
||||
/// Broadcasting version of `pow`.
|
||||
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||
rhs.broadcast_mul(&self.log()?)?.exp()
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
@ -4,7 +4,7 @@ use crate::{Result, Tensor};
|
||||
macro_rules! test_device {
|
||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
|
||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
|
||||
#[test]
|
||||
fn $test_cpu() -> Result<()> {
|
||||
$fn_name(&Device::Cpu)
|
||||
@ -15,12 +15,6 @@ macro_rules! test_device {
|
||||
fn $test_cuda() -> Result<()> {
|
||||
$fn_name(&Device::new_cuda(0)?)
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
#[test]
|
||||
fn $test_metal() -> Result<()> {
|
||||
$fn_name(&Device::new_metal(0)?)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -23,10 +23,6 @@ pub fn cuda_is_available() -> bool {
|
||||
cfg!(feature = "cuda")
|
||||
}
|
||||
|
||||
pub fn metal_is_available() -> bool {
|
||||
cfg!(feature = "metal")
|
||||
}
|
||||
|
||||
pub fn with_avx() -> bool {
|
||||
cfg!(target_feature = "avx")
|
||||
}
|
||||
|
@ -107,10 +107,6 @@ impl Var {
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
pub fn as_detached_tensor(&self) -> Tensor {
|
||||
self.0.detach()
|
||||
}
|
||||
|
||||
pub fn as_tensor(&self) -> &Tensor {
|
||||
&self.0
|
||||
}
|
||||
|
@ -13,11 +13,6 @@ res = torch.nn.functional.conv1d(t, w)
|
||||
print(res.flatten())
|
||||
res = torch.nn.functional.conv1d(t, w, padding=1)
|
||||
print(res.flatten())
|
||||
|
||||
w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
*/
|
||||
fn conv1d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
@ -50,15 +45,6 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -493,103 +479,17 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
// Replicate the issue from https://github.com/huggingface/candle/issues/1212
|
||||
let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32);
|
||||
let grads = loss.backward()?;
|
||||
let grad_t = grads.get(&t).unwrap();
|
||||
let grad_w = grads.get(&w).unwrap();
|
||||
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[9.29, -7.03, 7.87, 0.0, 0.0],
|
||||
[-1.8, -7.82, 5.9, 0.0, 0.0],
|
||||
[-3.12, 4.49, 5.52, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[21.73, 3.39, 4.77, 0.0, 0.0],
|
||||
[8.25, 3.73, 27.61, 0.0, 0.0],
|
||||
[-20.55, -5.61, -2.77, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[-8.98, 9.91, -7.15, 0.0, 0.0],
|
||||
[4.93, -0.33, 4.56, 0.0, 0.0],
|
||||
[-6.7, -5.76, -8.05, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[23.54, 6.98, -10.0, 0.0, 0.0],
|
||||
[9.65, 6.18, 18.72, 0.0, 0.0],
|
||||
[3.29, -5.27, 0.79, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[-3.47, 7.44, 0.66],
|
||||
[12.89, -3.4, -9.29],
|
||||
[-14.16, -0.83, 7.14]
|
||||
],
|
||||
[
|
||||
[-3.23, 5.37, -3.02],
|
||||
[-2.12, -11.24, 1.94],
|
||||
[6.97, 7.2, 2.99]
|
||||
],
|
||||
[
|
||||
[-4.04, -3.31, 4.87],
|
||||
[-6.68, -5.68, 1.73],
|
||||
[-5.54, 4.32, 0.52]
|
||||
],
|
||||
[[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal);
|
||||
test_device!(
|
||||
conv1d_small,
|
||||
conv1d_small_cpu,
|
||||
conv1d_small_gpu,
|
||||
conv1d_small_metal
|
||||
);
|
||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal);
|
||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
||||
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
||||
test_device!(
|
||||
conv2d_non_square,
|
||||
conv2d_non_square_cpu,
|
||||
conv2d_non_square_gpu,
|
||||
conv2d_non_square_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_small,
|
||||
conv2d_small_cpu,
|
||||
conv2d_small_gpu,
|
||||
conv2d_small_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_smaller,
|
||||
conv2d_smaller_cpu,
|
||||
conv2d_smaller_gpu,
|
||||
conv2d_smaller_metal
|
||||
);
|
||||
test_device!(
|
||||
conv2d_grad,
|
||||
conv2d_grad_cpu,
|
||||
conv2d_grad_gpu,
|
||||
conv2_grad_metal
|
||||
conv2d_non_square_gpu
|
||||
);
|
||||
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
||||
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|
||||
test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);
|
||||
|
Binary file not shown.
@ -205,231 +205,6 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0116, 1.0830, 1.0003, 0.6188],
|
||||
);
|
||||
|
||||
// Testing compared to pytorch torch.erf
|
||||
//
|
||||
// import torch
|
||||
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
||||
// y = x.erf()
|
||||
// print(y)
|
||||
// loss = y.sum()
|
||||
// loss.backward()
|
||||
// print(x.grad)
|
||||
let y = x.erf()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[0.0001, 0.4151, 0.0, 1.1033],
|
||||
);
|
||||
|
||||
// Testing compared to pytorch nn.GELU(approximate = 'none')
|
||||
//
|
||||
// import torch
|
||||
// import torch.nn.functional as F
|
||||
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
||||
// y = F.gelu(x, approximate='none')
|
||||
// print(y)
|
||||
// loss = y.sum()
|
||||
// loss.backward()
|
||||
// print(x.grad)
|
||||
let y = x.gelu_erf()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.9960, 0.8413, 3.9999, 0.0839]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0119, 1.0833, 1.0005, 0.6188],
|
||||
);
|
||||
|
||||
// Testing compared to pytorch elu
|
||||
//
|
||||
// import torch
|
||||
// import torch.nn.functional as F
|
||||
// x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True)
|
||||
// y = F.elu(x, alpha=2.0)
|
||||
// print(y)
|
||||
// loss = y.min
|
||||
// loss = y.sum()
|
||||
// loss.backward()
|
||||
// print(x.grad)
|
||||
let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?;
|
||||
let y = elu_x.elu(2.)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06.,
|
||||
07., 08., 09., 10., 11., 12.,
|
||||
13., 14., 15., 16., 17., 18.,
|
||||
19., 20., 21., 22., 23., 24.,
|
||||
25., 26., 27., 28., 29., 30.,
|
||||
31., 32., 33., 34., 35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
// gradient should be
|
||||
// row 1
|
||||
// 1+2+7+8 = 18
|
||||
// 3+4+9+10 = 26
|
||||
// 5+6+11+12 = 34
|
||||
// row 2
|
||||
// 13+14+19+20 = 66
|
||||
// 15+16+21+22 = 74
|
||||
// 17+18+23+24 = 82
|
||||
// row 3
|
||||
// 25+26+31+32 = 114
|
||||
// 27+28+33+34 = 122
|
||||
// 29+30+35+36 = 130
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
|
||||
[[18_f32, 26., 34.], [66., 74., 82.], [114., 122., 130.]]
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06.,
|
||||
07., 08., 09., 10., 11., 12.,
|
||||
13., 14., 15., 16., 17., 18.,
|
||||
19., 20., 21., 22., 23., 24.,
|
||||
25., 26., 27., 28., 29., 30.,
|
||||
31., 32., 33., 34., 35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
// gradient should be
|
||||
// row 1
|
||||
// 1+2+3+7+8+9+13+14+15 = 72
|
||||
// 4+5+6+10+11+12+16+17+18 = 99
|
||||
// row 2
|
||||
// 19+20+21+25+26+27+31+32+33 = 234
|
||||
// 22+23+24+28+29+30+34+35+36 = 243
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
|
||||
[[72_f32, 99.], [234., 261.]]
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]], [[6f32, 7.], [8., 9.]]]], device)?;
|
||||
|
||||
let y = x.interpolate2d(4, 4)?.reshape(32)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04.,
|
||||
05., 06., 07., 08.,
|
||||
09., 10., 11., 12.,
|
||||
13., 14., 15., 16.,
|
||||
17., 18., 19., 20.,
|
||||
21., 22., 23., 24.,
|
||||
25., 26., 27., 28.,
|
||||
29., 30., 31., 32.
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
// gradient should be
|
||||
// m1r1
|
||||
// 1+2+5+6=14
|
||||
// 3+4+7+8=22
|
||||
// m1r2
|
||||
// 9+10+13+14=46
|
||||
// 11+12+15+16=54
|
||||
// m2r1
|
||||
// 17+18+21+22=78
|
||||
// 19+20+23+24=86
|
||||
// m2r2
|
||||
// 25+26+29+30=110
|
||||
// 27+28+31+32=118
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
|
||||
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(
|
||||
&[[[[1f32, 2.], [4., 5.]]], [[[6f32, 7.], [8., 9.]]]],
|
||||
device,
|
||||
)?;
|
||||
|
||||
let y = x.interpolate2d(4, 4)?.reshape(32)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04.,
|
||||
05., 06., 07., 08.,
|
||||
09., 10., 11., 12.,
|
||||
13., 14., 15., 16.,
|
||||
17., 18., 19., 20.,
|
||||
21., 22., 23., 24.,
|
||||
25., 26., 27., 28.,
|
||||
29., 30., 31., 32.
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
// gradient should be
|
||||
// m1r1
|
||||
// 1+2+5+6=14
|
||||
// 3+4+7+8=22
|
||||
// m1r2
|
||||
// 9+10+13+14=46
|
||||
// 11+12+15+16=54
|
||||
// m2r1
|
||||
// 17+18+21+22=78
|
||||
// 19+20+23+24=86
|
||||
// m2r2
|
||||
// 25+26+29+30=110
|
||||
// 27+28+31+32=118
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
|
||||
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -475,29 +250,9 @@ fn binary_grad(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(
|
||||
simple_grad,
|
||||
simple_grad_cpu,
|
||||
simple_grad_gpu,
|
||||
simple_grad_metal
|
||||
);
|
||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal);
|
||||
test_device!(
|
||||
matmul_grad,
|
||||
matmul_grad_cpu,
|
||||
matmul_grad_gpu,
|
||||
matmul_grad_metal
|
||||
);
|
||||
test_device!(
|
||||
grad_descent,
|
||||
grad_descent_cpu,
|
||||
grad_descent_gpu,
|
||||
grad_descent_metal
|
||||
);
|
||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal);
|
||||
test_device!(
|
||||
binary_grad,
|
||||
binary_grad_cpu,
|
||||
binary_grad_gpu,
|
||||
binary_grad_metal
|
||||
);
|
||||
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
||||
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
|
||||
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
||||
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
||||
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);
|
||||
test_device!(binary_grad, binary_grad_cpu, binary_grad_gpu);
|
||||
|
@ -91,32 +91,3 @@ fn index_3d() -> Result<()> {
|
||||
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_assign() -> Result<()> {
|
||||
let dev = Device::Cpu;
|
||||
|
||||
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
|
||||
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
|
||||
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
|
||||
assert_eq!(
|
||||
out.to_vec2::<u32>()?,
|
||||
&[
|
||||
[0, 1, 2, 3, 4],
|
||||
[5, 6, 7, 0, 1],
|
||||
[10, 11, 12, 2, 3],
|
||||
[15, 16, 17, 4, 5]
|
||||
]
|
||||
);
|
||||
let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
|
||||
assert_eq!(
|
||||
out.to_vec2::<u32>()?,
|
||||
&[
|
||||
[0, 1, 2, 3, 4],
|
||||
[2, 3, 7, 8, 9],
|
||||
[4, 5, 12, 13, 14],
|
||||
[15, 16, 17, 18, 19]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ fn contiguous(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal);
|
||||
test_device!(contiguous, contiguous_cpu, contiguous_gpu);
|
||||
|
||||
#[test]
|
||||
fn strided_blocks() -> Result<()> {
|
||||
|
@ -98,17 +98,15 @@ fn upsample_nearest2d(dev: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal);
|
||||
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
|
||||
test_device!(
|
||||
avg_pool2d_pytorch,
|
||||
avg_pool2d_pytorch_cpu,
|
||||
avg_pool2d_pytorch_gpu,
|
||||
avg_pool2d_pytorch_metal
|
||||
avg_pool2d_pytorch_gpu
|
||||
);
|
||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal);
|
||||
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
|
||||
test_device!(
|
||||
upsample_nearest2d,
|
||||
upsample_nearest2d_cpu,
|
||||
upsample_nearest2d_gpu,
|
||||
upsample_nearest2d_metal
|
||||
upsample_nearest2d_gpu
|
||||
);
|
||||
|
@ -1,37 +0,0 @@
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
|
||||
# Write a trivial tensor to a pt file
|
||||
a= torch.tensor([[1,2,3,4], [5,6,7,8]])
|
||||
o = OrderedDict()
|
||||
o["test"] = a
|
||||
|
||||
# Write a trivial tensor to a pt file
|
||||
torch.save(o, "test.pt")
|
||||
|
||||
############################################################################################################
|
||||
# Write a trivial tensor to a pt file with a key
|
||||
torch.save({"model_state_dict": o}, "test_with_key.pt")
|
||||
|
||||
############################################################################################################
|
||||
# Create a tensor with fortran contiguous memory layout
|
||||
import numpy as np
|
||||
|
||||
# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers
|
||||
# For example, creating a 2x3x4 array
|
||||
array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4))
|
||||
|
||||
# Verify the memory order
|
||||
print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True
|
||||
print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False
|
||||
|
||||
# Step 2: Convert the NumPy array to a PyTorch tensor
|
||||
tensor_fortran = torch.from_numpy(array_fortran)
|
||||
|
||||
# Verify the tensor layout
|
||||
print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout
|
||||
|
||||
# Step 3: Save the PyTorch tensor to a .pth file
|
||||
torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth')
|
||||
|
||||
print("3D Tensor saved with Fortran layout.")
|
@ -1,31 +0,0 @@
|
||||
/// Regression test for pth files not loading on Windows.
|
||||
#[test]
|
||||
fn test_pth() {
|
||||
let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
|
||||
tensors.get("test").unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pth_with_key() {
|
||||
let tensors =
|
||||
candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
|
||||
.unwrap();
|
||||
tensors.get("test").unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pth_fortran_congiguous() {
|
||||
let tensors =
|
||||
candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap();
|
||||
let tensor = tensors.get("tensor_fortran").unwrap().unwrap();
|
||||
|
||||
assert_eq!(tensor.dims3().unwrap(), (2, 3, 4));
|
||||
|
||||
assert_eq!(
|
||||
tensor.to_vec3::<i64>().unwrap(),
|
||||
[
|
||||
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
||||
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
|
||||
]
|
||||
);
|
||||
}
|
@ -1,9 +1,7 @@
|
||||
use candle_core::{
|
||||
bail,
|
||||
quantized::{self, GgmlDType},
|
||||
test_device,
|
||||
test_utils::to_vec2_round,
|
||||
Device, Module, Result, Tensor,
|
||||
Device, Result, Tensor,
|
||||
};
|
||||
use quantized::{k_quants, GgmlType};
|
||||
use rand::prelude::*;
|
||||
@ -15,48 +13,16 @@ const GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS: f32 = 0.0075;
|
||||
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040;
|
||||
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
|
||||
|
||||
fn test_matmul(
|
||||
device: &Device,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
dtype: GgmlDType,
|
||||
) -> Result<()> {
|
||||
let lhs = (0..(m * k))
|
||||
.map(|v| v as f32 / (m * k) as f32)
|
||||
.collect::<Vec<_>>();
|
||||
let rhs = (0..(k * n))
|
||||
.map(|v| v as f32 / (n * k) as f32)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||
let rhs = Tensor::from_slice(&rhs, (k, n), device)?;
|
||||
let mm = lhs.matmul(&rhs)?;
|
||||
let qtensor = quantized::QTensor::quantize(&rhs.t()?, dtype)?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&lhs)?;
|
||||
|
||||
let error: f32 = ((&mm - &res)?.abs()? / &mm.abs()?)?
|
||||
.sum_all()?
|
||||
.to_scalar()?;
|
||||
let error = error / (b * m * n) as f32;
|
||||
assert!(
|
||||
error <= 0.02,
|
||||
"Error {error} is too big. \nExpected:\n {mm} \nFound:\n {res}\n for {dtype:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
#[test]
|
||||
fn quantized_matmul() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
||||
let mut dst = vec![42.; 3 * 4];
|
||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||
assert_eq!(
|
||||
@ -66,7 +32,6 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
341876.0, 994283.0, 1655709.0, 2301518.0
|
||||
]
|
||||
);
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||
assert_eq!(
|
||||
mm.to_vec2::<f32>()?,
|
||||
@ -77,49 +42,35 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
match device {
|
||||
Device::Metal(_) => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[84946.0, 214126.0, 344757.0, 473798.0],
|
||||
[213458.0, 604350.0, 1000469.0, 1387990.0],
|
||||
[341970.0, 994574.0, 1656181.0, 2302182.0]
|
||||
]
|
||||
),
|
||||
_ => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[85120.0, 214562.0, 345455.0, 474748.0],
|
||||
[213475.0, 604465.0, 1000686.0, 1388317.0],
|
||||
[341876.0, 994283.0, 1655709.0, 2301518.0]
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[85120.0, 214562.0, 345455.0, 474748.0],
|
||||
[213475.0, 604465.0, 1000686.0, 1388317.0],
|
||||
[341876.0, 994283.0, 1655709.0, 2301518.0]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
#[test]
|
||||
fn quantized_matmul_neg() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k))
|
||||
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
||||
.collect::<Vec<_>>();
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
||||
let mut dst = vec![42.; 3 * 4];
|
||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||
let rhs = (0..k * n)
|
||||
.map(|v| v as f32 - (k * n) as f32 / 3.0)
|
||||
.collect::<Vec<_>>();
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||
assert_eq!(
|
||||
@ -139,56 +90,32 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
match device {
|
||||
Device::Metal(_) => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[243666.0, -19714.0, -285433.0, -550453.0],
|
||||
[23782.0, 21654.0, 19400.0, 18369.0],
|
||||
[-196102.0, 63022.0, 324233.0, 587191.0]
|
||||
]
|
||||
),
|
||||
_ => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||
[23777.0, 21651.0, 19398.0, 18367.0],
|
||||
[-196472.0, 63012.0, 324585.0, 587902.0]
|
||||
]
|
||||
),
|
||||
}
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||
[23777.0, 21651.0, 19398.0, 18367.0],
|
||||
[-196472.0, 63012.0, 324585.0, 587902.0]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(
|
||||
quantized_matmul,
|
||||
quantized_matmul_cpu,
|
||||
quantized_matmul_cuda,
|
||||
quantized_matmul_metal
|
||||
);
|
||||
test_device!(
|
||||
quantized_matmul_neg,
|
||||
quantized_matmul_neg_cpu,
|
||||
quantized_matmul_neg_cuda,
|
||||
quantized_matmul_neg_metal
|
||||
);
|
||||
#[test]
|
||||
fn quantize_q4_0() -> Result<()> {
|
||||
use k_quants::BlockQ4_0;
|
||||
|
||||
fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ4_0::zeros(); 4];
|
||||
BlockQ4_0::from_float(&src, &mut quant)?;
|
||||
BlockQ4_0::to_float(&quant, dst.as_mut_slice())?;
|
||||
assert_eq!(
|
||||
dst.to_vec1::<f32>()?,
|
||||
dst,
|
||||
&[
|
||||
-0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625,
|
||||
11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25,
|
||||
@ -204,21 +131,21 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
127.0, 127.0
|
||||
]
|
||||
);
|
||||
ggml_quantization_error_test(GgmlDType::Q4_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test::<BlockQ4_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
#[test]
|
||||
fn quantize_q4_1() -> Result<()> {
|
||||
use k_quants::BlockQ4_1;
|
||||
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ4_1::zeros(); 4];
|
||||
BlockQ4_1::from_float(&src, &mut quant)?;
|
||||
BlockQ4_1::to_float(&quant, dst.as_mut_slice())?;
|
||||
assert_eq!(
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
round_vector(&dst),
|
||||
&[
|
||||
0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332,
|
||||
12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73,
|
||||
@ -234,21 +161,21 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||
118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996
|
||||
]
|
||||
);
|
||||
ggml_quantization_error_test(GgmlDType::Q4_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test::<BlockQ4_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
#[test]
|
||||
fn quantize_q5_0() -> Result<()> {
|
||||
use k_quants::BlockQ5_0;
|
||||
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ5_0::zeros(); 4];
|
||||
BlockQ5_0::from_float(&src, &mut quant)?;
|
||||
BlockQ5_0::to_float(&quant, dst.as_mut_slice())?;
|
||||
assert_eq!(
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
round_vector(&dst),
|
||||
&[
|
||||
-0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625,
|
||||
11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313,
|
||||
@ -264,21 +191,21 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0
|
||||
]
|
||||
);
|
||||
ggml_quantization_error_test(GgmlDType::Q5_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test::<BlockQ5_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantize_q5_1(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
#[test]
|
||||
fn quantize_q5_1() -> Result<()> {
|
||||
use k_quants::BlockQ5_1;
|
||||
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ5_1::zeros(); 4];
|
||||
BlockQ5_1::from_float(&src, &mut quant)?;
|
||||
BlockQ5_1::to_float(&quant, dst.as_mut_slice())?;
|
||||
assert_eq!(
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
dst,
|
||||
&[
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
|
||||
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
|
||||
@ -292,11 +219,13 @@ fn quantize_q5_1(device: &Device) -> Result<()> {
|
||||
124.0, 125.0, 126.0, 127.0
|
||||
]
|
||||
);
|
||||
ggml_quantization_error_test(GgmlDType::Q5_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
|
||||
ggml_quantization_error_test::<BlockQ5_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result<Tensor> {
|
||||
/// Generates a small test vector ranging from -`bound` to `bound` with `size` steps
|
||||
fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
|
||||
assert!(
|
||||
size % crate::quantized::k_quants::QK_K == 0,
|
||||
"size must be a multiple of {}",
|
||||
@ -306,8 +235,10 @@ fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result<Tensor>
|
||||
let src = (0..size)
|
||||
.map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let dst = vec![0f32; size];
|
||||
assert_eq!([src[0], src[size / 2]], [-bound, 0.0]);
|
||||
Tensor::from_vec(src, (size,), device)
|
||||
(src, dst)
|
||||
}
|
||||
|
||||
/// Round a vector
|
||||
@ -334,8 +265,7 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a vector similar to the ones used in GGML unit tests:
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
|
||||
(0..GGML_TEST_SIZE)
|
||||
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
|
||||
@ -354,16 +284,14 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
|
||||
sum / a.len() as f32
|
||||
}
|
||||
|
||||
/// Similar to the GGML quantization unit test:
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
||||
fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f32) -> Result<()> {
|
||||
/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
||||
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
|
||||
let src = create_ggml_like_vector(0.0);
|
||||
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
|
||||
let mut dst = vec![0.0; GGML_TEST_SIZE];
|
||||
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
|
||||
let error = calculate_rmse(src.as_slice(), dst.as_slice());
|
||||
if error > max_error {
|
||||
bail!(
|
||||
candle_core::bail!(
|
||||
"Quantization error {} exceeds max error {}",
|
||||
error,
|
||||
max_error
|
||||
@ -372,19 +300,19 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q2K;
|
||||
fn quantize_roundtrip<T: GgmlType>(src: &[f32], dst: &mut [f32]) -> Result<Vec<T>> {
|
||||
let mut quant = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
||||
T::from_float(src, &mut quant)?;
|
||||
T::to_float(&quant, dst)?;
|
||||
Ok(quant)
|
||||
}
|
||||
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
#[test]
|
||||
fn quantize_q2k() -> Result<()> {
|
||||
use k_quants::BlockQ2K;
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ2K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.1);
|
||||
|
||||
// Test some specific values
|
||||
@ -398,30 +326,20 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
[-0.499, -0.366, -0.249, 0.0, 0.295, 0.492]
|
||||
);
|
||||
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ2K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0);
|
||||
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
|
||||
ggml_quantization_error_test::<BlockQ2K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q3K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
#[test]
|
||||
fn quantize_q3k() -> Result<()> {
|
||||
use k_quants::BlockQ3K;
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ3K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.03);
|
||||
|
||||
// Test some specific values
|
||||
@ -435,30 +353,20 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
[-0.493, -0.37, -0.243, -0.0, 0.292, 0.492]
|
||||
);
|
||||
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ3K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5);
|
||||
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
|
||||
ggml_quantization_error_test::<BlockQ3K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q4K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
#[test]
|
||||
fn quantize_q4k() -> Result<()> {
|
||||
use k_quants::BlockQ4K;
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ4K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.017);
|
||||
|
||||
// Test some specific values
|
||||
@ -472,31 +380,21 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
[-0.5, -0.373, -0.25, 0.0, 0.288, 0.498]
|
||||
);
|
||||
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ4K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5);
|
||||
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test::<BlockQ4K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q5K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
#[test]
|
||||
fn quantize_q5k() -> Result<()> {
|
||||
use k_quants::BlockQ5K;
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.009);
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ5K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
@ -506,33 +404,24 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
let dst = round_vector(&dst);
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.5, -0.373, -0.25, 0.0, 0.279, 0.499]
|
||||
[-0.499, -0.372, -0.249, 0.001, 0.279, 0.499]
|
||||
);
|
||||
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ5K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5);
|
||||
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test::<BlockQ5K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q6K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
#[test]
|
||||
fn quantize_q6k() -> Result<()> {
|
||||
use k_quants::BlockQ6K;
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ6K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||
|
||||
// Test some specific values
|
||||
@ -546,31 +435,22 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
|
||||
);
|
||||
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ6K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0);
|
||||
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test::<BlockQ6K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantize_q8k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q8K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
#[test]
|
||||
fn quantize_q8k() -> Result<()> {
|
||||
use k_quants::BlockQ8K;
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ8K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.003);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
@ -583,79 +463,15 @@ fn quantize_q8k(device: &Device) -> Result<()> {
|
||||
[-0.5, -0.375, -0.25, -0.0, 0.281, 0.499]
|
||||
);
|
||||
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ8K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6);
|
||||
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test::<BlockQ8K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(
|
||||
quantize_q4_0,
|
||||
quantize_q4_0_cpu,
|
||||
quantize_q4_0_cuda,
|
||||
quantize_q4_0_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q4_1,
|
||||
quantize_q4_1_cpu,
|
||||
quantize_q4_1_cuda,
|
||||
quantize_q4_1_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q5_0,
|
||||
quantize_q5_0_cpu,
|
||||
quantize_q5_0_cuda,
|
||||
quantize_q5_0_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q5_1,
|
||||
quantize_q5_1_cpu,
|
||||
quantize_q5_1_cuda,
|
||||
quantize_q5_1_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q2k,
|
||||
quantize_q2k_cpu,
|
||||
quantize_q2k_cuda,
|
||||
quantize_q2k_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q3k,
|
||||
quantize_q3k_cpu,
|
||||
quantize_q3k_cuda,
|
||||
quantize_q3k_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q4k,
|
||||
quantize_q4k_cpu,
|
||||
quantize_q4k_cuda,
|
||||
quantize_q4k_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q5k,
|
||||
quantize_q5k_cpu,
|
||||
quantize_q5k_cuda,
|
||||
quantize_q5k_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q6k,
|
||||
quantize_q6k_cpu,
|
||||
quantize_q6k_cuda,
|
||||
quantize_q6k_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q8k,
|
||||
quantize_q8k_cpu,
|
||||
quantize_q8k_cuda,
|
||||
quantize_q8k_metal
|
||||
);
|
||||
|
||||
/// Very simple dot product implementation
|
||||
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b).map(|(a, b)| a * b).sum()
|
||||
@ -671,66 +487,54 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
|
||||
GgmlDType::Q5K => 0.000740,
|
||||
GgmlDType::Q6K => 0.000952,
|
||||
GgmlDType::Q4_0 => 0.001143,
|
||||
GgmlDType::Q4_1 => 0.008,
|
||||
GgmlDType::Q4_1 => 0.007784,
|
||||
GgmlDType::Q5_0 => 0.001353,
|
||||
GgmlDType::Q5_1 => 0.00149,
|
||||
GgmlDType::Q5_1 => 0.001363,
|
||||
GgmlDType::Q8_0 => 0.000092,
|
||||
|
||||
// Not from the ggml repo.
|
||||
GgmlDType::Q8K => 0.00065,
|
||||
_ => bail!("No GGML results for quantization type {dtype:?}",),
|
||||
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
||||
};
|
||||
Ok(err)
|
||||
}
|
||||
|
||||
/// Similar to the GGML matmul unit test:
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
||||
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
||||
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||
let a = create_ggml_like_vector(0.0);
|
||||
let b = create_ggml_like_vector(1.0);
|
||||
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 1.0)?;
|
||||
// Another example that is more likely to trigger the overflow reported in #1526
|
||||
let a = (0..GGML_TEST_SIZE)
|
||||
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
|
||||
.collect::<Vec<_>>();
|
||||
let b = (0..GGML_TEST_SIZE)
|
||||
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
|
||||
.collect::<Vec<_>>();
|
||||
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 2.0)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Result<()> {
|
||||
let length = a.len();
|
||||
|
||||
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
|
||||
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
|
||||
T::from_float(a, &mut a_quant)?;
|
||||
T::VecDotType::from_float(b, &mut b_quant)?;
|
||||
T::from_float(&a, &mut a_quant)?;
|
||||
T::VecDotType::from_float(&b, &mut b_quant)?;
|
||||
|
||||
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
||||
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
|
||||
let reference_result = vec_dot_reference(a, b);
|
||||
let reference_result = vec_dot_reference(&a, &b);
|
||||
|
||||
if (result - result_unopt).abs() / length as f32 > 1e-6 {
|
||||
bail!(
|
||||
candle_core::bail!(
|
||||
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
|
||||
)
|
||||
}
|
||||
|
||||
let error = (result - reference_result).abs() / length as f32;
|
||||
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m;
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
||||
|
||||
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",);
|
||||
candle_core::bail!(
|
||||
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
|
||||
);
|
||||
}
|
||||
|
||||
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
|
||||
// => we use a slightly higher error threshold
|
||||
const ERROR_LENIENCY: f32 = 0.00001;
|
||||
if error - ERROR_LENIENCY > ggml_error {
|
||||
bail!(
|
||||
candle_core::bail!(
|
||||
"Dot product error {} exceeds ggml reference error {}",
|
||||
error,
|
||||
ggml_error
|
||||
@ -739,16 +543,6 @@ fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Res
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_mm() -> Result<()> {
|
||||
ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ4_1>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ5_0>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ5_1>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ8_0>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
|
||||
fn get_random_tensors(
|
||||
m: usize,
|
||||
@ -772,112 +566,6 @@ fn get_random_tensors(
|
||||
Ok((lhs, rhs, mm))
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! quantized_matmul {
|
||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
|
||||
fn $fn_name(device: &Device) -> Result<()> {
|
||||
if device.is_cuda() {
|
||||
// TODO Enable Cuda GGML sometime maybe.
|
||||
return Ok(());
|
||||
}
|
||||
test_matmul(device, (1, 3, 4, 256), $dtype)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!($fn_name, $fn_name_cpu, $fn_name_cuda, $fn_name_metal);
|
||||
};
|
||||
}
|
||||
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q4_0_bis,
|
||||
quantized_matmul_q4_0_cpu,
|
||||
quantized_matmul_q4_0_cuda,
|
||||
quantized_matmul_q4_0_metal,
|
||||
GgmlDType::Q4_0
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q4_1_bis,
|
||||
quantized_matmul_q4_1_cpu,
|
||||
quantized_matmul_q4_1_cuda,
|
||||
quantized_matmul_q4_1_metal,
|
||||
GgmlDType::Q4_1
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q5_0_bis,
|
||||
quantized_matmul_q5_0_cpu,
|
||||
quantized_matmul_q5_0_cuda,
|
||||
quantized_matmul_q5_0_metal,
|
||||
GgmlDType::Q5_0
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q5_1_bis,
|
||||
quantized_matmul_q5_1_cpu,
|
||||
quantized_matmul_q5_1_cuda,
|
||||
quantized_matmul_q5_1_metal,
|
||||
GgmlDType::Q5_1
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q8_0_bis,
|
||||
quantized_matmul_q8_0_cpu,
|
||||
quantized_matmul_q8_0_cuda,
|
||||
quantized_matmul_q8_0_metal,
|
||||
GgmlDType::Q8_0
|
||||
);
|
||||
// Not implemented in Ggml
|
||||
// quantized_matmul!(
|
||||
// quantized_matmul_q8_1_bis,
|
||||
// quantized_matmul_q8_1_cpu,
|
||||
// quantized_matmul_q8_1_cuda,
|
||||
// quantized_matmul_q8_1_metal,
|
||||
// GgmlDType::Q8_1
|
||||
// );
|
||||
// TODO This is bugged (also bugged in GGML
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q2k_bis,
|
||||
quantized_matmul_q2k_cpu,
|
||||
quantized_matmul_q2k_cuda,
|
||||
quantized_matmul_q2k_metal,
|
||||
GgmlDType::Q2K
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q3k_bis,
|
||||
quantized_matmul_q3k_cpu,
|
||||
quantized_matmul_q3k_cuda,
|
||||
quantized_matmul_q3k_metal,
|
||||
GgmlDType::Q3K
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q4k_bis,
|
||||
quantized_matmul_q4k_cpu,
|
||||
quantized_matmul_q4k_cuda,
|
||||
quantized_matmul_q4k_metal,
|
||||
GgmlDType::Q4K
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q5k_bis,
|
||||
quantized_matmul_q5k_cpu,
|
||||
quantized_matmul_q5k_cuda,
|
||||
quantized_matmul_q5k_metal,
|
||||
GgmlDType::Q5K
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q6k_bis,
|
||||
quantized_matmul_q6k_cpu,
|
||||
quantized_matmul_q6k_cuda,
|
||||
quantized_matmul_q6k_metal,
|
||||
GgmlDType::Q6K
|
||||
);
|
||||
// Not implemented on metal
|
||||
// quantized_matmul!(
|
||||
// quantized_matmul_q8k_bis,
|
||||
// quantized_matmul_q8k_cpu,
|
||||
// quantized_matmul_q8k_cuda,
|
||||
// quantized_matmul_q8k_metal,
|
||||
// GgmlDType::Q8K
|
||||
// );
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q2k() -> Result<()> {
|
||||
use k_quants::BlockQ2K;
|
||||
@ -890,7 +578,7 @@ fn quantized_matmul_q2k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
@ -916,7 +604,7 @@ fn quantized_matmul_q3k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q3K)?;
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
@ -942,7 +630,7 @@ fn quantized_matmul_q4k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q4K)?;
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
@ -968,7 +656,7 @@ fn quantized_matmul_q5k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q5K)?;
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
@ -995,7 +683,7 @@ fn quantized_matmul_q6k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q6K)?;
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
@ -1020,7 +708,7 @@ fn quantized_matmul_q8k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q8K)?;
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn zeros(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||
@ -29,34 +29,7 @@ fn ones(device: &Device) -> Result<()> {
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn full(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
|
||||
[[42, 42, 42], [42, 42, 42]],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn arange(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
|
||||
[0, 1, 2, 3, 4],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::<u8>()?,
|
||||
[0, 2, 4],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::<u8>()?,
|
||||
[0, 3],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::<i64>()?,
|
||||
[5, 4, 3, 2, 1],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -188,22 +161,6 @@ fn transpose(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn var(device: &Device) -> Result<()> {
|
||||
// Values taken from https://pytorch.org/docs/stable/generated/torch.var.html
|
||||
let data = &[
|
||||
[0.2035f32, 1.2959, 1.8101, -0.4644],
|
||||
[1.5027, -0.3270, 0.5905, 0.6538],
|
||||
[-1.5745, 1.3330, -0.5596, -0.6548],
|
||||
[0.1264, -0.5080, 1.6420, 0.1992],
|
||||
];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?,
|
||||
&[[1.0631], [0.559], [1.4893], [0.8258]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sum(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
@ -1078,61 +1035,33 @@ fn randn(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
||||
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||
test_device!(max, max_cpu, max_gpu, max_metal);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal);
|
||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||
test_device!(
|
||||
broadcast_matmul,
|
||||
broadcast_matmul_cpu,
|
||||
broadcast_matmul_gpu,
|
||||
broadcast_matmul_metal
|
||||
);
|
||||
test_device!(
|
||||
broadcasting,
|
||||
broadcasting_cpu,
|
||||
broadcasting_gpu,
|
||||
broadcasting_metal
|
||||
);
|
||||
test_device!(
|
||||
index_select,
|
||||
index_select_cpu,
|
||||
index_select_gpu,
|
||||
index_select_metal
|
||||
);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
||||
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
||||
test_device!(
|
||||
scatter_add,
|
||||
scatter_add_cpu,
|
||||
scatter_add_gpu,
|
||||
scatter_add_metal
|
||||
);
|
||||
test_device!(
|
||||
slice_scatter,
|
||||
slice_scatter_cpu,
|
||||
slice_scatter_gpu,
|
||||
slice_scatter_metal
|
||||
);
|
||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||
test_device!(ones, ones_cpu, ones_gpu);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
|
||||
test_device!(cat, cat_cpu, cat_gpu);
|
||||
test_device!(sum, sum_cpu, sum_gpu);
|
||||
test_device!(min, min_cpu, min_gpu);
|
||||
test_device!(max, max_cpu, max_gpu);
|
||||
test_device!(argmax, argmax_cpu, argmax_gpu);
|
||||
test_device!(argmin, argmin_cpu, argmin_gpu);
|
||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu);
|
||||
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
|
||||
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
|
||||
test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||
test_device!(gather, gather_cpu, gather_gpu);
|
||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
||||
test_device!(randn, randn_cpu, randn_gpu);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu);
|
||||
|
||||
// There was originally a bug on the CPU implementation for randn
|
||||
// https://github.com/huggingface/candle/issues/381
|
||||
@ -1160,108 +1089,3 @@ fn pad_with_same() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn i64_abs() -> Result<()> {
|
||||
let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?;
|
||||
let t = t.abs()?;
|
||||
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tril_triu_eye() -> Result<()> {
|
||||
let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 0.0, 0.0, 0.0],
|
||||
[1.0, 1.0, 0.0, 0.0],
|
||||
[1.0, 1.0, 1.0, 0.0],
|
||||
[1.0, 1.0, 1.0, 1.0]
|
||||
],
|
||||
);
|
||||
let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
[0.0, 1.0, 1.0, 1.0],
|
||||
[0.0, 0.0, 1.0, 1.0],
|
||||
[0.0, 0.0, 0.0, 1.0]
|
||||
]
|
||||
);
|
||||
let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
[
|
||||
[1.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 1.0]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cumsum() -> Result<()> {
|
||||
let t = &[3f32, 1., 4., 1., 5.];
|
||||
let t = Tensor::new(t, &Device::Cpu)?;
|
||||
assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]);
|
||||
let t = t.unsqueeze(1)?;
|
||||
assert_eq!(
|
||||
t.cumsum(0)?.to_vec2::<f32>()?,
|
||||
[[3.0], [4.0], [8.0], [9.0], [14.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
t.cumsum(1)?.to_vec2::<f32>()?,
|
||||
[[3.0], [1.0], [4.0], [1.0], [5.0]]
|
||||
);
|
||||
let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let t = Tensor::new(t, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
t.cumsum(1)?.to_vec2::<f32>()?,
|
||||
[[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
t.cumsum(0)?.to_vec2::<f32>()?,
|
||||
[[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data.
|
||||
/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon.
|
||||
fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
|
||||
let a_vec: Vec<f64> = a.to_vec1()?;
|
||||
let b_vec: Vec<f64> = b.to_vec1()?;
|
||||
|
||||
assert_eq!(a_vec.len(), b_vec.len());
|
||||
for (a, b) in a_vec.iter().zip(b_vec.iter()) {
|
||||
assert!((a - b).abs() < epsilon);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn log_sum_exp() -> Result<()> {
|
||||
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
let output = input.log_sum_exp(D::Minus1)?;
|
||||
// The expectations obtained from pytorch.
|
||||
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
||||
assert_close(&output, &expected, 0.00001)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pow() -> Result<()> {
|
||||
let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
let rhs = (&lhs - 2.)?;
|
||||
let res = lhs.pow(&rhs)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res, 4)?,
|
||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Binary file not shown.
Binary file not shown.
@ -11,8 +11,8 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
|
@ -4,9 +4,7 @@
|
||||
//! <https://www.cs.toronto.edu/~kriz/cifar.html>
|
||||
//! The binary version of the dataset is used.
|
||||
use crate::vision::Dataset;
|
||||
use candle::{DType, Device, Error, Result, Tensor};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use parquet::file::reader::{FileReader, SerializedFileReader};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, Read};
|
||||
|
||||
@ -62,58 +60,3 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<Dataset> {
|
||||
labels: 10,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
|
||||
let samples = parquet.metadata().file_metadata().num_rows() as usize;
|
||||
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 1_024);
|
||||
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
|
||||
for row in parquet.into_iter().flatten() {
|
||||
for (_name, field) in row.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
buffer_images.extend(image.to_rgb8().as_raw());
|
||||
}
|
||||
}
|
||||
} else if let parquet::record::Field::Long(label) = field {
|
||||
buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
|
||||
.to_dtype(DType::U8)?
|
||||
/ 255.)?;
|
||||
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
||||
Ok((images, labels))
|
||||
}
|
||||
|
||||
pub fn load() -> Result<Dataset> {
|
||||
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let dataset_id = "cifar10".to_string();
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let test_parquet_filename = repo
|
||||
.get("plain_text/test/0000.parquet")
|
||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let train_parquet_filename = repo
|
||||
.get("plain_text/train/0000.parquet")
|
||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
|
||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
|
||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||
let (test_images, test_labels) = load_parquet(test_parquet)?;
|
||||
let (train_images, train_labels) = load_parquet(train_parquet)?;
|
||||
Ok(crate::vision::Dataset {
|
||||
train_images,
|
||||
train_labels,
|
||||
test_images,
|
||||
test_labels,
|
||||
labels: 10,
|
||||
})
|
||||
}
|
||||
|
@ -11,33 +11,28 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { workspace = true }
|
||||
candle-datasets = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
candle-onnx = { workspace = true, optional = true }
|
||||
|
||||
csv = "1.3.0"
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
hf-hub = { workspace = true, features = ["tokio"] }
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
|
||||
pyo3 = { version = "0.19.0", features = ["auto-initialize"], optional = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
symphonia = { version = "0.5.3", features = ["all"] }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
cpal= { version = "0.15.2", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
imageproc = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
@ -45,24 +40,21 @@ rusttype = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.29.1"
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
bindgen_cuda = { version = "0.1.1", optional = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
onnx = ["candle-onnx"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
microphone = ["cpal"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
@ -71,15 +63,3 @@ required-features = ["cuda", "nccl", "flash-attn"]
|
||||
[[example]]
|
||||
name = "reinforcement-learning"
|
||||
required-features = ["pyo3"]
|
||||
|
||||
[[example]]
|
||||
name = "onnx"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "onnx_basics"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "whisper-microphone"
|
||||
required-features = ["microphone"]
|
||||
|
@ -4,28 +4,235 @@ use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
struct KernelDirectories {
|
||||
kernel_glob: &'static str,
|
||||
kernel_dir: &'static str,
|
||||
rust_target: &'static str,
|
||||
include_dirs: &'static [&'static str],
|
||||
}
|
||||
|
||||
const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
||||
kernel_glob: "examples/custom-ops/kernels/*.cu",
|
||||
const DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
||||
kernel_dir: "examples/custom-ops/kernels/",
|
||||
rust_target: "examples/custom-ops/cuda_kernels.rs",
|
||||
include_dirs: &[],
|
||||
}];
|
||||
|
||||
impl KernelDirectories {
|
||||
fn maybe_build_ptx(
|
||||
&self,
|
||||
cu_file: &std::path::Path,
|
||||
ptx_file: &std::path::Path,
|
||||
compute_cap: usize,
|
||||
) -> Result<()> {
|
||||
let should_compile = if ptx_file.exists() {
|
||||
let ptx_modified = ptx_file.metadata()?.modified()?;
|
||||
let cu_modified = cu_file.metadata()?.modified()?;
|
||||
cu_modified.duration_since(ptx_modified).is_ok()
|
||||
} else {
|
||||
true
|
||||
};
|
||||
if should_compile {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
|
||||
let include_dirs: Vec<String> =
|
||||
self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
|
||||
command
|
||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("--ptx")
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.args(["--output-directory", out_dir.to_str().unwrap()])
|
||||
.arg(format!("-I/{}", self.kernel_dir))
|
||||
.args(include_dirs)
|
||||
.arg(cu_file);
|
||||
let output = command
|
||||
.spawn()
|
||||
.context("failed spawning nvcc")?
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.open(ptx_file)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> {
|
||||
println!("cargo:rerun-if-changed={}", self.kernel_dir);
|
||||
let kernel_dir = PathBuf::from(self.kernel_dir);
|
||||
let out_dir = out_dir.join(self.kernel_dir);
|
||||
if !out_dir.exists() {
|
||||
std::fs::create_dir_all(&out_dir)?;
|
||||
}
|
||||
let mut cu_files = vec![];
|
||||
let mut cuh_files = vec![];
|
||||
for file in std::fs::read_dir(kernel_dir)?.flatten() {
|
||||
let file = file.path();
|
||||
match file.extension().and_then(|v| v.to_str()) {
|
||||
Some("cu") => cu_files.push(file),
|
||||
Some("cuh") => cuh_files.push(file),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let mut ptx_paths = vec![];
|
||||
for cu_file in cu_files.iter() {
|
||||
let file_stem = cu_file
|
||||
.file_stem()
|
||||
.with_context(|| format!("no stem {cu_file:?}"))?;
|
||||
let file_stem = file_stem.to_string_lossy().into_owned();
|
||||
let ptx_file = out_dir.join(&format!("{file_stem}.ptx"));
|
||||
self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?;
|
||||
ptx_paths.push(ptx_file);
|
||||
}
|
||||
|
||||
let regenerate_rs_file = true;
|
||||
if regenerate_rs_file {
|
||||
let mut file = std::fs::File::create(self.rust_target)?;
|
||||
for ptx_path in ptx_paths {
|
||||
let name = ptx_path
|
||||
.file_stem()
|
||||
.context("empty stem")?
|
||||
.to_string_lossy();
|
||||
file.write_all(b"#[rustfmt::skip]\n")?;
|
||||
let const_definition = format!(
|
||||
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
|
||||
name.to_uppercase().replace('.', "_"),
|
||||
self.kernel_dir,
|
||||
);
|
||||
file.write_all(const_definition.as_bytes())?;
|
||||
file.write_all(b"\n")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
|
||||
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
|
||||
let out_dir = PathBuf::from(out_dir);
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
for kdir in KERNEL_DIRS.iter() {
|
||||
let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob);
|
||||
println!("cargo:info={builder:?}");
|
||||
let bindings = builder.build_ptx().unwrap();
|
||||
bindings.write(kdir.rust_target).unwrap()
|
||||
}
|
||||
set_cuda_include_dir()?;
|
||||
#[cfg(feature = "cuda")]
|
||||
let compute_cap = compute_cap()?;
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
let compute_cap = 0;
|
||||
for d in DIRS {
|
||||
d.process(&out_dir, compute_cap)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_cuda_include_dir() -> Result<()> {
|
||||
// NOTE: copied from cudarc build.rs.
|
||||
let env_vars = [
|
||||
"CUDA_PATH",
|
||||
"CUDA_ROOT",
|
||||
"CUDA_TOOLKIT_ROOT_DIR",
|
||||
"CUDNN_LIB",
|
||||
];
|
||||
let env_vars = env_vars
|
||||
.into_iter()
|
||||
.map(std::env::var)
|
||||
.filter_map(Result::ok)
|
||||
.map(Into::<PathBuf>::into);
|
||||
|
||||
let roots = [
|
||||
"/usr",
|
||||
"/usr/local/cuda",
|
||||
"/opt/cuda",
|
||||
"/usr/lib/cuda",
|
||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||
"C:/CUDA",
|
||||
];
|
||||
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||
let root = env_vars
|
||||
.chain(roots)
|
||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||
.context("cannot find include/cuda.h")?;
|
||||
println!(
|
||||
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||
root.join("include").display()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn compute_cap() -> Result<usize> {
|
||||
// Grab compute code from nvidia-smi
|
||||
let mut compute_cap = {
|
||||
let out = std::process::Command::new("nvidia-smi")
|
||||
.arg("--query-gpu=compute_cap")
|
||||
.arg("--format=csv")
|
||||
.output()
|
||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||
let mut lines = out.lines();
|
||||
assert_eq!(
|
||||
lines.next().context("missing line in stdout")?,
|
||||
"compute_cap"
|
||||
);
|
||||
let cap = lines
|
||||
.next()
|
||||
.context("missing line in stdout")?
|
||||
.replace('.', "");
|
||||
cap.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as int {cap}"))?
|
||||
};
|
||||
|
||||
// Grab available GPU codes from nvcc and select the highest one
|
||||
let max_nvcc_code = {
|
||||
let out = std::process::Command::new("nvcc")
|
||||
.arg("--list-gpu-code")
|
||||
.output()
|
||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||
|
||||
let out = out.lines().collect::<Vec<&str>>();
|
||||
let mut codes = Vec::with_capacity(out.len());
|
||||
for code in out {
|
||||
let code = code.split('_').collect::<Vec<&str>>();
|
||||
if !code.is_empty() && code.contains(&"sm") {
|
||||
if let Ok(num) = code[1].parse::<usize>() {
|
||||
codes.push(num);
|
||||
}
|
||||
}
|
||||
}
|
||||
codes.sort();
|
||||
if !codes.contains(&compute_cap) {
|
||||
anyhow::bail!(
|
||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
|
||||
);
|
||||
}
|
||||
*codes.last().unwrap()
|
||||
};
|
||||
|
||||
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
|
||||
// then choose the highest gpu code in nvcc
|
||||
if compute_cap > max_nvcc_code {
|
||||
println!(
|
||||
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
|
||||
);
|
||||
compute_cap = max_nvcc_code;
|
||||
}
|
||||
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
|
||||
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
compute_cap = compute_cap_str
|
||||
.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
|
||||
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
|
||||
}
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
||||
Ok(compute_cap)
|
||||
}
|
||||
|
@ -2,10 +2,10 @@
|
||||
|
||||
Bert is a general large language model. In this example it can be used for two
|
||||
different tasks:
|
||||
|
||||
- Compute sentence embeddings for a prompt.
|
||||
- Compute similarities between a set of sentences.
|
||||
|
||||
|
||||
## Sentence embeddings
|
||||
|
||||
Bert is used to compute the sentence embeddings for a prompt. The model weights
|
||||
@ -24,48 +24,6 @@ cargo run --example bert --release -- --prompt "Here is a test sentence"
|
||||
> Tensor[[1, 7, 384], f32]
|
||||
```
|
||||
|
||||
### Custom models
|
||||
|
||||
You can specify different models, such as BGE, with the `--model-id` flag:
|
||||
|
||||
```bash
|
||||
cargo run --example bert --release -- \
|
||||
--model-id BAAI/bge-large-zh-v1.5 \
|
||||
--prompt "Here is a test sentence"
|
||||
Loaded and encoded 435.70775ms
|
||||
[[[ 3.0944e-1, -7.8455e-5, -1.2768e0, ..., 1.3755e-2, -3.2371e-1, 2.3819e-1],
|
||||
[-2.8506e-1, 1.9953e-1, -1.3076e0, ..., 6.9819e-2, 1.0833e-2, -1.1512e0],
|
||||
[ 3.9892e-1, 2.0000e-1, -9.3178e-1, ..., -4.1393e-1, -4.9644e-2, -3.3786e-1],
|
||||
...
|
||||
[ 6.0345e-1, 3.5744e-1, -1.2672e0, ..., -6.9165e-1, -3.4973e-3, -8.4214e-1],
|
||||
[ 3.9218e-1, -3.2735e-1, -1.3123e0, ..., -4.9318e-1, -5.1334e-1, -3.6391e-1],
|
||||
[ 3.0978e-1, 2.5662e-4, -1.2773e0, ..., 1.3357e-2, -3.2390e-1, 2.3858e-1]]]
|
||||
Tensor[[1, 9, 1024], f32]
|
||||
Took 176.744667ms
|
||||
```
|
||||
|
||||
### Gelu approximation
|
||||
|
||||
You can get a speedup by using an approximation of the gelu activation, with a
|
||||
small loss of precision, by passing the `--approximate-gelu` flag:
|
||||
|
||||
```bash
|
||||
$ cargo run --example bert --release -- \
|
||||
--model-id BAAI/bge-large-zh-v1.5 \
|
||||
--prompt "Here is a test sentence" \
|
||||
--approximate-gelu
|
||||
Loaded and encoded 244.388042ms
|
||||
[[[ 3.1048e-1, -6.0339e-4, -1.2758e0, ..., 1.3718e-2, -3.2362e-1, 2.3775e-1],
|
||||
[-2.8354e-1, 1.9984e-1, -1.3077e0, ..., 6.9390e-2, 9.9681e-3, -1.1531e0],
|
||||
[ 3.9947e-1, 1.9917e-1, -9.3178e-1, ..., -4.1301e-1, -5.0719e-2, -3.3955e-1],
|
||||
...
|
||||
[ 6.0499e-1, 3.5664e-1, -1.2642e0, ..., -6.9134e-1, -3.4581e-3, -8.4471e-1],
|
||||
[ 3.9311e-1, -3.2812e-1, -1.3105e0, ..., -4.9291e-1, -5.1270e-1, -3.6543e-1],
|
||||
[ 3.1082e-1, -2.6737e-4, -1.2762e0, ..., 1.3319e-2, -3.2381e-1, 2.3815e-1]]]
|
||||
Tensor[[1, 9, 1024], f32]
|
||||
Took 116.840791ms
|
||||
```
|
||||
|
||||
## Similarities
|
||||
|
||||
In this example, Bert is used to compute the sentence embeddings for a set of
|
||||
|
@ -3,7 +3,7 @@ extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
|
||||
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::Tensor;
|
||||
@ -45,10 +45,6 @@ struct Args {
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
|
||||
/// Use tanh based approximation for Gelu instead of erf implementation.
|
||||
#[arg(long, default_value = "false")]
|
||||
approximate_gelu: bool,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@ -77,7 +73,7 @@ impl Args {
|
||||
(config, tokenizer, weights)
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: Config = serde_json::from_str(&config)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let vb = if self.use_pth {
|
||||
@ -85,9 +81,6 @@ impl Args {
|
||||
} else {
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||
};
|
||||
if self.approximate_gelu {
|
||||
config.hidden_act = HiddenAct::GeluApproximate;
|
||||
}
|
||||
let model = BertModel::load(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
|
@ -106,17 +106,17 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let config = blip::Config::image_captioning_large();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (image_embeds, device, mut model) = if args.quantized {
|
||||
let device = Device::Cpu;
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?;
|
||||
let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
|
||||
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||
(image_embeds, device, Model::Q(model))
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
@ -149,6 +149,6 @@ pub fn main() -> anyhow::Result<()> {
|
||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
println!();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,237 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::chatglm::{Config, Model};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
verbose_prompt,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||
if tokens.is_empty() {
|
||||
anyhow::bail!("Empty prompts are not supported in the chatglm model.")
|
||||
}
|
||||
if self.verbose_prompt {
|
||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
println!("{id:7} -> '{token}'");
|
||||
}
|
||||
}
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("</s>") {
|
||||
Some(token) => *token,
|
||||
None => anyhow::bail!("cannot find the endoftext token"),
|
||||
};
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => "THUDM/chatglm3-6b".to_string(),
|
||||
};
|
||||
let revision = match args.revision {
|
||||
Some(rev) => rev.to_string(),
|
||||
None => "main".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => api
|
||||
.model("lmz/candle-chatglm".to_string())
|
||||
.get("chatglm-tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_file {
|
||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::glm3_6b();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
args.verbose_prompt,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,22 +0,0 @@
|
||||
# candle-convnext
|
||||
|
||||
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545).
|
||||
|
||||
This candle implementation uses a pre-trained ConvNeXt network for inference. The
|
||||
classification head has been trained on the ImageNet dataset and returns the
|
||||
probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example convnext --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which tiny
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
mountain bike, all-terrain bike, off-roader: 84.09%
|
||||
bicycle-built-for-two, tandem bicycle, tandem: 4.15%
|
||||
maillot : 0.74%
|
||||
crash helmet : 0.54%
|
||||
unicycle, monocycle : 0.44%
|
||||
|
||||
```
|
@ -1,102 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::convnext;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
Tiny,
|
||||
Small,
|
||||
Base,
|
||||
Large,
|
||||
XLarge,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::Tiny => "tiny",
|
||||
Self::Small => "small",
|
||||
Self::Base => "base",
|
||||
Self::Large => "large",
|
||||
Self::XLarge => "xlarge",
|
||||
};
|
||||
// The XLarge model only has an ImageNet-22K variant
|
||||
let variant = match self {
|
||||
Self::XLarge => "fb_in22k_ft_in1k",
|
||||
_ => "fb_in1k",
|
||||
};
|
||||
|
||||
format!("timm/convnext_{name}.{variant}")
|
||||
}
|
||||
|
||||
fn config(&self) -> convnext::Config {
|
||||
match self {
|
||||
Self::Tiny => convnext::Config::tiny(),
|
||||
Self::Small => convnext::Config::small(),
|
||||
Self::Base => convnext::Config::base(),
|
||||
Self::Large => convnext::Config::large(),
|
||||
Self::XLarge => convnext::Config::xlarge(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::Tiny)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = convnext::convnext(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -0,0 +1,2 @@
|
||||
#[rustfmt::skip]
|
||||
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));
|
||||
|
@ -6,8 +6,7 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(unused)]
|
||||
mod cuda_kernels;
|
||||
|
||||
use clap::Parser;
|
||||
|
@ -1,22 +0,0 @@
|
||||
# candle-distilbert
|
||||
|
||||
DistilBert is a distiled version of the Bert model.
|
||||
|
||||
## Sentence embeddings
|
||||
|
||||
DistilBert is used to compute the sentence embeddings for a prompt. The model weights
|
||||
are downloaded from the hub on the first run.
|
||||
|
||||
```bash
|
||||
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
||||
|
||||
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
||||
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
||||
> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829],
|
||||
> ...
|
||||
> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042],
|
||||
> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190],
|
||||
> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]]
|
||||
> Tensor[[1, 7, 768], f32]
|
||||
|
||||
```
|
@ -1,135 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// When set, compute embeddings for this prompt.
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// Use the pytorch weights rather than the safetensors ones
|
||||
#[arg(long)]
|
||||
use_pth: bool,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
|
||||
let device = candle_examples::device(self.cpu)?;
|
||||
let default_model = "distilbert-base-uncased".to_string();
|
||||
let default_revision = "main".to_string();
|
||||
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config = api.get("config.json")?;
|
||||
let tokenizer = api.get("tokenizer.json")?;
|
||||
let weights = if self.use_pth {
|
||||
api.get("pytorch_model.bin")?
|
||||
} else {
|
||||
api.get("model.safetensors")?
|
||||
};
|
||||
(config, tokenizer, weights)
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let vb = if self.use_pth {
|
||||
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
||||
} else {
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||
};
|
||||
let model = DistilBertModel::load(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Tensor {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size, size), device).unwrap()
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = &model.device;
|
||||
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let mask = get_mask(tokens.len(), device);
|
||||
|
||||
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
|
||||
println!("mask: {:?}", mask.to_vec2::<u8>());
|
||||
|
||||
let ys = model.forward(&token_ids, &mask)?;
|
||||
println!("{ys}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||
}
|
@ -165,7 +165,14 @@ fn main() -> Result<()> {
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let filenames = candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = repo.get(rfilename)?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
|
@ -1,45 +0,0 @@
|
||||
# candle-jina-bert
|
||||
|
||||
Jina-Bert is a general large language model with a context size of 8192, [model
|
||||
card](https://huggingface.co/jinaai/jina-embeddings-v2-base-en). In this example
|
||||
it can be used for two different tasks:
|
||||
- Compute sentence embeddings for a prompt.
|
||||
- Compute similarities between a set of sentences.
|
||||
|
||||
|
||||
## Sentence embeddings
|
||||
|
||||
Jina-Bert is used to compute the sentence embeddings for a prompt. The model weights
|
||||
are downloaded from the hub on the first run.
|
||||
|
||||
```bash
|
||||
cargo run --example jina-bert --release -- --prompt "Here is a test sentence"
|
||||
|
||||
> [[[ 0.1595, -0.9885, 0.6494, ..., 0.3003, -0.6901, -1.2355],
|
||||
> [ 0.0374, -0.1798, 1.3359, ..., 0.6731, 0.2133, -1.6807],
|
||||
> [ 0.1700, -0.8534, 0.8924, ..., -0.1785, -0.0727, -1.5087],
|
||||
> ...
|
||||
> [-0.3113, -1.3665, 0.2027, ..., -0.2519, 0.1711, -1.5811],
|
||||
> [ 0.0907, -1.0492, 0.5382, ..., 0.0242, -0.7077, -1.0830],
|
||||
> [ 0.0369, -0.6343, 0.6105, ..., 0.0671, 0.3778, -1.1505]]]
|
||||
> Tensor[[1, 7, 768], f32]
|
||||
```
|
||||
|
||||
## Similarities
|
||||
|
||||
In this example, Jina-Bert is used to compute the sentence embeddings for a set of
|
||||
sentences (hardcoded in the examples). Then cosine similarities are computed for
|
||||
each sentence pair and they are reported by decreasing values, hence the first
|
||||
reported pair contains the two sentences that have the highest similarity score.
|
||||
The sentence embeddings are computed using average pooling through all the
|
||||
sentence tokens, including some potential padding.
|
||||
|
||||
```bash
|
||||
cargo run --example jina-bert --release
|
||||
|
||||
> score: 0.94 'The new movie is awesome' 'The new movie is so great'
|
||||
> score: 0.81 'The cat sits outside' 'The cat plays in the garden'
|
||||
> score: 0.78 'I love pasta' 'Do you like pizza?'
|
||||
> score: 0.68 'I love pasta' 'The new movie is awesome'
|
||||
> score: 0.67 'A man is playing guitar' 'A woman watches TV'
|
||||
```
|
@ -1,180 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle_transformers::models::jina_bert::{BertModel, Config};
|
||||
|
||||
use anyhow::Error as E;
|
||||
use candle::{DType, Module, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// When set, compute embeddings for this prompt.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
let model = match &self.model {
|
||||
Some(model_file) => std::path::PathBuf::from(model_file),
|
||||
None => Api::new()?
|
||||
.repo(Repo::new(
|
||||
"jinaai/jina-embeddings-v2-base-en".to_string(),
|
||||
RepoType::Model,
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let tokenizer = match &self.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => Api::new()?
|
||||
.repo(Repo::new(
|
||||
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
||||
RepoType::Model,
|
||||
))
|
||||
.get("tokenizer.json")?,
|
||||
};
|
||||
let device = candle_examples::device(self.cpu)?;
|
||||
let config = Config::v2_base();
|
||||
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let model = BertModel::new(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = &model.device;
|
||||
|
||||
if let Some(prompt) = args.prompt {
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
println!("Loaded and encoded {:?}", start.elapsed());
|
||||
for idx in 0..args.n {
|
||||
let start = std::time::Instant::now();
|
||||
let ys = model.forward(&token_ids)?;
|
||||
if idx == 0 {
|
||||
println!("{ys}");
|
||||
}
|
||||
println!("Took {:?}", start.elapsed());
|
||||
}
|
||||
} else {
|
||||
let sentences = [
|
||||
"The cat sits outside",
|
||||
"A man is playing guitar",
|
||||
"I love pasta",
|
||||
"The new movie is awesome",
|
||||
"The cat plays in the garden",
|
||||
"A woman watches TV",
|
||||
"The new movie is so great",
|
||||
"Do you like pizza?",
|
||||
];
|
||||
let n_sentences = sentences.len();
|
||||
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||
} else {
|
||||
let pp = tokenizers::PaddingParams {
|
||||
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||
..Default::default()
|
||||
};
|
||||
tokenizer.with_padding(Some(pp));
|
||||
}
|
||||
let tokens = tokenizer
|
||||
.encode_batch(sentences.to_vec(), true)
|
||||
.map_err(E::msg)?;
|
||||
let token_ids = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_ids().to_vec();
|
||||
Tensor::new(tokens.as_slice(), device)
|
||||
})
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||
println!("running inference on batch {:?}", token_ids.shape());
|
||||
let embeddings = model.forward(&token_ids)?;
|
||||
println!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
let embeddings = if args.normalize_embeddings {
|
||||
normalize_l2(&embeddings)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
println!("pooled embeddings {:?}", embeddings.shape());
|
||||
|
||||
let mut similarities = vec![];
|
||||
for i in 0..n_sentences {
|
||||
let e_i = embeddings.get(i)?;
|
||||
for j in (i + 1)..n_sentences {
|
||||
let e_j = embeddings.get(j)?;
|
||||
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
||||
similarities.push((cosine_similarity, i, j))
|
||||
}
|
||||
}
|
||||
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
|
||||
for &(score, i, j) in similarities[..5].iter() {
|
||||
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn normalize_l2(v: &Tensor) -> candle::Result<Tensor> {
|
||||
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
||||
}
|
@ -13,7 +13,7 @@ extern crate accelerate_src;
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{bail, Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
@ -22,21 +22,11 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
|
||||
use candle_transformers::models::llama as model;
|
||||
use model::{Llama, LlamaConfig};
|
||||
use model::{Config, Llama, LlamaConfig};
|
||||
|
||||
const EOS_TOKEN: &str = "</s>";
|
||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
V1,
|
||||
V2,
|
||||
#[value(name = "solar-10.7b")]
|
||||
Solar10_7B,
|
||||
#[value(name = "tiny-llama-1.1b-chat")]
|
||||
TinyLlama1_1BChat,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -44,6 +34,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use npy instead of safetensors
|
||||
#[arg(long)]
|
||||
npy: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
@ -82,13 +76,17 @@ struct Args {
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "v2")]
|
||||
which: Which,
|
||||
#[arg(long)]
|
||||
v1: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// The folder name that contains safetensor weights and json files
|
||||
/// (same structure as huggingface online)
|
||||
#[arg(long)]
|
||||
local_weights: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
repeat_penalty: f32,
|
||||
@ -120,34 +118,65 @@ fn main() -> Result<()> {
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
let (llama, tokenizer_filename, cache) = {
|
||||
let api = Api::new()?;
|
||||
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||
});
|
||||
println!("loading the model weights from {model_id}");
|
||||
let revision = args.revision.unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let (llama, tokenizer_filename, cache) = match args.npy {
|
||||
Some(filename) => {
|
||||
let config = if args.v1 {
|
||||
Config::config_7b_v1(args.use_flash_attn)
|
||||
} else {
|
||||
Config::config_7b_v2(args.use_flash_attn)
|
||||
};
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
let vb = VarBuilder::from_npz(filename, dtype, &device)?;
|
||||
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer, cache)
|
||||
}
|
||||
None => {
|
||||
let api = Api::new()?;
|
||||
let model_id = args.model_id.unwrap_or_else(|| {
|
||||
if args.v1 {
|
||||
"Narsil/amall-7b".to_string()
|
||||
} else {
|
||||
"meta-llama/Llama-2-7b-hf".to_string()
|
||||
}
|
||||
});
|
||||
println!("loading the model weights from {model_id}");
|
||||
let revision = args.revision.unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let config_filename = api.get("config.json")?;
|
||||
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let config = config.into_config(args.use_flash_attn);
|
||||
let tokenizer_filename = match &args.local_weights {
|
||||
Some(path) => (path.to_owned() + "tokenizer.json").into(),
|
||||
_ => api.get("tokenizer.json")?,
|
||||
};
|
||||
|
||||
let filenames = match args.which {
|
||||
Which::V1 | Which::V2 | Which::Solar10_7B => {
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
let config_filename = match &args.local_weights {
|
||||
Some(path) => (path.to_owned() + "config.json").into(),
|
||||
_ => api.get("config.json")?,
|
||||
};
|
||||
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let config = config.into_config(args.use_flash_attn);
|
||||
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
match &args.local_weights {
|
||||
Some(path) => {
|
||||
filenames.push((path.to_owned() + rfilename).into());
|
||||
}
|
||||
_ => {
|
||||
let filename = api.get(rfilename)?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
};
|
||||
}
|
||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||
};
|
||||
println!("building the model");
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||
println!("building the model");
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||
}
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||
@ -165,14 +194,14 @@ fn main() -> Result<()> {
|
||||
let mut index_pos = 0;
|
||||
let mut token_generated = 0;
|
||||
for index in 0..args.sample_len {
|
||||
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
|
||||
(1, index_pos)
|
||||
let context_size = if cache.use_kv_cache && index > 0 {
|
||||
1
|
||||
} else {
|
||||
(tokens.len(), 0)
|
||||
tokens.len()
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = llama.forward(&input, context_index)?;
|
||||
let logits = llama.forward(&input, index_pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
|
@ -6,10 +6,9 @@ extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use candle_transformers::models::llama2_c as model;
|
||||
use candle_transformers::models::llama2_c_weights as weights;
|
||||
use candle_transformers::models::quantized_llama2_c as qmodel;
|
||||
mod model;
|
||||
mod training;
|
||||
mod weights;
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
@ -20,7 +19,6 @@ use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use model::{Config, Llama};
|
||||
use qmodel::QLlama;
|
||||
use weights::TransformerWeights;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
@ -154,20 +152,6 @@ fn main() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
enum Model {
|
||||
Llama(Llama),
|
||||
QLlama(QLlama),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
|
||||
match self {
|
||||
Self::Llama(l) => Ok(l.forward(xs, pos)?),
|
||||
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
use std::io::BufRead;
|
||||
|
||||
@ -257,66 +241,24 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
|
||||
let device = candle_examples::device(common_args.cpu)?;
|
||||
|
||||
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
|
||||
let is_safetensors = config_path
|
||||
.extension()
|
||||
.map_or(false, |v| v == "safetensors");
|
||||
let (model, config) = if is_gguf {
|
||||
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
|
||||
let (_vocab_size, dim) = vb
|
||||
.get_no_shape("model.embed_tokens.weight")?
|
||||
.shape()
|
||||
.dims2()?;
|
||||
let config = match dim {
|
||||
64 => Config::tiny_260k(),
|
||||
288 => Config::tiny_15m(),
|
||||
512 => Config::tiny_42m(),
|
||||
768 => Config::tiny_110m(),
|
||||
_ => anyhow::bail!("no config for dim {dim}"),
|
||||
};
|
||||
let freq_cis_real = vb
|
||||
.get(
|
||||
(config.seq_len, config.head_size() / 2),
|
||||
"rot.freq_cis_real",
|
||||
)?
|
||||
.dequantize(&device)?;
|
||||
let freq_cis_imag = vb
|
||||
.get(
|
||||
(config.seq_len, config.head_size() / 2),
|
||||
"rot.freq_cis_imag",
|
||||
)?
|
||||
.dequantize(&device)?;
|
||||
|
||||
let fake_vb = candle_nn::VarBuilder::from_tensors(
|
||||
[
|
||||
("freq_cis_real".to_string(), freq_cis_real),
|
||||
("freq_cis_imag".to_string(), freq_cis_imag),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
candle::DType::F32,
|
||||
&device,
|
||||
);
|
||||
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
} else if is_safetensors {
|
||||
let config = Config::tiny_15m();
|
||||
let (vb, config) = if is_safetensors {
|
||||
let config = Config::tiny();
|
||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
(vb, config)
|
||||
} else {
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
println!("{config:?}");
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
(vb, config)
|
||||
};
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
|
||||
@ -331,7 +273,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0.. {
|
||||
if tokens.len() >= config.seq_len {
|
||||
if tokens.len() >= model.config.seq_len {
|
||||
break;
|
||||
}
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
|
@ -17,20 +17,7 @@ pub struct Config {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn tiny_260k() -> Self {
|
||||
Self {
|
||||
dim: 64,
|
||||
hidden_dim: 768,
|
||||
n_layers: 5,
|
||||
n_heads: 8,
|
||||
n_kv_heads: 4,
|
||||
vocab_size: 32000,
|
||||
seq_len: 512,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_15m() -> Self {
|
||||
pub fn tiny() -> Self {
|
||||
Self {
|
||||
dim: 288,
|
||||
hidden_dim: 768,
|
||||
@ -42,32 +29,6 @@ impl Config {
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_42m() -> Self {
|
||||
Self {
|
||||
dim: 512,
|
||||
hidden_dim: 768,
|
||||
n_layers: 8,
|
||||
n_heads: 8,
|
||||
n_kv_heads: 8,
|
||||
vocab_size: 32000,
|
||||
seq_len: 1024,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_110m() -> Self {
|
||||
Self {
|
||||
dim: 768,
|
||||
hidden_dim: 768,
|
||||
n_layers: 12,
|
||||
n_heads: 12,
|
||||
n_kv_heads: 12,
|
||||
vocab_size: 32000,
|
||||
seq_len: 1024,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -75,9 +36,9 @@ pub struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
pub use_kv_cache: bool,
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
pub cos: Tensor,
|
||||
pub sin: Tensor,
|
||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
@ -114,7 +75,7 @@ impl Cache {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn mask(&self, t: usize) -> Result<Tensor> {
|
||||
fn mask(&self, t: usize) -> Result<Tensor> {
|
||||
let mut masks = self.masks.lock().unwrap();
|
||||
if let Some(mask) = masks.get(&t) {
|
||||
Ok(mask.clone())
|
@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
);
|
||||
let varmap = candle_nn::VarMap::new();
|
||||
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
||||
let config = Config::tiny_15m();
|
||||
let config = Config::tiny();
|
||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
use anyhow::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use candle::{DType, Device, IndexOp, Result, Shape, Tensor};
|
||||
use candle::{DType, Device, IndexOp, Shape, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
use super::llama2_c::Config;
|
||||
use crate::model::Config;
|
||||
|
||||
pub struct TransformerWeights {
|
||||
// token embedding table
|
@ -143,7 +143,14 @@ fn main() -> Result<()> {
|
||||
let config_filename = api.get("config.json")?;
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = api.get(rfilename)?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
|
||||
if args.rank.is_none() {
|
||||
let children: Vec<_> = (0..args.num_shards)
|
||||
|
@ -1,15 +0,0 @@
|
||||
# candle-mamba-minimal: minimal implementation of Mamba
|
||||
|
||||
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
|
||||
|
||||
Compared to the mamba example, this version can handle training but is much
|
||||
slower.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
|
||||
Mamba is the most popular and best-selling game in the world. It has been downloaded more than 1,000 times by over 1 million people worldwide since its release on March 18th 2016.
|
||||
|
||||
The Mamba series of games are a collection that combines elements from all genres including action, adventure, strategy & puzzle games with some unique gameplay features such as stealth and survival. The game is also known for its innovative graphics and the ability to play in a variety of different modes like single player or multiplayer.
|
||||
```
|
@ -1,287 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
mod model;
|
||||
use model::{Config, Model};
|
||||
|
||||
use candle::{DType, Device, Module, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for _ in 0..sample_len {
|
||||
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
|
||||
enum Which {
|
||||
Mamba130m,
|
||||
Mamba370m,
|
||||
Mamba790m,
|
||||
Mamba1_4b,
|
||||
Mamba2_8b,
|
||||
Mamba2_8bSlimPj,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Which {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_id(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Mamba130m => "state-spaces/mamba-130m",
|
||||
Self::Mamba370m => "state-spaces/mamba-370m",
|
||||
Self::Mamba790m => "state-spaces/mamba-790m",
|
||||
Self::Mamba1_4b => "state-spaces/mamba-1.4b",
|
||||
Self::Mamba2_8b => "state-spaces/mamba-2.8b",
|
||||
Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
|
||||
}
|
||||
}
|
||||
|
||||
fn revision(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Mamba130m
|
||||
| Self::Mamba370m
|
||||
| Self::Mamba790m
|
||||
| Self::Mamba1_4b
|
||||
| Self::Mamba2_8bSlimPj => "refs/pr/1",
|
||||
Self::Mamba2_8b => "refs/pr/4",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "mamba130m")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id
|
||||
.unwrap_or_else(|| args.which.model_id().to_string()),
|
||||
RepoType::Model,
|
||||
args.revision
|
||||
.unwrap_or_else(|| args.which.revision().to_string()),
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => api
|
||||
.model("EleutherAI/gpt-neox-20b".to_string())
|
||||
.get("tokenizer.json")?,
|
||||
};
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = Model::new(&config, vb.pp("backbone"))?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,204 +0,0 @@
|
||||
/// This follows the lines of:
|
||||
/// https://github.com/johnma2006/mamba-minimal/blob/master/model.py
|
||||
/// Simple, minimal implementation of Mamba in one file of PyTorch.
|
||||
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{RmsNorm, VarBuilder};
|
||||
|
||||
use candle_transformers::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
d_model: usize,
|
||||
n_layer: usize,
|
||||
vocab_size: usize,
|
||||
pad_vocab_size_multiple: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn vocab_size(&self) -> usize {
|
||||
let pad = self.pad_vocab_size_multiple;
|
||||
(self.vocab_size + pad - 1) / pad * pad
|
||||
}
|
||||
|
||||
fn dt_rank(&self) -> usize {
|
||||
(self.d_model + 15) / 16
|
||||
}
|
||||
|
||||
fn d_conv(&self) -> usize {
|
||||
4
|
||||
}
|
||||
|
||||
fn d_state(&self) -> usize {
|
||||
16
|
||||
}
|
||||
|
||||
fn d_inner(&self) -> usize {
|
||||
self.d_model * 2
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L177
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MambaBlock {
|
||||
in_proj: Linear,
|
||||
conv1d: candle_nn::Conv1d,
|
||||
x_proj: Linear,
|
||||
dt_proj: Linear,
|
||||
a_log: Tensor,
|
||||
d: Tensor,
|
||||
out_proj: Linear,
|
||||
dt_rank: usize,
|
||||
}
|
||||
|
||||
impl MambaBlock {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let d_inner = cfg.d_inner();
|
||||
let d_conv = cfg.d_conv();
|
||||
let d_state = cfg.d_state();
|
||||
let dt_rank = cfg.dt_rank();
|
||||
let in_proj = linear_no_bias(cfg.d_model, d_inner * 2, vb.pp("in_proj"))?;
|
||||
let conv_cfg = candle_nn::Conv1dConfig {
|
||||
groups: d_inner,
|
||||
padding: d_conv - 1,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1d = candle_nn::conv1d(d_inner, d_inner, d_conv, conv_cfg, vb.pp("conv1d"))?;
|
||||
let x_proj = linear_no_bias(d_inner, dt_rank + d_state * 2, vb.pp("x_proj"))?;
|
||||
let dt_proj = linear(dt_rank, d_inner, vb.pp("dt_proj"))?;
|
||||
let a_log = vb.get((d_inner, d_state), "A_log")?;
|
||||
let d = vb.get(d_inner, "D")?;
|
||||
let out_proj = linear_no_bias(d_inner, cfg.d_model, vb.pp("out_proj"))?;
|
||||
Ok(Self {
|
||||
in_proj,
|
||||
conv1d,
|
||||
x_proj,
|
||||
dt_proj,
|
||||
a_log,
|
||||
d,
|
||||
out_proj,
|
||||
dt_rank,
|
||||
})
|
||||
}
|
||||
|
||||
fn ssm(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_d_in, n) = self.a_log.dims2()?;
|
||||
let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
|
||||
let d = self.d.to_dtype(candle::DType::F32)?;
|
||||
let x_dbl = xs.apply(&self.x_proj)?;
|
||||
let delta = x_dbl.narrow(D::Minus1, 0, self.dt_rank)?;
|
||||
let b = x_dbl.narrow(D::Minus1, self.dt_rank, n)?;
|
||||
let c = x_dbl.narrow(D::Minus1, self.dt_rank + n, n)?;
|
||||
let delta = delta.contiguous()?.apply(&self.dt_proj)?;
|
||||
// softplus without threshold
|
||||
let delta = (delta.exp()? + 1.)?.log()?;
|
||||
let ss = selective_scan(xs, &delta, &a, &b, &c, &d)?;
|
||||
Ok(ss)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L275
|
||||
fn selective_scan(
|
||||
u: &Tensor,
|
||||
delta: &Tensor,
|
||||
a: &Tensor,
|
||||
b: &Tensor,
|
||||
c: &Tensor,
|
||||
d: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, l, d_in) = u.dims3()?;
|
||||
let n = a.dim(1)?;
|
||||
let delta = delta.t()?.reshape((b_sz, d_in, l, 1))?; // b d_in l 1
|
||||
let delta_a = delta.broadcast_mul(&a.reshape((1, d_in, 1, n))?)?.exp()?;
|
||||
let delta_b_u = delta
|
||||
.broadcast_mul(&b.reshape((b_sz, 1, l, n))?)?
|
||||
.broadcast_mul(&u.t()?.reshape((b_sz, d_in, l, 1))?)?;
|
||||
let mut xs = Tensor::zeros((b_sz, d_in, n), delta_a.dtype(), delta_a.device())?;
|
||||
let mut ys = Vec::with_capacity(l);
|
||||
for i in 0..l {
|
||||
xs = ((delta_a.i((.., .., i))? * xs)? + delta_b_u.i((.., .., i))?)?;
|
||||
let y = xs.matmul(&c.i((.., i, ..))?.unsqueeze(2)?)?.squeeze(2)?;
|
||||
ys.push(y)
|
||||
}
|
||||
let ys = Tensor::stack(ys.as_slice(), 1)?;
|
||||
ys + u.broadcast_mul(d)
|
||||
}
|
||||
|
||||
impl Module for MambaBlock {
|
||||
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L206
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len, _dim) = xs.dims3()?;
|
||||
let xs_and_res = xs.apply(&self.in_proj)?.chunk(2, D::Minus1)?;
|
||||
let (xs, res) = (&xs_and_res[0], &xs_and_res[1]);
|
||||
let xs = xs
|
||||
.t()?
|
||||
.apply(&self.conv1d)?
|
||||
.narrow(D::Minus1, 0, seq_len)?
|
||||
.t()?;
|
||||
let xs = candle_nn::ops::silu(&xs)?;
|
||||
let ys = (self.ssm(&xs)? * candle_nn::ops::silu(res))?;
|
||||
ys.apply(&self.out_proj)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L143
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ResidualBlock {
|
||||
mixer: MambaBlock,
|
||||
norm: RmsNorm,
|
||||
}
|
||||
|
||||
impl ResidualBlock {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?;
|
||||
let mixer = MambaBlock::new(cfg, vb.pp("mixer"))?;
|
||||
Ok(Self { mixer, norm })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ResidualBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.norm)?.apply(&self.mixer)? + xs
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L56
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Model {
|
||||
embedding: candle_nn::Embedding,
|
||||
layers: Vec<ResidualBlock>,
|
||||
norm_f: RmsNorm,
|
||||
lm_head: Linear,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let embedding = candle_nn::embedding(cfg.vocab_size(), cfg.d_model, vb.pp("embedding"))?;
|
||||
let mut layers = Vec::with_capacity(cfg.n_layer);
|
||||
let vb_l = vb.pp("layers");
|
||||
for layer_idx in 0..cfg.n_layer {
|
||||
let layer = ResidualBlock::new(cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
|
||||
let lm_head = Linear::from_weights(embedding.embeddings().clone(), None);
|
||||
Ok(Self {
|
||||
embedding,
|
||||
layers,
|
||||
norm_f,
|
||||
lm_head,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Model {
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let (_b_size, seq_len) = input_ids.dims2()?;
|
||||
let mut xs = self.embedding.forward(input_ids)?;
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs)?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm_f)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
}
|
@ -1,17 +0,0 @@
|
||||
# candle-mamba: Mamba implementation
|
||||
|
||||
Candle implementation of *Mamba* [1] inference only. Mamba is an alternative to
|
||||
the transformer architecture. It leverages State Space Models (SSMs) with the
|
||||
goal of being computationally efficient on long sequences. The implementation is
|
||||
based on [mamba.rs](https://github.com/LaurentMazare/mamba.rs).
|
||||
|
||||
- [1]. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752).
|
||||
|
||||
Compared to the mamba-minimal example, this version is far more efficient but
|
||||
would only work for inference.
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
|
||||
```
|
||||
|
@ -1,299 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::mamba::{Config, Model, State};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
config: Config,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
config: Config,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
config,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let mut state = State::new(1, &self.config, &self.device)?;
|
||||
let mut next_logits = None;
|
||||
for &t in tokens.iter() {
|
||||
let input = Tensor::new(&[t], &self.device)?;
|
||||
let logits = self.model.forward(&input, &mut state)?;
|
||||
next_logits = Some(logits);
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for _ in 0..sample_len {
|
||||
let logits = match next_logits.as_ref() {
|
||||
Some(logits) => logits,
|
||||
None => anyhow::bail!("cannot work on an empty prompt"),
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
let input = Tensor::new(&[next_token], &self.device)?;
|
||||
next_logits = Some(self.model.forward(&input, &mut state)?)
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
|
||||
enum Which {
|
||||
Mamba130m,
|
||||
Mamba370m,
|
||||
Mamba790m,
|
||||
Mamba1_4b,
|
||||
Mamba2_8b,
|
||||
Mamba2_8bSlimPj,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Which {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_id(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Mamba130m => "state-spaces/mamba-130m",
|
||||
Self::Mamba370m => "state-spaces/mamba-370m",
|
||||
Self::Mamba790m => "state-spaces/mamba-790m",
|
||||
Self::Mamba1_4b => "state-spaces/mamba-1.4b",
|
||||
Self::Mamba2_8b => "state-spaces/mamba-2.8b",
|
||||
Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
|
||||
}
|
||||
}
|
||||
|
||||
fn revision(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Mamba130m
|
||||
| Self::Mamba370m
|
||||
| Self::Mamba790m
|
||||
| Self::Mamba1_4b
|
||||
| Self::Mamba2_8bSlimPj => "refs/pr/1",
|
||||
Self::Mamba2_8b => "refs/pr/4",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "mamba130m")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id
|
||||
.unwrap_or_else(|| args.which.model_id().to_string()),
|
||||
RepoType::Model,
|
||||
args.revision
|
||||
.unwrap_or_else(|| args.which.revision().to_string()),
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => api
|
||||
.model("EleutherAI/gpt-neox-20b".to_string())
|
||||
.get("tokenizer.json")?,
|
||||
};
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = Model::new(&config, vb.pp("backbone"))?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
config,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,38 +0,0 @@
|
||||
# candle-marian-mt
|
||||
|
||||
`marian-mt` is a neural machine translation model. In this example it is used to
|
||||
translate text from French to English. See the associated [model
|
||||
card](https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en) for details on
|
||||
the model itself.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
cargo run --example marian-mt --release -- \
|
||||
--text "Demain, dès l'aube, à l'heure où blanchit la campagne, Je partirai. Vois-tu, je sais que tu m'attends. J'irai par la forêt, j'irai par la montagne. Je ne puis demeurer loin de toi plus longtemps."
|
||||
```
|
||||
|
||||
```
|
||||
<NIL> Tomorrow, at dawn, at the time when the country is whitening, I will go. See,
|
||||
I know you are waiting for me. I will go through the forest, I will go through the
|
||||
mountain. I cannot stay far from you any longer.</s>
|
||||
```
|
||||
|
||||
## Generating the tokenizer.json files
|
||||
|
||||
You can use the following script to generate the `tokenizer.json` config files
|
||||
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
|
||||
packages to be install and use the `convert_slow_tokenizer.py` script from this
|
||||
directory.
|
||||
|
||||
```python
|
||||
from convert_slow_tokenizer import MarianConverter
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
|
||||
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
|
||||
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
|
||||
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
|
||||
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
|
||||
```
|
File diff suppressed because it is too large
Load Diff
@ -1,152 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::marian;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
Base,
|
||||
Big,
|
||||
}
|
||||
|
||||
// TODO: Maybe add support for the conditional prompt.
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_dec: Option<String>,
|
||||
|
||||
/// Choose the variant of the model to run.
|
||||
#[arg(long, default_value = "big")]
|
||||
which: Which,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use the quantized version of the model.
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Text to be translated
|
||||
#[arg(long)]
|
||||
text: String,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use hf_hub::api::sync::Api;
|
||||
let args = Args::parse();
|
||||
|
||||
let config = match args.which {
|
||||
Which::Base => marian::Config::opus_mt_fr_en(),
|
||||
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
|
||||
};
|
||||
let tokenizer = {
|
||||
let tokenizer = match args.tokenizer {
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::Base => "tokenizer-marian-base-fr.json",
|
||||
Which::Big => "tokenizer-marian-fr.json",
|
||||
};
|
||||
Api::new()?
|
||||
.model("lmz/candle-marian".to_string())
|
||||
.get(name)?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
};
|
||||
|
||||
let tokenizer_dec = {
|
||||
let tokenizer = match args.tokenizer_dec {
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::Base => "tokenizer-marian-base-en.json",
|
||||
Which::Big => "tokenizer-marian-en.json",
|
||||
};
|
||||
Api::new()?
|
||||
.model("lmz/candle-marian".to_string())
|
||||
.get(name)?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
};
|
||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = {
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => match args.which {
|
||||
Which::Base => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-fr-en".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/4".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
Which::Big => Api::new()?
|
||||
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
||||
.get("model.safetensors")?,
|
||||
},
|
||||
};
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
||||
};
|
||||
let mut model = marian::MTModel::new(&config, vb)?;
|
||||
|
||||
let mut logits_processor =
|
||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||
|
||||
let encoder_xs = {
|
||||
let mut tokens = tokenizer
|
||||
.encode(args.text, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
tokens.push(config.eos_token_id);
|
||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
model.encoder().forward(&tokens, 0)?
|
||||
};
|
||||
|
||||
let mut token_ids = vec![config.decoder_start_token_id];
|
||||
for index in 0..1000 {
|
||||
let context_size = if index >= 1 { 1 } else { token_ids.len() };
|
||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
token_ids.push(token);
|
||||
if let Some(t) = tokenizer_dec.next_token(token)? {
|
||||
use std::io::Write;
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
if token == config.eos_token_id || token == config.forced_eos_token_id {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
println!();
|
||||
Ok(())
|
||||
}
|
@ -155,8 +155,8 @@ struct Args {
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
#[arg(long, default_value = "lmz/candle-mistral")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
@ -207,18 +207,8 @@ fn main() -> Result<()> {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => {
|
||||
if args.quantized {
|
||||
"lmz/candle-mistral".to_string()
|
||||
} else {
|
||||
"mistralai/Mistral-7B-v0.1".to_string()
|
||||
}
|
||||
}
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
@ -235,7 +225,10 @@ fn main() -> Result<()> {
|
||||
if args.quantized {
|
||||
vec![repo.get("model-q4k.gguf")?]
|
||||
} else {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
vec![
|
||||
repo.get("pytorch_model-00001-of-00002.safetensors")?,
|
||||
repo.get("pytorch_model-00002-of-00002.safetensors")?,
|
||||
]
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -244,14 +237,13 @@ fn main() -> Result<()> {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||
let model = QMistral::new(&config, vb)?;
|
||||
(Model::Quantized(model), device)
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
|
@ -1,25 +0,0 @@
|
||||
# candle-mixtral: 8x7b LLM using a sparse mixture of experts.
|
||||
|
||||
Mixtral-8x7B-v0.1 is a pretrained generative LLM with 56 billion parameters.
|
||||
|
||||
- [Blog post](https://mistral.ai/news/mixtral-of-experts/) from Mistral announcing the model release.
|
||||
- [Model card](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) on the HuggingFace Hub.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example mixtral --release -- --prompt "def print_prime(n): "
|
||||
def print_prime(n): # n is the number of prime numbers to be printed
|
||||
i = 2
|
||||
count = 0
|
||||
while (count < n):
|
||||
if (isPrime(i)):
|
||||
print(i)
|
||||
count += 1
|
||||
i += 1
|
||||
|
||||
def isPrime(n):
|
||||
for x in range(2, int(n**0.5)+1):
|
||||
if (n % x == 0):
|
||||
...
|
||||
```
|
@ -1,241 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::mixtral::{Config, Model};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("</s>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::v0_1_8x7b(args.use_flash_attn);
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum};
|
||||
use rand::prelude::*;
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
|
||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||
|
||||
const IMAGE_DIM: usize = 784;
|
||||
const LABELS: usize = 10;
|
||||
@ -95,7 +95,7 @@ impl ConvNet {
|
||||
.flatten_from(1)?
|
||||
.apply(&self.fc1)?
|
||||
.relu()?;
|
||||
self.dropout.forward_t(&xs, train)?.apply(&self.fc2)
|
||||
self.dropout.forward(&xs, train)?.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,22 +0,0 @@
|
||||
# candle-mobileone
|
||||
|
||||
[MobileOne: An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040).
|
||||
|
||||
This candle implementation uses a pre-trained MobileOne network for inference. The
|
||||
classification head has been trained on the ImageNet dataset and returns the
|
||||
probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example mobileone --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which s2
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
mountain bike, all-terrain bike, off-roader: 79.33%
|
||||
bicycle-built-for-two, tandem bicycle, tandem: 15.32%
|
||||
crash helmet : 2.58%
|
||||
unicycle, monocycle : 1.70%
|
||||
alp : 0.21%
|
||||
|
||||
```
|
@ -1,96 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::mobileone;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
S0,
|
||||
S1,
|
||||
S2,
|
||||
S3,
|
||||
S4,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::S0 => "s0",
|
||||
Self::S1 => "s1",
|
||||
Self::S2 => "s2",
|
||||
Self::S3 => "s3",
|
||||
Self::S4 => "s4",
|
||||
};
|
||||
format!("timm/mobileone_{}.apple_in1k", name)
|
||||
}
|
||||
|
||||
fn config(&self) -> mobileone::Config {
|
||||
match self {
|
||||
Self::S0 => mobileone::Config::s0(),
|
||||
Self::S1 => mobileone::Config::s1(),
|
||||
Self::S2 => mobileone::Config::s2(),
|
||||
Self::S3 => mobileone::Config::s3(),
|
||||
Self::S4 => mobileone::Config::s4(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::S0)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = mobileone::mobileone(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -8,7 +8,6 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum NormType {
|
||||
WeightNorm,
|
||||
TimeGroupNorm,
|
||||
None,
|
||||
}
|
||||
|
||||
@ -269,7 +268,6 @@ impl Module for EncodecConvTranspose1d {
|
||||
struct EncodecConv1d {
|
||||
causal: bool,
|
||||
conv: Conv1d,
|
||||
norm: Option<candle_nn::GroupNorm>,
|
||||
}
|
||||
|
||||
impl EncodecConv1d {
|
||||
@ -294,7 +292,7 @@ impl EncodecConv1d {
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
NormType::None | NormType::TimeGroupNorm => conv1d(
|
||||
NormType::None => conv1d(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
@ -307,17 +305,9 @@ impl EncodecConv1d {
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
};
|
||||
let norm = match cfg.norm_type {
|
||||
NormType::None | NormType::WeightNorm => None,
|
||||
NormType::TimeGroupNorm => {
|
||||
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||
Some(gn)
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
causal: cfg.use_causal_conv,
|
||||
conv,
|
||||
norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -326,10 +316,8 @@ impl Module for EncodecConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: padding, depending on causal.
|
||||
let xs = self.conv.forward(xs)?;
|
||||
match &self.norm {
|
||||
None => Ok(xs),
|
||||
Some(norm) => xs.apply(norm),
|
||||
}
|
||||
// If we add support for NormType "time_group_norm", we should add some normalization here.
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -321,7 +321,7 @@ impl MusicgenDecoder {
|
||||
let positions = self.embed_positions.forward(&input)?.to_device(dev)?;
|
||||
let mut xs = inputs_embeds.broadcast_add(&positions)?;
|
||||
let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?;
|
||||
for decoder_layer in self.layers.iter_mut() {
|
||||
for (_layer_idx, decoder_layer) in self.layers.iter_mut().enumerate() {
|
||||
xs = decoder_layer.forward(&xs, &attention_mask, None)?;
|
||||
}
|
||||
let xs = self.layer_norm.forward(&xs)?;
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user