mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
11 Commits
llama-v3-m
...
metal5
Author | SHA1 | Date | |
---|---|---|---|
67d93b4f42 | |||
c35d7d50db | |||
9694671bbf | |||
3dbf65ef20 | |||
b2db5adf82 | |||
9ef040338d | |||
3aefc709c7 | |||
c8c603ce96 | |||
61ad8d91cc | |||
2cd1e59c9e | |||
9c4b4f0da0 |
74
.github/workflows/ci_cuda.yaml
vendored
74
.github/workflows/ci_cuda.yaml
vendored
@ -5,15 +5,49 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
start-runner:
|
||||||
|
name: Start self-hosted EC2 runner
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
# Don't run on forks, they won't have access to secrets anyway.
|
||||||
|
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
||||||
|
env:
|
||||||
|
AWS_REGION: us-east-1
|
||||||
|
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
||||||
|
EC2_INSTANCE_TYPE: g5.xlarge
|
||||||
|
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
|
||||||
|
EC2_SECURITY_GROUP: sg-030175c435ac141d6
|
||||||
|
outputs:
|
||||||
|
label: ${{ steps.start-ec2-runner.outputs.label }}
|
||||||
|
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
|
||||||
|
steps:
|
||||||
|
- name: Configure AWS credentials
|
||||||
|
uses: aws-actions/configure-aws-credentials@v1
|
||||||
|
with:
|
||||||
|
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||||
|
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||||
|
aws-region: ${{ env.AWS_REGION }}
|
||||||
|
- name: Start EC2 runner
|
||||||
|
id: start-ec2-runner
|
||||||
|
uses: philschmid/philschmid-ec2-github-runner@main
|
||||||
|
with:
|
||||||
|
mode: start
|
||||||
|
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||||
|
ec2-image-id: ${{ env.EC2_AMI_ID }}
|
||||||
|
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
|
||||||
|
subnet-id: ${{ env.EC2_SUBNET_ID }}
|
||||||
|
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
|
||||||
|
aws-resource-tags: > # optional, requires additional permissions
|
||||||
|
[
|
||||||
|
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
|
||||||
|
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
|
||||||
|
]
|
||||||
|
|
||||||
test-cuda:
|
test-cuda:
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
runs-on: [single-gpu, nvidia-gpu, t4, ci]
|
needs: start-runner # required to start the main job when the runner is ready
|
||||||
container:
|
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||||
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
|
|
||||||
options: --gpus 0
|
|
||||||
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
packages: write
|
packages: write
|
||||||
@ -24,10 +58,32 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
- name: Install dependencies
|
|
||||||
run: apt-get update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
|
|
||||||
- name: Install Rust Stable
|
- name: Install Rust Stable
|
||||||
uses: actions-rust-lang/setup-rust-toolchain@v1
|
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||||
- uses: Swatinem/rust-cache@v2
|
- uses: Swatinem/rust-cache@v2
|
||||||
|
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
|
||||||
- name: Test (cuda)
|
- name: Test (cuda)
|
||||||
run: cargo test --features cuda
|
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||||
|
stop-runner:
|
||||||
|
name: Stop self-hosted EC2 runner
|
||||||
|
needs:
|
||||||
|
- start-runner
|
||||||
|
- test-cuda
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
AWS_REGION: us-east-1
|
||||||
|
if: ${{ (success() || failure()) && github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }} # required to stop the runner even if the error happened in the previous jobs
|
||||||
|
steps:
|
||||||
|
- name: Configure AWS credentials
|
||||||
|
uses: aws-actions/configure-aws-credentials@v1
|
||||||
|
with:
|
||||||
|
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||||
|
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||||
|
aws-region: ${{ env.AWS_REGION }}
|
||||||
|
- name: Stop EC2 runner
|
||||||
|
uses: philschmid/philschmid-ec2-github-runner@main
|
||||||
|
with:
|
||||||
|
mode: stop
|
||||||
|
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||||
|
label: ${{ needs.start-runner.outputs.label }}
|
||||||
|
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
||||||
|
30
Cargo.toml
30
Cargo.toml
@ -9,7 +9,6 @@ members = [
|
|||||||
"candle-transformers",
|
"candle-transformers",
|
||||||
"candle-wasm-examples/*",
|
"candle-wasm-examples/*",
|
||||||
"candle-wasm-tests",
|
"candle-wasm-tests",
|
||||||
"tensor-tools",
|
|
||||||
]
|
]
|
||||||
exclude = [
|
exclude = [
|
||||||
"candle-flash-attn",
|
"candle-flash-attn",
|
||||||
@ -20,7 +19,7 @@ exclude = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.5.0"
|
version = "0.3.3"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -29,38 +28,37 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
ab_glyph = "0.2.23"
|
|
||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle = { path = "./candle-core", package = "candle-core", version = "0.5.0" }
|
candle = { path = "./candle-core", package = "candle-core" }
|
||||||
candle-datasets = { path = "./candle-datasets", version = "0.5.0" }
|
candle-datasets = { path = "./candle-datasets" }
|
||||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.0" }
|
candle-flash-attn = { path = "./candle-flash-attn" }
|
||||||
candle-kernels = { path = "./candle-kernels", version = "0.5.0" }
|
candle-kernels = { path = "./candle-kernels" }
|
||||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.0" }
|
candle-metal-kernels = { path = "./candle-metal-kernels" }
|
||||||
candle-nn = { path = "./candle-nn", version = "0.5.0" }
|
candle-nn = { path = "./candle-nn" }
|
||||||
candle-onnx = { path = "./candle-onnx", version = "0.5.0" }
|
candle-onnx = { path = "./candle-onnx" }
|
||||||
candle-transformers = { path = "./candle-transformers", version = "0.5.0" }
|
candle-transformers = { path = "./candle-transformers" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||||
fancy-regex = "0.13.0"
|
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||||
hf-hub = "0.3.0"
|
hf-hub = "0.3.0"
|
||||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||||
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
|
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||||
imageproc = { version = "0.24.0", default-features = false }
|
imageproc = { version = "0.23.0", default-features = false }
|
||||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||||
libc = { version = "0.2.147" }
|
libc = { version = "0.2.147" }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||||
num_cpus = "1.15.0"
|
num_cpus = "1.15.0"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
parquet = { version = "51.0.0" }
|
parquet = { version = "45.0.0" }
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
rand_distr = "0.4.3"
|
rand_distr = "0.4.3"
|
||||||
rayon = "1.7.0"
|
rayon = "1.7.0"
|
||||||
safetensors = "0.4.1"
|
rusttype = { version = "0.9", default-features = false }
|
||||||
|
safetensors = "0.3.1"
|
||||||
serde = { version = "1.0.171", features = ["derive"] }
|
serde = { version = "1.0.171", features = ["derive"] }
|
||||||
serde_plain = "1.0.2"
|
serde_plain = "1.0.2"
|
||||||
serde_json = "1.0.99"
|
serde_json = "1.0.99"
|
||||||
|
51
README.md
51
README.md
@ -63,25 +63,17 @@ We also provide a some command line based examples using state of the art models
|
|||||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
|
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
|
||||||
the SOLAR-10.7B variant.
|
the SOLAR-10.7B variant.
|
||||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||||
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
|
|
||||||
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
|
|
||||||
Griffin based models from Google that mix attention with a RNN like state.
|
|
||||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||||
pre-trained on 1T tokens of English and code datasets. Also supports
|
pre-trained on 1T tokens of English and code datasets.
|
||||||
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
|
- [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal
|
||||||
- [Mamba](./candle-examples/examples/mamba/): an inference only
|
|
||||||
implementation of the Mamba state space model.
|
implementation of the Mamba state space model.
|
||||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||||
better performance than all publicly available 13b models as of 2023-09-28.
|
better performance than all publicly available 13b models as of 2023-09-28.
|
||||||
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
|
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
|
||||||
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
||||||
much faster inference.
|
much faster inference.
|
||||||
- [StarCoder](./candle-examples/examples/bigcode/) and
|
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||||
[StarCoder2](./candle-examples/examples/starcoder2/): LLM specialized to code generation.
|
|
||||||
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
|
|
||||||
- [RWKV v5 and v6](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
|
|
||||||
performance.
|
|
||||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||||
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||||
(English/Chinese) general LLMs with 6b and 34b parameters.
|
(English/Chinese) general LLMs with 6b and 34b parameters.
|
||||||
@ -111,12 +103,7 @@ We also provide a some command line based examples using state of the art models
|
|||||||
|
|
||||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||||
|
|
||||||
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmantation model.
|
|
||||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||||
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
|
|
||||||
model using residual vector quantization.
|
|
||||||
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
|
|
||||||
text-to-speech.
|
|
||||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
||||||
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
||||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||||
@ -124,16 +111,11 @@ We also provide a some command line based examples using state of the art models
|
|||||||
evaluation, segmentation).
|
evaluation, segmentation).
|
||||||
- [VGG](./candle-examples/examples/vgg/),
|
- [VGG](./candle-examples/examples/vgg/),
|
||||||
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
||||||
|
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||||
generate captions for an image.
|
generate captions for an image.
|
||||||
- [CLIP](./candle-examples/examples/clip/): multi-model vision and language
|
|
||||||
model.
|
|
||||||
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
|
|
||||||
dedicated submodels for hand-writing and printed recognition.
|
|
||||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||||
model, generates the translated text from the input text.
|
model, generates the translated text from the input text.
|
||||||
- [Moondream](./candle-examples/examples/moondream/): tiny computer-vision model
|
|
||||||
that can answer real-world questions about images.
|
|
||||||
|
|
||||||
Run them using commands like:
|
Run them using commands like:
|
||||||
```
|
```
|
||||||
@ -177,11 +159,9 @@ And then head over to
|
|||||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||||
serving local LLMs including an OpenAI compatible API server.
|
serving local LLMs including an OpenAI compatible API server.
|
||||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||||
- [`candle-coursera-ml`](https://github.com/vishpat/candle-coursera-ml): Implementation of ML algorithms from Coursera's [Machine Learning Specialization](https://www.coursera.org/specializations/machine-learning-introduction) course.
|
|
||||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||||
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
||||||
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
|
|
||||||
|
|
||||||
If you have an addition to this list, please submit a pull request.
|
If you have an addition to this list, please submit a pull request.
|
||||||
|
|
||||||
@ -202,18 +182,15 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- Language Models.
|
- Language Models.
|
||||||
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
|
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
|
||||||
- Falcon.
|
- Falcon.
|
||||||
- StarCoder, StarCoder2.
|
- StarCoder.
|
||||||
- Phi 1, 1.5, and 2.
|
- Phi 1, 1.5, and 2.
|
||||||
- Mamba, Minimal Mamba
|
- Minimal Mamba
|
||||||
- Gemma 2b and 7b.
|
|
||||||
- Mistral 7b v0.1.
|
- Mistral 7b v0.1.
|
||||||
- Mixtral 8x7b v0.1.
|
- Mixtral 8x7b v0.1.
|
||||||
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
|
- StableLM-3B-4E1T.
|
||||||
- Replit-code-v1.5-3B.
|
- Replit-code-v1.5-3B.
|
||||||
- Bert.
|
- Bert.
|
||||||
- Yi-6B and Yi-34B.
|
- Yi-6B and Yi-34B.
|
||||||
- Qwen1.5, Qwen1.5 MoE.
|
|
||||||
- RWKV v5 and v6.
|
|
||||||
- Quantized LLMs.
|
- Quantized LLMs.
|
||||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||||
- Mistral 7b, and 7b instruct.
|
- Mistral 7b, and 7b instruct.
|
||||||
@ -223,22 +200,16 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- Text to text.
|
- Text to text.
|
||||||
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||||
- Marian MT (Machine Translation).
|
- Marian MT (Machine Translation).
|
||||||
|
- Whisper (multi-lingual support).
|
||||||
- Text to image.
|
- Text to image.
|
||||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||||
- Wurstchen v2.
|
- Wurstchen v2.
|
||||||
- Image to text.
|
- Image to text.
|
||||||
- BLIP.
|
- BLIP.
|
||||||
- TrOCR.
|
|
||||||
- Audio.
|
|
||||||
- Whisper, multi-lingual speech-to-text.
|
|
||||||
- EnCodec, audio compression model.
|
|
||||||
- MetaVoice-1B, text-to-speech model.
|
|
||||||
- Computer Vision Models.
|
- Computer Vision Models.
|
||||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
|
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG.
|
||||||
ConvNeXTv2, MobileOne, EfficientVit (MSRA).
|
|
||||||
- yolo-v3, yolo-v8.
|
- yolo-v3, yolo-v8.
|
||||||
- Segment-Anything Model (SAM).
|
- Segment-Anything Model (SAM).
|
||||||
- SegFormer.
|
|
||||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||||
- Serverless (on CPU), small and fast deployments.
|
- Serverless (on CPU), small and fast deployments.
|
||||||
- Quantization support using the llama.cpp quantized types.
|
- Quantization support using the llama.cpp quantized types.
|
||||||
@ -375,9 +346,9 @@ git submodule update --init
|
|||||||
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
||||||
```
|
```
|
||||||
|
|
||||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.
|
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
|
||||||
```
|
```
|
||||||
env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Linking error on windows when running rustdoc or mdbook tests
|
#### Linking error on windows when running rustdoc or mdbook tests
|
||||||
|
@ -2,10 +2,7 @@ mod benchmarks;
|
|||||||
|
|
||||||
use criterion::criterion_main;
|
use criterion::criterion_main;
|
||||||
criterion_main!(
|
criterion_main!(
|
||||||
benchmarks::affine::benches,
|
|
||||||
benchmarks::matmul::benches,
|
benchmarks::matmul::benches,
|
||||||
benchmarks::random::benches,
|
benchmarks::affine::benches,
|
||||||
benchmarks::where_cond::benches,
|
benchmarks::where_cond::benches
|
||||||
benchmarks::conv_transpose2d::benches,
|
|
||||||
benchmarks::qmatmul::benches,
|
|
||||||
);
|
);
|
||||||
|
@ -1,59 +0,0 @@
|
|||||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
|
||||||
use candle_core::{DType, Device, Tensor};
|
|
||||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn run(
|
|
||||||
x: &Tensor,
|
|
||||||
k: &Tensor,
|
|
||||||
padding: usize,
|
|
||||||
output_padding: usize,
|
|
||||||
stride: usize,
|
|
||||||
dilation: usize,
|
|
||||||
) {
|
|
||||||
x.conv_transpose2d(k, padding, output_padding, stride, dilation)
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
|
||||||
let t = Tensor::arange(0.0f32, 10000.0, device)
|
|
||||||
.unwrap()
|
|
||||||
.reshape((1, 4, 50, 50))
|
|
||||||
.unwrap()
|
|
||||||
.to_dtype(dtype)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let kernel = Tensor::arange(0.0f32, 100.0, device)
|
|
||||||
.unwrap()
|
|
||||||
.reshape((4, 1, 5, 5))
|
|
||||||
.unwrap()
|
|
||||||
.to_dtype(dtype)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let flops = t.dims().iter().product::<usize>() * dtype.size_in_bytes();
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name(name));
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |b| {
|
|
||||||
b.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
run(black_box(&t), black_box(&kernel), 1, 0, 1, 2);
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
|
||||||
let handler = BenchDeviceHandler::new().unwrap();
|
|
||||||
for device in handler.devices {
|
|
||||||
run_benchmark(c, &device, DType::F32, "conv_transpose2d_f32");
|
|
||||||
run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16");
|
|
||||||
run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
|
@ -1,8 +1,5 @@
|
|||||||
pub(crate) mod affine;
|
pub(crate) mod affine;
|
||||||
pub(crate) mod conv_transpose2d;
|
|
||||||
pub(crate) mod matmul;
|
pub(crate) mod matmul;
|
||||||
pub(crate) mod qmatmul;
|
|
||||||
pub(crate) mod random;
|
|
||||||
pub(crate) mod where_cond;
|
pub(crate) mod where_cond;
|
||||||
|
|
||||||
use candle_core::{Device, Result};
|
use candle_core::{Device, Result};
|
||||||
|
@ -1,72 +0,0 @@
|
|||||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
|
||||||
use candle_core::{
|
|
||||||
quantized::{self, GgmlDType, QMatMul},
|
|
||||||
Device, Module, Tensor,
|
|
||||||
};
|
|
||||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn run(matmul: &QMatMul, x: &Tensor) {
|
|
||||||
matmul.forward(&x).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
|
|
||||||
let b = 1;
|
|
||||||
let m = 1;
|
|
||||||
let n = 1024;
|
|
||||||
let k = 1024;
|
|
||||||
|
|
||||||
let lhs = (0..(m * k))
|
|
||||||
.map(|v| v as f32 / (m * k) as f32)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let rhs = (0..(k * n))
|
|
||||||
.map(|v| v as f32 / (n * k) as f32)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let lhs = Tensor::from_slice(&lhs, (m, k), device).unwrap();
|
|
||||||
let rhs = Tensor::from_slice(&rhs, (k, n), device).unwrap();
|
|
||||||
|
|
||||||
let qtensor = quantized::QTensor::quantize(&rhs.t().unwrap(), dtype).unwrap();
|
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor).unwrap();
|
|
||||||
|
|
||||||
let flops = b * m * n * k;
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype)));
|
|
||||||
group.sample_size(200);
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |b| {
|
|
||||||
b.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
run(black_box(&matmul), black_box(&lhs));
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
|
||||||
let handler = BenchDeviceHandler::new().unwrap();
|
|
||||||
for device in handler.devices {
|
|
||||||
for dtype in vec![
|
|
||||||
GgmlDType::F32,
|
|
||||||
GgmlDType::F16,
|
|
||||||
GgmlDType::Q4_0,
|
|
||||||
GgmlDType::Q4_1,
|
|
||||||
GgmlDType::Q5_0,
|
|
||||||
GgmlDType::Q5_1,
|
|
||||||
GgmlDType::Q8_0,
|
|
||||||
GgmlDType::Q2K,
|
|
||||||
GgmlDType::Q3K,
|
|
||||||
GgmlDType::Q4K,
|
|
||||||
GgmlDType::Q5K,
|
|
||||||
GgmlDType::Q6K,
|
|
||||||
] {
|
|
||||||
run_bench(c, &device, dtype);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
|
@ -1,63 +0,0 @@
|
|||||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
|
||||||
use candle_core::{DType, Device, Tensor};
|
|
||||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn rand_uniform(a: &Tensor) {
|
|
||||||
a.rand_like(-1.0, 123.0).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rand_normal(a: &Tensor) {
|
|
||||||
a.randn_like(100.0, 15.0).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_random_bench(c: &mut Criterion, device: &Device) {
|
|
||||||
let b = 1;
|
|
||||||
|
|
||||||
let rows = 2048;
|
|
||||||
let cols = 2048;
|
|
||||||
|
|
||||||
let dtype = DType::F32;
|
|
||||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
|
||||||
|
|
||||||
let flops = b * rows * cols * dtype.size_in_bytes();
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |benches| {
|
|
||||||
benches.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
rand_uniform(black_box(&tensor));
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
|
|
||||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name("random_normal"));
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |benches| {
|
|
||||||
benches.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
rand_normal(black_box(&tensor));
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
|
||||||
let handler = BenchDeviceHandler::new().unwrap();
|
|
||||||
for device in handler.devices {
|
|
||||||
run_random_bench(c, &device);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
|
@ -5,32 +5,25 @@ extern crate accelerate_src;
|
|||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle_core::{Device, Module, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
use candle_core::quantized::{QMatMul, QTensor};
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let device = Device::new_cuda(0)?;
|
let device = Device::new_cuda(0)?;
|
||||||
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
|
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
|
||||||
let q_cpu = q.to_device(&Device::Cpu)?;
|
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
|
||||||
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
|
let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||||
let q = QMatMul::from_qtensor(q)?;
|
println!("{out_t}");
|
||||||
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
|
let in_t = in_t.to_device(&Device::Cpu)?;
|
||||||
let res_q_cuda = q.forward(&x)?;
|
let k_t = k_t.to_device(&Device::Cpu)?;
|
||||||
println!("{res_q_cuda}");
|
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||||
|
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
|
||||||
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
|
.sqr()?
|
||||||
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
|
.sum_all()?;
|
||||||
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
|
|
||||||
let x_cpu = x.to_device(&Device::Cpu)?;
|
|
||||||
let res_q_cpu = q_cpu.forward(&x_cpu)?;
|
|
||||||
println!("{res_q_cpu}");
|
|
||||||
|
|
||||||
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
|
|
||||||
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
|
|
||||||
.abs()?
|
|
||||||
.flatten_all()?
|
|
||||||
.max(0)?;
|
|
||||||
println!("{diff}");
|
println!("{diff}");
|
||||||
|
|
||||||
|
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
|
||||||
|
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
|
||||||
|
let res = t.conv2d(&w, 1, 1, 1, 1)?;
|
||||||
|
println!("{res:?}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use candle::quantized::{gguf_file, GgmlDType, QTensor};
|
use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
|
||||||
use candle::{Device, Result};
|
use candle_core::{Device, Result};
|
||||||
use clap::{Parser, Subcommand, ValueEnum};
|
use clap::{Parser, Subcommand, ValueEnum};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
@ -117,24 +117,6 @@ enum Command {
|
|||||||
verbose: bool,
|
verbose: bool,
|
||||||
},
|
},
|
||||||
|
|
||||||
Print {
|
|
||||||
file: std::path::PathBuf,
|
|
||||||
|
|
||||||
names: Vec<String>,
|
|
||||||
|
|
||||||
/// The file format to use, if unspecified infer from the file extension.
|
|
||||||
#[arg(long, value_enum)]
|
|
||||||
format: Option<Format>,
|
|
||||||
|
|
||||||
/// Print the whole content of each tensor.
|
|
||||||
#[arg(long)]
|
|
||||||
full: bool,
|
|
||||||
|
|
||||||
/// Line width for printing the tensors.
|
|
||||||
#[arg(long)]
|
|
||||||
line_width: Option<usize>,
|
|
||||||
},
|
|
||||||
|
|
||||||
Quantize {
|
Quantize {
|
||||||
/// The input file(s), in safetensors format.
|
/// The input file(s), in safetensors format.
|
||||||
in_file: Vec<std::path::PathBuf>,
|
in_file: Vec<std::path::PathBuf>,
|
||||||
@ -168,105 +150,6 @@ struct Args {
|
|||||||
command: Command,
|
command: Command,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_print(
|
|
||||||
file: &std::path::PathBuf,
|
|
||||||
names: Vec<String>,
|
|
||||||
format: Option<Format>,
|
|
||||||
full: bool,
|
|
||||||
line_width: Option<usize>,
|
|
||||||
device: &Device,
|
|
||||||
) -> Result<()> {
|
|
||||||
if full {
|
|
||||||
candle::display::set_print_options_full();
|
|
||||||
}
|
|
||||||
if let Some(line_width) = line_width {
|
|
||||||
candle::display::set_line_width(line_width)
|
|
||||||
}
|
|
||||||
let format = match format {
|
|
||||||
Some(format) => format,
|
|
||||||
None => match Format::infer(file) {
|
|
||||||
Some(format) => format,
|
|
||||||
None => {
|
|
||||||
println!(
|
|
||||||
"{file:?}: cannot infer format from file extension, use the --format flag"
|
|
||||||
);
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
},
|
|
||||||
};
|
|
||||||
match format {
|
|
||||||
Format::Npz => {
|
|
||||||
let tensors = candle::npy::NpzTensors::new(file)?;
|
|
||||||
for name in names.iter() {
|
|
||||||
println!("==== {name} ====");
|
|
||||||
match tensors.get(name)? {
|
|
||||||
Some(tensor) => println!("{tensor}"),
|
|
||||||
None => println!("not found"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Format::Safetensors => {
|
|
||||||
use candle::safetensors::Load;
|
|
||||||
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
|
|
||||||
let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
|
|
||||||
for name in names.iter() {
|
|
||||||
println!("==== {name} ====");
|
|
||||||
match tensors.get(name) {
|
|
||||||
Some(tensor_view) => {
|
|
||||||
let tensor = tensor_view.load(device)?;
|
|
||||||
println!("{tensor}")
|
|
||||||
}
|
|
||||||
None => println!("not found"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Format::Pth => {
|
|
||||||
let pth_file = candle::pickle::PthTensors::new(file, None)?;
|
|
||||||
for name in names.iter() {
|
|
||||||
println!("==== {name} ====");
|
|
||||||
match pth_file.get(name)? {
|
|
||||||
Some(tensor) => {
|
|
||||||
println!("{tensor}")
|
|
||||||
}
|
|
||||||
None => println!("not found"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Format::Pickle => {
|
|
||||||
candle::bail!("pickle format is not supported for print")
|
|
||||||
}
|
|
||||||
Format::Ggml => {
|
|
||||||
let mut file = std::fs::File::open(file)?;
|
|
||||||
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
|
|
||||||
for name in names.iter() {
|
|
||||||
println!("==== {name} ====");
|
|
||||||
match content.tensors.get(name) {
|
|
||||||
Some(tensor) => {
|
|
||||||
let tensor = tensor.dequantize(device)?;
|
|
||||||
println!("{tensor}")
|
|
||||||
}
|
|
||||||
None => println!("not found"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Format::Gguf => {
|
|
||||||
let mut file = std::fs::File::open(file)?;
|
|
||||||
let content = gguf_file::Content::read(&mut file)?;
|
|
||||||
for name in names.iter() {
|
|
||||||
println!("==== {name} ====");
|
|
||||||
match content.tensor(&mut file, name, device) {
|
|
||||||
Ok(tensor) => {
|
|
||||||
let tensor = tensor.dequantize(device)?;
|
|
||||||
println!("{tensor}")
|
|
||||||
}
|
|
||||||
Err(_) => println!("not found"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_ls(
|
fn run_ls(
|
||||||
file: &std::path::PathBuf,
|
file: &std::path::PathBuf,
|
||||||
format: Option<Format>,
|
format: Option<Format>,
|
||||||
@ -287,7 +170,7 @@ fn run_ls(
|
|||||||
};
|
};
|
||||||
match format {
|
match format {
|
||||||
Format::Npz => {
|
Format::Npz => {
|
||||||
let tensors = candle::npy::NpzTensors::new(file)?;
|
let tensors = candle_core::npy::NpzTensors::new(file)?;
|
||||||
let mut names = tensors.names();
|
let mut names = tensors.names();
|
||||||
names.sort();
|
names.sort();
|
||||||
for name in names {
|
for name in names {
|
||||||
@ -299,12 +182,12 @@ fn run_ls(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Format::Safetensors => {
|
Format::Safetensors => {
|
||||||
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
|
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
|
||||||
let mut tensors = tensors.tensors();
|
let mut tensors = tensors.tensors();
|
||||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||||
for (name, view) in tensors.iter() {
|
for (name, view) in tensors.iter() {
|
||||||
let dtype = view.dtype();
|
let dtype = view.dtype();
|
||||||
let dtype = match candle::DType::try_from(dtype) {
|
let dtype = match candle_core::DType::try_from(dtype) {
|
||||||
Ok(dtype) => format!("{dtype:?}"),
|
Ok(dtype) => format!("{dtype:?}"),
|
||||||
Err(_) => format!("{dtype:?}"),
|
Err(_) => format!("{dtype:?}"),
|
||||||
};
|
};
|
||||||
@ -313,7 +196,7 @@ fn run_ls(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Format::Pth => {
|
Format::Pth => {
|
||||||
let mut tensors = candle::pickle::read_pth_tensor_info(file, verbose, None)?;
|
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?;
|
||||||
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
||||||
for tensor_info in tensors.iter() {
|
for tensor_info in tensors.iter() {
|
||||||
println!(
|
println!(
|
||||||
@ -330,7 +213,7 @@ fn run_ls(
|
|||||||
Format::Pickle => {
|
Format::Pickle => {
|
||||||
let file = std::fs::File::open(file)?;
|
let file = std::fs::File::open(file)?;
|
||||||
let mut reader = std::io::BufReader::new(file);
|
let mut reader = std::io::BufReader::new(file);
|
||||||
let mut stack = candle::pickle::Stack::empty();
|
let mut stack = candle_core::pickle::Stack::empty();
|
||||||
stack.read_loop(&mut reader)?;
|
stack.read_loop(&mut reader)?;
|
||||||
for (i, obj) in stack.stack().iter().enumerate() {
|
for (i, obj) in stack.stack().iter().enumerate() {
|
||||||
println!("{i} {obj:?}");
|
println!("{i} {obj:?}");
|
||||||
@ -338,7 +221,7 @@ fn run_ls(
|
|||||||
}
|
}
|
||||||
Format::Ggml => {
|
Format::Ggml => {
|
||||||
let mut file = std::fs::File::open(file)?;
|
let mut file = std::fs::File::open(file)?;
|
||||||
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
|
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
|
||||||
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
||||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||||
for (name, qtensor) in tensors.iter() {
|
for (name, qtensor) in tensors.iter() {
|
||||||
@ -374,7 +257,7 @@ fn run_quantize_safetensors(
|
|||||||
let mut out_file = std::fs::File::create(out_file)?;
|
let mut out_file = std::fs::File::create(out_file)?;
|
||||||
let mut tensors = std::collections::HashMap::new();
|
let mut tensors = std::collections::HashMap::new();
|
||||||
for in_file in in_files.iter() {
|
for in_file in in_files.iter() {
|
||||||
let in_tensors = candle::safetensors::load(in_file, &Device::Cpu)?;
|
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
|
||||||
tensors.extend(in_tensors)
|
tensors.extend(in_tensors)
|
||||||
}
|
}
|
||||||
println!("tensors: {}", tensors.len());
|
println!("tensors: {}", tensors.len());
|
||||||
@ -416,7 +299,7 @@ fn run_dequantize(
|
|||||||
let tensor = tensor.dequantize(device)?;
|
let tensor = tensor.dequantize(device)?;
|
||||||
tensors.insert(tensor_name.to_string(), tensor);
|
tensors.insert(tensor_name.to_string(), tensor);
|
||||||
}
|
}
|
||||||
candle::safetensors::save(&tensors, out_file)?;
|
candle_core::safetensors::save(&tensors, out_file)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -428,11 +311,11 @@ fn run_quantize(
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if in_files.is_empty() {
|
if in_files.is_empty() {
|
||||||
candle::bail!("no specified input files")
|
candle_core::bail!("no specified input files")
|
||||||
}
|
}
|
||||||
if let Some(extension) = out_file.extension() {
|
if let Some(extension) = out_file.extension() {
|
||||||
if extension == "safetensors" {
|
if extension == "safetensors" {
|
||||||
candle::bail!("the generated file cannot use the safetensors extension")
|
candle_core::bail!("the generated file cannot use the safetensors extension")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Some(extension) = in_files[0].extension() {
|
if let Some(extension) = in_files[0].extension() {
|
||||||
@ -442,7 +325,7 @@ fn run_quantize(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if in_files.len() != 1 {
|
if in_files.len() != 1 {
|
||||||
candle::bail!("only a single in-file can be used when quantizing gguf files")
|
candle_core::bail!("only a single in-file can be used when quantizing gguf files")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open the out file early so as to fail directly on missing directories etc.
|
// Open the out file early so as to fail directly on missing directories etc.
|
||||||
@ -494,13 +377,6 @@ fn main() -> anyhow::Result<()> {
|
|||||||
run_ls(file, format.clone(), verbose, &device)?
|
run_ls(file, format.clone(), verbose, &device)?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Command::Print {
|
|
||||||
file,
|
|
||||||
names,
|
|
||||||
format,
|
|
||||||
full,
|
|
||||||
line_width,
|
|
||||||
} => run_print(&file, names, format, full, line_width, &device)?,
|
|
||||||
Command::Quantize {
|
Command::Quantize {
|
||||||
in_file,
|
in_file,
|
||||||
out_file,
|
out_file,
|
@ -380,16 +380,6 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
|||||||
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn vs_exp_inplace(y: &mut [f32]) {
|
|
||||||
unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn vd_exp_inplace(y: &mut [f64]) {
|
|
||||||
unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
@ -412,28 +402,6 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
|
||||||
*y = -v
|
|
||||||
}
|
|
||||||
vs_exp_inplace(ys);
|
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
|
||||||
*y = v / (1.0 + *y)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
|
||||||
*y = -v
|
|
||||||
}
|
|
||||||
vd_exp_inplace(ys);
|
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
|
||||||
*y = v / (1.0 + *y)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! binary_op {
|
macro_rules! binary_op {
|
||||||
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@ -98,19 +98,6 @@ pub trait BackendStorage: Sized {
|
|||||||
) -> Result<Self>;
|
) -> Result<Self>;
|
||||||
|
|
||||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
|
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
// Similar to cudaMemcpy2D, though values are in elements and not in bytes.
|
|
||||||
fn copy2d(
|
|
||||||
&self,
|
|
||||||
_: &mut Self,
|
|
||||||
_d1: usize,
|
|
||||||
_d2: usize,
|
|
||||||
_src_stride1: usize,
|
|
||||||
_dst_stride1: usize,
|
|
||||||
_src_offset: usize,
|
|
||||||
_dst_offset: usize,
|
|
||||||
) -> Result<()>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||||
@ -127,22 +114,11 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
|||||||
|
|
||||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
/// This function is unsafe as it doesn't initialize the underlying data store.
|
|
||||||
/// The caller should ensure that the data is properly initialized as early as possible
|
|
||||||
/// after this call.
|
|
||||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
|
|
||||||
|
|
||||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn set_seed(&self, _: u64) -> Result<()>;
|
fn set_seed(&self, _: u64) -> Result<()>;
|
||||||
|
|
||||||
/// Synchronize should block until all the operations on the device are completed.
|
|
||||||
fn synchronize(&self) -> Result<()>;
|
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
/// Methods for backpropagation of gradients.
|
|
||||||
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
|
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
|
||||||
use crate::{Error, Result, Tensor, TensorId};
|
use crate::{Error, Result, Tensor, TensorId};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -112,10 +111,9 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
Op::Unary(_node, UnaryOp::Ceil)
|
Op::Unary(_node, UnaryOp::Ceil)
|
||||||
| Op::Unary(_node, UnaryOp::Floor)
|
| Op::Unary(_node, UnaryOp::Floor)
|
||||||
| Op::Unary(_node, UnaryOp::Round)
|
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||||
| Op::Unary(_node, UnaryOp::Sign) => nodes,
|
|
||||||
Op::Reshape(node)
|
Op::Reshape(node)
|
||||||
| Op::UpsampleNearest1D { arg: node, .. }
|
| Op::UpsampleNearest1D(node)
|
||||||
| Op::UpsampleNearest2D { arg: node, .. }
|
| Op::UpsampleNearest2D { arg: node, .. }
|
||||||
| Op::AvgPool2D { arg: node, .. }
|
| Op::AvgPool2D { arg: node, .. }
|
||||||
| Op::MaxPool2D { arg: node, .. }
|
| Op::MaxPool2D { arg: node, .. }
|
||||||
@ -177,7 +175,7 @@ impl Tensor {
|
|||||||
// the backprop graph of the backprop itself. This would be an issue for second order
|
// the backprop graph of the backprop itself. This would be an issue for second order
|
||||||
// derivatives but these are out of scope at the moment.
|
// derivatives but these are out of scope at the moment.
|
||||||
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
||||||
let grad = if do_not_detach { grad } else { grad.detach() };
|
let grad = if do_not_detach { grad } else { grad.detach()? };
|
||||||
if let Some(op) = node.op() {
|
if let Some(op) = node.op() {
|
||||||
match op {
|
match op {
|
||||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||||
@ -252,7 +250,6 @@ impl Tensor {
|
|||||||
out_padding,
|
out_padding,
|
||||||
*stride,
|
*stride,
|
||||||
*dilation,
|
*dilation,
|
||||||
/* groups */ 1,
|
|
||||||
)?;
|
)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||||
@ -312,32 +309,9 @@ impl Tensor {
|
|||||||
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
|
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
|
||||||
op: "conv-transpose1d",
|
op: "conv-transpose1d",
|
||||||
})?,
|
})?,
|
||||||
Op::ConvTranspose2D {
|
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||||
arg,
|
op: "conv-transpose2d",
|
||||||
kernel,
|
})?,
|
||||||
padding,
|
|
||||||
stride,
|
|
||||||
dilation,
|
|
||||||
output_padding: _output_padding,
|
|
||||||
} => {
|
|
||||||
let grad_arg = grad.conv2d(kernel, *padding, *dilation, *stride, 1)?;
|
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
|
||||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
|
||||||
|
|
||||||
let grad_kernel = grad
|
|
||||||
.transpose(0, 1)?
|
|
||||||
.conv2d(&arg.transpose(0, 1)?, *padding, *stride, *dilation, 1)?
|
|
||||||
.transpose(0, 1)?;
|
|
||||||
let sum_grad = grads.or_insert(kernel)?;
|
|
||||||
let (_, _, k0, k1) = kernel.dims4()?;
|
|
||||||
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
|
|
||||||
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
|
|
||||||
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
|
|
||||||
} else {
|
|
||||||
grad_kernel
|
|
||||||
};
|
|
||||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
|
||||||
}
|
|
||||||
Op::AvgPool2D {
|
Op::AvgPool2D {
|
||||||
arg,
|
arg,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
@ -373,18 +347,9 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||||
}
|
}
|
||||||
Op::UpsampleNearest1D { arg, target_size } => {
|
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
||||||
let (_n, c, size) = arg.dims3()?;
|
op: "upsample-nearest1d",
|
||||||
if target_size % size != 0 {
|
})?,
|
||||||
crate::bail!("backward not supported for non integer upscaling factors")
|
|
||||||
}
|
|
||||||
let scale = target_size / size;
|
|
||||||
|
|
||||||
let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
|
|
||||||
let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
|
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
|
||||||
*sum_grad = conv_sum;
|
|
||||||
}
|
|
||||||
Op::UpsampleNearest2D {
|
Op::UpsampleNearest2D {
|
||||||
arg,
|
arg,
|
||||||
target_h,
|
target_h,
|
||||||
@ -489,6 +454,7 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad)?;
|
*sum_grad = sum_grad.add(&grad)?;
|
||||||
}
|
}
|
||||||
|
Op::Cmp(_args, _) => {}
|
||||||
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
|
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
|
||||||
let node = broadcast_back(arg, node, reduced_dims)?;
|
let node = broadcast_back(arg, node, reduced_dims)?;
|
||||||
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||||
@ -578,18 +544,20 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Unary(_, UnaryOp::Floor)
|
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
|
||||||
| Op::Unary(_, UnaryOp::Round)
|
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
|
||||||
| Op::Reduce(_, ReduceOp::ArgMin, _)
|
|
||||||
| Op::Reduce(_, ReduceOp::ArgMax, _)
|
|
||||||
| Op::Unary(_, UnaryOp::Sign)
|
|
||||||
| Op::Cmp(_, _) => {}
|
|
||||||
Op::Reshape(arg) => {
|
Op::Reshape(arg) => {
|
||||||
let arg_grad = grad.reshape(arg.dims())?;
|
let arg_grad = grad.reshape(arg.dims())?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
||||||
|
Op::Unary(_, UnaryOp::Floor) => {
|
||||||
|
Err(Error::BackwardNotSupported { op: "floor" })?
|
||||||
|
}
|
||||||
|
Op::Unary(_, UnaryOp::Round) => {
|
||||||
|
Err(Error::BackwardNotSupported { op: "round" })?
|
||||||
|
}
|
||||||
Op::Unary(arg, UnaryOp::Gelu) => {
|
Op::Unary(arg, UnaryOp::Gelu) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
let cube = arg.powf(3.)?;
|
let cube = arg.powf(3.)?;
|
||||||
@ -621,13 +589,6 @@ impl Tensor {
|
|||||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||||
}
|
}
|
||||||
Op::Unary(arg, UnaryOp::Silu) => {
|
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
|
||||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
|
||||||
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
|
|
||||||
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
|
||||||
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
|
||||||
}
|
|
||||||
Op::Elu(arg, alpha) => {
|
Op::Elu(arg, alpha) => {
|
||||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
@ -712,38 +673,30 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation.
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct GradStore(HashMap<TensorId, Tensor>);
|
pub struct GradStore(HashMap<TensorId, Tensor>);
|
||||||
|
|
||||||
impl GradStore {
|
impl GradStore {
|
||||||
/// Create a new gradient store
|
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
GradStore(HashMap::new())
|
GradStore(HashMap::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the gradient tensor corresponding to the given tensor id
|
|
||||||
pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
|
pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
|
||||||
self.0.get(&id)
|
self.0.get(&id)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the gradient tensor associated with the given tensor
|
|
||||||
pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
|
pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
|
||||||
self.0.get(&tensor.id())
|
self.0.get(&tensor.id())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Remove the gradient tensor associated with the given tensor, returning it if it exists
|
|
||||||
pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
|
pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
|
||||||
self.0.remove(&tensor.id())
|
self.0.remove(&tensor.id())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed
|
|
||||||
pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
|
pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
|
||||||
self.0.insert(tensor.id(), grad)
|
self.0.insert(tensor.id(), grad)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the gradient tensor associated with the given tensor, or, if it does not exist,
|
|
||||||
/// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
|
|
||||||
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
|
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
|
||||||
use std::collections::hash_map::Entry;
|
use std::collections::hash_map::Entry;
|
||||||
let grad = match self.0.entry(tensor.id()) {
|
let grad = match self.0.entry(tensor.id()) {
|
||||||
|
@ -187,16 +187,36 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv_transpose1d_single_group(
|
/// Applies a 1D transposed convolution over the input tensor.
|
||||||
|
pub fn conv_transpose1d(
|
||||||
&self,
|
&self,
|
||||||
kernel: &Self,
|
kernel: &Self,
|
||||||
params: &ParamsConvTranspose1D,
|
padding: usize,
|
||||||
|
output_padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
let (b_size, c_in, l_in) = self.dims3()?;
|
||||||
|
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||||
|
if c_in != c_in_k {
|
||||||
|
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||||
|
}
|
||||||
|
let params = ParamsConvTranspose1D {
|
||||||
|
b_size,
|
||||||
|
l_in,
|
||||||
|
k_size,
|
||||||
|
c_out,
|
||||||
|
c_in,
|
||||||
|
padding,
|
||||||
|
output_padding,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
};
|
||||||
let storage = self.storage().conv_transpose1d(
|
let storage = self.storage().conv_transpose1d(
|
||||||
self.layout(),
|
self.layout(),
|
||||||
&kernel.storage(),
|
&kernel.storage(),
|
||||||
kernel.layout(),
|
kernel.layout(),
|
||||||
params,
|
¶ms,
|
||||||
)?;
|
)?;
|
||||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
||||||
arg,
|
arg,
|
||||||
@ -210,49 +230,6 @@ impl Tensor {
|
|||||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies a 1D transposed convolution over the input tensor.
|
|
||||||
pub fn conv_transpose1d(
|
|
||||||
&self,
|
|
||||||
kernel: &Self,
|
|
||||||
padding: usize,
|
|
||||||
output_padding: usize,
|
|
||||||
stride: usize,
|
|
||||||
dilation: usize,
|
|
||||||
groups: usize,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
|
||||||
let (b_size, c_in, l_in) = self.dims3()?;
|
|
||||||
if c_in != c_in_k {
|
|
||||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
|
||||||
}
|
|
||||||
if c_in % groups != 0 {
|
|
||||||
crate::bail!("in_channel {c_in} is not divisible by the number of groups")
|
|
||||||
}
|
|
||||||
let params = ParamsConvTranspose1D {
|
|
||||||
b_size,
|
|
||||||
l_in,
|
|
||||||
k_size,
|
|
||||||
c_out,
|
|
||||||
c_in: c_in / groups,
|
|
||||||
padding,
|
|
||||||
output_padding,
|
|
||||||
stride,
|
|
||||||
dilation,
|
|
||||||
};
|
|
||||||
if groups == 1 {
|
|
||||||
self.conv_transpose1d_single_group(kernel, ¶ms)
|
|
||||||
} else {
|
|
||||||
let blocks = self.chunk(groups, 1)?;
|
|
||||||
let kernel = kernel.chunk(groups, 0)?;
|
|
||||||
let blocks = blocks
|
|
||||||
.iter()
|
|
||||||
.zip(&kernel)
|
|
||||||
.map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, ¶ms))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
Tensor::cat(&blocks, 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||||
let storage =
|
let storage =
|
||||||
self.storage()
|
self.storage()
|
||||||
|
@ -4,13 +4,7 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
|
|||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
mod utils;
|
|
||||||
pub use utils::{
|
|
||||||
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
|
|
||||||
};
|
|
||||||
|
|
||||||
const USE_IM2COL_CONV1D: bool = true;
|
const USE_IM2COL_CONV1D: bool = true;
|
||||||
const USE_IM2COL_CONV1D_TR: bool = true;
|
|
||||||
const USE_IM2COL_CONV2D: bool = true;
|
const USE_IM2COL_CONV2D: bool = true;
|
||||||
|
|
||||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||||
@ -29,6 +23,102 @@ pub enum CpuStorage {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct CpuDevice;
|
pub struct CpuDevice;
|
||||||
|
|
||||||
|
pub trait Map1 {
|
||||||
|
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
||||||
|
|
||||||
|
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
||||||
|
match vs {
|
||||||
|
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
|
||||||
|
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
|
||||||
|
CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
|
||||||
|
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
|
||||||
|
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
|
||||||
|
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
|
||||||
|
CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map1Any {
|
||||||
|
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
|
||||||
|
&self,
|
||||||
|
vs: &[T],
|
||||||
|
layout: &Layout,
|
||||||
|
wrap: W,
|
||||||
|
) -> Result<CpuStorage>;
|
||||||
|
|
||||||
|
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
||||||
|
match vs {
|
||||||
|
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
|
||||||
|
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
|
||||||
|
CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
|
||||||
|
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
|
||||||
|
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
|
||||||
|
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
|
||||||
|
CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type C = CpuStorage;
|
||||||
|
pub trait Map2 {
|
||||||
|
const OP: &'static str;
|
||||||
|
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
||||||
|
|
||||||
|
fn map(
|
||||||
|
&self,
|
||||||
|
v1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
v2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
) -> Result<CpuStorage> {
|
||||||
|
match (v1, v2) {
|
||||||
|
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
||||||
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: v1.dtype(),
|
||||||
|
rhs: v2.dtype(),
|
||||||
|
op: Self::OP,
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2U8 {
|
||||||
|
const OP: &'static str;
|
||||||
|
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||||
|
|
||||||
|
fn map(
|
||||||
|
&self,
|
||||||
|
v1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
v2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
) -> Result<CpuStorage> {
|
||||||
|
match (v1, v2) {
|
||||||
|
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: v1.dtype(),
|
||||||
|
rhs: v2.dtype(),
|
||||||
|
op: Self::OP,
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Cmp(CmpOp);
|
struct Cmp(CmpOp);
|
||||||
impl Map2U8 for Cmp {
|
impl Map2U8 for Cmp {
|
||||||
const OP: &'static str = "cmp";
|
const OP: &'static str = "cmp";
|
||||||
@ -275,6 +365,275 @@ impl<'a> Map1 for ReduceSum<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||||
|
vs: &[T],
|
||||||
|
layout: &Layout,
|
||||||
|
mut f: F,
|
||||||
|
) -> Vec<U> {
|
||||||
|
match layout.strided_blocks() {
|
||||||
|
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
||||||
|
[start_offset..start_offset + len]
|
||||||
|
.iter()
|
||||||
|
.map(|&v| f(v))
|
||||||
|
.collect(),
|
||||||
|
crate::StridedBlocks::MultipleBlocks {
|
||||||
|
block_start_index,
|
||||||
|
block_len,
|
||||||
|
} => {
|
||||||
|
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
||||||
|
// Specialize the case where block_len is one to avoid the second loop.
|
||||||
|
if block_len == 1 {
|
||||||
|
for index in block_start_index {
|
||||||
|
let v = unsafe { vs.get_unchecked(index) };
|
||||||
|
result.push(f(*v))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for index in block_start_index {
|
||||||
|
for offset in 0..block_len {
|
||||||
|
let v = unsafe { vs.get_unchecked(index + offset) };
|
||||||
|
result.push(f(*v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
||||||
|
vs: &[T],
|
||||||
|
layout: &Layout,
|
||||||
|
mut f: F,
|
||||||
|
mut f_vec: FV,
|
||||||
|
) -> Vec<U> {
|
||||||
|
match layout.strided_blocks() {
|
||||||
|
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
|
let mut ys: Vec<U> = Vec::with_capacity(len);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||||
|
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(len) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
crate::StridedBlocks::MultipleBlocks {
|
||||||
|
block_start_index,
|
||||||
|
block_len,
|
||||||
|
} => {
|
||||||
|
let el_count = layout.shape().elem_count();
|
||||||
|
// Specialize the case where block_len is one to avoid the second loop.
|
||||||
|
if block_len == 1 {
|
||||||
|
let mut result = Vec::with_capacity(el_count);
|
||||||
|
for index in block_start_index {
|
||||||
|
let v = unsafe { vs.get_unchecked(index) };
|
||||||
|
result.push(f(*v))
|
||||||
|
}
|
||||||
|
result
|
||||||
|
} else {
|
||||||
|
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||||
|
let mut dst_index = 0;
|
||||||
|
for src_index in block_start_index {
|
||||||
|
let vs = &vs[src_index..src_index + block_len];
|
||||||
|
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
||||||
|
f_vec(vs, ys);
|
||||||
|
dst_index += block_len;
|
||||||
|
}
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function maps over two strided index sequences.
|
||||||
|
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
lhs: &[T],
|
||||||
|
rhs: &[T],
|
||||||
|
mut f: F,
|
||||||
|
) -> Vec<U> {
|
||||||
|
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||||
|
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
||||||
|
.iter()
|
||||||
|
.zip(rhs[o_r1..o_r2].iter())
|
||||||
|
.map(|(&l, &r)| f(l, r))
|
||||||
|
.collect(),
|
||||||
|
(Some((o_l1, o_l2)), None) => {
|
||||||
|
// TODO: Maybe we want to avoid going through the layout twice.
|
||||||
|
match rhs_l.offsets_b() {
|
||||||
|
Some(ob) => {
|
||||||
|
let mut i_in_block = 0;
|
||||||
|
let mut i_right_broadcast = 0;
|
||||||
|
lhs[o_l1..o_l2]
|
||||||
|
.iter()
|
||||||
|
.map(|&l| {
|
||||||
|
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
||||||
|
i_right_broadcast += 1;
|
||||||
|
if i_right_broadcast >= ob.right_broadcast {
|
||||||
|
i_in_block += 1;
|
||||||
|
i_right_broadcast = 0;
|
||||||
|
}
|
||||||
|
if i_in_block >= ob.len {
|
||||||
|
i_in_block = 0
|
||||||
|
}
|
||||||
|
f(l, *r)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(None, Some((o_r1, o_r2))) => {
|
||||||
|
// TODO: Maybe we want to avoid going through the layout twice.
|
||||||
|
match lhs_l.offsets_b() {
|
||||||
|
Some(ob) => {
|
||||||
|
let mut i_in_block = 0;
|
||||||
|
let mut i_right_broadcast = 0;
|
||||||
|
rhs[o_r1..o_r2]
|
||||||
|
.iter()
|
||||||
|
.map(|&r| {
|
||||||
|
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
||||||
|
i_right_broadcast += 1;
|
||||||
|
if i_right_broadcast >= ob.right_broadcast {
|
||||||
|
i_in_block += 1;
|
||||||
|
i_right_broadcast = 0;
|
||||||
|
}
|
||||||
|
if i_in_block >= ob.len {
|
||||||
|
i_in_block = 0
|
||||||
|
}
|
||||||
|
f(*l, r)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Similar to binary_map but with vectorized variants.
|
||||||
|
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
lhs: &[T],
|
||||||
|
rhs: &[T],
|
||||||
|
mut f: F,
|
||||||
|
mut f_vec: FV,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let el_count = lhs_l.shape().elem_count();
|
||||||
|
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||||
|
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
||||||
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||||
|
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
||||||
|
Some(ob) if ob.right_broadcast == 1 => {
|
||||||
|
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||||
|
let mut dst_i = 0;
|
||||||
|
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
||||||
|
f_vec(
|
||||||
|
&lhs[src_i..src_i + ob.len],
|
||||||
|
rhs,
|
||||||
|
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||||
|
);
|
||||||
|
dst_i += ob.len;
|
||||||
|
}
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
Some(ob) => {
|
||||||
|
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys = lhs[o_l1..o_l2].to_vec();
|
||||||
|
for idx_l in 0..ob.left_broadcast {
|
||||||
|
let start = idx_l * ob.len * ob.right_broadcast;
|
||||||
|
for (i, &r) in rhs.iter().enumerate() {
|
||||||
|
let start = start + i * ob.right_broadcast;
|
||||||
|
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||||
|
*v = f(*v, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
||||||
|
Some(ob) if ob.right_broadcast == 1 => {
|
||||||
|
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||||
|
let mut dst_i = 0;
|
||||||
|
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
||||||
|
f_vec(
|
||||||
|
lhs,
|
||||||
|
&rhs[src_i..src_i + ob.len],
|
||||||
|
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||||
|
);
|
||||||
|
dst_i += ob.len;
|
||||||
|
}
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
Some(ob) => {
|
||||||
|
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys = rhs[o_r1..o_r2].to_vec();
|
||||||
|
for idx_l in 0..ob.left_broadcast {
|
||||||
|
let start = idx_l * ob.len * ob.right_broadcast;
|
||||||
|
for (i, &l) in lhs.iter().enumerate() {
|
||||||
|
let start = start + i * ob.right_broadcast;
|
||||||
|
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||||
|
*v = f(l, *v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
_ => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Affine(f64, f64);
|
struct Affine(f64, f64);
|
||||||
|
|
||||||
impl Map1 for Affine {
|
impl Map1 for Affine {
|
||||||
@ -663,26 +1022,6 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn copy2d_<T: Copy>(
|
|
||||||
src: &[T],
|
|
||||||
dst: &mut [T],
|
|
||||||
d1: usize,
|
|
||||||
d2: usize,
|
|
||||||
src_stride1: usize,
|
|
||||||
dst_stride1: usize,
|
|
||||||
src_offset: usize,
|
|
||||||
dst_offset: usize,
|
|
||||||
) {
|
|
||||||
for i1 in 0..d1 {
|
|
||||||
let dst_idx = i1 * dst_stride1 + dst_offset;
|
|
||||||
let src_idx = i1 * src_stride1 + src_offset;
|
|
||||||
let dst = &mut dst[dst_idx..dst_idx + d2];
|
|
||||||
let src = &src[src_idx..src_idx + d2];
|
|
||||||
dst.copy_from_slice(src)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
|
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
|
||||||
match src_l.strided_blocks() {
|
match src_l.strided_blocks() {
|
||||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
@ -917,34 +1256,6 @@ impl Map1 for Im2Col {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Col2Im1D {
|
|
||||||
stride: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Map1 for Col2Im1D {
|
|
||||||
fn f<T: WithDType>(&self, col: &[T], l: &Layout) -> Result<Vec<T>> {
|
|
||||||
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
|
|
||||||
let stride = self.stride;
|
|
||||||
let l_out = (l_in - 1) * stride + k_size;
|
|
||||||
let mut im = vec![T::zero(); b_size * c_out * l_out];
|
|
||||||
let (dst_s0, dst_s1) = (c_out * l_out, l_out);
|
|
||||||
let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
|
|
||||||
for l_in_i in 0..l_in {
|
|
||||||
for k_i in 0..k_size {
|
|
||||||
let l_out_i = l_in_i * stride + k_i;
|
|
||||||
for b_i in 0..b_size {
|
|
||||||
for c_i in 0..c_out {
|
|
||||||
let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
|
|
||||||
let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
|
|
||||||
im[dst_idx] += col[src_idx]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(im)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||||
|
|
||||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||||
@ -952,7 +1263,6 @@ impl<'a> Map2 for ConvTranspose1D<'a> {
|
|||||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||||
let p = self.0;
|
let p = self.0;
|
||||||
let inp = &inp[inp_l.start_offset()..];
|
let inp = &inp[inp_l.start_offset()..];
|
||||||
let k = &k[k_l.start_offset()..];
|
|
||||||
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
||||||
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||||
let l_out = p.l_out();
|
let l_out = p.l_out();
|
||||||
@ -1204,30 +1514,6 @@ impl MatMul {
|
|||||||
}))
|
}))
|
||||||
.bt()
|
.bt()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
|
|
||||||
let lhs_stride = lhs_l.stride();
|
|
||||||
let rhs_stride = rhs_l.stride();
|
|
||||||
let rank = lhs_stride.len();
|
|
||||||
let (_b, m, n, k) = self.0;
|
|
||||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
|
||||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
|
||||||
[_, stride] if lhs_l.dims()[0] == 1 => stride,
|
|
||||||
[stride, _] if lhs_l.dims()[1] == 1 => stride,
|
|
||||||
[stride] => stride,
|
|
||||||
[] => m * k,
|
|
||||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
|
||||||
};
|
|
||||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
|
||||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
|
||||||
[_, stride] if rhs_l.dims()[0] == 1 => stride,
|
|
||||||
[stride, _] if rhs_l.dims()[1] == 1 => stride,
|
|
||||||
[stride] => stride,
|
|
||||||
[] => n * k,
|
|
||||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
|
||||||
};
|
|
||||||
Ok((a_skip, b_skip))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Map2 for MatMul {
|
impl Map2 for MatMul {
|
||||||
@ -1261,7 +1547,18 @@ impl Map2 for MatMul {
|
|||||||
let rhs_cs = rhs_stride[rank - 1];
|
let rhs_cs = rhs_stride[rank - 1];
|
||||||
let rhs_rs = rhs_stride[rank - 2];
|
let rhs_rs = rhs_stride[rank - 2];
|
||||||
|
|
||||||
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => m * k,
|
||||||
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||||
|
};
|
||||||
|
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => n * k,
|
||||||
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||||
|
};
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
let dst_shape: Shape = (m, n).into();
|
let dst_shape: Shape = (m, n).into();
|
||||||
@ -1321,8 +1618,20 @@ impl Map2 for MatMul {
|
|||||||
|
|
||||||
let lhs_stride = lhs_l.stride();
|
let lhs_stride = lhs_l.stride();
|
||||||
let rhs_stride = rhs_l.stride();
|
let rhs_stride = rhs_l.stride();
|
||||||
|
let rank = lhs_stride.len();
|
||||||
|
|
||||||
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => m * k,
|
||||||
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||||
|
};
|
||||||
|
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => n * k,
|
||||||
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||||
|
};
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||||
@ -1330,7 +1639,7 @@ impl Map2 for MatMul {
|
|||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
|
|
||||||
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||||
(n as i32, b'N')
|
(n as i32, b'N')
|
||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
(k as i32, b'T')
|
(k as i32, b'T')
|
||||||
@ -1338,7 +1647,7 @@ impl Map2 for MatMul {
|
|||||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||||
};
|
};
|
||||||
// The b tensor has dims batching, m, k (lhs)
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||||
(k as i32, b'N')
|
(k as i32, b'N')
|
||||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
(m as i32, b'T')
|
(m as i32, b'T')
|
||||||
@ -1412,8 +1721,20 @@ impl Map2 for MatMul {
|
|||||||
|
|
||||||
let lhs_stride = lhs_l.stride();
|
let lhs_stride = lhs_l.stride();
|
||||||
let rhs_stride = rhs_l.stride();
|
let rhs_stride = rhs_l.stride();
|
||||||
|
let rank = lhs_stride.len();
|
||||||
|
|
||||||
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => m * k,
|
||||||
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||||
|
};
|
||||||
|
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => n * k,
|
||||||
|
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||||
|
};
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||||
@ -1421,7 +1742,7 @@ impl Map2 for MatMul {
|
|||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
|
|
||||||
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||||
(n as i32, b'N')
|
(n as i32, b'N')
|
||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
(k as i32, b'T')
|
(k as i32, b'T')
|
||||||
@ -1429,7 +1750,7 @@ impl Map2 for MatMul {
|
|||||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||||
};
|
};
|
||||||
// The b tensor has dims batching, m, k (lhs)
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||||
(k as i32, b'N')
|
(k as i32, b'N')
|
||||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
(m as i32, b'T')
|
(m as i32, b'T')
|
||||||
@ -2101,48 +2422,6 @@ impl BackendStorage for CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn copy2d(
|
|
||||||
&self,
|
|
||||||
dst: &mut Self,
|
|
||||||
d1: usize,
|
|
||||||
d2: usize,
|
|
||||||
src_s: usize,
|
|
||||||
dst_s: usize,
|
|
||||||
src_o: usize,
|
|
||||||
dst_o: usize,
|
|
||||||
) -> Result<()> {
|
|
||||||
match (self, dst) {
|
|
||||||
(Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
|
|
||||||
(Self::U32(src), Self::U32(dst)) => {
|
|
||||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
|
||||||
}
|
|
||||||
(Self::I64(src), Self::I64(dst)) => {
|
|
||||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
|
||||||
}
|
|
||||||
(Self::BF16(src), Self::BF16(dst)) => {
|
|
||||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
|
||||||
}
|
|
||||||
(Self::F16(src), Self::F16(dst)) => {
|
|
||||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
|
||||||
}
|
|
||||||
(Self::F32(src), Self::F32(dst)) => {
|
|
||||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
|
||||||
}
|
|
||||||
(Self::F64(src), Self::F64(dst)) => {
|
|
||||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
|
||||||
}
|
|
||||||
(_, dst) => {
|
|
||||||
return Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: self.dtype(),
|
|
||||||
rhs: dst.dtype(),
|
|
||||||
op: "copy2d",
|
|
||||||
}
|
|
||||||
.bt());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||||
match (self, dst) {
|
match (self, dst) {
|
||||||
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||||
@ -2211,10 +2490,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
} else {
|
} else {
|
||||||
// Make the kernel contiguous if not already the case.
|
// Make the kernel contiguous if not already the case.
|
||||||
let mut kernel_c = unsafe {
|
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||||
self.device()
|
|
||||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
|
||||||
};
|
|
||||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
@ -2222,7 +2498,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
};
|
};
|
||||||
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
|
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
|
||||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
@ -2234,52 +2510,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
kernel_l: &Layout,
|
kernel_l: &Layout,
|
||||||
params: &crate::conv::ParamsConvTranspose1D,
|
params: &crate::conv::ParamsConvTranspose1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let can_use_col2im = kernel_l.is_contiguous()
|
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
||||||
&& params.dilation == 1
|
|
||||||
&& params.padding == 0
|
|
||||||
&& params.output_padding == 0;
|
|
||||||
if USE_IM2COL_CONV1D_TR && can_use_col2im {
|
|
||||||
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
|
||||||
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
|
||||||
if !kernel_l.is_contiguous() {
|
|
||||||
crate::bail!(
|
|
||||||
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if c_in != c_in2 {
|
|
||||||
crate::bail!(
|
|
||||||
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
|
||||||
l.shape(),
|
|
||||||
kernel_l.shape()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
let col = {
|
|
||||||
// This merges the last two dimensions of the kernel together.
|
|
||||||
let kernel_l_mm = Layout::new(
|
|
||||||
(b_size, c_in, k_size * c_out).into(),
|
|
||||||
vec![0, k_size * c_out, 1],
|
|
||||||
kernel_l.start_offset(),
|
|
||||||
);
|
|
||||||
self.matmul(
|
|
||||||
kernel,
|
|
||||||
(
|
|
||||||
b_size,
|
|
||||||
/* m */ l_in,
|
|
||||||
/* n */ c_out * k_size,
|
|
||||||
/* k */ c_in,
|
|
||||||
),
|
|
||||||
&l.transpose(1, 2)?,
|
|
||||||
&kernel_l_mm,
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
|
|
||||||
Col2Im1D {
|
|
||||||
stride: params.stride,
|
|
||||||
}
|
|
||||||
.map(&col, &col_l)
|
|
||||||
} else {
|
|
||||||
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
@ -2313,10 +2544,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
} else {
|
} else {
|
||||||
// Make the kernel contiguous if not already the case.
|
// Make the kernel contiguous if not already the case.
|
||||||
let mut kernel_c = unsafe {
|
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||||
self.device()
|
|
||||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
|
||||||
};
|
|
||||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
@ -2326,7 +2554,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
|
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
.transpose(1, 3)?;
|
.transpose(1, 3)?;
|
||||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
@ -2346,7 +2574,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||||
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||||
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()),
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2355,7 +2583,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||||
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||||
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()),
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2372,7 +2600,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||||
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||||
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2449,10 +2677,6 @@ impl BackendDevice for CpuDevice {
|
|||||||
Ok(s.clone())
|
Ok(s.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result<Self::Storage> {
|
|
||||||
Ok(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new(_: usize) -> Result<Self> {
|
fn new(_: usize) -> Result<Self> {
|
||||||
Ok(Self)
|
Ok(Self)
|
||||||
}
|
}
|
||||||
@ -2554,53 +2778,6 @@ impl BackendDevice for CpuDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::uninit_vec)]
|
|
||||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
// The code below is highly unsafe but hopefully not directly unsound as we only consider
|
|
||||||
// types that are Copy, not Drop, and for which all bit patterns are proper values.
|
|
||||||
// It's still pretty risky, see the following for more details:
|
|
||||||
// https://github.com/rust-lang/rust-clippy/issues/4483
|
|
||||||
let storage = match dtype {
|
|
||||||
DType::U8 => {
|
|
||||||
let mut v = Vec::with_capacity(elem_count);
|
|
||||||
v.set_len(elem_count);
|
|
||||||
CpuStorage::U8(v)
|
|
||||||
}
|
|
||||||
DType::U32 => {
|
|
||||||
let mut v = Vec::with_capacity(elem_count);
|
|
||||||
v.set_len(elem_count);
|
|
||||||
CpuStorage::U32(v)
|
|
||||||
}
|
|
||||||
DType::I64 => {
|
|
||||||
let mut v = Vec::with_capacity(elem_count);
|
|
||||||
v.set_len(elem_count);
|
|
||||||
CpuStorage::I64(v)
|
|
||||||
}
|
|
||||||
DType::BF16 => {
|
|
||||||
let mut v = Vec::with_capacity(elem_count);
|
|
||||||
v.set_len(elem_count);
|
|
||||||
CpuStorage::BF16(v)
|
|
||||||
}
|
|
||||||
DType::F16 => {
|
|
||||||
let mut v = Vec::with_capacity(elem_count);
|
|
||||||
v.set_len(elem_count);
|
|
||||||
CpuStorage::F16(v)
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
let mut v = Vec::with_capacity(elem_count);
|
|
||||||
v.set_len(elem_count);
|
|
||||||
CpuStorage::F32(v)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
let mut v = Vec::with_capacity(elem_count);
|
|
||||||
v.set_len(elem_count);
|
|
||||||
CpuStorage::F64(v)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(storage)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let storage = match dtype {
|
let storage = match dtype {
|
||||||
@ -2628,10 +2805,6 @@ impl BackendDevice for CpuDevice {
|
|||||||
};
|
};
|
||||||
Ok(storage)
|
Ok(storage)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn synchronize(&self) -> Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
@ -1,350 +0,0 @@
|
|||||||
/// Helper functions to write CPU kernels.
|
|
||||||
use crate::backend::BackendStorage;
|
|
||||||
use crate::{Error, Layout, Result, WithDType};
|
|
||||||
|
|
||||||
type C = super::CpuStorage;
|
|
||||||
pub trait Map1 {
|
|
||||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
|
||||||
|
|
||||||
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
|
|
||||||
match vs {
|
|
||||||
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
|
|
||||||
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
|
|
||||||
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
|
|
||||||
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
|
|
||||||
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
|
|
||||||
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
|
|
||||||
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map1Any {
|
|
||||||
fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
|
|
||||||
|
|
||||||
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
|
|
||||||
match vs {
|
|
||||||
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
|
|
||||||
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
|
|
||||||
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
|
|
||||||
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
|
|
||||||
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
|
|
||||||
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
|
|
||||||
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map2 {
|
|
||||||
const OP: &'static str;
|
|
||||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
|
||||||
|
|
||||||
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
|
|
||||||
match (v1, v2) {
|
|
||||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
|
||||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: v1.dtype(),
|
|
||||||
rhs: v2.dtype(),
|
|
||||||
op: Self::OP,
|
|
||||||
}
|
|
||||||
.bt()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map2U8 {
|
|
||||||
const OP: &'static str;
|
|
||||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
|
||||||
|
|
||||||
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
|
|
||||||
match (v1, v2) {
|
|
||||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: v1.dtype(),
|
|
||||||
rhs: v2.dtype(),
|
|
||||||
op: Self::OP,
|
|
||||||
}
|
|
||||||
.bt()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
|
||||||
lhs_l: &Layout,
|
|
||||||
rhs_l: &Layout,
|
|
||||||
lhs: &[T],
|
|
||||||
rhs: &[T],
|
|
||||||
mut f: F,
|
|
||||||
) -> Vec<U> {
|
|
||||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
|
||||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
|
||||||
.iter()
|
|
||||||
.zip(rhs[o_r1..o_r2].iter())
|
|
||||||
.map(|(&l, &r)| f(l, r))
|
|
||||||
.collect(),
|
|
||||||
(Some((o_l1, o_l2)), None) => {
|
|
||||||
// TODO: Maybe we want to avoid going through the layout twice.
|
|
||||||
match rhs_l.offsets_b() {
|
|
||||||
Some(ob) => {
|
|
||||||
let mut i_in_block = 0;
|
|
||||||
let mut i_right_broadcast = 0;
|
|
||||||
lhs[o_l1..o_l2]
|
|
||||||
.iter()
|
|
||||||
.map(|&l| {
|
|
||||||
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
|
||||||
i_right_broadcast += 1;
|
|
||||||
if i_right_broadcast >= ob.right_broadcast {
|
|
||||||
i_in_block += 1;
|
|
||||||
i_right_broadcast = 0;
|
|
||||||
}
|
|
||||||
if i_in_block >= ob.len {
|
|
||||||
i_in_block = 0
|
|
||||||
}
|
|
||||||
f(l, *r)
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(None, Some((o_r1, o_r2))) => {
|
|
||||||
// TODO: Maybe we want to avoid going through the layout twice.
|
|
||||||
match lhs_l.offsets_b() {
|
|
||||||
Some(ob) => {
|
|
||||||
let mut i_in_block = 0;
|
|
||||||
let mut i_right_broadcast = 0;
|
|
||||||
rhs[o_r1..o_r2]
|
|
||||||
.iter()
|
|
||||||
.map(|&r| {
|
|
||||||
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
|
||||||
i_right_broadcast += 1;
|
|
||||||
if i_right_broadcast >= ob.right_broadcast {
|
|
||||||
i_in_block += 1;
|
|
||||||
i_right_broadcast = 0;
|
|
||||||
}
|
|
||||||
if i_in_block >= ob.len {
|
|
||||||
i_in_block = 0
|
|
||||||
}
|
|
||||||
f(*l, r)
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Similar to binary_map but with vectorized variants.
|
|
||||||
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
|
||||||
lhs_l: &Layout,
|
|
||||||
rhs_l: &Layout,
|
|
||||||
lhs: &[T],
|
|
||||||
rhs: &[T],
|
|
||||||
mut f: F,
|
|
||||||
mut f_vec: FV,
|
|
||||||
) -> Vec<T> {
|
|
||||||
let el_count = lhs_l.shape().elem_count();
|
|
||||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
|
||||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
||||||
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
|
||||||
Some(ob) if ob.right_broadcast == 1 => {
|
|
||||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
||||||
let mut dst_i = 0;
|
|
||||||
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
|
||||||
f_vec(
|
|
||||||
&lhs[src_i..src_i + ob.len],
|
|
||||||
rhs,
|
|
||||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
|
||||||
);
|
|
||||||
dst_i += ob.len;
|
|
||||||
}
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
Some(ob) => {
|
|
||||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys = lhs[o_l1..o_l2].to_vec();
|
|
||||||
for idx_l in 0..ob.left_broadcast {
|
|
||||||
let start = idx_l * ob.len * ob.right_broadcast;
|
|
||||||
for (i, &r) in rhs.iter().enumerate() {
|
|
||||||
let start = start + i * ob.right_broadcast;
|
|
||||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
|
||||||
*v = f(*v, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
|
||||||
Some(ob) if ob.right_broadcast == 1 => {
|
|
||||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
||||||
let mut dst_i = 0;
|
|
||||||
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
|
||||||
f_vec(
|
|
||||||
lhs,
|
|
||||||
&rhs[src_i..src_i + ob.len],
|
|
||||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
|
||||||
);
|
|
||||||
dst_i += ob.len;
|
|
||||||
}
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
Some(ob) => {
|
|
||||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys = rhs[o_r1..o_r2].to_vec();
|
|
||||||
for idx_l in 0..ob.left_broadcast {
|
|
||||||
let start = idx_l * ob.len * ob.right_broadcast;
|
|
||||||
for (i, &l) in lhs.iter().enumerate() {
|
|
||||||
let start = start + i * ob.right_broadcast;
|
|
||||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
|
||||||
*v = f(l, *v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
_ => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
|
||||||
vs: &[T],
|
|
||||||
layout: &Layout,
|
|
||||||
mut f: F,
|
|
||||||
) -> Vec<U> {
|
|
||||||
match layout.strided_blocks() {
|
|
||||||
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
|
||||||
[start_offset..start_offset + len]
|
|
||||||
.iter()
|
|
||||||
.map(|&v| f(v))
|
|
||||||
.collect(),
|
|
||||||
crate::StridedBlocks::MultipleBlocks {
|
|
||||||
block_start_index,
|
|
||||||
block_len,
|
|
||||||
} => {
|
|
||||||
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
|
||||||
// Specialize the case where block_len is one to avoid the second loop.
|
|
||||||
if block_len == 1 {
|
|
||||||
for index in block_start_index {
|
|
||||||
let v = unsafe { vs.get_unchecked(index) };
|
|
||||||
result.push(f(*v))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for index in block_start_index {
|
|
||||||
for offset in 0..block_len {
|
|
||||||
let v = unsafe { vs.get_unchecked(index + offset) };
|
|
||||||
result.push(f(*v))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
|
||||||
vs: &[T],
|
|
||||||
layout: &Layout,
|
|
||||||
mut f: F,
|
|
||||||
mut f_vec: FV,
|
|
||||||
) -> Vec<U> {
|
|
||||||
match layout.strided_blocks() {
|
|
||||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
|
||||||
let mut ys: Vec<U> = Vec::with_capacity(len);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
|
||||||
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(len) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
crate::StridedBlocks::MultipleBlocks {
|
|
||||||
block_start_index,
|
|
||||||
block_len,
|
|
||||||
} => {
|
|
||||||
let el_count = layout.shape().elem_count();
|
|
||||||
// Specialize the case where block_len is one to avoid the second loop.
|
|
||||||
if block_len == 1 {
|
|
||||||
let mut result = Vec::with_capacity(el_count);
|
|
||||||
for index in block_start_index {
|
|
||||||
let v = unsafe { vs.get_unchecked(index) };
|
|
||||||
result.push(f(*v))
|
|
||||||
}
|
|
||||||
result
|
|
||||||
} else {
|
|
||||||
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
|
||||||
let mut dst_index = 0;
|
|
||||||
for src_index in block_start_index {
|
|
||||||
let vs = &vs[src_index..src_index + block_len];
|
|
||||||
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
|
||||||
f_vec(vs, ys);
|
|
||||||
dst_index += block_len;
|
|
||||||
}
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -5,41 +5,395 @@ pub use candle_kernels as kernels;
|
|||||||
pub use cudarc;
|
pub use cudarc;
|
||||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||||
use cudarc::driver::{
|
use cudarc::driver::{
|
||||||
CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
CudaFunction, CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig,
|
||||||
|
ValidAsZeroBits,
|
||||||
};
|
};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
#[cfg(feature = "cudnn")]
|
/// cudarc related errors
|
||||||
pub mod cudnn;
|
#[derive(thiserror::Error, Debug)]
|
||||||
mod device;
|
pub enum CudaError {
|
||||||
mod error;
|
#[error(transparent)]
|
||||||
mod utils;
|
Cuda(#[from] cudarc::driver::DriverError),
|
||||||
pub use device::{CudaDevice, DeviceId};
|
|
||||||
pub use error::{CudaError, WrapErr};
|
|
||||||
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
|
|
||||||
|
|
||||||
enum SlicePtrOrNull<T> {
|
#[error(transparent)]
|
||||||
Ptr(CudaSlice<T>),
|
Compiler(#[from] cudarc::nvrtc::CompileError),
|
||||||
Null,
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Cublas(#[from] cudarc::cublas::result::CublasError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Curand(#[from] cudarc::curand::result::CurandError),
|
||||||
|
|
||||||
|
#[error("missing kernel '{module_name}'")]
|
||||||
|
MissingKernel { module_name: String },
|
||||||
|
|
||||||
|
#[error("unsupported dtype {dtype:?} for {op}")]
|
||||||
|
UnsupportedDtype { dtype: DType, op: &'static str },
|
||||||
|
|
||||||
|
#[error("internal error '{0}'")]
|
||||||
|
InternalError(&'static str),
|
||||||
|
|
||||||
|
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||||
|
MatMulNonContiguous {
|
||||||
|
lhs_stride: Vec<usize>,
|
||||||
|
rhs_stride: Vec<usize>,
|
||||||
|
mnk: (usize, usize, usize),
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||||
|
UnexpectedDType {
|
||||||
|
msg: &'static str,
|
||||||
|
expected: DType,
|
||||||
|
got: DType,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("{cuda} when loading {module_name}")]
|
||||||
|
Load {
|
||||||
|
cuda: cudarc::driver::DriverError,
|
||||||
|
module_name: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<T: DeviceRepr> DeviceRepr for &SlicePtrOrNull<T> {
|
impl From<CudaError> for crate::Error {
|
||||||
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
|
fn from(val: CudaError) -> Self {
|
||||||
match self {
|
crate::Error::Cuda(Box::new(val)).bt()
|
||||||
SlicePtrOrNull::Ptr(slice) => slice.as_kernel_param(),
|
|
||||||
SlicePtrOrNull::Null => 0usize.as_kernel_param(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SlicePtrOrNull<usize> {
|
/// Unique identifier for cuda devices.
|
||||||
fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||||
let ds = if l.is_contiguous() {
|
pub struct DeviceId(usize);
|
||||||
SlicePtrOrNull::Null
|
|
||||||
} else {
|
impl DeviceId {
|
||||||
SlicePtrOrNull::Ptr(dev.htod_copy([l.dims(), l.stride()].concat()).w()?)
|
fn new() -> Self {
|
||||||
|
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||||
|
use std::sync::atomic;
|
||||||
|
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||||
|
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CudaRng(cudarc::curand::CudaRng);
|
||||||
|
unsafe impl Send for CudaRng {}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct CudaDevice {
|
||||||
|
id: DeviceId,
|
||||||
|
device: Arc<cudarc::driver::CudaDevice>,
|
||||||
|
blas: Arc<cudarc::cublas::CudaBlas>,
|
||||||
|
curand: Arc<Mutex<CudaRng>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for CudaDevice {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "CudaDevice({:?})", self.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for CudaDevice {
|
||||||
|
type Target = Arc<cudarc::driver::CudaDevice>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait WrapErr<O> {
|
||||||
|
fn w(self) -> std::result::Result<O, crate::Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
||||||
|
fn w(self) -> std::result::Result<O, crate::Error> {
|
||||||
|
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CudaDevice {
|
||||||
|
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
||||||
|
self.device.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn id(&self) -> DeviceId {
|
||||||
|
self.id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
||||||
|
let params = (&data, v as u8, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
||||||
|
let params = (&data, v as u32, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
||||||
|
let params = (&data, v as i64, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
||||||
|
let params = (&data, bf16::from_f64(v), elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
||||||
|
let params = (&data, f16::from_f64(v), elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
||||||
|
let params = (&data, v as f32, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
||||||
|
let params = (&data, v, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
Ok(ds)
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||||
|
if !self.has_func(module_name, module_name) {
|
||||||
|
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
||||||
|
// done once per kernel name.
|
||||||
|
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
||||||
|
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
||||||
|
.map_err(|cuda| CudaError::Load {
|
||||||
|
cuda,
|
||||||
|
module_name: module_name.to_string(),
|
||||||
|
})
|
||||||
|
.w()?;
|
||||||
|
}
|
||||||
|
self.get_func(module_name, module_name)
|
||||||
|
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
||||||
|
// able to only build the error value if needed.
|
||||||
|
.ok_or(CudaError::MissingKernel {
|
||||||
|
module_name: module_name.to_string(),
|
||||||
|
})
|
||||||
|
.w()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BackendDevice for CudaDevice {
|
||||||
|
type Storage = CudaStorage;
|
||||||
|
|
||||||
|
fn new(ordinal: usize) -> Result<Self> {
|
||||||
|
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
||||||
|
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
||||||
|
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
||||||
|
Ok(Self {
|
||||||
|
id: DeviceId::new(),
|
||||||
|
device,
|
||||||
|
blas: Arc::new(blas),
|
||||||
|
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
|
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||||
|
// state will be identical and the same random numbers will be generated.
|
||||||
|
let mut curand = self.curand.lock().unwrap();
|
||||||
|
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
|
crate::DeviceLocation::Cuda {
|
||||||
|
gpu_id: self.device.ordinal(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn same_device(&self, rhs: &Self) -> bool {
|
||||||
|
self.id == rhs.id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 => {
|
||||||
|
let data = self.alloc_zeros::<u8>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let data = self.alloc_zeros::<f16>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let data = self.alloc_zeros::<f32>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let data = self.alloc_zeros::<f64>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let curand = self.curand.lock().unwrap();
|
||||||
|
let slice = match dtype {
|
||||||
|
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||||
|
// cudarc changes.
|
||||||
|
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||||
|
Err(CudaError::UnsupportedDtype {
|
||||||
|
dtype,
|
||||||
|
op: "rand_uniform",
|
||||||
|
})
|
||||||
|
.w()?
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||||
|
curand.0.fill_with_uniform(&mut data).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||||
|
curand.0.fill_with_uniform(&mut data).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let slice = if lo == 0. && up == 1.0 {
|
||||||
|
slice
|
||||||
|
} else {
|
||||||
|
let layout = Layout::contiguous(shape);
|
||||||
|
Affine(up - lo, lo).map(&slice, self, &layout)?
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
|
||||||
|
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||||
|
// cudarc changes.
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let curand = self.curand.lock().unwrap();
|
||||||
|
// curand can only generate an odd number of values.
|
||||||
|
// https://github.com/huggingface/candle/issues/734
|
||||||
|
let elem_count_round = if elem_count % 2 == 1 {
|
||||||
|
elem_count + 1
|
||||||
|
} else {
|
||||||
|
elem_count
|
||||||
|
};
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||||
|
Err(CudaError::UnsupportedDtype {
|
||||||
|
dtype,
|
||||||
|
op: "rand_normal",
|
||||||
|
})
|
||||||
|
.w()?
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
||||||
|
curand
|
||||||
|
.0
|
||||||
|
.fill_with_normal(&mut data, mean as f32, std as f32)
|
||||||
|
.w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
||||||
|
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
|
self.const_impl(1., shape, dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||||
|
let slice = match storage {
|
||||||
|
CpuStorage::U8(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
CpuStorage::U32(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
CpuStorage::I64(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
CpuStorage::BF16(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F16(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F32(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F64(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,6 +407,133 @@ pub enum CudaStorageSlice {
|
|||||||
F32(CudaSlice<f32>),
|
F32(CudaSlice<f32>),
|
||||||
F64(CudaSlice<f64>),
|
F64(CudaSlice<f64>),
|
||||||
}
|
}
|
||||||
|
type S = CudaStorageSlice;
|
||||||
|
|
||||||
|
pub trait Map1 {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>>;
|
||||||
|
|
||||||
|
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||||
|
let out = match s {
|
||||||
|
S::U8(s) => S::U8(self.f(s, d, l)?),
|
||||||
|
S::U32(s) => S::U32(self.f(s, d, l)?),
|
||||||
|
S::I64(s) => S::I64(self.f(s, d, l)?),
|
||||||
|
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
||||||
|
S::F16(s) => S::F16(self.f(s, d, l)?),
|
||||||
|
S::F32(s) => S::F32(self.f(s, d, l)?),
|
||||||
|
S::F64(s) => S::F64(self.f(s, d, l)?),
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2 {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src1: &CudaSlice<T>,
|
||||||
|
layout1: &Layout,
|
||||||
|
src2: &CudaSlice<T>,
|
||||||
|
layout2: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaSlice<T>>;
|
||||||
|
|
||||||
|
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||||
|
let out = match (s1, s2) {
|
||||||
|
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2InPlace {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
dst: &mut CudaSlice<T>,
|
||||||
|
dst_shape: &Shape,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
src_l: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<()>;
|
||||||
|
|
||||||
|
fn map(
|
||||||
|
&self,
|
||||||
|
dst: &mut S,
|
||||||
|
dst_s: &Shape,
|
||||||
|
src: &S,
|
||||||
|
src_l: &Layout,
|
||||||
|
d: &CudaDevice,
|
||||||
|
) -> Result<()> {
|
||||||
|
match (dst, src) {
|
||||||
|
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map1Any {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
wrap: W,
|
||||||
|
) -> Result<S>;
|
||||||
|
|
||||||
|
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||||
|
let out = match s {
|
||||||
|
S::U8(s) => self.f(s, d, l, S::U8)?,
|
||||||
|
S::U32(s) => self.f(s, d, l, S::U32)?,
|
||||||
|
S::I64(s) => self.f(s, d, l, S::I64)?,
|
||||||
|
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
||||||
|
S::F16(s) => self.f(s, d, l, S::F16)?,
|
||||||
|
S::F32(s) => self.f(s, d, l, S::F32)?,
|
||||||
|
S::F64(s) => self.f(s, d, l, S::F64)?,
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2Any {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src1: &CudaSlice<T>,
|
||||||
|
layout1: &Layout,
|
||||||
|
src2: &CudaSlice<T>,
|
||||||
|
layout2: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<S>;
|
||||||
|
|
||||||
|
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||||
|
let out = match (s1, s2) {
|
||||||
|
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Clone;
|
struct Clone;
|
||||||
impl Map1 for Clone {
|
impl Map1 for Clone {
|
||||||
@ -83,7 +564,7 @@ impl Map1 for Affine {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
@ -115,7 +596,7 @@ impl Map1 for Elu {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), kernels::UNARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
@ -238,7 +719,7 @@ impl Map1 for Powf {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
@ -371,7 +852,7 @@ impl<U: UnaryOpT> Map1 for U {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::UNARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
@ -668,55 +1149,6 @@ impl<'a> Map2 for Conv2D<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
|
||||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
||||||
&self,
|
|
||||||
inp: &CudaSlice<T>,
|
|
||||||
inp_l: &Layout,
|
|
||||||
k: &CudaSlice<T>,
|
|
||||||
k_l: &Layout,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<CudaSlice<T>> {
|
|
||||||
// Kernel shape: (c_in_k, c_out, l_k)
|
|
||||||
// Input shape: (b_size, c_in, l_in)
|
|
||||||
let p = &self.0;
|
|
||||||
let l_out = p.l_out();
|
|
||||||
let dst_el = p.c_out * l_out * p.b_size;
|
|
||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
|
||||||
let k = &k.slice(k_l.start_offset()..);
|
|
||||||
let shape = inp_l.shape();
|
|
||||||
let dims = shape.dims();
|
|
||||||
let el = shape.elem_count();
|
|
||||||
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
|
||||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), kernels::CONV)?;
|
|
||||||
let ds = if dims.len() == 3 {
|
|
||||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
|
||||||
} else {
|
|
||||||
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
|
|
||||||
};
|
|
||||||
let ds = dev.htod_copy(ds).w()?;
|
|
||||||
let params = (
|
|
||||||
el,
|
|
||||||
l_out,
|
|
||||||
p.stride,
|
|
||||||
p.padding,
|
|
||||||
p.output_padding,
|
|
||||||
p.dilation,
|
|
||||||
&ds,
|
|
||||||
inp,
|
|
||||||
k,
|
|
||||||
&out,
|
|
||||||
);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
|
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
|
||||||
impl<'a> Map2 for ConvTranspose2D<'a> {
|
impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
@ -921,14 +1353,9 @@ impl<U: crate::op::BinaryOpT> Map2 for U {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||||
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
let dims_and_strides = dev
|
||||||
SlicePtrOrNull::Null
|
.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||||
} else {
|
.w()?;
|
||||||
SlicePtrOrNull::Ptr(
|
|
||||||
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
|
||||||
.w()?,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::BINARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::BINARY)?;
|
||||||
@ -955,14 +1382,9 @@ impl Map2Any for Cmp {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||||
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
let dims_and_strides = dev
|
||||||
SlicePtrOrNull::Null
|
.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||||
} else {
|
.w()?;
|
||||||
SlicePtrOrNull::Ptr(
|
|
||||||
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
|
||||||
.w()?,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||||
let name = match self.0 {
|
let name = match self.0 {
|
||||||
@ -1070,30 +1492,26 @@ fn gemm_config<T>(
|
|||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
// The a tensor has dims batching, k, n (rhs)
|
// The a tensor has dims batching, k, n (rhs)
|
||||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||||
// there is a single element.
|
|
||||||
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
|
||||||
(n as i32, cublasOperation_t::CUBLAS_OP_N)
|
(n as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||||
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
(k as i32, cublasOperation_t::CUBLAS_OP_T)
|
(k as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||||
} else {
|
} else {
|
||||||
Err(CudaError::MatMulNonContiguous {
|
Err(CudaError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_l.clone(),
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
rhs_stride: rhs_l.clone(),
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?
|
})?
|
||||||
};
|
};
|
||||||
// The b tensor has dims batching, m, k (lhs)
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||||
// there is a single element.
|
|
||||||
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
|
||||||
(k as i32, cublasOperation_t::CUBLAS_OP_N)
|
(k as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||||
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
(m as i32, cublasOperation_t::CUBLAS_OP_T)
|
(m as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||||
} else {
|
} else {
|
||||||
Err(CudaError::MatMulNonContiguous {
|
Err(CudaError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_l.clone(),
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
rhs_stride: rhs_l.clone(),
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?
|
})?
|
||||||
};
|
};
|
||||||
@ -1114,25 +1532,21 @@ fn gemm_config<T>(
|
|||||||
|
|
||||||
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
|
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
|
||||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||||
[_, stride] if lhs_l.dims()[0] == 1 => stride,
|
|
||||||
[stride, _] if lhs_l.dims()[1] == 1 => stride,
|
|
||||||
[stride] => stride,
|
[stride] => stride,
|
||||||
[] => m * k,
|
[] => m * k,
|
||||||
_ => Err(CudaError::MatMulNonContiguous {
|
_ => Err(CudaError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_l.clone(),
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
rhs_stride: rhs_l.clone(),
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?,
|
})?,
|
||||||
};
|
};
|
||||||
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
|
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
|
||||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
[_, stride] if rhs_l.dims()[0] == 1 => stride,
|
|
||||||
[stride, _] if rhs_l.dims()[1] == 1 => stride,
|
|
||||||
[stride] => stride,
|
[stride] => stride,
|
||||||
[] => n * k,
|
[] => n * k,
|
||||||
_ => Err(CudaError::MatMulNonContiguous {
|
_ => Err(CudaError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_l.clone(),
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
rhs_stride: rhs_l.clone(),
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?,
|
})?,
|
||||||
};
|
};
|
||||||
@ -1177,7 +1591,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
let dev = self.device();
|
let dev = self.device();
|
||||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||||
let start_o = layout.start_offset();
|
let start_o = layout.start_offset();
|
||||||
// This returns an i64 rather than a &i64, this is useful to get around some temporary
|
// This returns an i64 rather than a &i64, this is useful to get around some temporary
|
||||||
// lifetime issue and is safe as long as self.slice does not go out of scope before inp
|
// lifetime issue and is safe as long as self.slice does not go out of scope before inp
|
||||||
@ -1381,10 +1795,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
} else {
|
} else {
|
||||||
// Make the kernel contiguous if not already the case.
|
// Make the kernel contiguous if not already the case.
|
||||||
let mut kernel_c = unsafe {
|
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||||
self.device()
|
|
||||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
|
||||||
};
|
|
||||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
@ -1392,22 +1803,19 @@ impl BackendStorage for CudaStorage {
|
|||||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
};
|
};
|
||||||
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
||||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv_transpose1d(
|
fn conv_transpose1d(
|
||||||
&self,
|
&self,
|
||||||
l: &Layout,
|
_: &Layout,
|
||||||
kernel: &Self,
|
_: &Self,
|
||||||
kernel_l: &Layout,
|
_: &Layout,
|
||||||
params: &crate::conv::ParamsConvTranspose1D,
|
_: &crate::conv::ParamsConvTranspose1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
todo!()
|
||||||
let slice =
|
|
||||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
|
||||||
Ok(Self { slice, device })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "cudnn"))]
|
#[cfg(not(feature = "cudnn"))]
|
||||||
@ -1449,10 +1857,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
} else {
|
} else {
|
||||||
// Make the kernel contiguous if not already the case.
|
// Make the kernel contiguous if not already the case.
|
||||||
let mut kernel_c = unsafe {
|
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||||
self.device()
|
|
||||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
|
||||||
};
|
|
||||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
@ -1462,7 +1867,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
.transpose(1, 3)?;
|
.transpose(1, 3)?;
|
||||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
@ -1599,7 +2004,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
dim: usize,
|
dim: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
|
||||||
self.copy_strided_src(&mut acc, 0, l)?;
|
self.copy_strided_src(&mut acc, 0, l)?;
|
||||||
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||||
Ok(acc)
|
Ok(acc)
|
||||||
@ -1614,7 +2019,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
dim: usize,
|
dim: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
|
||||||
self.copy_strided_src(&mut acc, 0, l)?;
|
self.copy_strided_src(&mut acc, 0, l)?;
|
||||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||||
Ok(acc)
|
Ok(acc)
|
||||||
@ -1688,72 +2093,6 @@ impl BackendStorage for CudaStorage {
|
|||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn copy2d(
|
|
||||||
&self,
|
|
||||||
dst: &mut Self,
|
|
||||||
d1: usize,
|
|
||||||
d2: usize,
|
|
||||||
src_s: usize,
|
|
||||||
dst_s: usize,
|
|
||||||
src_o: usize,
|
|
||||||
dst_o: usize,
|
|
||||||
) -> Result<()> {
|
|
||||||
let dev = &self.device;
|
|
||||||
let d1 = d1 as u32;
|
|
||||||
let d2 = d2 as u32;
|
|
||||||
// Nothing to copy so we exit early to avoid launching a kernel and some potential invalid
|
|
||||||
// argument with a null pointer.
|
|
||||||
if d1 == 0 || d2 == 0 {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let dst_s = dst_s as u32;
|
|
||||||
let src_s = src_s as u32;
|
|
||||||
let (src, dst, kname) = match (&self.slice, &mut dst.slice) {
|
|
||||||
(S::U8(s), S::U8(d)) => (
|
|
||||||
*s.slice(src_o..).device_ptr(),
|
|
||||||
*d.slice(dst_o..).device_ptr(),
|
|
||||||
"copy2d_u8",
|
|
||||||
),
|
|
||||||
(S::U32(s), S::U32(d)) => (
|
|
||||||
*s.slice(src_o..).device_ptr(),
|
|
||||||
*d.slice(dst_o..).device_ptr(),
|
|
||||||
"copy2d_u32",
|
|
||||||
),
|
|
||||||
(S::I64(s), S::I64(d)) => (
|
|
||||||
*s.slice(src_o..).device_ptr(),
|
|
||||||
*d.slice(dst_o..).device_ptr(),
|
|
||||||
"copy2d_i64",
|
|
||||||
),
|
|
||||||
(S::BF16(s), S::BF16(d)) => (
|
|
||||||
*s.slice(src_o..).device_ptr(),
|
|
||||||
*d.slice(dst_o..).device_ptr(),
|
|
||||||
"copy2d_bf16",
|
|
||||||
),
|
|
||||||
(S::F16(s), S::F16(d)) => (
|
|
||||||
*s.slice(src_o..).device_ptr(),
|
|
||||||
*d.slice(dst_o..).device_ptr(),
|
|
||||||
"copy2d_f16",
|
|
||||||
),
|
|
||||||
(S::F32(s), S::F32(d)) => (
|
|
||||||
*s.slice(src_o..).device_ptr(),
|
|
||||||
*d.slice(dst_o..).device_ptr(),
|
|
||||||
"copy2d_f32",
|
|
||||||
),
|
|
||||||
(S::F64(s), S::F64(d)) => (
|
|
||||||
*s.slice(src_o..).device_ptr(),
|
|
||||||
*d.slice(dst_o..).device_ptr(),
|
|
||||||
"copy2d_f64",
|
|
||||||
),
|
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?,
|
|
||||||
};
|
|
||||||
let func = dev.get_or_load_func(kname, kernels::FILL)?;
|
|
||||||
let cfg = LaunchConfig::for_num_elems(d1 * d2);
|
|
||||||
let params = (src, dst, d1, d2, src_s, dst_s);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||||
let src_shape = src_l.shape();
|
let src_shape = src_l.shape();
|
||||||
let dims = src_shape.dims();
|
let dims = src_shape.dims();
|
||||||
@ -1763,7 +2102,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
}
|
}
|
||||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||||
let dev = &self.device;
|
let dev = &self.device;
|
||||||
let ds = SlicePtrOrNull::params_from_layout(dev, src_l)?;
|
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
|
||||||
match (&self.slice, &mut dst.slice) {
|
match (&self.slice, &mut dst.slice) {
|
||||||
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
||||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
@ -1,415 +0,0 @@
|
|||||||
use crate::backend::BackendDevice;
|
|
||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
|
||||||
pub use candle_kernels as kernels;
|
|
||||||
pub use cudarc;
|
|
||||||
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
|
||||||
use half::{bf16, f16};
|
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
|
||||||
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
|
||||||
|
|
||||||
/// Unique identifier for cuda devices.
|
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
|
||||||
pub struct DeviceId(usize);
|
|
||||||
|
|
||||||
impl DeviceId {
|
|
||||||
fn new() -> Self {
|
|
||||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
|
||||||
use std::sync::atomic;
|
|
||||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
|
||||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct CudaRng(cudarc::curand::CudaRng);
|
|
||||||
unsafe impl Send for CudaRng {}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct CudaDevice {
|
|
||||||
id: DeviceId,
|
|
||||||
device: Arc<cudarc::driver::CudaDevice>,
|
|
||||||
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
|
||||||
curand: Arc<Mutex<CudaRng>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Debug for CudaDevice {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "CudaDevice({:?})", self.id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::ops::Deref for CudaDevice {
|
|
||||||
type Target = Arc<cudarc::driver::CudaDevice>;
|
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CudaDevice {
|
|
||||||
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
|
||||||
self.device.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn id(&self) -> DeviceId {
|
|
||||||
self.id
|
|
||||||
}
|
|
||||||
|
|
||||||
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
|
||||||
let slice = match dtype {
|
|
||||||
DType::U8 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
|
||||||
let params = (&data, v as u8, elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
DType::U32 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
|
||||||
let params = (&data, v as u32, elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
DType::I64 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
|
||||||
let params = (&data, v as i64, elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
DType::BF16 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
|
||||||
let params = (&data, bf16::from_f64(v), elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
DType::F16 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
|
||||||
let params = (&data, f16::from_f64(v), elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
|
||||||
let params = (&data, v as f32, elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
|
||||||
let params = (&data, v, elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
|
||||||
if !self.has_func(module_name, module_name) {
|
|
||||||
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
|
||||||
// done once per kernel name.
|
|
||||||
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
|
||||||
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
|
||||||
.map_err(|cuda| CudaError::Load {
|
|
||||||
cuda,
|
|
||||||
module_name: module_name.to_string(),
|
|
||||||
})
|
|
||||||
.w()?;
|
|
||||||
}
|
|
||||||
self.get_func(module_name, module_name)
|
|
||||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
|
||||||
// able to only build the error value if needed.
|
|
||||||
.ok_or(CudaError::MissingKernel {
|
|
||||||
module_name: module_name.to_string(),
|
|
||||||
})
|
|
||||||
.w()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BackendDevice for CudaDevice {
|
|
||||||
type Storage = CudaStorage;
|
|
||||||
|
|
||||||
fn new(ordinal: usize) -> Result<Self> {
|
|
||||||
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
|
||||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
|
||||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
|
||||||
Ok(Self {
|
|
||||||
id: DeviceId::new(),
|
|
||||||
device,
|
|
||||||
blas: Arc::new(blas),
|
|
||||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
|
||||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
|
||||||
// state will be identical and the same random numbers will be generated.
|
|
||||||
let mut curand = self.curand.lock().unwrap();
|
|
||||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn location(&self) -> crate::DeviceLocation {
|
|
||||||
crate::DeviceLocation::Cuda {
|
|
||||||
gpu_id: self.device.ordinal(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn same_device(&self, rhs: &Self) -> bool {
|
|
||||||
self.id == rhs.id
|
|
||||||
}
|
|
||||||
|
|
||||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let slice = match dtype {
|
|
||||||
DType::U8 => {
|
|
||||||
let data = self.alloc_zeros::<u8>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
DType::U32 => {
|
|
||||||
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
DType::I64 => {
|
|
||||||
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
DType::BF16 => {
|
|
||||||
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
DType::F16 => {
|
|
||||||
let data = self.alloc_zeros::<f16>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
let data = self.alloc_zeros::<f32>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
let data = self.alloc_zeros::<f64>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let curand = self.curand.lock().unwrap();
|
|
||||||
let slice = match dtype {
|
|
||||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
|
||||||
// cudarc changes.
|
|
||||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
|
||||||
Err(CudaError::UnsupportedDtype {
|
|
||||||
dtype,
|
|
||||||
op: "rand_uniform",
|
|
||||||
})
|
|
||||||
.w()?
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
|
||||||
curand.0.fill_with_uniform(&mut data).w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
|
||||||
curand.0.fill_with_uniform(&mut data).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let slice = if lo == 0. && up == 1.0 {
|
|
||||||
slice
|
|
||||||
} else {
|
|
||||||
use super::utils::Map1;
|
|
||||||
let layout = Layout::contiguous(shape);
|
|
||||||
super::Affine(up - lo, lo).map(&slice, self, &layout)?
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
|
|
||||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
|
||||||
// cudarc changes.
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let curand = self.curand.lock().unwrap();
|
|
||||||
// curand can only generate an odd number of values.
|
|
||||||
// https://github.com/huggingface/candle/issues/734
|
|
||||||
let elem_count_round = if elem_count % 2 == 1 {
|
|
||||||
elem_count + 1
|
|
||||||
} else {
|
|
||||||
elem_count
|
|
||||||
};
|
|
||||||
let slice = match dtype {
|
|
||||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
|
||||||
Err(CudaError::UnsupportedDtype {
|
|
||||||
dtype,
|
|
||||||
op: "rand_normal",
|
|
||||||
})
|
|
||||||
.w()?
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
|
||||||
curand
|
|
||||||
.0
|
|
||||||
.fill_with_normal(&mut data, mean as f32, std as f32)
|
|
||||||
.w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
|
||||||
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
|
||||||
self.const_impl(1., shape, dtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let slice = match dtype {
|
|
||||||
DType::U8 => {
|
|
||||||
let data = self.alloc::<u8>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
DType::U32 => {
|
|
||||||
let data = self.alloc::<u32>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
DType::I64 => {
|
|
||||||
let data = self.alloc::<i64>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
DType::BF16 => {
|
|
||||||
let data = self.alloc::<bf16>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
DType::F16 => {
|
|
||||||
let data = self.alloc::<f16>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
let data = self.alloc::<f32>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
let data = self.alloc::<f64>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
|
||||||
let slice = match storage {
|
|
||||||
CpuStorage::U8(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
CpuStorage::U32(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
CpuStorage::I64(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
CpuStorage::BF16(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F16(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F32(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F64(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
|
||||||
let slice = match storage {
|
|
||||||
CpuStorage::U8(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
CpuStorage::U32(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
CpuStorage::I64(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
CpuStorage::BF16(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F16(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F32(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F64(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn synchronize(&self) -> Result<()> {
|
|
||||||
self.device.synchronize().map_err(crate::Error::wrap)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,62 +0,0 @@
|
|||||||
use crate::{DType, Layout};
|
|
||||||
|
|
||||||
/// cudarc related errors
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
|
||||||
pub enum CudaError {
|
|
||||||
#[error(transparent)]
|
|
||||||
Cuda(#[from] cudarc::driver::DriverError),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
Compiler(#[from] cudarc::nvrtc::CompileError),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
Cublas(#[from] cudarc::cublas::result::CublasError),
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
Curand(#[from] cudarc::curand::result::CurandError),
|
|
||||||
|
|
||||||
#[error("missing kernel '{module_name}'")]
|
|
||||||
MissingKernel { module_name: String },
|
|
||||||
|
|
||||||
#[error("unsupported dtype {dtype:?} for {op}")]
|
|
||||||
UnsupportedDtype { dtype: DType, op: &'static str },
|
|
||||||
|
|
||||||
#[error("internal error '{0}'")]
|
|
||||||
InternalError(&'static str),
|
|
||||||
|
|
||||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
|
||||||
MatMulNonContiguous {
|
|
||||||
lhs_stride: Layout,
|
|
||||||
rhs_stride: Layout,
|
|
||||||
mnk: (usize, usize, usize),
|
|
||||||
},
|
|
||||||
|
|
||||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
|
||||||
UnexpectedDType {
|
|
||||||
msg: &'static str,
|
|
||||||
expected: DType,
|
|
||||||
got: DType,
|
|
||||||
},
|
|
||||||
|
|
||||||
#[error("{cuda} when loading {module_name}")]
|
|
||||||
Load {
|
|
||||||
cuda: cudarc::driver::DriverError,
|
|
||||||
module_name: String,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<CudaError> for crate::Error {
|
|
||||||
fn from(val: CudaError) -> Self {
|
|
||||||
crate::Error::Cuda(Box::new(val)).bt()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait WrapErr<O> {
|
|
||||||
fn w(self) -> std::result::Result<O, crate::Error>;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
|
||||||
fn w(self) -> std::result::Result<O, crate::Error> {
|
|
||||||
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt())
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,134 +0,0 @@
|
|||||||
/// Helper functions to plug cuda kernels in candle.
|
|
||||||
use crate::{Layout, Result, Shape, WithDType};
|
|
||||||
pub use cudarc;
|
|
||||||
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
|
|
||||||
|
|
||||||
use super::{CudaDevice, CudaError, WrapErr};
|
|
||||||
|
|
||||||
pub type S = super::CudaStorageSlice;
|
|
||||||
|
|
||||||
pub trait Map1 {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
||||||
&self,
|
|
||||||
src: &CudaSlice<T>,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
layout: &Layout,
|
|
||||||
) -> Result<CudaSlice<T>>;
|
|
||||||
|
|
||||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
|
||||||
let out = match s {
|
|
||||||
S::U8(s) => S::U8(self.f(s, d, l)?),
|
|
||||||
S::U32(s) => S::U32(self.f(s, d, l)?),
|
|
||||||
S::I64(s) => S::I64(self.f(s, d, l)?),
|
|
||||||
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
|
||||||
S::F16(s) => S::F16(self.f(s, d, l)?),
|
|
||||||
S::F32(s) => S::F32(self.f(s, d, l)?),
|
|
||||||
S::F64(s) => S::F64(self.f(s, d, l)?),
|
|
||||||
};
|
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map2 {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
||||||
&self,
|
|
||||||
src1: &CudaSlice<T>,
|
|
||||||
layout1: &Layout,
|
|
||||||
src2: &CudaSlice<T>,
|
|
||||||
layout2: &Layout,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<CudaSlice<T>>;
|
|
||||||
|
|
||||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
|
||||||
let out = match (s1, s2) {
|
|
||||||
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
|
||||||
};
|
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map2InPlace {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
||||||
&self,
|
|
||||||
dst: &mut CudaSlice<T>,
|
|
||||||
dst_shape: &Shape,
|
|
||||||
src: &CudaSlice<T>,
|
|
||||||
src_l: &Layout,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<()>;
|
|
||||||
|
|
||||||
fn map(
|
|
||||||
&self,
|
|
||||||
dst: &mut S,
|
|
||||||
dst_s: &Shape,
|
|
||||||
src: &S,
|
|
||||||
src_l: &Layout,
|
|
||||||
d: &CudaDevice,
|
|
||||||
) -> Result<()> {
|
|
||||||
match (dst, src) {
|
|
||||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map1Any {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
|
||||||
&self,
|
|
||||||
src: &CudaSlice<T>,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
layout: &Layout,
|
|
||||||
wrap: W,
|
|
||||||
) -> Result<S>;
|
|
||||||
|
|
||||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
|
||||||
let out = match s {
|
|
||||||
S::U8(s) => self.f(s, d, l, S::U8)?,
|
|
||||||
S::U32(s) => self.f(s, d, l, S::U32)?,
|
|
||||||
S::I64(s) => self.f(s, d, l, S::I64)?,
|
|
||||||
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
|
||||||
S::F16(s) => self.f(s, d, l, S::F16)?,
|
|
||||||
S::F32(s) => self.f(s, d, l, S::F32)?,
|
|
||||||
S::F64(s) => self.f(s, d, l, S::F64)?,
|
|
||||||
};
|
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map2Any {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
||||||
&self,
|
|
||||||
src1: &CudaSlice<T>,
|
|
||||||
layout1: &Layout,
|
|
||||||
src2: &CudaSlice<T>,
|
|
||||||
layout2: &Layout,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<S>;
|
|
||||||
|
|
||||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
|
||||||
let out = match (s1, s2) {
|
|
||||||
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
|
|
||||||
};
|
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,377 +0,0 @@
|
|||||||
use crate::op::{BackpropOp, Op};
|
|
||||||
use crate::tensor::from_storage;
|
|
||||||
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
/// Unary ops that can be defined in user-land.
|
|
||||||
pub trait CustomOp1 {
|
|
||||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
_storage: &MetalStorage,
|
|
||||||
_layout: &Layout,
|
|
||||||
) -> Result<(MetalStorage, Shape)> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
|
||||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
|
||||||
/// The function should return the gradient of the argument.
|
|
||||||
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
|
|
||||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait CustomOp2 {
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(
|
|
||||||
&self,
|
|
||||||
s1: &CpuStorage,
|
|
||||||
l1: &Layout,
|
|
||||||
s2: &CpuStorage,
|
|
||||||
l2: &Layout,
|
|
||||||
) -> Result<(CpuStorage, Shape)>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(
|
|
||||||
&self,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(CudaStorage, Shape)> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(MetalStorage, Shape)> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn bwd(
|
|
||||||
&self,
|
|
||||||
_arg1: &Tensor,
|
|
||||||
_arg2: &Tensor,
|
|
||||||
_res: &Tensor,
|
|
||||||
_grad_res: &Tensor,
|
|
||||||
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
|
||||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait CustomOp3 {
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(
|
|
||||||
&self,
|
|
||||||
s1: &CpuStorage,
|
|
||||||
l1: &Layout,
|
|
||||||
s2: &CpuStorage,
|
|
||||||
l2: &Layout,
|
|
||||||
s3: &CpuStorage,
|
|
||||||
l3: &Layout,
|
|
||||||
) -> Result<(CpuStorage, Shape)>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(
|
|
||||||
&self,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(CudaStorage, Shape)> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(MetalStorage, Shape)> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn bwd(
|
|
||||||
&self,
|
|
||||||
_arg1: &Tensor,
|
|
||||||
_arg2: &Tensor,
|
|
||||||
_arg3: &Tensor,
|
|
||||||
_res: &Tensor,
|
|
||||||
_grad_res: &Tensor,
|
|
||||||
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
|
|
||||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Tensor {
|
|
||||||
/// Applies a unary custom op without backward support
|
|
||||||
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
|
|
||||||
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
|
|
||||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a binary custom op without backward support
|
|
||||||
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
|
|
||||||
let (storage, shape) =
|
|
||||||
self.storage()
|
|
||||||
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
|
|
||||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a ternary custom op without backward support
|
|
||||||
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
|
|
||||||
let (storage, shape) = self.storage().apply_op3(
|
|
||||||
self.layout(),
|
|
||||||
&t2.storage(),
|
|
||||||
t2.layout(),
|
|
||||||
&t3.storage(),
|
|
||||||
t3.layout(),
|
|
||||||
c,
|
|
||||||
)?;
|
|
||||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a unary custom op.
|
|
||||||
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
|
|
||||||
let (storage, shape) = self
|
|
||||||
.storage()
|
|
||||||
.apply_op1(self.layout(), c.as_ref().as_ref())?;
|
|
||||||
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
|
|
||||||
self.apply_op1_arc(Arc::new(Box::new(c)))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a binary custom op.
|
|
||||||
pub fn apply_op2_arc(
|
|
||||||
&self,
|
|
||||||
rhs: &Self,
|
|
||||||
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let (storage, shape) = self.storage().apply_op2(
|
|
||||||
self.layout(),
|
|
||||||
&rhs.storage(),
|
|
||||||
rhs.layout(),
|
|
||||||
c.as_ref().as_ref(),
|
|
||||||
)?;
|
|
||||||
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
|
|
||||||
self.apply_op2_arc(r, Arc::new(Box::new(c)))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a ternary custom op.
|
|
||||||
pub fn apply_op3_arc(
|
|
||||||
&self,
|
|
||||||
t2: &Self,
|
|
||||||
t3: &Self,
|
|
||||||
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let (storage, shape) = self.storage().apply_op3(
|
|
||||||
self.layout(),
|
|
||||||
&t2.storage(),
|
|
||||||
t2.layout(),
|
|
||||||
&t3.storage(),
|
|
||||||
t3.layout(),
|
|
||||||
c.as_ref().as_ref(),
|
|
||||||
)?;
|
|
||||||
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
|
|
||||||
Op::CustomOp3(t1, t2, t3, c.clone())
|
|
||||||
});
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
|
|
||||||
&self,
|
|
||||||
t2: &Self,
|
|
||||||
t3: &Self,
|
|
||||||
c: C,
|
|
||||||
) -> Result<Self> {
|
|
||||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// In place ops.
|
|
||||||
|
|
||||||
/// Unary ops that can be defined in user-land.
|
|
||||||
/// These ops work in place and as such back-prop is unsupported.
|
|
||||||
pub trait InplaceOp1 {
|
|
||||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait InplaceOp2 {
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout)
|
|
||||||
-> Result<()>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
_: &mut MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<()> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait InplaceOp3 {
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(
|
|
||||||
&self,
|
|
||||||
s1: &mut CpuStorage,
|
|
||||||
l1: &Layout,
|
|
||||||
s2: &CpuStorage,
|
|
||||||
l2: &Layout,
|
|
||||||
s3: &CpuStorage,
|
|
||||||
l3: &Layout,
|
|
||||||
) -> Result<()>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(
|
|
||||||
&self,
|
|
||||||
_: &mut CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<()> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
_: &mut MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<()> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Tensor {
|
|
||||||
/// Applies a unary custom op in place.
|
|
||||||
pub fn inplace_op1<C: InplaceOp1>(&self, c: &C) -> Result<()> {
|
|
||||||
self.storage_mut().inplace_op1(self.layout(), c)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a unary custom op in place (for the first tensor).
|
|
||||||
pub fn inplace_op2<C: InplaceOp2>(&self, rhs: &Self, c: &C) -> Result<()> {
|
|
||||||
self.storage_mut()
|
|
||||||
.inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a ternary custom op in place (for the first tensor).
|
|
||||||
pub fn inplace_op3<C: InplaceOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> {
|
|
||||||
self.storage_mut().inplace_op3(
|
|
||||||
self.layout(),
|
|
||||||
&t2.storage(),
|
|
||||||
t2.layout(),
|
|
||||||
&t3.storage(),
|
|
||||||
t3.layout(),
|
|
||||||
c,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
@ -289,34 +289,17 @@ impl Device {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
|
||||||
match self {
|
|
||||||
Device::Cpu => {
|
|
||||||
let storage = CpuDevice.alloc_uninit(shape, dtype)?;
|
|
||||||
Ok(Storage::Cpu(storage))
|
|
||||||
}
|
|
||||||
Device::Cuda(device) => {
|
|
||||||
let storage = device.alloc_uninit(shape, dtype)?;
|
|
||||||
Ok(Storage::Cuda(storage))
|
|
||||||
}
|
|
||||||
Device::Metal(device) => {
|
|
||||||
let storage = device.alloc_uninit(shape, dtype)?;
|
|
||||||
Ok(Storage::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
|
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
let storage = array.to_cpu_storage();
|
let storage = array.to_cpu_storage();
|
||||||
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
Device::Metal(device) => {
|
Device::Metal(device) => {
|
||||||
let storage = array.to_cpu_storage();
|
let storage = array.to_cpu_storage();
|
||||||
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Metal(storage))
|
Ok(Storage::Metal(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -327,22 +310,14 @@ impl Device {
|
|||||||
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
|
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
let storage = S::to_cpu_storage_owned(data);
|
let storage = S::to_cpu_storage_owned(data);
|
||||||
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Cuda(storage))
|
Ok(Storage::Cuda(storage))
|
||||||
}
|
}
|
||||||
Device::Metal(device) => {
|
Device::Metal(device) => {
|
||||||
let storage = S::to_cpu_storage_owned(data);
|
let storage = S::to_cpu_storage_owned(data);
|
||||||
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||||
Ok(Storage::Metal(storage))
|
Ok(Storage::Metal(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn synchronize(&self) -> Result<()> {
|
|
||||||
match self {
|
|
||||||
Self::Cpu => Ok(()),
|
|
||||||
Self::Cuda(d) => d.synchronize(),
|
|
||||||
Self::Metal(d) => d.synchronize(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -65,13 +65,12 @@ impl std::fmt::Debug for Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Options for Tensor pretty printing
|
/// Options for Tensor pretty printing
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct PrinterOptions {
|
pub struct PrinterOptions {
|
||||||
pub precision: usize,
|
precision: usize,
|
||||||
pub threshold: usize,
|
threshold: usize,
|
||||||
pub edge_items: usize,
|
edge_items: usize,
|
||||||
pub line_width: usize,
|
line_width: usize,
|
||||||
pub sci_mode: Option<bool>,
|
sci_mode: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
|
static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
|
||||||
@ -90,10 +89,6 @@ impl PrinterOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
|
|
||||||
&PRINT_OPTS
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_print_options(options: PrinterOptions) {
|
pub fn set_print_options(options: PrinterOptions) {
|
||||||
*PRINT_OPTS.lock().unwrap() = options
|
*PRINT_OPTS.lock().unwrap() = options
|
||||||
}
|
}
|
||||||
@ -122,26 +117,6 @@ pub fn set_print_options_full() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_line_width(line_width: usize) {
|
|
||||||
PRINT_OPTS.lock().unwrap().line_width = line_width
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_precision(precision: usize) {
|
|
||||||
PRINT_OPTS.lock().unwrap().precision = precision
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_edge_items(edge_items: usize) {
|
|
||||||
PRINT_OPTS.lock().unwrap().edge_items = edge_items
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_threshold(threshold: usize) {
|
|
||||||
PRINT_OPTS.lock().unwrap().threshold = threshold
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_sci_mode(sci_mode: Option<bool>) {
|
|
||||||
PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
|
|
||||||
}
|
|
||||||
|
|
||||||
struct FmtSize {
|
struct FmtSize {
|
||||||
current_size: usize,
|
current_size: usize,
|
||||||
}
|
}
|
||||||
|
@ -23,15 +23,7 @@ pub enum DType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
pub struct DTypeParseError(String);
|
pub struct DTypeParseError;
|
||||||
|
|
||||||
impl std::fmt::Display for DTypeParseError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "cannot parse '{}' as a dtype", self.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for DTypeParseError {}
|
|
||||||
|
|
||||||
impl std::str::FromStr for DType {
|
impl std::str::FromStr for DType {
|
||||||
type Err = DTypeParseError;
|
type Err = DTypeParseError;
|
||||||
@ -44,7 +36,7 @@ impl std::str::FromStr for DType {
|
|||||||
"f16" => Ok(Self::F16),
|
"f16" => Ok(Self::F16),
|
||||||
"f32" => Ok(Self::F32),
|
"f32" => Ok(Self::F32),
|
||||||
"f64" => Ok(Self::F64),
|
"f64" => Ok(Self::F64),
|
||||||
_ => Err(DTypeParseError(s.to_string())),
|
_ => Err(DTypeParseError),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -154,19 +154,6 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn copy2d(
|
|
||||||
&self,
|
|
||||||
_: &mut Self,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
) -> Result<()> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
@ -210,18 +197,10 @@ impl crate::backend::BackendDevice for CudaDevice {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
@ -229,8 +208,4 @@ impl crate::backend::BackendDevice for CudaDevice {
|
|||||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn synchronize(&self) -> Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -166,19 +166,6 @@ impl crate::backend::BackendStorage for MetalStorage {
|
|||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn copy2d(
|
|
||||||
&self,
|
|
||||||
_: &mut Self,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
) -> Result<()> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
@ -222,18 +209,10 @@ impl crate::backend::BackendDevice for MetalDevice {
|
|||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
@ -241,8 +220,4 @@ impl crate::backend::BackendDevice for MetalDevice {
|
|||||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn synchronize(&self) -> Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -70,7 +70,7 @@ impl Layout {
|
|||||||
self.shape.is_fortran_contiguous(&self.stride)
|
self.shape.is_fortran_contiguous(&self.stride)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||||
let dims = self.shape().dims();
|
let dims = self.shape().dims();
|
||||||
if dim >= dims.len() {
|
if dim >= dims.len() {
|
||||||
Err(Error::DimOutOfRange {
|
Err(Error::DimOutOfRange {
|
||||||
@ -99,7 +99,7 @@ impl Layout {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
||||||
let rank = self.shape.rank();
|
let rank = self.shape.rank();
|
||||||
if rank <= dim1 || rank <= dim2 {
|
if rank <= dim1 || rank <= dim2 {
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
@ -120,7 +120,7 @@ impl Layout {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||||
let is_permutation =
|
let is_permutation =
|
||||||
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
||||||
if !is_permutation {
|
if !is_permutation {
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
//!
|
//!
|
||||||
//! ## Features
|
//! ## Features
|
||||||
//!
|
//!
|
||||||
//! - Simple syntax (looks and feels like PyTorch)
|
//! - Simple syntax (looks and like PyTorch)
|
||||||
//! - CPU and Cuda backends (and M1 support)
|
//! - CPU and Cuda backends (and M1 support)
|
||||||
//! - Enable serverless (CPU) small and fast deployments
|
//! - Enable serverless (CPU) small and fast deployments
|
||||||
//! - Model training
|
//! - Model training
|
||||||
@ -37,13 +37,14 @@
|
|||||||
mod accelerate;
|
mod accelerate;
|
||||||
pub mod backend;
|
pub mod backend;
|
||||||
pub mod backprop;
|
pub mod backprop;
|
||||||
pub mod conv;
|
mod conv;
|
||||||
mod convert;
|
mod convert;
|
||||||
pub mod cpu;
|
pub mod cpu;
|
||||||
pub mod cpu_backend;
|
pub mod cpu_backend;
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
pub mod cuda_backend;
|
pub mod cuda_backend;
|
||||||
mod custom_op;
|
#[cfg(feature = "cudnn")]
|
||||||
|
pub mod cudnn;
|
||||||
mod device;
|
mod device;
|
||||||
pub mod display;
|
pub mod display;
|
||||||
mod dtype;
|
mod dtype;
|
||||||
@ -57,7 +58,7 @@ pub mod metal_backend;
|
|||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
mod mkl;
|
mod mkl;
|
||||||
pub mod npy;
|
pub mod npy;
|
||||||
pub mod op;
|
mod op;
|
||||||
pub mod pickle;
|
pub mod pickle;
|
||||||
pub mod quantized;
|
pub mod quantized;
|
||||||
pub mod safetensors;
|
pub mod safetensors;
|
||||||
@ -66,21 +67,17 @@ pub mod shape;
|
|||||||
mod storage;
|
mod storage;
|
||||||
mod strided_index;
|
mod strided_index;
|
||||||
mod tensor;
|
mod tensor;
|
||||||
mod tensor_cat;
|
|
||||||
pub mod test_utils;
|
pub mod test_utils;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
mod variable;
|
mod variable;
|
||||||
|
|
||||||
#[cfg(feature = "cudnn")]
|
|
||||||
pub use cuda_backend::cudnn;
|
|
||||||
|
|
||||||
pub use cpu_backend::CpuStorage;
|
pub use cpu_backend::CpuStorage;
|
||||||
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
|
||||||
pub use device::{Device, DeviceLocation, NdArray};
|
pub use device::{Device, DeviceLocation, NdArray};
|
||||||
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
pub use indexer::IndexOp;
|
pub use indexer::IndexOp;
|
||||||
pub use layout::Layout;
|
pub use layout::Layout;
|
||||||
|
pub use op::{CustomOp1, CustomOp2, CustomOp3};
|
||||||
pub use shape::{Shape, D};
|
pub use shape::{Shape, D};
|
||||||
pub use storage::Storage;
|
pub use storage::Storage;
|
||||||
pub use strided_index::{StridedBlocks, StridedIndex};
|
pub use strided_index::{StridedBlocks, StridedIndex};
|
||||||
@ -132,15 +129,6 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M: Module> Module for Option<&M> {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
match self {
|
|
||||||
None => Ok(xs.clone()),
|
|
||||||
Some(m) => m.forward(xs),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// A trait defining a module with forward method using a single tensor argument and a flag to
|
// A trait defining a module with forward method using a single tensor argument and a flag to
|
||||||
// separate the training and evaluation behaviors.
|
// separate the training and evaluation behaviors.
|
||||||
pub trait ModuleT {
|
pub trait ModuleT {
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,287 +0,0 @@
|
|||||||
use crate::{DType, Result};
|
|
||||||
use candle_metal_kernels::Kernels;
|
|
||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::ffi::c_void;
|
|
||||||
use std::path::Path;
|
|
||||||
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};
|
|
||||||
|
|
||||||
use super::MetalError;
|
|
||||||
|
|
||||||
/// Unique identifier for cuda devices.
|
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
|
||||||
pub struct DeviceId(usize);
|
|
||||||
|
|
||||||
impl DeviceId {
|
|
||||||
pub(crate) fn new() -> Self {
|
|
||||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
|
||||||
use std::sync::atomic;
|
|
||||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
|
||||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
|
||||||
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct MetalDevice {
|
|
||||||
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
|
|
||||||
/// the device itself.
|
|
||||||
pub(crate) id: DeviceId,
|
|
||||||
|
|
||||||
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
|
||||||
pub(crate) device: metal::Device,
|
|
||||||
|
|
||||||
/// Single command queue for the entire device.
|
|
||||||
pub(crate) command_queue: CommandQueue,
|
|
||||||
/// One command buffer at a time.
|
|
||||||
/// The scheduler works by allowing multiple
|
|
||||||
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
|
||||||
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
|
||||||
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
|
||||||
/// to start to work).
|
|
||||||
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
|
||||||
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
|
||||||
/// command buffer2 starts (or there are metal bugs there)
|
|
||||||
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
|
|
||||||
/// Keeps track of the current amount of compute command encoders on the current
|
|
||||||
/// command buffer
|
|
||||||
/// Arc, RwLock because of the interior mutability.
|
|
||||||
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
|
|
||||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
|
||||||
pub(crate) compute_per_buffer: usize,
|
|
||||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
|
||||||
/// Heavily used by [`candle_metal_kernels`]
|
|
||||||
pub(crate) kernels: Arc<Kernels>,
|
|
||||||
/// Simple allocator struct.
|
|
||||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
|
||||||
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
|
||||||
/// (could be linked to FFI communication overhead).
|
|
||||||
///
|
|
||||||
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
|
|
||||||
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
|
|
||||||
/// to be reused. However, in order for this to work, we need to guarantee the order of
|
|
||||||
/// operation, so that this buffer is not being used by another kernel at the same time.
|
|
||||||
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
|
|
||||||
///
|
|
||||||
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
|
||||||
/// (strong_count = 1).
|
|
||||||
pub(crate) buffers: AllocatedBuffers,
|
|
||||||
/// Seed for random number generation.
|
|
||||||
pub(crate) seed: Arc<Mutex<Buffer>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Debug for MetalDevice {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "MetalDevice({:?})", self.id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::ops::Deref for MetalDevice {
|
|
||||||
type Target = metal::DeviceRef;
|
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MetalDevice {
|
|
||||||
pub fn id(&self) -> DeviceId {
|
|
||||||
self.id
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn metal_device(&self) -> &metal::Device {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn command_queue(&self) -> &CommandQueue {
|
|
||||||
&self.command_queue
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
|
||||||
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
|
|
||||||
let mut command_buffer = command_buffer_lock.to_owned();
|
|
||||||
let mut index = self
|
|
||||||
.command_buffer_index
|
|
||||||
.try_write()
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
if *index > self.compute_per_buffer {
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
|
||||||
*command_buffer_lock = command_buffer.clone();
|
|
||||||
*index = 0;
|
|
||||||
|
|
||||||
self.drop_unused_buffers()?;
|
|
||||||
}
|
|
||||||
*index += 1;
|
|
||||||
Ok(command_buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn wait_until_completed(&self) -> Result<()> {
|
|
||||||
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
|
|
||||||
match command_buffer.status() {
|
|
||||||
metal::MTLCommandBufferStatus::Committed
|
|
||||||
| metal::MTLCommandBufferStatus::Scheduled
|
|
||||||
| metal::MTLCommandBufferStatus::Completed => {
|
|
||||||
panic!("Already committed");
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn kernels(&self) -> &Kernels {
|
|
||||||
&self.kernels
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn device(&self) -> &metal::Device {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new buffer (not necessarily zeroed).
|
|
||||||
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
|
||||||
/// This means the buffer data cannot be read on the CPU directly.
|
|
||||||
///
|
|
||||||
/// [`name`] is only used to keep track of the resource origin in case of bugs
|
|
||||||
pub fn new_buffer(
|
|
||||||
&self,
|
|
||||||
element_count: usize,
|
|
||||||
dtype: DType,
|
|
||||||
name: &str,
|
|
||||||
) -> Result<Arc<Buffer>> {
|
|
||||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
|
||||||
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new buffer (not necessarily zeroed).
|
|
||||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
|
||||||
/// This means the buffer can be read on the CPU but will require manual
|
|
||||||
/// synchronization when the CPU memory is modified
|
|
||||||
/// Used as a bridge to gather data back from the GPU
|
|
||||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
|
|
||||||
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new buffer from data.
|
|
||||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
|
||||||
///
|
|
||||||
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
|
|
||||||
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
|
|
||||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
|
||||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
|
||||||
let new_buffer = self.device.new_buffer_with_data(
|
|
||||||
data.as_ptr() as *const c_void,
|
|
||||||
size,
|
|
||||||
MTLResourceOptions::StorageModeManaged,
|
|
||||||
);
|
|
||||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
|
||||||
let subbuffers = buffers
|
|
||||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
|
||||||
.or_insert(vec![]);
|
|
||||||
|
|
||||||
let new_buffer = Arc::new(new_buffer);
|
|
||||||
subbuffers.push(new_buffer.clone());
|
|
||||||
Ok(new_buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
|
|
||||||
let buffer = self.allocate_buffer(
|
|
||||||
size_in_bytes as NSUInteger,
|
|
||||||
MTLResourceOptions::StorageModePrivate,
|
|
||||||
"allocate_zeros",
|
|
||||||
)?;
|
|
||||||
let command_buffer = self.command_buffer()?;
|
|
||||||
command_buffer.set_label("zeros");
|
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
|
||||||
blit.fill_buffer(
|
|
||||||
&buffer,
|
|
||||||
metal::NSRange {
|
|
||||||
location: 0,
|
|
||||||
length: buffer.length(),
|
|
||||||
},
|
|
||||||
0,
|
|
||||||
);
|
|
||||||
blit.end_encoding();
|
|
||||||
Ok(buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn find_available_buffer(
|
|
||||||
&self,
|
|
||||||
size: NSUInteger,
|
|
||||||
option: MTLResourceOptions,
|
|
||||||
buffers: &RwLockWriteGuard<BufferMap>,
|
|
||||||
) -> Option<Arc<Buffer>> {
|
|
||||||
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
|
||||||
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
|
||||||
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
|
||||||
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
|
||||||
for sub in subbuffers {
|
|
||||||
if Arc::strong_count(sub) == 1 {
|
|
||||||
best_buffer = Some(sub);
|
|
||||||
best_buffer_size = *buffer_size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
best_buffer.cloned()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn drop_unused_buffers(&self) -> Result<()> {
|
|
||||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
|
||||||
for subbuffers in buffers.values_mut() {
|
|
||||||
let newbuffers = subbuffers
|
|
||||||
.iter()
|
|
||||||
.filter(|s| Arc::strong_count(*s) > 1)
|
|
||||||
.map(Arc::clone)
|
|
||||||
.collect();
|
|
||||||
*subbuffers = newbuffers;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The critical allocator algorithm
|
|
||||||
fn allocate_buffer(
|
|
||||||
&self,
|
|
||||||
size: NSUInteger,
|
|
||||||
option: MTLResourceOptions,
|
|
||||||
_name: &str,
|
|
||||||
) -> Result<Arc<Buffer>> {
|
|
||||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
|
||||||
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
|
||||||
// Cloning also ensures we increment the strong count
|
|
||||||
return Ok(b.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
let size = buf_size(size);
|
|
||||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
|
||||||
|
|
||||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
|
||||||
let new_buffer = Arc::new(new_buffer);
|
|
||||||
subbuffers.push(new_buffer.clone());
|
|
||||||
|
|
||||||
Ok(new_buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a metal GPU capture trace on [`path`].
|
|
||||||
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
|
||||||
let capture = metal::CaptureManager::shared();
|
|
||||||
let descriptor = metal::CaptureDescriptor::new();
|
|
||||||
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
|
||||||
descriptor.set_capture_device(self);
|
|
||||||
descriptor.set_output_url(path);
|
|
||||||
|
|
||||||
capture
|
|
||||||
.start_capture(&descriptor)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn buf_size(size: NSUInteger) -> NSUInteger {
|
|
||||||
size.saturating_sub(1).next_power_of_two() as NSUInteger
|
|
||||||
}
|
|
@ -333,16 +333,6 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
|||||||
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn vs_exp_inplace(y: &mut [f32]) {
|
|
||||||
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn vd_exp_inplace(y: &mut [f64]) {
|
|
||||||
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||||
@ -365,28 +355,6 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
|
||||||
*y = -v
|
|
||||||
}
|
|
||||||
vs_exp_inplace(ys);
|
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
|
||||||
*y = v / (1.0 + *y)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
|
||||||
*y = -v
|
|
||||||
}
|
|
||||||
vd_exp_inplace(ys);
|
|
||||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
|
||||||
*y = v / (1.0 + *y)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! binary_op {
|
macro_rules! binary_op {
|
||||||
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
|
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::Tensor;
|
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use num_traits::float::Float;
|
use num_traits::float::Float;
|
||||||
|
|
||||||
@ -61,12 +61,10 @@ pub enum UnaryOp {
|
|||||||
GeluErf,
|
GeluErf,
|
||||||
Erf,
|
Erf,
|
||||||
Relu,
|
Relu,
|
||||||
Silu,
|
|
||||||
Tanh,
|
Tanh,
|
||||||
Floor,
|
Floor,
|
||||||
Ceil,
|
Ceil,
|
||||||
Round,
|
Round,
|
||||||
Sign,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -133,10 +131,7 @@ pub enum Op {
|
|||||||
stride: (usize, usize),
|
stride: (usize, usize),
|
||||||
},
|
},
|
||||||
|
|
||||||
UpsampleNearest1D {
|
UpsampleNearest1D(Tensor),
|
||||||
arg: Tensor,
|
|
||||||
target_size: usize,
|
|
||||||
},
|
|
||||||
UpsampleNearest2D {
|
UpsampleNearest2D {
|
||||||
arg: Tensor,
|
arg: Tensor,
|
||||||
target_h: usize,
|
target_h: usize,
|
||||||
@ -162,23 +157,168 @@ pub enum Op {
|
|||||||
Permute(Tensor, Vec<usize>),
|
Permute(Tensor, Vec<usize>),
|
||||||
Elu(Tensor, f64),
|
Elu(Tensor, f64),
|
||||||
Powf(Tensor, f64),
|
Powf(Tensor, f64),
|
||||||
CustomOp1(
|
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
|
||||||
Tensor,
|
|
||||||
std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
|
|
||||||
),
|
|
||||||
CustomOp2(
|
CustomOp2(
|
||||||
Tensor,
|
Tensor,
|
||||||
Tensor,
|
Tensor,
|
||||||
std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>,
|
std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||||
),
|
),
|
||||||
CustomOp3(
|
CustomOp3(
|
||||||
Tensor,
|
Tensor,
|
||||||
Tensor,
|
Tensor,
|
||||||
Tensor,
|
Tensor,
|
||||||
std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>,
|
std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Unary ops that can be defined in user-land.
|
||||||
|
pub trait CustomOp1 {
|
||||||
|
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||||
|
Err(crate::Error::Cuda(
|
||||||
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_storage: &MetalStorage,
|
||||||
|
_layout: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||||
|
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||||
|
/// The function should return the gradient of the argument.
|
||||||
|
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||||
|
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CustomOp2 {
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
s1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
) -> Result<(CpuStorage, Shape)>;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(CudaStorage, Shape)> {
|
||||||
|
Err(crate::Error::Cuda(
|
||||||
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bwd(
|
||||||
|
&self,
|
||||||
|
_arg1: &Tensor,
|
||||||
|
_arg2: &Tensor,
|
||||||
|
_res: &Tensor,
|
||||||
|
_grad_res: &Tensor,
|
||||||
|
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
||||||
|
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CustomOp3 {
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
s1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
s3: &CpuStorage,
|
||||||
|
l3: &Layout,
|
||||||
|
) -> Result<(CpuStorage, Shape)>;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(CudaStorage, Shape)> {
|
||||||
|
Err(crate::Error::Cuda(
|
||||||
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bwd(
|
||||||
|
&self,
|
||||||
|
_arg1: &Tensor,
|
||||||
|
_arg2: &Tensor,
|
||||||
|
_arg3: &Tensor,
|
||||||
|
_res: &Tensor,
|
||||||
|
_grad_res: &Tensor,
|
||||||
|
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
|
||||||
|
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait UnaryOpT {
|
pub trait UnaryOpT {
|
||||||
const NAME: &'static str;
|
const NAME: &'static str;
|
||||||
const KERNEL: &'static str;
|
const KERNEL: &'static str;
|
||||||
@ -250,12 +390,10 @@ pub(crate) struct Gelu;
|
|||||||
pub(crate) struct GeluErf;
|
pub(crate) struct GeluErf;
|
||||||
pub(crate) struct Erf;
|
pub(crate) struct Erf;
|
||||||
pub(crate) struct Relu;
|
pub(crate) struct Relu;
|
||||||
pub(crate) struct Silu;
|
|
||||||
pub(crate) struct Tanh;
|
pub(crate) struct Tanh;
|
||||||
pub(crate) struct Floor;
|
pub(crate) struct Floor;
|
||||||
pub(crate) struct Ceil;
|
pub(crate) struct Ceil;
|
||||||
pub(crate) struct Round;
|
pub(crate) struct Round;
|
||||||
pub(crate) struct Sign;
|
|
||||||
|
|
||||||
macro_rules! bin_op {
|
macro_rules! bin_op {
|
||||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||||
@ -459,13 +597,6 @@ unary_op!(Recip, "recip", v, v.recip());
|
|||||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||||
|
|
||||||
// Hardcode the value for sqrt(2/pi)
|
|
||||||
// https://github.com/huggingface/candle/issues/1982
|
|
||||||
#[allow(clippy::excessive_precision)]
|
|
||||||
const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
|
|
||||||
#[allow(clippy::excessive_precision)]
|
|
||||||
const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
|
|
||||||
|
|
||||||
/// Tanh based approximation of the `gelu` operation
|
/// Tanh based approximation of the `gelu` operation
|
||||||
/// GeluErf is the more precise one.
|
/// GeluErf is the more precise one.
|
||||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||||
@ -478,7 +609,7 @@ impl UnaryOpT for Gelu {
|
|||||||
* v
|
* v
|
||||||
* (bf16::ONE
|
* (bf16::ONE
|
||||||
+ bf16::tanh(
|
+ bf16::tanh(
|
||||||
bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
|
(bf16::from_f32_const(2.0) / bf16::PI).sqrt()
|
||||||
* v
|
* v
|
||||||
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
|
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
|
||||||
))
|
))
|
||||||
@ -489,18 +620,22 @@ impl UnaryOpT for Gelu {
|
|||||||
* v
|
* v
|
||||||
* (f16::ONE
|
* (f16::ONE
|
||||||
+ f16::tanh(
|
+ f16::tanh(
|
||||||
f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
|
(f16::from_f32_const(2.0) / f16::PI).sqrt()
|
||||||
* v
|
* v
|
||||||
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
|
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn f32(v: f32) -> f32 {
|
fn f32(v: f32) -> f32 {
|
||||||
0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
|
0.5 * v
|
||||||
|
* (1.0
|
||||||
|
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||||
}
|
}
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn f64(v: f64) -> f64 {
|
fn f64(v: f64) -> f64 {
|
||||||
0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
|
0.5 * v
|
||||||
|
* (1.0
|
||||||
|
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||||
}
|
}
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn u8(_: u8) -> u8 {
|
fn u8(_: u8) -> u8 {
|
||||||
@ -589,77 +724,6 @@ impl UnaryOpT for Erf {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Silu operation
|
|
||||||
impl UnaryOpT for Silu {
|
|
||||||
const NAME: &'static str = "silu";
|
|
||||||
const V: Self = Silu;
|
|
||||||
#[inline(always)]
|
|
||||||
fn bf16(v: bf16) -> bf16 {
|
|
||||||
v / (bf16::ONE + (-v).exp())
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn f16(v: f16) -> f16 {
|
|
||||||
v / (f16::ONE + (-v).exp())
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn f32(v: f32) -> f32 {
|
|
||||||
v / (1.0 + (-v).exp())
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn f64(v: f64) -> f64 {
|
|
||||||
v / (1.0 + (-v).exp())
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn u8(_: u8) -> u8 {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn u32(_: u32) -> u32 {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn i64(_: i64) -> i64 {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
const KERNEL: &'static str = "usilu";
|
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
|
||||||
const F32_VEC: bool = true;
|
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
|
||||||
#[inline(always)]
|
|
||||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
|
||||||
crate::mkl::vs_silu(xs, ys)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
|
||||||
const F64_VEC: bool = true;
|
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
|
||||||
#[inline(always)]
|
|
||||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
|
||||||
crate::mkl::vd_silu(xs, ys)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
const F32_VEC: bool = true;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
#[inline(always)]
|
|
||||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
|
||||||
crate::accelerate::vs_silu(xs, ys)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
const F64_VEC: bool = true;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
#[inline(always)]
|
|
||||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
|
||||||
crate::accelerate::vd_silu(xs, ys)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UnaryOpT for Abs {
|
impl UnaryOpT for Abs {
|
||||||
const NAME: &'static str = "abs";
|
const NAME: &'static str = "abs";
|
||||||
const KERNEL: &'static str = "uabs";
|
const KERNEL: &'static str = "uabs";
|
||||||
@ -927,37 +991,3 @@ impl std::ops::Deref for BackpropOp {
|
|||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UnaryOpT for Sign {
|
|
||||||
const NAME: &'static str = "sign";
|
|
||||||
const KERNEL: &'static str = "usign";
|
|
||||||
const V: Self = Sign;
|
|
||||||
#[inline(always)]
|
|
||||||
fn bf16(v: bf16) -> bf16 {
|
|
||||||
bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn f16(v: f16) -> f16 {
|
|
||||||
f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn f32(v: f32) -> f32 {
|
|
||||||
f32::from(v > 0.) - f32::from(v < 0.)
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn f64(v: f64) -> f64 {
|
|
||||||
f64::from(v > 0.) - f64::from(v < 0.)
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn u8(v: u8) -> u8 {
|
|
||||||
u8::min(1, v)
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn u32(v: u32) -> u32 {
|
|
||||||
u32::min(1, v)
|
|
||||||
}
|
|
||||||
#[inline(always)]
|
|
||||||
fn i64(v: i64) -> i64 {
|
|
||||||
(v > 0) as i64 - (v < 0) as i64
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -42,7 +42,7 @@ pub enum OpCode {
|
|||||||
Stop = b'.',
|
Stop = b'.',
|
||||||
NewObj = 0x81,
|
NewObj = 0x81,
|
||||||
EmptyList = b']',
|
EmptyList = b']',
|
||||||
BinFloat = b'G',
|
BinFloat = b'g',
|
||||||
Append = b'a',
|
Append = b'a',
|
||||||
Appends = b'e',
|
Appends = b'e',
|
||||||
}
|
}
|
||||||
@ -217,13 +217,6 @@ impl Object {
|
|||||||
let args = args.remove(1);
|
let args = args.remove(1);
|
||||||
(callable, args)
|
(callable, args)
|
||||||
}
|
}
|
||||||
Object::Class {
|
|
||||||
module_name,
|
|
||||||
class_name,
|
|
||||||
} if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
|
|
||||||
let mut args = args.tuple()?;
|
|
||||||
args.remove(0).reduce()?
|
|
||||||
}
|
|
||||||
_ => (callable, args),
|
_ => (callable, args),
|
||||||
};
|
};
|
||||||
match callable {
|
match callable {
|
||||||
@ -234,11 +227,13 @@ impl Object {
|
|||||||
_ => return Ok(None),
|
_ => return Ok(None),
|
||||||
};
|
};
|
||||||
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
||||||
|
let mut path = dir_name.to_path_buf();
|
||||||
|
path.push(file_path);
|
||||||
Ok(Some(TensorInfo {
|
Ok(Some(TensorInfo {
|
||||||
name,
|
name,
|
||||||
dtype,
|
dtype,
|
||||||
layout,
|
layout,
|
||||||
path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
|
path: path.to_string_lossy().into_owned(),
|
||||||
storage_size,
|
storage_size,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
@ -350,10 +345,8 @@ impl Stack {
|
|||||||
module_name,
|
module_name,
|
||||||
class_name,
|
class_name,
|
||||||
} => {
|
} => {
|
||||||
if module_name == "collections"
|
if module_name == "collections" && class_name == "OrderedDict" {
|
||||||
&& (class_name == "OrderedDict" || class_name == "defaultdict")
|
// TODO: have a separate ordered dict.
|
||||||
{
|
|
||||||
// TODO: have a separate ordered dict and a separate default dict.
|
|
||||||
Some(Object::Dict(vec![]))
|
Some(Object::Dict(vec![]))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -462,10 +455,7 @@ impl Stack {
|
|||||||
self.push(Object::Int(arg))
|
self.push(Object::Int(arg))
|
||||||
}
|
}
|
||||||
OpCode::BinFloat => {
|
OpCode::BinFloat => {
|
||||||
// Somehow floats are encoded using BigEndian whereas int types use LittleEndian.
|
let arg = r.read_f64::<LittleEndian>()?;
|
||||||
// https://github.com/python/cpython/blob/0c80da4c14d904a367968955544dd6ae58c8101c/Lib/pickletools.py#L855
|
|
||||||
// https://github.com/pytorch/pytorch/blob/372d078f361e726bb4ac0884ac334b04c58179ef/torch/_weights_only_unpickler.py#L243
|
|
||||||
let arg = r.read_f64::<byteorder::BigEndian>()?;
|
|
||||||
self.push(Object::Float(arg))
|
self.push(Object::Float(arg))
|
||||||
}
|
}
|
||||||
OpCode::BinUnicode => {
|
OpCode::BinUnicode => {
|
||||||
@ -637,16 +627,9 @@ pub struct TensorInfo {
|
|||||||
pub storage_size: usize,
|
pub storage_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Read the tensor info from a .pth file.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * `file` - The path to the .pth file.
|
|
||||||
/// * `verbose` - Whether to print debug information.
|
|
||||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file.
|
|
||||||
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||||
file: P,
|
file: P,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
key: Option<&str>,
|
|
||||||
) -> Result<Vec<TensorInfo>> {
|
) -> Result<Vec<TensorInfo>> {
|
||||||
let file = std::fs::File::open(file)?;
|
let file = std::fs::File::open(file)?;
|
||||||
let zip_reader = std::io::BufReader::new(file);
|
let zip_reader = std::io::BufReader::new(file);
|
||||||
@ -668,9 +651,8 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
|||||||
stack.read_loop(&mut reader)?;
|
stack.read_loop(&mut reader)?;
|
||||||
let obj = stack.finalize()?;
|
let obj = stack.finalize()?;
|
||||||
if VERBOSE || verbose {
|
if VERBOSE || verbose {
|
||||||
println!("{obj:#?}");
|
println!("{obj:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
let obj = match obj {
|
let obj = match obj {
|
||||||
Object::Build { callable, args } => match *callable {
|
Object::Build { callable, args } => match *callable {
|
||||||
Object::Reduce { callable, args: _ } => match *callable {
|
Object::Reduce { callable, args: _ } => match *callable {
|
||||||
@ -684,24 +666,6 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
|||||||
},
|
},
|
||||||
obj => obj,
|
obj => obj,
|
||||||
};
|
};
|
||||||
|
|
||||||
// If key is provided, then we need to extract the state_dict from the object.
|
|
||||||
let obj = if let Some(key) = key {
|
|
||||||
if let Object::Dict(key_values) = obj {
|
|
||||||
key_values
|
|
||||||
.into_iter()
|
|
||||||
.find(|(k, _)| *k == Object::Unicode(key.to_owned()))
|
|
||||||
.map(|(_, v)| v)
|
|
||||||
.ok_or_else(|| E::Msg(format!("key {key} not found")))?
|
|
||||||
} else {
|
|
||||||
obj
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
obj
|
|
||||||
};
|
|
||||||
|
|
||||||
// If the object is a dict, then we can extract the tensor info from it.
|
|
||||||
// NOTE: We are assuming that the `obj` is state_dict by this stage.
|
|
||||||
if let Object::Dict(key_values) = obj {
|
if let Object::Dict(key_values) = obj {
|
||||||
for (name, value) in key_values.into_iter() {
|
for (name, value) in key_values.into_iter() {
|
||||||
match value.into_tensor_info(name, &dir_name) {
|
match value.into_tensor_info(name, &dir_name) {
|
||||||
@ -724,8 +688,8 @@ pub struct PthTensors {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl PthTensors {
|
impl PthTensors {
|
||||||
pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
|
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
|
||||||
let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
|
let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
|
||||||
let tensor_infos = tensor_infos
|
let tensor_infos = tensor_infos
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|ti| (ti.name.to_string(), ti))
|
.map(|ti| (ti.name.to_string(), ti))
|
||||||
@ -748,12 +712,10 @@ impl PthTensors {
|
|||||||
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
||||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||||
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
|
|
||||||
let rank = tensor_info.layout.shape().rank();
|
|
||||||
|
|
||||||
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
||||||
// case and when the tensor is fortran contiguous.
|
// case.
|
||||||
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
|
if !tensor_info.layout.is_contiguous() {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"cannot retrieve non-contiguous tensors {:?}",
|
"cannot retrieve non-contiguous tensors {:?}",
|
||||||
tensor_info.layout
|
tensor_info.layout
|
||||||
@ -771,33 +733,13 @@ impl PthTensors {
|
|||||||
tensor_info.dtype,
|
tensor_info.dtype,
|
||||||
&mut reader,
|
&mut reader,
|
||||||
)?;
|
)?;
|
||||||
|
Ok(Some(tensor))
|
||||||
if rank > 1 && is_fortran_contiguous {
|
|
||||||
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
|
|
||||||
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
|
|
||||||
let tensor = tensor.reshape(shape_reversed)?;
|
|
||||||
|
|
||||||
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
|
|
||||||
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
|
|
||||||
let tensor = tensor.permute(dim_indeces_reversed)?;
|
|
||||||
Ok(Some(tensor))
|
|
||||||
} else {
|
|
||||||
Ok(Some(tensor))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Read all the tensors from a PyTorch pth file with a given key.
|
/// Read all the tensors from a PyTorch pth file.
|
||||||
///
|
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
||||||
/// # Arguments
|
let pth = PthTensors::new(path)?;
|
||||||
/// * `path` - Path to the pth file.
|
|
||||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
|
||||||
/// contains multiple objects and the state_dict is the one we are interested in.
|
|
||||||
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
|
||||||
path: P,
|
|
||||||
key: Option<&str>,
|
|
||||||
) -> Result<Vec<(String, Tensor)>> {
|
|
||||||
let pth = PthTensors::new(path, key)?;
|
|
||||||
let tensor_names = pth.tensor_infos.keys();
|
let tensor_names = pth.tensor_infos.keys();
|
||||||
let mut tensors = Vec::with_capacity(tensor_names.len());
|
let mut tensors = Vec::with_capacity(tensor_names.len());
|
||||||
for name in tensor_names {
|
for name in tensor_names {
|
||||||
@ -807,11 +749,3 @@ pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
|||||||
}
|
}
|
||||||
Ok(tensors)
|
Ok(tensors)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Read all the tensors from a PyTorch pth file.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * `path` - Path to the pth file.
|
|
||||||
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
|
||||||
read_all_with_key(path, None)
|
|
||||||
}
|
|
||||||
|
@ -1,618 +0,0 @@
|
|||||||
use super::{GgmlDType, QStorage};
|
|
||||||
use crate::quantized::k_quants::GgmlType;
|
|
||||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
|
||||||
use crate::{CudaDevice, CudaStorage, Result};
|
|
||||||
|
|
||||||
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct QCudaStorage {
|
|
||||||
data: CudaSlice<u8>,
|
|
||||||
dtype: GgmlDType,
|
|
||||||
device: CudaDevice,
|
|
||||||
}
|
|
||||||
|
|
||||||
static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
|
|
||||||
|
|
||||||
pub fn set_force_dmmv(f: bool) {
|
|
||||||
FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const WARP_SIZE: usize = 32;
|
|
||||||
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
|
|
||||||
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
|
|
||||||
pub const NWARPS_Q4_0_AMPERE: usize = 4;
|
|
||||||
pub const GGML_CUDA_MMV_X: usize = 32;
|
|
||||||
pub const GGML_CUDA_MMV_Y: usize = 1;
|
|
||||||
pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256;
|
|
||||||
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
|
|
||||||
pub const MATRIX_ROW_PADDING: usize = 512;
|
|
||||||
|
|
||||||
fn ceil_div(p: usize, q: usize) -> usize {
|
|
||||||
(p + q - 1) / q
|
|
||||||
}
|
|
||||||
|
|
||||||
fn pad(p: usize, q: usize) -> usize {
|
|
||||||
ceil_div(p, q) * q
|
|
||||||
}
|
|
||||||
|
|
||||||
fn quantize_q8_1(
|
|
||||||
src: &CudaView<f32>,
|
|
||||||
dst: &mut CudaSlice<u8>,
|
|
||||||
elem_count: usize,
|
|
||||||
ky: usize,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<()> {
|
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let kx = elem_count;
|
|
||||||
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
|
|
||||||
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
|
||||||
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
|
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
|
||||||
grid_dim: (num_blocks as u32, ky as u32, 1),
|
|
||||||
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
|
|
||||||
shared_mem_bytes: 0,
|
|
||||||
};
|
|
||||||
let params = (src, dst, kx as i32, kx_padded as i32);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dequantize(
|
|
||||||
data: &CudaSlice<u8>,
|
|
||||||
dtype: GgmlDType,
|
|
||||||
elem_count: usize,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<CudaStorage> {
|
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let nb = (elem_count + 255) / 256;
|
|
||||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
|
||||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
|
|
||||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
|
|
||||||
GgmlDType::Q5_0 => (
|
|
||||||
"dequantize_block_q5_0",
|
|
||||||
false,
|
|
||||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
|
||||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
|
||||||
),
|
|
||||||
GgmlDType::Q5_1 => (
|
|
||||||
"dequantize_block_q5_1",
|
|
||||||
false,
|
|
||||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
|
||||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
|
||||||
),
|
|
||||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
|
|
||||||
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
|
|
||||||
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
|
|
||||||
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
|
|
||||||
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
|
|
||||||
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
|
|
||||||
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
|
|
||||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
|
||||||
};
|
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
|
||||||
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
|
|
||||||
// See e.g.
|
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
|
||||||
grid_dim: (num_blocks as u32, 1, 1),
|
|
||||||
block_dim: (block_dim as u32, 1, 1),
|
|
||||||
shared_mem_bytes: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
if is_k {
|
|
||||||
let params = (data, &dst);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
} else {
|
|
||||||
let nb32 = match dtype {
|
|
||||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
|
||||||
_ => elem_count / 32,
|
|
||||||
};
|
|
||||||
let params = (data, &dst, nb32 as i32);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
}
|
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dequantize_mul_mat_vec(
|
|
||||||
data: &CudaSlice<u8>,
|
|
||||||
y: &CudaView<f32>,
|
|
||||||
dtype: GgmlDType,
|
|
||||||
ncols: usize,
|
|
||||||
nrows: usize,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<CudaStorage> {
|
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
|
||||||
if data_elems < ncols * nrows {
|
|
||||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
|
||||||
}
|
|
||||||
if y.len() != ncols {
|
|
||||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
|
||||||
}
|
|
||||||
let kernel_name = match dtype {
|
|
||||||
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
|
|
||||||
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
|
|
||||||
GgmlDType::Q5_0 => "dequantize_mul_mat_vec_q5_0_cuda",
|
|
||||||
GgmlDType::Q5_1 => "dequantize_mul_mat_vec_q5_1_cuda",
|
|
||||||
GgmlDType::Q8_0 => "dequantize_mul_mat_vec_q8_0_cuda",
|
|
||||||
GgmlDType::Q2K => "dequantize_mul_mat_vec_q2_k",
|
|
||||||
GgmlDType::Q3K => "dequantize_mul_mat_vec_q3_k",
|
|
||||||
GgmlDType::Q4K => "dequantize_mul_mat_vec_q4_k",
|
|
||||||
GgmlDType::Q5K => "dequantize_mul_mat_vec_q5_k",
|
|
||||||
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
|
|
||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
|
||||||
};
|
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
|
||||||
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
|
||||||
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
|
||||||
grid_dim: (block_num_y as u32, 1, 1),
|
|
||||||
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
|
|
||||||
shared_mem_bytes: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
let params = (data, y, &dst, ncols as i32, nrows as i32);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mul_mat_vec_via_q8_1(
|
|
||||||
data: &CudaSlice<u8>,
|
|
||||||
y: &CudaView<f32>,
|
|
||||||
dtype: GgmlDType,
|
|
||||||
ncols: usize,
|
|
||||||
nrows: usize,
|
|
||||||
b_size: usize,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<CudaStorage> {
|
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
|
||||||
if data_elems < ncols * nrows {
|
|
||||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
|
||||||
}
|
|
||||||
if y.len() != ncols * b_size {
|
|
||||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
|
||||||
}
|
|
||||||
if b_size == 0 || b_size > 8 {
|
|
||||||
crate::bail!("only bsize between 1 and 8 are supported, got {b_size}")
|
|
||||||
}
|
|
||||||
// Start by quantizing y
|
|
||||||
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
|
||||||
let y_size_in_bytes =
|
|
||||||
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
|
||||||
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
|
|
||||||
|
|
||||||
let kernel_name = match dtype {
|
|
||||||
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
|
|
||||||
GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda",
|
|
||||||
GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda",
|
|
||||||
GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda",
|
|
||||||
GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda",
|
|
||||||
GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda",
|
|
||||||
GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda",
|
|
||||||
GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda",
|
|
||||||
GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda",
|
|
||||||
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
|
|
||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
|
||||||
};
|
|
||||||
let kernel_name = format!("{kernel_name}{b_size}");
|
|
||||||
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
|
||||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
|
||||||
let (nblocks, nwarps) = match b_size {
|
|
||||||
1 => (nrows as u32, 4),
|
|
||||||
2..=4 => ((nrows as u32 + 1) / 2, 4),
|
|
||||||
5..=8 => ((nrows as u32 + 1) / 2, 2),
|
|
||||||
_ => crate::bail!("unexpected bsize {b_size}"),
|
|
||||||
};
|
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
|
||||||
grid_dim: (nblocks, 1, 1),
|
|
||||||
block_dim: (WARP_SIZE as u32, nwarps, 1),
|
|
||||||
shared_mem_bytes: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
let params = (
|
|
||||||
data,
|
|
||||||
&y_q8_1,
|
|
||||||
&dst,
|
|
||||||
/* ncols_x */ ncols as i32,
|
|
||||||
/* nrows_x */ nrows as i32,
|
|
||||||
/* nrows_y */ ncols_padded as i32,
|
|
||||||
/* nrows_dst */ nrows as i32,
|
|
||||||
);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn mul_mat_via_q8_1(
|
|
||||||
data: &CudaSlice<u8>,
|
|
||||||
y: &CudaView<f32>,
|
|
||||||
dtype: GgmlDType,
|
|
||||||
x_rows: usize,
|
|
||||||
x_cols: usize,
|
|
||||||
y_rows: usize,
|
|
||||||
y_cols: usize,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<CudaStorage> {
|
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
|
||||||
if data_elems < x_rows * x_cols {
|
|
||||||
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
|
|
||||||
}
|
|
||||||
if y.len() != y_rows * y_cols {
|
|
||||||
crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len())
|
|
||||||
}
|
|
||||||
if x_cols != y_rows {
|
|
||||||
crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}")
|
|
||||||
}
|
|
||||||
let k = x_cols;
|
|
||||||
// Start by quantizing y
|
|
||||||
let k_padded = pad(k, MATRIX_ROW_PADDING);
|
|
||||||
let y_size_in_bytes =
|
|
||||||
k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
|
||||||
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
|
|
||||||
|
|
||||||
let (kernel_name, mmq_x, mmq_y) = match dtype {
|
|
||||||
GgmlDType::Q4_0 => ("mul_mat_q4_0", 64, 128),
|
|
||||||
GgmlDType::Q4_1 => ("mul_mat_q4_1", 64, 128),
|
|
||||||
GgmlDType::Q5_0 => ("mul_mat_q5_0", 128, 64),
|
|
||||||
GgmlDType::Q5_1 => ("mul_mat_q5_1", 128, 64),
|
|
||||||
GgmlDType::Q8_0 => ("mul_mat_q8_0", 128, 64),
|
|
||||||
GgmlDType::Q2K => ("mul_mat_q2_K", 64, 128),
|
|
||||||
GgmlDType::Q3K => ("mul_mat_q3_K", 128, 128),
|
|
||||||
GgmlDType::Q4K => ("mul_mat_q4_K", 64, 128),
|
|
||||||
GgmlDType::Q5K => ("mul_mat_q5_K", 64, 128),
|
|
||||||
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
|
|
||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
|
||||||
};
|
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
|
||||||
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
|
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
|
||||||
grid_dim: (
|
|
||||||
ceil_div(x_rows, mmq_y) as u32,
|
|
||||||
ceil_div(y_cols, mmq_x) as u32,
|
|
||||||
1,
|
|
||||||
),
|
|
||||||
block_dim: (WARP_SIZE as u32, 4, 1),
|
|
||||||
shared_mem_bytes: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
let params = (
|
|
||||||
/* vx */ data,
|
|
||||||
/* vy */ &y_q8_1,
|
|
||||||
/* dst */ &dst,
|
|
||||||
/* ncols_x */ x_cols as i32,
|
|
||||||
/* nrows_x */ x_rows as i32,
|
|
||||||
/* ncols_y */ y_cols as i32,
|
|
||||||
/* nrows_y */ k_padded as i32,
|
|
||||||
/* nrows_dst */ x_rows as i32,
|
|
||||||
);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QCudaStorage {
|
|
||||||
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
|
|
||||||
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
|
|
||||||
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
|
|
||||||
Ok(QCudaStorage {
|
|
||||||
data,
|
|
||||||
device: device.clone(),
|
|
||||||
dtype,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
|
||||||
self.dtype
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn device(&self) -> &CudaDevice {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
|
|
||||||
fn deq<T: GgmlType>(buffer: &[u8], n: usize, dst: &mut [f32]) -> Result<()> {
|
|
||||||
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
|
|
||||||
let vec = slice.to_vec();
|
|
||||||
T::to_float(&vec, dst)
|
|
||||||
}
|
|
||||||
|
|
||||||
let fast_kernel = matches!(
|
|
||||||
self.dtype,
|
|
||||||
GgmlDType::Q4_0
|
|
||||||
| GgmlDType::Q4_1
|
|
||||||
| GgmlDType::Q5_0
|
|
||||||
| GgmlDType::Q5_1
|
|
||||||
| GgmlDType::Q8_0
|
|
||||||
| GgmlDType::Q2K
|
|
||||||
| GgmlDType::Q3K
|
|
||||||
| GgmlDType::Q4K
|
|
||||||
| GgmlDType::Q5K
|
|
||||||
| GgmlDType::Q6K
|
|
||||||
| GgmlDType::Q8K
|
|
||||||
);
|
|
||||||
if fast_kernel {
|
|
||||||
return dequantize(&self.data, self.dtype, elem_count, self.device());
|
|
||||||
}
|
|
||||||
// Run the dequantization on cpu.
|
|
||||||
|
|
||||||
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
|
|
||||||
let mut out = vec![0.0; elem_count];
|
|
||||||
let block_len = elem_count / self.dtype.block_size();
|
|
||||||
match self.dtype {
|
|
||||||
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q5_1 => deq::<crate::quantized::BlockQ5_1>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q8_0 => deq::<crate::quantized::BlockQ8_0>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q8_1 => deq::<crate::quantized::BlockQ8_1>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q2K => deq::<crate::quantized::BlockQ2K>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q3K => deq::<crate::quantized::BlockQ3K>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q4K => deq::<crate::quantized::BlockQ4K>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q5K => deq::<crate::quantized::BlockQ5K>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q6K => deq::<crate::quantized::BlockQ6K>(&buffer, block_len, &mut out)?,
|
|
||||||
GgmlDType::Q8K => deq::<crate::quantized::BlockQ8K>(&buffer, block_len, &mut out)?,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.device
|
|
||||||
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
|
||||||
// Run the quantization on cpu.
|
|
||||||
let src = match &src.slice {
|
|
||||||
crate::cuda_backend::CudaStorageSlice::F32(data) => {
|
|
||||||
self.device.dtoh_sync_copy(data).w()?
|
|
||||||
}
|
|
||||||
_ => crate::bail!("only f32 can be quantized"),
|
|
||||||
};
|
|
||||||
let src_len = src.len();
|
|
||||||
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
|
||||||
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
|
|
||||||
qcpu_storage.quantize(&src)?;
|
|
||||||
let data = qcpu_storage.data()?;
|
|
||||||
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
|
|
||||||
self.data = data;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn storage_size_in_bytes(&self) -> usize {
|
|
||||||
self.data.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fwd(
|
|
||||||
&self,
|
|
||||||
self_shape: &crate::Shape,
|
|
||||||
storage: &CudaStorage,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
) -> Result<(CudaStorage, crate::Shape)> {
|
|
||||||
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
|
||||||
1
|
|
||||||
} else {
|
|
||||||
8
|
|
||||||
};
|
|
||||||
let use_vec_kernel = match layout.shape().dims() {
|
|
||||||
[b, m, _k] => b * m <= max_bm,
|
|
||||||
[b, _k] => *b <= max_bm,
|
|
||||||
_ => false,
|
|
||||||
};
|
|
||||||
if use_vec_kernel {
|
|
||||||
self.dequantize_matmul_vec(self_shape, storage, layout)
|
|
||||||
} else {
|
|
||||||
self.dequantize_matmul(self_shape, storage, layout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QCudaStorage {
|
|
||||||
fn dequantize_matmul_vec(
|
|
||||||
&self,
|
|
||||||
self_shape: &crate::Shape,
|
|
||||||
rhs: &CudaStorage,
|
|
||||||
rhs_l: &crate::Layout,
|
|
||||||
) -> Result<(CudaStorage, crate::Shape)> {
|
|
||||||
let (nrows, ncols) = self_shape.dims2()?;
|
|
||||||
let rhs = rhs.as_cuda_slice::<f32>()?;
|
|
||||||
let rhs = match rhs_l.contiguous_offsets() {
|
|
||||||
Some((o1, o2)) => rhs.slice(o1..o2),
|
|
||||||
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
|
|
||||||
};
|
|
||||||
let (b_size, k) = match rhs_l.shape().dims() {
|
|
||||||
[b, m, k] => (b * m, *k),
|
|
||||||
[b, k] => (*b, *k),
|
|
||||||
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
|
|
||||||
};
|
|
||||||
if ncols != k {
|
|
||||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
|
|
||||||
}
|
|
||||||
|
|
||||||
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
|
||||||
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
|
||||||
} else {
|
|
||||||
mul_mat_vec_via_q8_1(
|
|
||||||
&self.data,
|
|
||||||
&rhs,
|
|
||||||
self.dtype,
|
|
||||||
ncols,
|
|
||||||
nrows,
|
|
||||||
b_size,
|
|
||||||
self.device(),
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
let mut out_shape = rhs_l.shape().dims().to_vec();
|
|
||||||
out_shape.pop();
|
|
||||||
out_shape.push(nrows);
|
|
||||||
Ok((out, out_shape.into()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dequantize_matmul(
|
|
||||||
&self,
|
|
||||||
self_shape: &crate::Shape,
|
|
||||||
storage: &CudaStorage,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
) -> Result<(CudaStorage, crate::Shape)> {
|
|
||||||
use crate::backend::BackendStorage;
|
|
||||||
let (n, k) = self_shape.dims2()?;
|
|
||||||
let (b, m, k2) = match layout.shape().dims() {
|
|
||||||
&[b, m, k2] => (b, m, k2),
|
|
||||||
&[m, k2] => (1, m, k2),
|
|
||||||
s => crate::bail!("unexpected shape for input {s:?}"),
|
|
||||||
};
|
|
||||||
if k2 != k {
|
|
||||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
|
|
||||||
}
|
|
||||||
|
|
||||||
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
|
||||||
let data_f32 = self.dequantize(n * k)?;
|
|
||||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
|
|
||||||
storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?
|
|
||||||
} else {
|
|
||||||
let storage = storage.as_cuda_slice::<f32>()?;
|
|
||||||
let storage = match layout.contiguous_offsets() {
|
|
||||||
Some((o1, o2)) => storage.slice(o1..o2),
|
|
||||||
None => Err(crate::Error::RequiresContiguous {
|
|
||||||
op: "quantized-matmul",
|
|
||||||
}
|
|
||||||
.bt())?,
|
|
||||||
};
|
|
||||||
mul_mat_via_q8_1(
|
|
||||||
&self.data,
|
|
||||||
&storage,
|
|
||||||
self.dtype,
|
|
||||||
/* x_rows */ n,
|
|
||||||
/* x_cols */ k,
|
|
||||||
/* y_rows */ k,
|
|
||||||
/* y_cols */ b * m,
|
|
||||||
self.device(),
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
let mut out_shape = layout.shape().dims().to_vec();
|
|
||||||
out_shape.pop();
|
|
||||||
out_shape.push(n);
|
|
||||||
Ok((out, out_shape.into()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
|
||||||
device: &CudaDevice,
|
|
||||||
data: &[T],
|
|
||||||
) -> Result<super::QStorage> {
|
|
||||||
let data = unsafe {
|
|
||||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
|
|
||||||
};
|
|
||||||
let data = device.htod_sync_copy(data).w()?;
|
|
||||||
Ok(QStorage::Cuda(QCudaStorage {
|
|
||||||
data,
|
|
||||||
device: device.clone(),
|
|
||||||
dtype: T::DTYPE,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod test {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn cuda_quantize_q8_1() -> Result<()> {
|
|
||||||
let dev = CudaDevice::new(0)?;
|
|
||||||
let el = 256;
|
|
||||||
let el_padded = pad(el, MATRIX_ROW_PADDING);
|
|
||||||
let y_size_in_bytes =
|
|
||||||
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
|
||||||
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
|
||||||
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn cuda_mmv_q8_1() -> Result<()> {
|
|
||||||
let dev = CudaDevice::new(0)?;
|
|
||||||
let ncols = 256;
|
|
||||||
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
|
||||||
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
|
||||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
|
||||||
let cuda_storage = mul_mat_vec_via_q8_1(
|
|
||||||
&xs.data,
|
|
||||||
&y.slice(..),
|
|
||||||
/* dtype */ GgmlDType::Q4_0,
|
|
||||||
/* ncols */ ncols,
|
|
||||||
/* nrows */ 1,
|
|
||||||
/* b_size */ 1,
|
|
||||||
&dev,
|
|
||||||
)?;
|
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
|
||||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
|
||||||
assert_eq!(vs.len(), 1);
|
|
||||||
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
|
||||||
// Q8 means 1/256 precision.
|
|
||||||
assert_eq!(vs[0], 5561664.5);
|
|
||||||
|
|
||||||
let cuda_storage = dequantize_mul_mat_vec(
|
|
||||||
&xs.data,
|
|
||||||
&y.slice(..),
|
|
||||||
/* dtype */ GgmlDType::Q4_0,
|
|
||||||
/* ncols */ ncols,
|
|
||||||
/* nrows */ 1,
|
|
||||||
&dev,
|
|
||||||
)?;
|
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
|
||||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
|
||||||
assert_eq!(vs.len(), 1);
|
|
||||||
assert_eq!(vs[0], 5561851.0);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn cuda_mm_q8_1() -> Result<()> {
|
|
||||||
let dev = CudaDevice::new(0)?;
|
|
||||||
let ncols = 256;
|
|
||||||
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
|
||||||
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
|
||||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
|
||||||
let cuda_storage = mul_mat_via_q8_1(
|
|
||||||
&xs.data,
|
|
||||||
&y.slice(..),
|
|
||||||
/* dtype */ GgmlDType::Q4_0,
|
|
||||||
/* x_rows */ 4,
|
|
||||||
/* x_cols */ ncols,
|
|
||||||
/* y_rows */ ncols,
|
|
||||||
/* y_cols */ 4,
|
|
||||||
&dev,
|
|
||||||
)?;
|
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
|
||||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
|
||||||
|
|
||||||
/*
|
|
||||||
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
|
||||||
x @ x.t() / 16
|
|
||||||
tensor([[ 347480.0000, 869720.0000, 1391960.0000, 1914200.0000],
|
|
||||||
[ 869720.0000, 2440536.0000, 4011352.0000, 5582166.5000],
|
|
||||||
[ 1391960.0000, 4011352.0000, 6630742.0000, 9250132.0000],
|
|
||||||
[ 1914200.0000, 5582166.5000, 9250132.0000, 12918099.0000]])
|
|
||||||
*/
|
|
||||||
assert_eq!(vs.len(), 16);
|
|
||||||
assert_eq!(vs[0], 347604.0);
|
|
||||||
assert_eq!(vs[1], 888153.06);
|
|
||||||
assert_eq!(vs[4], 869780.7);
|
|
||||||
assert_eq!(vs[5], 2483145.0);
|
|
||||||
assert_eq!(vs[11], 9407368.0);
|
|
||||||
assert_eq!(vs[14], 9470856.0);
|
|
||||||
assert_eq!(vs[15], 13138824.0);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,50 +0,0 @@
|
|||||||
#![allow(unused)]
|
|
||||||
use super::GgmlDType;
|
|
||||||
use crate::{CudaDevice, CudaStorage, Error, Result};
|
|
||||||
|
|
||||||
pub struct QCudaStorage {
|
|
||||||
dtype: GgmlDType,
|
|
||||||
device: CudaDevice,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QCudaStorage {
|
|
||||||
pub fn zeros(_: &CudaDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
|
||||||
self.dtype
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn device(&self) -> &CudaDevice {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dequantize(&self, _elem_count: usize) -> Result<CudaStorage> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn storage_size_in_bytes(&self) -> usize {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fwd(
|
|
||||||
&self,
|
|
||||||
_self_shape: &crate::Shape,
|
|
||||||
_storage: &CudaStorage,
|
|
||||||
_layout: &crate::Layout,
|
|
||||||
) -> Result<(CudaStorage, crate::Shape)> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
|
||||||
_device: &CudaDevice,
|
|
||||||
_data: &[T],
|
|
||||||
) -> Result<super::QStorage> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
@ -1,50 +0,0 @@
|
|||||||
#![allow(unused)]
|
|
||||||
use super::GgmlDType;
|
|
||||||
use crate::{Error, MetalDevice, MetalStorage, Result};
|
|
||||||
|
|
||||||
pub struct QMetalStorage {
|
|
||||||
dtype: GgmlDType,
|
|
||||||
device: MetalDevice,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QMetalStorage {
|
|
||||||
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
|
||||||
self.dtype
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn device(&self) -> &MetalDevice {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn storage_size_in_bytes(&self) -> usize {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fwd(
|
|
||||||
&self,
|
|
||||||
_self_shape: &crate::Shape,
|
|
||||||
_storage: &MetalStorage,
|
|
||||||
_layout: &crate::Layout,
|
|
||||||
) -> Result<(MetalStorage, crate::Shape)> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
|
||||||
_device: &MetalDevice,
|
|
||||||
_data: &[T],
|
|
||||||
) -> Result<super::QStorage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
@ -1,5 +1,7 @@
|
|||||||
//! Support for the GGML file format.
|
//! Support for the GGML file format.
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
use super::metal::load_quantized_metal;
|
||||||
use super::{k_quants, GgmlDType, QStorage};
|
use super::{k_quants, GgmlDType, QStorage};
|
||||||
use crate::{Device, Result};
|
use crate::{Device, Result};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
@ -128,8 +130,13 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
|||||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||||
let data: QStorage = match device {
|
let data: QStorage = match device {
|
||||||
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
|
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
|
||||||
Device::Metal(metal) => super::metal::load_quantized(metal, data)?,
|
#[cfg(feature = "metal")]
|
||||||
Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?,
|
Device::Metal(metal) => load_quantized_metal(metal, data)?,
|
||||||
|
#[cfg(not(feature = "metal"))]
|
||||||
|
Device::Metal(_metal) => {
|
||||||
|
crate::bail!("Metal backend requires `metal` feature")
|
||||||
|
}
|
||||||
|
device => unimplemented!("Implement quantized tensor for device {device:?}"),
|
||||||
};
|
};
|
||||||
super::QTensor::new(data, dims)
|
super::QTensor::new(data, dims)
|
||||||
}
|
}
|
||||||
@ -226,7 +233,6 @@ pub struct Content {
|
|||||||
pub hparams: HParams,
|
pub hparams: HParams,
|
||||||
pub vocab: Vocab,
|
pub vocab: Vocab,
|
||||||
pub tensors: HashMap<String, super::QTensor>,
|
pub tensors: HashMap<String, super::QTensor>,
|
||||||
pub device: Device,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Content {
|
impl Content {
|
||||||
@ -246,13 +252,11 @@ impl Content {
|
|||||||
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||||
tensors.insert(name, tensor);
|
tensors.insert(name, tensor);
|
||||||
}
|
}
|
||||||
let device = device.clone();
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
magic,
|
magic,
|
||||||
hparams,
|
hparams,
|
||||||
vocab,
|
vocab,
|
||||||
tensors,
|
tensors,
|
||||||
device,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
use super::{GgmlDType, QStorage};
|
use super::{GgmlDType, QStorage};
|
||||||
use crate::backend::BackendStorage;
|
use crate::{DType, MetalDevice, MetalStorage, Result};
|
||||||
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
|
|
||||||
use metal::Buffer;
|
use metal::Buffer;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@ -11,31 +10,23 @@ pub struct QMetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl QMetalStorage {
|
impl QMetalStorage {
|
||||||
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
|
|
||||||
let size = elem_count * dtype.type_size() / dtype.block_size();
|
|
||||||
let buffer = device.allocate_zeros(size)?;
|
|
||||||
Ok(Self {
|
|
||||||
buffer,
|
|
||||||
device: device.clone(),
|
|
||||||
dtype,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
pub fn dtype(&self) -> GgmlDType {
|
||||||
self.dtype
|
self.dtype
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn device(&self) -> &MetalDevice {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn buffer(&self) -> &Buffer {
|
pub fn buffer(&self) -> &Buffer {
|
||||||
&self.buffer
|
&self.buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: GgmlDType) -> Self {
|
||||||
use crate::quantized::k_quants::GgmlType;
|
Self {
|
||||||
|
device,
|
||||||
|
buffer,
|
||||||
|
dtype,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
||||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
command_buffer.set_label("to_cpu");
|
command_buffer.set_label("to_cpu");
|
||||||
@ -45,73 +36,87 @@ impl QMetalStorage {
|
|||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
self.device.wait_until_completed()?;
|
self.device.wait_until_completed()?;
|
||||||
let mut out = vec![0.0; elem_count];
|
let mut out = vec![0.0; elem_count];
|
||||||
let block_len = elem_count / self.dtype.block_size();
|
|
||||||
match self.dtype {
|
match self.dtype {
|
||||||
GgmlDType::F32 => {
|
GgmlDType::F32 => {
|
||||||
let vec: Vec<f32> = read_to_vec(&buffer, block_len);
|
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
f32::to_float(&vec, &mut out)?;
|
f32::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::F16 => {
|
GgmlDType::F16 => {
|
||||||
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
|
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
half::f16::to_float(&vec, &mut out)?;
|
half::f16::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q4_0 => {
|
GgmlDType::Q4_0 => {
|
||||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
|
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q4_1 => {
|
GgmlDType::Q4_1 => {
|
||||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
|
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q5_0 => {
|
GgmlDType::Q5_0 => {
|
||||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
|
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q5_1 => {
|
GgmlDType::Q5_1 => {
|
||||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
|
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q8_0 => {
|
GgmlDType::Q8_0 => {
|
||||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
|
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q8_1 => {
|
GgmlDType::Q8_1 => {
|
||||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
|
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q2K => {
|
GgmlDType::Q2K => {
|
||||||
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
|
let vec: Vec<crate::quantized::BlockQ2K> =
|
||||||
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q3K => {
|
GgmlDType::Q3K => {
|
||||||
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
|
let vec: Vec<crate::quantized::BlockQ3K> =
|
||||||
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q4K => {
|
GgmlDType::Q4K => {
|
||||||
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
|
let vec: Vec<crate::quantized::BlockQ4K> =
|
||||||
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q5K => {
|
GgmlDType::Q5K => {
|
||||||
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
|
let vec: Vec<crate::quantized::BlockQ5K> =
|
||||||
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q6K => {
|
GgmlDType::Q6K => {
|
||||||
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
|
let vec: Vec<crate::quantized::BlockQ6K> =
|
||||||
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
GgmlDType::Q8K => {
|
GgmlDType::Q8K => {
|
||||||
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
|
let vec: Vec<crate::quantized::BlockQ8K> =
|
||||||
|
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||||
|
use crate::quantized::k_quants::GgmlType;
|
||||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let buffer = self.device.new_buffer_with_data(&out)?;
|
let buffer = self.device.new_buffer_with_data(&out)?;
|
||||||
Ok(MetalStorage::new(
|
Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
|
||||||
buffer,
|
|
||||||
self.device.clone(),
|
|
||||||
elem_count,
|
|
||||||
DType::F32,
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
||||||
@ -125,65 +130,9 @@ impl QMetalStorage {
|
|||||||
self.buffer = buffer;
|
self.buffer = buffer;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn storage_size_in_bytes(&self) -> usize {
|
|
||||||
self.buffer.length() as usize
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fwd(
|
|
||||||
&self,
|
|
||||||
self_shape: &Shape,
|
|
||||||
storage: &MetalStorage,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
) -> Result<(MetalStorage, Shape)> {
|
|
||||||
use crate::MetalError;
|
|
||||||
|
|
||||||
if !layout.is_contiguous() {
|
|
||||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
|
||||||
}
|
|
||||||
let src_shape = layout.shape();
|
|
||||||
// self is transposed so n is first then k.
|
|
||||||
if src_shape.rank() < 2 {
|
|
||||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
|
||||||
}
|
|
||||||
let (n, k) = self_shape.dims2()?;
|
|
||||||
let mut dst_shape = src_shape.dims().to_vec();
|
|
||||||
|
|
||||||
// We always use a single batch dimension and stack all the tensors in the batch on the
|
|
||||||
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
|
|
||||||
// properly.
|
|
||||||
let (b, m) = match dst_shape.len() {
|
|
||||||
3 => (1, dst_shape[0] * dst_shape[1]),
|
|
||||||
2 => (1, dst_shape[0]),
|
|
||||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
|
||||||
};
|
|
||||||
let last_k = dst_shape.pop().unwrap();
|
|
||||||
if last_k != k {
|
|
||||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
|
|
||||||
}
|
|
||||||
dst_shape.push(n);
|
|
||||||
let dst_shape = Shape::from(dst_shape);
|
|
||||||
let device = storage.device().clone();
|
|
||||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
|
||||||
let command_buffer = device.command_buffer()?;
|
|
||||||
candle_metal_kernels::call_quantized_matmul_t(
|
|
||||||
device.device(),
|
|
||||||
&command_buffer,
|
|
||||||
device.kernels(),
|
|
||||||
self.dtype.into(),
|
|
||||||
(b, m, n, k),
|
|
||||||
storage.buffer(),
|
|
||||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
|
||||||
&self.buffer,
|
|
||||||
&dst,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
|
|
||||||
Ok((dst_storage, dst_shape))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
||||||
device: &MetalDevice,
|
device: &MetalDevice,
|
||||||
data: &[T],
|
data: &[T],
|
||||||
) -> Result<QStorage> {
|
) -> Result<QStorage> {
|
||||||
@ -202,24 +151,3 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
|||||||
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
||||||
slice.to_vec()
|
slice.to_vec()
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
|
||||||
fn from(value: GgmlDType) -> Self {
|
|
||||||
match value {
|
|
||||||
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
|
||||||
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
|
||||||
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
|
||||||
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
|
||||||
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
|
||||||
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
|
||||||
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
|
||||||
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
|
||||||
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
|
||||||
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
|
||||||
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
|
||||||
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
|
||||||
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
|
||||||
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,27 +1,16 @@
|
|||||||
|
#[cfg(feature = "metal")]
|
||||||
|
use crate::{backend::BackendStorage, DType};
|
||||||
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
||||||
use k_quants::*;
|
use k_quants::*;
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
|
|
||||||
#[cfg(target_feature = "avx")]
|
#[cfg(target_feature = "avx")]
|
||||||
pub mod avx;
|
pub mod avx;
|
||||||
mod dummy_cuda;
|
|
||||||
mod dummy_metal;
|
|
||||||
pub mod ggml_file;
|
pub mod ggml_file;
|
||||||
pub mod gguf_file;
|
pub mod gguf_file;
|
||||||
pub mod k_quants;
|
pub mod k_quants;
|
||||||
#[cfg(feature = "metal")]
|
#[cfg(feature = "metal")]
|
||||||
pub mod metal;
|
pub mod metal;
|
||||||
#[cfg(not(feature = "metal"))]
|
|
||||||
mod metal {
|
|
||||||
pub use super::dummy_metal::*;
|
|
||||||
}
|
|
||||||
#[cfg(feature = "cuda")]
|
|
||||||
pub mod cuda;
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
|
||||||
mod cuda {
|
|
||||||
pub use super::dummy_cuda::*;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
pub mod neon;
|
pub mod neon;
|
||||||
#[cfg(target_feature = "simd128")]
|
#[cfg(target_feature = "simd128")]
|
||||||
@ -43,13 +32,22 @@ impl Device {
|
|||||||
let storage = dtype.cpu_zeros(elem_count);
|
let storage = dtype.cpu_zeros(elem_count);
|
||||||
Ok(QStorage::Cpu(storage))
|
Ok(QStorage::Cpu(storage))
|
||||||
}
|
}
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
Device::Metal(metal) => {
|
Device::Metal(metal) => {
|
||||||
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
|
let size = elem_count * dtype.type_size() / dtype.block_size();
|
||||||
Ok(QStorage::Metal(storage))
|
let buffer = metal.allocate_zeros(size)?;
|
||||||
|
Ok(QStorage::Metal(metal::QMetalStorage::new(
|
||||||
|
buffer,
|
||||||
|
metal.clone(),
|
||||||
|
dtype,
|
||||||
|
)))
|
||||||
}
|
}
|
||||||
Device::Cuda(cuda) => {
|
#[cfg(not(feature = "metal"))]
|
||||||
let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
|
Device::Metal(_metal) => {
|
||||||
Ok(QStorage::Cuda(storage))
|
crate::bail!("Metal feature not activated");
|
||||||
|
}
|
||||||
|
Device::Cuda(_cuda) => {
|
||||||
|
crate::bail!("Cuda ggml quantization not supported");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -57,40 +55,32 @@ impl Device {
|
|||||||
|
|
||||||
pub enum QStorage {
|
pub enum QStorage {
|
||||||
Cpu(Box<dyn QuantizedType>),
|
Cpu(Box<dyn QuantizedType>),
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
Metal(metal::QMetalStorage),
|
Metal(metal::QMetalStorage),
|
||||||
Cuda(cuda::QCudaStorage),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl QStorage {
|
impl QStorage {
|
||||||
fn block_size(&self) -> usize {
|
fn block_size(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(storage) => storage.block_size(),
|
QStorage::Cpu(storage) => storage.block_size(),
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
QStorage::Metal(storage) => storage.dtype().block_size(),
|
QStorage::Metal(storage) => storage.dtype().block_size(),
|
||||||
QStorage::Cuda(storage) => storage.dtype().block_size(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dtype(&self) -> GgmlDType {
|
fn dtype(&self) -> GgmlDType {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(storage) => storage.dtype(),
|
QStorage::Cpu(storage) => storage.dtype(),
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
QStorage::Metal(storage) => storage.dtype(),
|
QStorage::Metal(storage) => storage.dtype(),
|
||||||
QStorage::Cuda(storage) => storage.dtype(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn device(&self) -> Device {
|
|
||||||
match self {
|
|
||||||
QStorage::Cpu(_storage) => Device::Cpu,
|
|
||||||
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
|
|
||||||
QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn size_in_bytes(&self) -> usize {
|
fn size_in_bytes(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
||||||
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
|
#[cfg(feature = "metal")]
|
||||||
QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
|
QStorage::Metal(storage) => storage.buffer().length() as usize,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,8 +89,8 @@ impl QStorage {
|
|||||||
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
||||||
storage.from_float(src.as_slice::<f32>()?)?;
|
storage.from_float(src.as_slice::<f32>()?)?;
|
||||||
}
|
}
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
||||||
(QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
|
|
||||||
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -109,8 +99,8 @@ impl QStorage {
|
|||||||
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
||||||
QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,7 +112,8 @@ impl QStorage {
|
|||||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||||
Ok(Cow::from(data))
|
Ok(Cow::from(data))
|
||||||
}
|
}
|
||||||
QStorage::Metal(_) | QStorage::Cuda(_) => {
|
#[cfg(feature = "metal")]
|
||||||
|
QStorage::Metal(_storage) => {
|
||||||
crate::bail!("not implemented");
|
crate::bail!("not implemented");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -345,10 +336,6 @@ impl QTensor {
|
|||||||
self.storage.dtype()
|
self.storage.dtype()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn device(&self) -> Device {
|
|
||||||
self.storage.device()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn rank(&self) -> usize {
|
pub fn rank(&self) -> usize {
|
||||||
self.shape.rank()
|
self.shape.rank()
|
||||||
}
|
}
|
||||||
@ -398,7 +385,7 @@ impl QMatMul {
|
|||||||
_ => DEQUANTIZE_ALL.with(|b| *b),
|
_ => DEQUANTIZE_ALL.with(|b| *b),
|
||||||
};
|
};
|
||||||
let t = if dequantize {
|
let t = if dequantize {
|
||||||
let tensor = qtensor.dequantize(&qtensor.device())?;
|
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||||
Self::Tensor(tensor)
|
Self::Tensor(tensor)
|
||||||
} else {
|
} else {
|
||||||
Self::QTensor(qtensor)
|
Self::QTensor(qtensor)
|
||||||
@ -440,7 +427,8 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
#[allow(clippy::infallible_destructuring_match)]
|
#[allow(clippy::infallible_destructuring_match)]
|
||||||
let self_storage = match &self.storage {
|
let self_storage = match &self.storage {
|
||||||
QStorage::Cpu(storage) => storage,
|
QStorage::Cpu(storage) => storage,
|
||||||
QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
|
#[cfg(feature = "metal")]
|
||||||
|
_ => crate::bail!("Invalid storage"),
|
||||||
};
|
};
|
||||||
let slice = storage.as_slice::<f32>()?;
|
let slice = storage.as_slice::<f32>()?;
|
||||||
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||||
@ -449,28 +437,79 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
fn metal_fwd(
|
fn metal_fwd(
|
||||||
&self,
|
&self,
|
||||||
storage: &crate::MetalStorage,
|
storage: &crate::MetalStorage,
|
||||||
layout: &crate::Layout,
|
layout: &crate::Layout,
|
||||||
) -> Result<(crate::MetalStorage, Shape)> {
|
) -> Result<(crate::MetalStorage, Shape)> {
|
||||||
let self_storage = match &self.storage {
|
use crate::MetalError;
|
||||||
QStorage::Metal(metal) => metal,
|
|
||||||
|
if !layout.is_contiguous() {
|
||||||
|
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||||
|
}
|
||||||
|
let src_shape = layout.shape();
|
||||||
|
// self is transposed so n is first then k.
|
||||||
|
if src_shape.rank() < 2 {
|
||||||
|
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||||
|
}
|
||||||
|
let (n, k) = self.shape.dims2()?;
|
||||||
|
let mut dst_shape = src_shape.dims().to_vec();
|
||||||
|
|
||||||
|
let (b, m) = match dst_shape.len() {
|
||||||
|
3 => (dst_shape[0], dst_shape[1]),
|
||||||
|
2 => (1, dst_shape[0]),
|
||||||
|
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||||
|
};
|
||||||
|
let last_k = dst_shape.pop().unwrap();
|
||||||
|
if last_k != k {
|
||||||
|
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||||
|
}
|
||||||
|
dst_shape.push(n);
|
||||||
|
let dst_shape = Shape::from(dst_shape);
|
||||||
|
let device = storage.device().clone();
|
||||||
|
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||||
|
let (buffer, dtype) = match &self.storage {
|
||||||
|
QStorage::Metal(metal) => (metal.buffer(), metal.dtype()),
|
||||||
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
||||||
};
|
};
|
||||||
self_storage.fwd(&self.shape, storage, layout)
|
let command_buffer = device.command_buffer()?;
|
||||||
|
candle_metal_kernels::call_quantized_matmul_t(
|
||||||
|
device.device(),
|
||||||
|
&command_buffer,
|
||||||
|
device.kernels(),
|
||||||
|
dtype.into(),
|
||||||
|
(b, m, n, k),
|
||||||
|
storage.buffer(),
|
||||||
|
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||||
|
buffer,
|
||||||
|
&dst,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
||||||
|
Ok((dst_storage, dst_shape))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn cuda_fwd(
|
#[cfg(feature = "metal")]
|
||||||
&self,
|
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
||||||
storage: &crate::CudaStorage,
|
fn from(value: GgmlDType) -> Self {
|
||||||
layout: &crate::Layout,
|
match value {
|
||||||
) -> Result<(crate::CudaStorage, Shape)> {
|
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
||||||
let self_storage = match &self.storage {
|
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
||||||
QStorage::Cuda(cuda) => cuda,
|
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
||||||
_ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
|
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
||||||
};
|
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
||||||
self_storage.fwd(&self.shape, storage, layout)
|
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
||||||
|
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
||||||
|
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
||||||
|
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
||||||
|
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
||||||
|
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
||||||
|
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
||||||
|
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
||||||
|
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -171,7 +171,7 @@ impl Shape {
|
|||||||
}
|
}
|
||||||
let mut acc = 1;
|
let mut acc = 1;
|
||||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
|
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
|
||||||
if dim > 1 && stride != acc {
|
if stride != acc {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
acc *= dim;
|
acc *= dim;
|
||||||
@ -186,7 +186,7 @@ impl Shape {
|
|||||||
}
|
}
|
||||||
let mut acc = 1;
|
let mut acc = 1;
|
||||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
|
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
|
||||||
if dim > 1 && stride != acc {
|
if stride != acc {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
acc *= dim;
|
acc *= dim;
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
use crate::op::{self, CmpOp, ReduceOp};
|
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||||
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
|
||||||
|
|
||||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||||
// out of memory. Instead try_clone should be used.
|
// out of memory. Instead try_clone should be used.
|
||||||
@ -44,19 +43,9 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||||
let lhs_device = self.device();
|
let lhs = self.device().location();
|
||||||
let rhs_device = rhs.device();
|
let rhs = rhs.device().location();
|
||||||
let lhs = lhs_device.location();
|
if lhs != rhs {
|
||||||
let rhs = rhs_device.location();
|
|
||||||
let same_device = if self.device().is_metal() {
|
|
||||||
// On metal, we require the device to be exactly the same rather than
|
|
||||||
// having the same location. In cuda this is not necessary as all CudaDevice on the
|
|
||||||
// same GPU will use the same cuda stream.
|
|
||||||
lhs_device.same_device(&rhs_device)
|
|
||||||
} else {
|
|
||||||
lhs == rhs
|
|
||||||
};
|
|
||||||
if !same_device {
|
|
||||||
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
|
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
|
||||||
} else {
|
} else {
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -263,51 +252,6 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> {
|
|
||||||
match self {
|
|
||||||
Self::Cpu(storage) => c.cpu_fwd(storage, l),
|
|
||||||
Self::Cuda(storage) => c.cuda_fwd(storage, l),
|
|
||||||
Self::Metal(storage) => c.metal_fwd(storage, l),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn inplace_op2(
|
|
||||||
&mut self,
|
|
||||||
l1: &Layout,
|
|
||||||
t2: &Self,
|
|
||||||
l2: &Layout,
|
|
||||||
c: &dyn InplaceOp2,
|
|
||||||
) -> Result<()> {
|
|
||||||
self.same_device(t2, c.name())?;
|
|
||||||
match (self, t2) {
|
|
||||||
(Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2),
|
|
||||||
(Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2),
|
|
||||||
(Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2),
|
|
||||||
_ => unreachable!(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn inplace_op3(
|
|
||||||
&mut self,
|
|
||||||
l1: &Layout,
|
|
||||||
t2: &Self,
|
|
||||||
l2: &Layout,
|
|
||||||
t3: &Self,
|
|
||||||
l3: &Layout,
|
|
||||||
c: &dyn InplaceOp3,
|
|
||||||
) -> Result<()> {
|
|
||||||
self.same_device(t2, c.name())?;
|
|
||||||
self.same_device(t3, c.name())?;
|
|
||||||
match (self, t2, t3) {
|
|
||||||
(Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3),
|
|
||||||
(Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3),
|
|
||||||
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
|
|
||||||
c.metal_fwd(s1, l1, s2, l2, s3, l3)
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||||
match self {
|
match self {
|
||||||
Storage::Cpu(storage) => {
|
Storage::Cpu(storage) => {
|
||||||
@ -408,10 +352,6 @@ impl Storage {
|
|||||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||||
Ok(Self::Cuda(s))
|
Ok(Self::Cuda(s))
|
||||||
}
|
}
|
||||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
|
||||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
|
||||||
Ok(Self::Metal(s))
|
|
||||||
}
|
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
@ -757,32 +697,4 @@ impl Storage {
|
|||||||
.bt()),
|
.bt()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub(crate) fn copy2d(
|
|
||||||
&self,
|
|
||||||
dst: &mut Self,
|
|
||||||
d1: usize,
|
|
||||||
d2: usize,
|
|
||||||
src_s: usize,
|
|
||||||
dst_s: usize,
|
|
||||||
src_o: usize,
|
|
||||||
dst_o: usize,
|
|
||||||
) -> Result<()> {
|
|
||||||
match (self, dst) {
|
|
||||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
|
|
||||||
(Self::Cuda(src), Self::Cuda(dst)) => {
|
|
||||||
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
|
|
||||||
}
|
|
||||||
(Self::Metal(src), Self::Metal(dst)) => {
|
|
||||||
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
|
|
||||||
}
|
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
|
||||||
lhs: lhs.device().location(),
|
|
||||||
rhs: rhs.device().location(),
|
|
||||||
op: "copy2d",
|
|
||||||
}
|
|
||||||
.bt()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
//! Tensors are N-dimensional matrixes of elements using a single data type.
|
//! Tensors are N-dimensional matrixes of elements using a single data type.
|
||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
|
use crate::op::{
|
||||||
|
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
|
||||||
|
};
|
||||||
use crate::scalar::TensorOrScalar;
|
use crate::scalar::TensorOrScalar;
|
||||||
use crate::shape::{Dim, Dims};
|
use crate::shape::{Dim, Dims};
|
||||||
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||||
@ -79,9 +81,6 @@ macro_rules! unary_op {
|
|||||||
($fn_name:ident, $op_name:ident) => {
|
($fn_name:ident, $op_name:ident) => {
|
||||||
pub fn $fn_name(&self) -> Result<Self> {
|
pub fn $fn_name(&self) -> Result<Self> {
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
if shape.elem_count() == 0 {
|
|
||||||
return Ok(self.clone());
|
|
||||||
}
|
|
||||||
let storage = self
|
let storage = self
|
||||||
.storage()
|
.storage()
|
||||||
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||||
@ -95,9 +94,6 @@ macro_rules! binary_op {
|
|||||||
($fn_name:ident, $op_name:ident) => {
|
($fn_name:ident, $op_name:ident) => {
|
||||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||||
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||||
if shape.elem_count() == 0 {
|
|
||||||
return Ok(self.clone());
|
|
||||||
}
|
|
||||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||||
&*rhs.storage(),
|
&*rhs.storage(),
|
||||||
self.layout(),
|
self.layout(),
|
||||||
@ -120,9 +116,6 @@ macro_rules! binary_op_scalar {
|
|||||||
.broadcast_as(self.shape())?,
|
.broadcast_as(self.shape())?,
|
||||||
};
|
};
|
||||||
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
|
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
|
||||||
if self.elem_count() == 0 {
|
|
||||||
return Ok(self.clone());
|
|
||||||
}
|
|
||||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||||
&*rhs.storage(),
|
&*rhs.storage(),
|
||||||
self.layout(),
|
self.layout(),
|
||||||
@ -515,11 +508,9 @@ impl Tensor {
|
|||||||
unary_op!(gelu_erf, GeluErf);
|
unary_op!(gelu_erf, GeluErf);
|
||||||
unary_op!(erf, Erf);
|
unary_op!(erf, Erf);
|
||||||
unary_op!(relu, Relu);
|
unary_op!(relu, Relu);
|
||||||
unary_op!(silu, Silu);
|
|
||||||
unary_op!(ceil, Ceil);
|
unary_op!(ceil, Ceil);
|
||||||
unary_op!(floor, Floor);
|
unary_op!(floor, Floor);
|
||||||
unary_op!(round, Round);
|
unary_op!(round, Round);
|
||||||
unary_op!(sign, Sign);
|
|
||||||
|
|
||||||
/// Round element of the input tensor to the nearest integer.
|
/// Round element of the input tensor to the nearest integer.
|
||||||
///
|
///
|
||||||
@ -655,9 +646,6 @@ impl Tensor {
|
|||||||
/// # Ok::<(), candle_core::Error>(())
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
/// ```
|
/// ```
|
||||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||||
if self.elem_count() == 0 {
|
|
||||||
return Ok(self.clone());
|
|
||||||
}
|
|
||||||
let storage = self.storage().affine(self.layout(), mul, add)?;
|
let storage = self.storage().affine(self.layout(), mul, add)?;
|
||||||
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
|
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
|
||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
@ -665,9 +653,6 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
||||||
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
||||||
if self.elem_count() == 0 {
|
|
||||||
return Ok(self.clone());
|
|
||||||
}
|
|
||||||
let storage = self.storage().elu(self.layout(), alpha)?;
|
let storage = self.storage().elu(self.layout(), alpha)?;
|
||||||
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
|
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
|
||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
@ -675,15 +660,12 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Raise the tensor to some float exponent `e`.
|
/// Raise the tensor to some float exponent `e`.
|
||||||
pub fn powf(&self, e: f64) -> Result<Self> {
|
pub fn powf(&self, e: f64) -> Result<Self> {
|
||||||
if self.elem_count() == 0 {
|
|
||||||
return Ok(self.clone());
|
|
||||||
}
|
|
||||||
let storage = self.storage().powf(self.layout(), e)?;
|
let storage = self.storage().powf(self.layout(), e)?;
|
||||||
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
|
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
|
||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
||||||
if dim >= self.dims().len() {
|
if dim >= self.dims().len() {
|
||||||
Err(Error::DimOutOfRange {
|
Err(Error::DimOutOfRange {
|
||||||
shape: self.shape().clone(),
|
shape: self.shape().clone(),
|
||||||
@ -822,35 +804,6 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Roll the tensor input along the given dimension.
|
|
||||||
/// Elements that are shifted beyond the last position are re-introduced at the first position.
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # use candle_core::{Tensor, Device};
|
|
||||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
|
||||||
/// let tensor = tensor.roll(1, 0)?;
|
|
||||||
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
|
|
||||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
|
||||||
/// let tensor = tensor.roll(-1, 0)?;
|
|
||||||
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
|
|
||||||
/// # Ok::<(), candle_core::Error>(())
|
|
||||||
/// ```
|
|
||||||
pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
|
|
||||||
where
|
|
||||||
D: Dim + Clone,
|
|
||||||
{
|
|
||||||
let dim = dim.to_index(self.shape(), "roll")?;
|
|
||||||
let dim_size = self.dim(dim)?;
|
|
||||||
let shift = shift.rem_euclid(dim_size as i32) as usize;
|
|
||||||
if shift == 0 {
|
|
||||||
Ok(self.clone())
|
|
||||||
} else {
|
|
||||||
let a = self.narrow(dim, 0, dim_size - shift)?;
|
|
||||||
let b = self.narrow(dim, dim_size - shift, shift)?;
|
|
||||||
Tensor::cat(&[&b, &a], dim)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||||
/// input dimensions.
|
/// input dimensions.
|
||||||
///
|
///
|
||||||
@ -1032,7 +985,7 @@ impl Tensor {
|
|||||||
/// tensor also has three dimensions, `(batch, channels, target_size)`.
|
/// tensor also has three dimensions, `(batch, channels, target_size)`.
|
||||||
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
|
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
|
||||||
let (n, c, _l) = self.dims3()?;
|
let (n, c, _l) = self.dims3()?;
|
||||||
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
|
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
|
||||||
let storage = self
|
let storage = self
|
||||||
.storage()
|
.storage()
|
||||||
.upsample_nearest1d(self.layout(), target_size)?;
|
.upsample_nearest1d(self.layout(), target_size)?;
|
||||||
@ -1172,9 +1125,6 @@ impl Tensor {
|
|||||||
let n = b_dims[dim - 1];
|
let n = b_dims[dim - 1];
|
||||||
|
|
||||||
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
||||||
if c_shape.elem_count() == 0 || k == 0 {
|
|
||||||
return Tensor::zeros(c_shape, self.dtype(), self.device());
|
|
||||||
}
|
|
||||||
let batching: usize = a_dims[..dim - 2].iter().product();
|
let batching: usize = a_dims[..dim - 2].iter().product();
|
||||||
let batching_b: usize = b_dims[..dim - 2].iter().product();
|
let batching_b: usize = b_dims[..dim - 2].iter().product();
|
||||||
if k != k2 || batching != batching_b {
|
if k != k2 || batching != batching_b {
|
||||||
@ -1371,7 +1321,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
.bt())?
|
.bt())?
|
||||||
}
|
}
|
||||||
let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
|
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
||||||
self.storage()
|
self.storage()
|
||||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
let offset = start * src.dims()[1..].iter().product::<usize>();
|
let offset = start * src.dims()[1..].iter().product::<usize>();
|
||||||
@ -1903,9 +1853,9 @@ impl Tensor {
|
|||||||
/// this new node. The storage of this tensor is shared with the initial tensor.
|
/// this new node. The storage of this tensor is shared with the initial tensor.
|
||||||
///
|
///
|
||||||
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
||||||
pub fn detach(&self) -> Tensor {
|
pub fn detach(&self) -> Result<Tensor> {
|
||||||
if self.op.is_none() && !self.is_variable {
|
if self.op.is_none() && !self.is_variable {
|
||||||
self.clone()
|
Ok(self.clone())
|
||||||
} else {
|
} else {
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
@ -1916,7 +1866,7 @@ impl Tensor {
|
|||||||
dtype: self.dtype,
|
dtype: self.dtype,
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
};
|
};
|
||||||
Tensor(Arc::new(tensor_))
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2021,7 +1971,7 @@ impl Tensor {
|
|||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
} else {
|
} else {
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||||
self.storage()
|
self.storage()
|
||||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
let op = BackpropOp::new1(self, Op::Copy);
|
let op = BackpropOp::new1(self, Op::Copy);
|
||||||
@ -2029,21 +1979,11 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a tensor that is in row major order. This always makes a copy.
|
|
||||||
pub fn force_contiguous(&self) -> Result<Tensor> {
|
|
||||||
let shape = self.shape();
|
|
||||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
|
||||||
self.storage()
|
|
||||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
|
||||||
let op = BackpropOp::new1(self, Op::Copy);
|
|
||||||
Ok(from_storage(storage, shape.clone(), op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a variable based on the values currently stored in a tensor. The storage is always
|
/// Create a variable based on the values currently stored in a tensor. The storage is always
|
||||||
/// copied.
|
/// copied.
|
||||||
pub(crate) fn make_var(&self) -> Result<Tensor> {
|
pub(crate) fn make_var(&self) -> Result<Tensor> {
|
||||||
let shape = self.shape().clone();
|
let shape = self.shape().clone();
|
||||||
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
|
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||||
self.storage()
|
self.storage()
|
||||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
Ok(from_storage(storage, shape, BackpropOp::none(), true))
|
Ok(from_storage(storage, shape, BackpropOp::none(), true))
|
||||||
@ -2096,7 +2036,7 @@ impl Tensor {
|
|||||||
};
|
};
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
} else {
|
} else {
|
||||||
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
|
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||||
self.storage()
|
self.storage()
|
||||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
@ -2123,19 +2063,8 @@ impl Tensor {
|
|||||||
let dim = dim.to_index(self.shape(), "squeeze")?;
|
let dim = dim.to_index(self.shape(), "squeeze")?;
|
||||||
if dims[dim] == 1 {
|
if dims[dim] == 1 {
|
||||||
let mut dims = dims.to_vec();
|
let mut dims = dims.to_vec();
|
||||||
let mut strides = self.stride().to_vec();
|
|
||||||
dims.remove(dim);
|
dims.remove(dim);
|
||||||
strides.remove(dim);
|
self.reshape(dims)
|
||||||
let tensor_ = Tensor_ {
|
|
||||||
id: TensorId::new(),
|
|
||||||
storage: self.storage.clone(),
|
|
||||||
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
|
|
||||||
op: BackpropOp::new1(self, Op::Reshape),
|
|
||||||
is_variable: false,
|
|
||||||
dtype: self.dtype,
|
|
||||||
device: self.device.clone(),
|
|
||||||
};
|
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
|
||||||
} else {
|
} else {
|
||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
}
|
}
|
||||||
@ -2156,24 +2085,10 @@ impl Tensor {
|
|||||||
/// ```
|
/// ```
|
||||||
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
|
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||||
let mut dims = self.dims().to_vec();
|
let mut dims = self.dims().to_vec();
|
||||||
let mut strides = self.stride().to_vec();
|
|
||||||
let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
|
let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
|
||||||
// Cannot panic because to_index_plus_one already checks dimensions
|
// Cannot panic because to_index_plus_one already checks dimensions
|
||||||
dims.insert(dim, 1);
|
dims.insert(dim, 1);
|
||||||
// Any stride would work here, but we pick one so as to maximize the probability to remain
|
self.reshape(dims)
|
||||||
// C contiguous.
|
|
||||||
let stride = if dim < strides.len() { strides[dim] } else { 1 };
|
|
||||||
strides.insert(dim, stride);
|
|
||||||
let tensor_ = Tensor_ {
|
|
||||||
id: TensorId::new(),
|
|
||||||
storage: self.storage.clone(),
|
|
||||||
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
|
|
||||||
op: BackpropOp::new1(self, Op::Reshape),
|
|
||||||
is_variable: false,
|
|
||||||
dtype: self.dtype,
|
|
||||||
device: self.device.clone(),
|
|
||||||
};
|
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stacks two or more tensors along a particular dimension.
|
/// Stacks two or more tensors along a particular dimension.
|
||||||
@ -2204,6 +2119,152 @@ impl Tensor {
|
|||||||
Self::cat(&args, dim)
|
Self::cat(&args, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Concatenates two or more tensors along a particular dimension.
|
||||||
|
///
|
||||||
|
/// All tensors must of the same rank, and the output will have
|
||||||
|
/// the same rank
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use candle_core::{Tensor, DType, Device};
|
||||||
|
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||||
|
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||||
|
///
|
||||||
|
/// let c = Tensor::cat(&[&a, &b], 0)?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[4, 3]);
|
||||||
|
///
|
||||||
|
/// let c = Tensor::cat(&[&a, &b], 1)?;
|
||||||
|
/// assert_eq!(c.shape().dims(), &[2, 6]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
|
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||||
|
if args.is_empty() {
|
||||||
|
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||||
|
}
|
||||||
|
let arg0 = args[0].as_ref();
|
||||||
|
if args.len() == 1 {
|
||||||
|
return Ok(arg0.clone());
|
||||||
|
}
|
||||||
|
let dim = dim.to_index(arg0.shape(), "cat")?;
|
||||||
|
for arg in args {
|
||||||
|
arg.as_ref().check_dim(dim, "cat")?;
|
||||||
|
}
|
||||||
|
for (arg_idx, arg) in args.iter().enumerate() {
|
||||||
|
let arg = arg.as_ref();
|
||||||
|
if arg0.rank() != arg.rank() {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: arg0.rank(),
|
||||||
|
got: arg.rank(),
|
||||||
|
shape: arg.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
for (dim_idx, (v1, v2)) in arg0
|
||||||
|
.shape()
|
||||||
|
.dims()
|
||||||
|
.iter()
|
||||||
|
.zip(arg.shape().dims().iter())
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
if dim_idx != dim && v1 != v2 {
|
||||||
|
Err(Error::ShapeMismatchCat {
|
||||||
|
dim: dim_idx,
|
||||||
|
first_shape: arg0.shape().clone(),
|
||||||
|
n: arg_idx + 1,
|
||||||
|
nth_shape: arg.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if dim == 0 {
|
||||||
|
Self::cat0(args)
|
||||||
|
} else {
|
||||||
|
// TODO: Avoid these transpositions and have an implementation that works
|
||||||
|
// for dim != 0...
|
||||||
|
let args: Vec<Tensor> = args
|
||||||
|
.iter()
|
||||||
|
.map(|a| a.as_ref().transpose(0, dim))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let cat = Self::cat0(&args)?;
|
||||||
|
cat.transpose(0, dim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
||||||
|
if args.is_empty() {
|
||||||
|
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||||
|
}
|
||||||
|
let arg0 = args[0].as_ref();
|
||||||
|
if args.len() == 1 {
|
||||||
|
return Ok(arg0.clone());
|
||||||
|
}
|
||||||
|
let rank = arg0.rank();
|
||||||
|
let device = arg0.device();
|
||||||
|
let dtype = arg0.dtype();
|
||||||
|
let first_dims = arg0.shape().dims();
|
||||||
|
let mut cat_dims = first_dims.to_vec();
|
||||||
|
cat_dims[0] = 0;
|
||||||
|
let mut offsets = vec![0usize];
|
||||||
|
for (arg_idx, arg) in args.iter().enumerate() {
|
||||||
|
let arg = arg.as_ref();
|
||||||
|
if arg.dtype() != dtype {
|
||||||
|
Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: dtype,
|
||||||
|
rhs: arg.dtype(),
|
||||||
|
op: "cat",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
if arg.device().location() != device.location() {
|
||||||
|
Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: device.location(),
|
||||||
|
rhs: arg.device().location(),
|
||||||
|
op: "cat",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
if rank != arg.rank() {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: rank,
|
||||||
|
got: arg.rank(),
|
||||||
|
shape: arg.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
for (dim_idx, (v1, v2)) in arg0
|
||||||
|
.shape()
|
||||||
|
.dims()
|
||||||
|
.iter()
|
||||||
|
.zip(arg.shape().dims().iter())
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
if dim_idx == 0 {
|
||||||
|
cat_dims[0] += v2;
|
||||||
|
}
|
||||||
|
if dim_idx != 0 && v1 != v2 {
|
||||||
|
Err(Error::ShapeMismatchCat {
|
||||||
|
dim: dim_idx,
|
||||||
|
first_shape: arg0.shape().clone(),
|
||||||
|
n: arg_idx + 1,
|
||||||
|
nth_shape: arg.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
||||||
|
offsets.push(next_offset);
|
||||||
|
}
|
||||||
|
let shape = Shape::from(cat_dims);
|
||||||
|
let op = BackpropOp::new(args, |args| Op::Cat(args, 0));
|
||||||
|
let mut storage = device.zeros(&shape, dtype)?;
|
||||||
|
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||||
|
let arg = arg.as_ref();
|
||||||
|
arg.storage()
|
||||||
|
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
||||||
|
}
|
||||||
|
Ok(from_storage(storage, shape, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
|
/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
|
||||||
/// input tensor values and `right` elements after.
|
/// input tensor values and `right` elements after.
|
||||||
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
||||||
@ -2286,10 +2347,6 @@ impl Tensor {
|
|||||||
self.storage.read().unwrap()
|
self.storage.read().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
|
|
||||||
self.storage.write().unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we extend the visibility of this function to be usable outside of this crate, we should
|
// If we extend the visibility of this function to be usable outside of this crate, we should
|
||||||
// make it unsafe.
|
// make it unsafe.
|
||||||
pub(crate) fn storage_mut_and_layout(
|
pub(crate) fn storage_mut_and_layout(
|
||||||
@ -2311,6 +2368,96 @@ impl Tensor {
|
|||||||
std::ptr::eq(lhs, rhs)
|
std::ptr::eq(lhs, rhs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Applies a unary custom op without backward support
|
||||||
|
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
|
||||||
|
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
|
||||||
|
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a binary custom op without backward support
|
||||||
|
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
|
||||||
|
let (storage, shape) =
|
||||||
|
self.storage()
|
||||||
|
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
|
||||||
|
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a ternary custom op without backward support
|
||||||
|
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
|
||||||
|
let (storage, shape) = self.storage().apply_op3(
|
||||||
|
self.layout(),
|
||||||
|
&t2.storage(),
|
||||||
|
t2.layout(),
|
||||||
|
&t3.storage(),
|
||||||
|
t3.layout(),
|
||||||
|
c,
|
||||||
|
)?;
|
||||||
|
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a unary custom op.
|
||||||
|
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
|
||||||
|
let (storage, shape) = self
|
||||||
|
.storage()
|
||||||
|
.apply_op1(self.layout(), c.as_ref().as_ref())?;
|
||||||
|
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
||||||
|
Ok(from_storage(storage, shape, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
|
||||||
|
self.apply_op1_arc(Arc::new(Box::new(c)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a binary custom op.
|
||||||
|
pub fn apply_op2_arc(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let (storage, shape) = self.storage().apply_op2(
|
||||||
|
self.layout(),
|
||||||
|
&rhs.storage(),
|
||||||
|
rhs.layout(),
|
||||||
|
c.as_ref().as_ref(),
|
||||||
|
)?;
|
||||||
|
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
|
||||||
|
Ok(from_storage(storage, shape, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
|
||||||
|
self.apply_op2_arc(r, Arc::new(Box::new(c)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a ternary custom op.
|
||||||
|
pub fn apply_op3_arc(
|
||||||
|
&self,
|
||||||
|
t2: &Self,
|
||||||
|
t3: &Self,
|
||||||
|
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let (storage, shape) = self.storage().apply_op3(
|
||||||
|
self.layout(),
|
||||||
|
&t2.storage(),
|
||||||
|
t2.layout(),
|
||||||
|
&t3.storage(),
|
||||||
|
t3.layout(),
|
||||||
|
c.as_ref().as_ref(),
|
||||||
|
)?;
|
||||||
|
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
|
||||||
|
Op::CustomOp3(t1, t2, t3, c.clone())
|
||||||
|
});
|
||||||
|
Ok(from_storage(storage, shape, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
|
||||||
|
&self,
|
||||||
|
t2: &Self,
|
||||||
|
t3: &Self,
|
||||||
|
c: C,
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||||
|
}
|
||||||
|
|
||||||
/// Normalize a 'relative' axis value: positive values are kept, negative
|
/// Normalize a 'relative' axis value: positive values are kept, negative
|
||||||
/// values means counting the dimensions from the back.
|
/// values means counting the dimensions from the back.
|
||||||
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
||||||
|
@ -1,238 +0,0 @@
|
|||||||
use crate::{shape::Dim, Error, Result, Shape, Tensor};
|
|
||||||
|
|
||||||
impl Tensor {
|
|
||||||
/// Concatenates two or more tensors along a particular dimension.
|
|
||||||
///
|
|
||||||
/// All tensors must of the same rank, and the output will have
|
|
||||||
/// the same rank
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # use candle_core::{Tensor, DType, Device};
|
|
||||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
|
||||||
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
|
||||||
///
|
|
||||||
/// let c = Tensor::cat(&[&a, &b], 0)?;
|
|
||||||
/// assert_eq!(c.shape().dims(), &[4, 3]);
|
|
||||||
///
|
|
||||||
/// let c = Tensor::cat(&[&a, &b], 1)?;
|
|
||||||
/// assert_eq!(c.shape().dims(), &[2, 6]);
|
|
||||||
/// # Ok::<(), candle_core::Error>(())
|
|
||||||
/// ```
|
|
||||||
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
|
||||||
if args.is_empty() {
|
|
||||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
|
||||||
}
|
|
||||||
let arg0 = args[0].as_ref();
|
|
||||||
if args.len() == 1 {
|
|
||||||
return Ok(arg0.clone());
|
|
||||||
}
|
|
||||||
let dim = dim.to_index(arg0.shape(), "cat")?;
|
|
||||||
for arg in args {
|
|
||||||
arg.as_ref().check_dim(dim, "cat")?;
|
|
||||||
}
|
|
||||||
for (arg_idx, arg) in args.iter().enumerate() {
|
|
||||||
let arg = arg.as_ref();
|
|
||||||
if arg0.rank() != arg.rank() {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: arg0.rank(),
|
|
||||||
got: arg.rank(),
|
|
||||||
shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
for (dim_idx, (v1, v2)) in arg0
|
|
||||||
.shape()
|
|
||||||
.dims()
|
|
||||||
.iter()
|
|
||||||
.zip(arg.shape().dims().iter())
|
|
||||||
.enumerate()
|
|
||||||
{
|
|
||||||
if dim_idx != dim && v1 != v2 {
|
|
||||||
Err(Error::ShapeMismatchCat {
|
|
||||||
dim: dim_idx,
|
|
||||||
first_shape: arg0.shape().clone(),
|
|
||||||
n: arg_idx + 1,
|
|
||||||
nth_shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
|
|
||||||
if all_contiguous {
|
|
||||||
Self::cat_contiguous(args, dim)
|
|
||||||
} else if dim == 0 {
|
|
||||||
Self::cat0(args)
|
|
||||||
} else {
|
|
||||||
let args: Vec<Tensor> = args
|
|
||||||
.iter()
|
|
||||||
.map(|a| a.as_ref().transpose(0, dim))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let cat = Self::cat0(&args)?;
|
|
||||||
cat.transpose(0, dim)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
|
||||||
if args.is_empty() {
|
|
||||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
|
||||||
}
|
|
||||||
let arg0 = args[0].as_ref();
|
|
||||||
if args.len() == 1 {
|
|
||||||
return Ok(arg0.clone());
|
|
||||||
}
|
|
||||||
let rank = arg0.rank();
|
|
||||||
let device = arg0.device();
|
|
||||||
let dtype = arg0.dtype();
|
|
||||||
let first_dims = arg0.shape().dims();
|
|
||||||
let mut cat_dims = first_dims.to_vec();
|
|
||||||
cat_dims[0] = 0;
|
|
||||||
let mut offsets = vec![0usize];
|
|
||||||
for (arg_idx, arg) in args.iter().enumerate() {
|
|
||||||
let arg = arg.as_ref();
|
|
||||||
if arg.dtype() != dtype {
|
|
||||||
Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: dtype,
|
|
||||||
rhs: arg.dtype(),
|
|
||||||
op: "cat",
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
if arg.device().location() != device.location() {
|
|
||||||
Err(Error::DeviceMismatchBinaryOp {
|
|
||||||
lhs: device.location(),
|
|
||||||
rhs: arg.device().location(),
|
|
||||||
op: "cat",
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
if rank != arg.rank() {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: rank,
|
|
||||||
got: arg.rank(),
|
|
||||||
shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
for (dim_idx, (v1, v2)) in arg0
|
|
||||||
.shape()
|
|
||||||
.dims()
|
|
||||||
.iter()
|
|
||||||
.zip(arg.shape().dims().iter())
|
|
||||||
.enumerate()
|
|
||||||
{
|
|
||||||
if dim_idx == 0 {
|
|
||||||
cat_dims[0] += v2;
|
|
||||||
}
|
|
||||||
if dim_idx != 0 && v1 != v2 {
|
|
||||||
Err(Error::ShapeMismatchCat {
|
|
||||||
dim: dim_idx,
|
|
||||||
first_shape: arg0.shape().clone(),
|
|
||||||
n: arg_idx + 1,
|
|
||||||
nth_shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
|
||||||
offsets.push(next_offset);
|
|
||||||
}
|
|
||||||
let shape = Shape::from(cat_dims);
|
|
||||||
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
|
|
||||||
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
|
||||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
|
||||||
let arg = arg.as_ref();
|
|
||||||
arg.storage()
|
|
||||||
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
|
||||||
}
|
|
||||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
|
|
||||||
if args.is_empty() {
|
|
||||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
|
||||||
}
|
|
||||||
let arg0 = args[0].as_ref();
|
|
||||||
if args.len() == 1 {
|
|
||||||
return Ok(arg0.clone());
|
|
||||||
}
|
|
||||||
let rank = arg0.rank();
|
|
||||||
let device = arg0.device();
|
|
||||||
let dtype = arg0.dtype();
|
|
||||||
let first_dims = arg0.shape().dims();
|
|
||||||
let mut cat_dims = first_dims.to_vec();
|
|
||||||
cat_dims[dim] = 0;
|
|
||||||
for (arg_idx, arg) in args.iter().enumerate() {
|
|
||||||
let arg = arg.as_ref();
|
|
||||||
if arg.dtype() != dtype {
|
|
||||||
Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: dtype,
|
|
||||||
rhs: arg.dtype(),
|
|
||||||
op: "cat",
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
if arg.device().location() != device.location() {
|
|
||||||
Err(Error::DeviceMismatchBinaryOp {
|
|
||||||
lhs: device.location(),
|
|
||||||
rhs: arg.device().location(),
|
|
||||||
op: "cat",
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
if rank != arg.rank() {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: rank,
|
|
||||||
got: arg.rank(),
|
|
||||||
shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
for (dim_idx, (v1, v2)) in arg0
|
|
||||||
.shape()
|
|
||||||
.dims()
|
|
||||||
.iter()
|
|
||||||
.zip(arg.shape().dims().iter())
|
|
||||||
.enumerate()
|
|
||||||
{
|
|
||||||
if dim_idx == dim {
|
|
||||||
cat_dims[dim] += v2;
|
|
||||||
}
|
|
||||||
if dim_idx != dim && v1 != v2 {
|
|
||||||
Err(Error::ShapeMismatchCat {
|
|
||||||
dim: dim_idx,
|
|
||||||
first_shape: arg0.shape().clone(),
|
|
||||||
n: arg_idx + 1,
|
|
||||||
nth_shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let cat_target_dim_len = cat_dims[dim];
|
|
||||||
let block_size: usize = cat_dims.iter().skip(1 + dim).product();
|
|
||||||
let shape = Shape::from(cat_dims);
|
|
||||||
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
|
|
||||||
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
|
||||||
let mut dst_o = 0;
|
|
||||||
for arg in args.iter() {
|
|
||||||
let arg = arg.as_ref();
|
|
||||||
let arg_dims = arg.shape().dims();
|
|
||||||
let d1: usize = arg_dims.iter().take(dim).product();
|
|
||||||
let d2 = block_size * arg_dims[dim];
|
|
||||||
let dst_s = block_size * cat_target_dim_len;
|
|
||||||
let src_o = arg.layout().start_offset();
|
|
||||||
arg.storage().copy2d(
|
|
||||||
&mut storage,
|
|
||||||
d1,
|
|
||||||
d2,
|
|
||||||
/* src_s */ d2,
|
|
||||||
dst_s,
|
|
||||||
src_o,
|
|
||||||
dst_o,
|
|
||||||
)?;
|
|
||||||
dst_o += d2;
|
|
||||||
}
|
|
||||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
}
|
|
@ -107,10 +107,6 @@ impl Var {
|
|||||||
Ok(Self(inner))
|
Ok(Self(inner))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn as_detached_tensor(&self) -> Tensor {
|
|
||||||
self.0.detach()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn as_tensor(&self) -> &Tensor {
|
pub fn as_tensor(&self) -> &Tensor {
|
||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
|
@ -18,9 +18,6 @@ w_t = w.transpose(0, 1)
|
|||||||
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
||||||
print(res.shape)
|
print(res.shape)
|
||||||
print(res)
|
print(res)
|
||||||
res = torch.nn.functional.conv_transpose1d(t, w_t, groups=2)
|
|
||||||
print(res.shape)
|
|
||||||
print(res)
|
|
||||||
*/
|
*/
|
||||||
fn conv1d(dev: &Device) -> Result<()> {
|
fn conv1d(dev: &Device) -> Result<()> {
|
||||||
let t = Tensor::new(
|
let t = Tensor::new(
|
||||||
@ -53,11 +50,8 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||||
);
|
);
|
||||||
|
if dev.is_cpu() {
|
||||||
let w = w.transpose(0, 1)?;
|
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||||
// The CPU kernels applied in the contiguous and non contiguous cases are different.
|
|
||||||
for w in [w.clone(), w.contiguous()?] {
|
|
||||||
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 1)?;
|
|
||||||
assert_eq!(res.dims(), [1, 2, 7]);
|
assert_eq!(res.dims(), [1, 2, 7]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
@ -66,17 +60,6 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||||||
4.7076, -5.9745, -0.8276, 1.621
|
4.7076, -5.9745, -0.8276, 1.621
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 2)?;
|
|
||||||
assert_eq!(res.dims(), [1, 4, 7]);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
|
|
||||||
[
|
|
||||||
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
|
|
||||||
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
|
|
||||||
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
|
|
||||||
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -135,7 +118,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||||||
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
||||||
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
||||||
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
||||||
-0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||||
],
|
],
|
||||||
dev,
|
dev,
|
||||||
)?;
|
)?;
|
||||||
@ -163,9 +146,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||||
|
|
||||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||||
@ -190,7 +171,6 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
// Dilations.
|
// Dilations.
|
||||||
let res = t.conv2d(&w, 0, 1, 2, 1)?;
|
let res = t.conv2d(&w, 0, 1, 2, 1)?;
|
||||||
assert_eq!(res.dims(), [1, 2, 1, 1]);
|
assert_eq!(res.dims(), [1, 2, 1, 1]);
|
||||||
@ -229,7 +209,6 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -276,13 +255,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
[
|
[
|
||||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1640,
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
||||||
-0.0111, -0.1742, 0.0, 0.0, 0.0, 0.0, 2.6437, -2.0268, 1.1823, 0.0, 0.0, 0.0, 0.0,
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000,
|
||||||
3.2855, -1.0324, 0.2539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855,
|
||||||
0.0, 0.0, 0.0, 0.0
|
-1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
||||||
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -384,7 +363,6 @@ print(w.grad.shape)
|
|||||||
print(w.grad[0])
|
print(w.grad[0])
|
||||||
*/
|
*/
|
||||||
fn conv2d_grad(dev: &Device) -> Result<()> {
|
fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||||
// conv-transposes are not implemented for metal
|
|
||||||
use candle_core::Var;
|
use candle_core::Var;
|
||||||
let t = Var::from_slice(
|
let t = Var::from_slice(
|
||||||
&[
|
&[
|
||||||
@ -397,7 +375,7 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
|||||||
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
||||||
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
||||||
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
||||||
-0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||||
],
|
],
|
||||||
(1, 4, 5, 5),
|
(1, 4, 5, 5),
|
||||||
dev,
|
dev,
|
||||||
@ -582,154 +560,6 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
// Conv Transpose 2d Test
|
|
||||||
//tested against following python
|
|
||||||
|
|
||||||
// import torch
|
|
||||||
// torch.manual_seed(4242)
|
|
||||||
// padding = 4
|
|
||||||
// outpadding = 2
|
|
||||||
// dilation = 3
|
|
||||||
// stride = 3
|
|
||||||
// input = torch.randn((1, 4, 7, 5), requires_grad=True)
|
|
||||||
// kernel = torch.randn((4, 2, 3, 5), requires_grad=True)
|
|
||||||
// print("input", input.flatten())
|
|
||||||
// print("kernel", kernel.flatten())
|
|
||||||
// res = torch.nn.functional.conv_transpose2d(
|
|
||||||
// input,
|
|
||||||
// kernel,
|
|
||||||
// stride=stride,
|
|
||||||
// padding=padding,
|
|
||||||
// dilation=dilation,
|
|
||||||
// output_padding=outpadding,
|
|
||||||
// )
|
|
||||||
// res.retain_grad()
|
|
||||||
// print(res.shape)
|
|
||||||
// loss = (res**2).sum()
|
|
||||||
// print(loss)
|
|
||||||
// loss.backward()
|
|
||||||
// print(input.grad.shape)
|
|
||||||
// print("input grad", torch.round(input.grad, decimals=1))
|
|
||||||
// print(kernel.grad.shape)
|
|
||||||
// print("kernel grad", torch.round(kernel.grad.flatten(), decimals=1))
|
|
||||||
|
|
||||||
let padding = 4;
|
|
||||||
let outpadding = 2;
|
|
||||||
let dilation = 3;
|
|
||||||
let stride = 3;
|
|
||||||
|
|
||||||
let t = Var::from_slice(
|
|
||||||
&[
|
|
||||||
0.4056_f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997,
|
|
||||||
3.0616, 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843,
|
|
||||||
0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013,
|
|
||||||
-0.6836, 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130,
|
|
||||||
1.3123, 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071,
|
|
||||||
1.1586, 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090,
|
|
||||||
0.2049, 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323,
|
|
||||||
-1.3712, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742,
|
|
||||||
0.3790, -0.4431, -0.4720, -0.7890, 0.2620, 0.5411, -1.1715, -2.4997, 2.3249, -0.8912,
|
|
||||||
-0.4733, -0.5701, -2.8888, -1.4112, -0.5471, -0.9234, -1.1660, 0.4189, -0.7465,
|
|
||||||
-0.6473, 0.1402, 0.7875, 0.5377, -0.6779, -0.8088, -0.4864, -0.2312, 0.9279, 0.1264,
|
|
||||||
1.5480, 0.8265, -0.1025, 0.5138, -0.2512, 0.1576, 1.2705, 0.3641, -0.9325, 0.6451,
|
|
||||||
-0.8537, 0.2378, 0.1794, 0.2752, -0.3687, -1.1149, -0.1410, -0.5829, -0.0892, 1.4258,
|
|
||||||
-2.2789, 0.5270, 0.1825, 1.7007, -0.5263, -0.2954, 0.4440, 0.5537, 0.3492, 0.6186,
|
|
||||||
1.6475, 0.2219,
|
|
||||||
],
|
|
||||||
(1, 4, 7, 5),
|
|
||||||
dev,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
let w = Var::from_slice(
|
|
||||||
&[
|
|
||||||
-1.1744_f32, 0.3266, 2.5893, 1.0142, 0.1763, 0.7752, 0.6604, 0.2029, -0.2145, 0.7234,
|
|
||||||
-0.3441, -1.5400, -0.6333, 0.6613, 0.2083, 0.6230, -1.7002, 0.3393, 0.4049, 1.0762,
|
|
||||||
0.2723, 1.4181, 0.0029, -0.2122, 1.7668, 1.4168, 0.3320, -0.2719, 0.7932, -0.7204,
|
|
||||||
0.4447, 0.1211, 0.5908, 1.0089, -0.1646, 1.8033, -0.6286, 0.2016, -0.3370, 1.2555,
|
|
||||||
0.8009, -0.6488, -0.4652, -1.5685, 1.5860, 0.5583, 0.4623, 0.6026, 0.8828, 2.4990,
|
|
||||||
0.6811, -0.3369, 1.3320, 1.7669, -1.1067, 1.2958, -0.9415, -0.9655, -0.4462, 0.7181,
|
|
||||||
0.5181, -1.1658, -1.8467, -0.7763, 1.2769, 0.8651, 0.9890, 1.5092, 0.7207, -0.8481,
|
|
||||||
0.7417, 0.3375, -1.2685, 1.4572, 1.0915, 0.1093, -0.8550, -0.5831, -0.6309, -0.2509,
|
|
||||||
0.5220, -0.0914, 0.7900, 0.1096, 0.3258, 0.2723, -1.0942, -0.3393, -0.1653, 0.5732,
|
|
||||||
-0.8014, 1.8194, -1.9023, 0.2127, 1.8636, -0.8979, 0.1927, -0.2778, 0.3105, 0.0071,
|
|
||||||
-1.1823, 0.2476, -0.7178, -1.3821, 1.0769, -0.4376, -0.9967, -0.1227, 1.6197, -1.0604,
|
|
||||||
0.1372, 0.8141, -0.6163, 0.7304, -0.8285, 2.0636, -0.7176, 0.2495, -0.2581, -0.4478,
|
|
||||||
],
|
|
||||||
(4, 2, 3, 5),
|
|
||||||
dev,
|
|
||||||
)?;
|
|
||||||
let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?;
|
|
||||||
let loss = res.sqr()?.sum_all()?;
|
|
||||||
assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 2904.0);
|
|
||||||
let grads = loss.backward()?;
|
|
||||||
|
|
||||||
let grad_t = grads.get(&t).unwrap();
|
|
||||||
let grad_w = grads.get(&w).unwrap();
|
|
||||||
assert_eq!(grad_t.dims(), [1, 4, 7, 5]);
|
|
||||||
assert_eq!(grad_w.dims(), [4, 2, 3, 5]);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?,
|
|
||||||
[
|
|
||||||
// torch gets 89.1
|
|
||||||
-89.0, -135.3, 136.7, 102.0, -53.4, 117.9, 118.6, -43.9, -218.0, -58.5, -114.3, -150.0,
|
|
||||||
-15.6, 172.1, 66.3, -64.3, -27.9, -19.8, 31.7, 62.1, 5.5, 92.6, 28.2, -29.6, 55.9,
|
|
||||||
52.7, -72.7, -119.8, 53.8, -25.5, 128.8, 19.3, 68.0, 190.9, -64.1, -86.2, -111.2,
|
|
||||||
106.6, -67.7, 37.8, 115.9, 50.4, -77.7, -54.9, 22.3, -4.6, 89.8, 61.7, 122.4, 192.6,
|
|
||||||
-27.8, -104.6, 57.0, 166.4, 27.1, 6.1, 18.7, -93.2, 31.5, 168.2, -3.7, -99.5, -55.5,
|
|
||||||
-10.8, 17.5, 20.8, 16.9, 43.8, 42.0, -89.2, 18.8, -9.6, -84.1, 212.6, 19.7, -50.0,
|
|
||||||
-52.0, -40.0, -166.6, -73.2, -10.8, -73.3, 31.5, -23.4, -79.3, -27.0, -84.4, -42.9,
|
|
||||||
-20.3, 51.8, -16.7, 76.3, -120.5, -65.8, 96.5, -10.7, -45.9, -88.1, 65.4, -7.0, -1.5,
|
|
||||||
92.8, -25.1, -114.2, -5.8, -14.8, -51.2, -20.7, 54.2, -79.8, 47.7, -29.2, -8.8, 53.5,
|
|
||||||
-28.4, 85.0, -18.3, 107.0, 28.3, -71.8
|
|
||||||
]
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec3_round(&grad_t.i(0)?, 1)?,
|
|
||||||
[
|
|
||||||
[
|
|
||||||
[32.3, -41.6, -24.0, 14.1, 17.6],
|
|
||||||
[-11.8, 72.5, 87.6, 46.4, 61.5],
|
|
||||||
[115.0, 108.5, -48.6, -63.4, -50.0],
|
|
||||||
[51.3, 5.4, 31.3, 91.1, -30.9],
|
|
||||||
[52.7, 92.8, -68.0, -47.0, 83.0],
|
|
||||||
// pytorch gets -107.1
|
|
||||||
[-10.2, -107.0, -5.4, 213.1, -31.4],
|
|
||||||
[-2.4, 65.1, 9.2, -146.2, -24.2]
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[-72.6, -63.9, -61.9, 45.3, 33.0],
|
|
||||||
[79.3, -0.5, -26.2, 78.2, 42.7],
|
|
||||||
[90.9, 141.6, 40.1, -62.7, 37.0],
|
|
||||||
[32.8, 198.2, -0.8, -31.1, 27.3],
|
|
||||||
// torch gets 48.0
|
|
||||||
[34.5, 34.9, -47.9, 127.6, -12.3],
|
|
||||||
[-61.4, -3.2, -2.9, -10.9, -16.6],
|
|
||||||
[74.6, 60.1, -68.9, 34.5, -50.4]
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[37.5, -56.9, -43.6, -13.5, -9.9],
|
|
||||||
[40.0, 97.3, 28.6, 14.2, -30.1],
|
|
||||||
[-22.3, -126.3, -68.8, -8.2, 26.1],
|
|
||||||
[-32.9, 37.3, 108.5, -54.8, 29.6],
|
|
||||||
[34.9, -176.9, -125.0, -28.3, -13.9],
|
|
||||||
[-54.9, 142.6, 62.1, -80.4, -65.6],
|
|
||||||
[7.4, -91.1, -67.6, 35.0, 39.7]
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[-57.2, -40.9, -10.1, 32.6, 29.4],
|
|
||||||
[18.7, -18.0, 29.5, -1.2, 59.2],
|
|
||||||
[-14.0, -74.4, 19.8, -117.0, 58.2],
|
|
||||||
[-21.8, 163.5, -71.1, -99.0, 80.9],
|
|
||||||
[-58.9, -10.9, 93.8, -139.6, 98.0],
|
|
||||||
// torch gets 54.5
|
|
||||||
[-54.4, 135.3, 6.0, -79.1, 134.6],
|
|
||||||
[27.5, -76.0, 43.4, -2.8, -7.8]
|
|
||||||
]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,34 +112,3 @@ fn custom_op1_with_backward() -> Result<()> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
impl candle_core::InplaceOp1 for Elu {
|
|
||||||
fn name(&self) -> &'static str {
|
|
||||||
"elu"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cpu_fwd(&self, s: &mut CpuStorage, _l: &Layout) -> Result<()> {
|
|
||||||
let alpha = self.alpha;
|
|
||||||
match s {
|
|
||||||
CpuStorage::BF16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
|
||||||
CpuStorage::F16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
|
||||||
CpuStorage::F32(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
|
||||||
CpuStorage::F64(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
|
||||||
_ => candle_core::bail!("unsupported dtype for inplace elu"),
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn inplace_op1() -> Result<()> {
|
|
||||||
let cpu = &Device::Cpu;
|
|
||||||
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
|
|
||||||
let t = (t - 5.)?;
|
|
||||||
t.inplace_op1(&Elu { alpha: 1. })?;
|
|
||||||
assert_eq!(
|
|
||||||
to_vec1_round(&t, 4)?,
|
|
||||||
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
Binary file not shown.
@ -1,4 +1,3 @@
|
|||||||
#![allow(clippy::approx_constant)]
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
||||||
|
|
||||||
@ -97,24 +96,24 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec1_round(&y, 4)?,
|
y.to_vec1::<f32>()?,
|
||||||
[20.0855, 2.7183, 54.5982, 1.1618]
|
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec1_round(grad_x, 4)?,
|
grad_x.to_vec1::<f32>()?,
|
||||||
[20.0855, 2.7183, 54.5982, 1.1618]
|
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
||||||
);
|
);
|
||||||
let y = x.exp()?.sqr()?;
|
let y = x.exp()?.sqr()?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec1_round(&y, 3)?,
|
y.to_vec1::<f32>()?,
|
||||||
[403.429, 7.389, 2980.958, 1.35]
|
[403.4288, 7.3890557, 2980.9578, 1.3498588]
|
||||||
);
|
);
|
||||||
// exp(x)^2 = exp(2*x)
|
// exp(x)^2 = exp(2*x)
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec1_round(grad_x, 2)?,
|
grad_x.to_vec1::<f32>()?,
|
||||||
[806.86, 14.78, 5961.92, 2.7]
|
[806.8576, 14.778111, 5961.9155, 2.6997175]
|
||||||
);
|
);
|
||||||
let y = x.sin()?;
|
let y = x.sin()?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
@ -262,7 +261,6 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
let y = elu_x.elu(2.)?;
|
let y = elu_x.elu(2.)?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec1_round(&y, 4)?,
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
[-1.2642, 0.0000, -1.7293, 3.0000]
|
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||||
@ -272,51 +270,19 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||||
);
|
);
|
||||||
|
|
||||||
// testing compared to pytorch nn.Silu()
|
|
||||||
let y = x.silu()?;
|
|
||||||
let grads = y.backward()?;
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(&y, 4)?,
|
|
||||||
[2.8577, 0.7311, 3.9281, 0.0806]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(grad_x, 4)?,
|
|
||||||
[1.0881, 0.9277, 1.0527, 0.5747],
|
|
||||||
);
|
|
||||||
|
|
||||||
if device.is_cpu() {
|
|
||||||
let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?;
|
|
||||||
let y = x.interpolate1d(12)?.reshape(36)?;
|
|
||||||
|
|
||||||
let z = Tensor::new(
|
|
||||||
&[
|
|
||||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16.,
|
|
||||||
17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.,
|
|
||||||
33., 34., 35., 36.,
|
|
||||||
],
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
|
||||||
let grads = loss.backward()?;
|
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec3_round(grad_x, 4)?,
|
|
||||||
[[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// manually checked: see comments
|
// manually checked: see comments
|
||||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
let z = Tensor::new(
|
let z = Tensor::new(
|
||||||
&[
|
&[
|
||||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
|
1_f32, 02., 03., 04., 05., 06.,
|
||||||
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
|
07., 08., 09., 10., 11., 12.,
|
||||||
35., 36.,
|
13., 14., 15., 16., 17., 18.,
|
||||||
|
19., 20., 21., 22., 23., 24.,
|
||||||
|
25., 26., 27., 28., 29., 30.,
|
||||||
|
31., 32., 33., 34., 35., 36.,
|
||||||
],
|
],
|
||||||
device,
|
device,
|
||||||
)?;
|
)?;
|
||||||
@ -347,11 +313,15 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
|
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
|
||||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
let z = Tensor::new(
|
let z = Tensor::new(
|
||||||
&[
|
&[
|
||||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
|
1_f32, 02., 03., 04., 05., 06.,
|
||||||
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
|
07., 08., 09., 10., 11., 12.,
|
||||||
35., 36.,
|
13., 14., 15., 16., 17., 18.,
|
||||||
|
19., 20., 21., 22., 23., 24.,
|
||||||
|
25., 26., 27., 28., 29., 30.,
|
||||||
|
31., 32., 33., 34., 35., 36.,
|
||||||
],
|
],
|
||||||
device,
|
device,
|
||||||
)?;
|
)?;
|
||||||
|
@ -88,7 +88,7 @@ fn strided_blocks() -> Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||||
let tensor = tensor.i((.., 1))?.contiguous()?;
|
let tensor = tensor.i((.., 1))?;
|
||||||
match tensor.strided_blocks() {
|
match tensor.strided_blocks() {
|
||||||
candle::StridedBlocks::SingleBlock { start_offset, len } => {
|
candle::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
assert_eq!(start_offset, 0);
|
assert_eq!(start_offset, 0);
|
||||||
@ -100,20 +100,6 @@ fn strided_blocks() -> Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||||
let tensor = tensor.i((.., 1))?;
|
|
||||||
match tensor.strided_blocks() {
|
|
||||||
candle::StridedBlocks::SingleBlock { .. } => {
|
|
||||||
panic!("unexpected block structure")
|
|
||||||
}
|
|
||||||
candle::StridedBlocks::MultipleBlocks {
|
|
||||||
block_len,
|
|
||||||
block_start_index,
|
|
||||||
} => {
|
|
||||||
assert_eq!(block_len, 4);
|
|
||||||
assert_eq!(block_start_index.collect::<Vec<_>>(), &[4, 16])
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
|
||||||
match tensor.t()?.strided_blocks() {
|
match tensor.t()?.strided_blocks() {
|
||||||
candle::StridedBlocks::SingleBlock { .. } => {
|
candle::StridedBlocks::SingleBlock { .. } => {
|
||||||
panic!("unexpected block structure")
|
panic!("unexpected block structure")
|
||||||
|
@ -1,106 +0,0 @@
|
|||||||
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
|
|
||||||
|
|
||||||
fn matmul(device: &Device) -> Result<()> {
|
|
||||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
|
||||||
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
|
||||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
|
||||||
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
|
||||||
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
|
||||||
|
|
||||||
let data = vec![1.0f32, 2.0];
|
|
||||||
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
|
||||||
let data = vec![3.0f32, 4.0];
|
|
||||||
let b = Tensor::from_slice(&data, (1, 2), device)?;
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
|
||||||
|
|
||||||
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
|
||||||
let a = Tensor::from_slice(&data, (2, 3), device)?;
|
|
||||||
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
|
||||||
let b = Tensor::from_slice(&data, (3, 2), device)?;
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
|
||||||
|
|
||||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
|
||||||
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
|
|
||||||
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
|
||||||
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
|
|
||||||
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
|
|
||||||
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec3::<f32>()?, &expected);
|
|
||||||
|
|
||||||
// Also perform the matmul on contiguous transposed versions.
|
|
||||||
let a_tt = a.t()?.contiguous()?.t()?;
|
|
||||||
assert!(!a_tt.is_contiguous());
|
|
||||||
assert_eq!(a.dims(), a_tt.dims());
|
|
||||||
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
|
||||||
|
|
||||||
let b_tt = b.t()?.contiguous()?.t()?;
|
|
||||||
assert!(!b_tt.is_contiguous());
|
|
||||||
assert_eq!(b.dims(), b_tt.dims());
|
|
||||||
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
|
||||||
|
|
||||||
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
|
||||||
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
|
||||||
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn broadcast_matmul(device: &Device) -> Result<()> {
|
|
||||||
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
|
||||||
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
|
||||||
let out = lhs.broadcast_matmul(&rhs)?;
|
|
||||||
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
|
||||||
for idx1 in 0..3 {
|
|
||||||
for idx2 in 0..6 {
|
|
||||||
let out = out.i((idx1, idx2))?;
|
|
||||||
let lhs = lhs.i((idx1, 0))?;
|
|
||||||
let rhs = rhs.i(idx2)?;
|
|
||||||
let out2 = lhs.matmul(&rhs);
|
|
||||||
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
|
||||||
// With cuda, we see errors of up to ~1e-12.
|
|
||||||
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/candle/issues/1948
|
|
||||||
fn squeeze_mm(device: &Device) -> Result<()> {
|
|
||||||
let seq_len = 8_usize;
|
|
||||||
let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
|
|
||||||
let x = a.i((.., seq_len - 1, ..))?;
|
|
||||||
let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
|
|
||||||
let x = x.matmul(&w)?;
|
|
||||||
assert_eq!(x.dims(), &[1, 32]);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/candle/issues/1992
|
|
||||||
fn mm_layout(device: &Device) -> Result<()> {
|
|
||||||
let a = Tensor::arange(0f32, 16f32, device)?.reshape((1, 1, 4, 4))?;
|
|
||||||
let b = Tensor::arange(0f32, 8f32, device)?.reshape((1, 1, 4, 2))?;
|
|
||||||
let mm1 = a.matmul(&b)?;
|
|
||||||
// Forces the layout to be:
|
|
||||||
// shape: [1, 1, 4, 2], stride: [8, 2, 2, 1], start_offset: 0
|
|
||||||
// This is still a contiguous matrix but matmul checks are only the two last dimensions have
|
|
||||||
// non 1 sizes but matmul check may be reluctant to handle it.
|
|
||||||
let b = b.transpose(1, 2)?.force_contiguous()?.transpose(1, 2)?;
|
|
||||||
let mm2 = a.matmul(&b)?;
|
|
||||||
let diff = (mm1 - mm2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
|
||||||
assert_eq!(diff, 0.);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
|
||||||
test_device!(
|
|
||||||
broadcast_matmul,
|
|
||||||
broadcast_matmul_cpu,
|
|
||||||
broadcast_matmul_gpu,
|
|
||||||
broadcast_matmul_metal
|
|
||||||
);
|
|
||||||
test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal);
|
|
||||||
test_device!(mm_layout, mm_layout_cpu, mm_layout_gpu, mm_layout_metal);
|
|
@ -43,9 +43,6 @@ res = torch.nn.functional.avg_pool2d(t, 2)
|
|||||||
print(res)
|
print(res)
|
||||||
*/
|
*/
|
||||||
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
||||||
if dev.is_metal() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
let t = Tensor::new(
|
let t = Tensor::new(
|
||||||
&[
|
&[
|
||||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||||
|
@ -1,37 +0,0 @@
|
|||||||
import torch
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
# Write a trivial tensor to a pt file
|
|
||||||
a= torch.tensor([[1,2,3,4], [5,6,7,8]])
|
|
||||||
o = OrderedDict()
|
|
||||||
o["test"] = a
|
|
||||||
|
|
||||||
# Write a trivial tensor to a pt file
|
|
||||||
torch.save(o, "test.pt")
|
|
||||||
|
|
||||||
############################################################################################################
|
|
||||||
# Write a trivial tensor to a pt file with a key
|
|
||||||
torch.save({"model_state_dict": o}, "test_with_key.pt")
|
|
||||||
|
|
||||||
############################################################################################################
|
|
||||||
# Create a tensor with fortran contiguous memory layout
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers
|
|
||||||
# For example, creating a 2x3x4 array
|
|
||||||
array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4))
|
|
||||||
|
|
||||||
# Verify the memory order
|
|
||||||
print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True
|
|
||||||
print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False
|
|
||||||
|
|
||||||
# Step 2: Convert the NumPy array to a PyTorch tensor
|
|
||||||
tensor_fortran = torch.from_numpy(array_fortran)
|
|
||||||
|
|
||||||
# Verify the tensor layout
|
|
||||||
print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout
|
|
||||||
|
|
||||||
# Step 3: Save the PyTorch tensor to a .pth file
|
|
||||||
torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth')
|
|
||||||
|
|
||||||
print("3D Tensor saved with Fortran layout.")
|
|
@ -1,31 +0,0 @@
|
|||||||
/// Regression test for pth files not loading on Windows.
|
|
||||||
#[test]
|
|
||||||
fn test_pth() {
|
|
||||||
let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
|
|
||||||
tensors.get("test").unwrap().unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_pth_with_key() {
|
|
||||||
let tensors =
|
|
||||||
candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
|
|
||||||
.unwrap();
|
|
||||||
tensors.get("test").unwrap().unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_pth_fortran_congiguous() {
|
|
||||||
let tensors =
|
|
||||||
candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap();
|
|
||||||
let tensor = tensors.get("tensor_fortran").unwrap().unwrap();
|
|
||||||
|
|
||||||
assert_eq!(tensor.dims3().unwrap(), (2, 3, 4));
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
tensor.to_vec3::<i64>().unwrap(),
|
|
||||||
[
|
|
||||||
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
|
||||||
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
}
|
|
@ -3,7 +3,7 @@ use candle_core::{
|
|||||||
quantized::{self, GgmlDType},
|
quantized::{self, GgmlDType},
|
||||||
test_device,
|
test_device,
|
||||||
test_utils::to_vec2_round,
|
test_utils::to_vec2_round,
|
||||||
Device, IndexOp, Module, Result, Tensor,
|
Device, Module, Result, Tensor,
|
||||||
};
|
};
|
||||||
use quantized::{k_quants, GgmlType};
|
use quantized::{k_quants, GgmlType};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
@ -47,14 +47,18 @@ fn test_matmul(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantized_matmul(device: &Device) -> Result<()> {
|
fn quantized_matmul(device: &Device) -> Result<()> {
|
||||||
|
// TODO Enable this later when we enable cuda.
|
||||||
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let (m, k, n) = (3, 64, 4);
|
let (m, k, n) = (3, 64, 4);
|
||||||
let lhs_s = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?;
|
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||||
let mut dst = vec![42.; 3 * 4];
|
let mut dst = vec![42.; 3 * 4];
|
||||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||||
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||||
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?;
|
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
||||||
&[
|
&[
|
||||||
@ -63,7 +67,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||||
let mm = lhs.matmul(&tensor_rhs)?;
|
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
mm.to_vec2::<f32>()?,
|
mm.to_vec2::<f32>()?,
|
||||||
&[
|
&[
|
||||||
@ -75,7 +79,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
|||||||
|
|
||||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
let res = matmul.forward(&lhs)?;
|
let res = matmul.forward(&tensor_lhs)?;
|
||||||
match device {
|
match device {
|
||||||
Device::Metal(_) => assert_eq!(
|
Device::Metal(_) => assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
@ -85,15 +89,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
|||||||
[341970.0, 994574.0, 1656181.0, 2302182.0]
|
[341970.0, 994574.0, 1656181.0, 2302182.0]
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
Device::Cuda(_) => assert_eq!(
|
_ => assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
|
||||||
&[
|
|
||||||
[84866.0, 214045.0, 344676.0, 473707.0],
|
|
||||||
[213425.0, 604313.0, 1000431.0, 1387960.0],
|
|
||||||
[342030.0, 994630.0, 1656248.0, 2302250.0]
|
|
||||||
]
|
|
||||||
),
|
|
||||||
Device::Cpu => assert_eq!(
|
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
&[
|
&[
|
||||||
[85120.0, 214562.0, 345455.0, 474748.0],
|
[85120.0, 214562.0, 345455.0, 474748.0],
|
||||||
@ -102,16 +98,22 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
|
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||||
|
// TODO Enable this later when we enable cuda.
|
||||||
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let (m, k, n) = (3, 64, 4);
|
let (m, k, n) = (3, 64, 4);
|
||||||
let lhs_s = (0..(m * k))
|
let lhs = (0..(m * k))
|
||||||
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?;
|
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||||
let mut dst = vec![42.; 3 * 4];
|
let mut dst = vec![42.; 3 * 4];
|
||||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||||
let rhs = (0..k * n)
|
let rhs = (0..k * n)
|
||||||
@ -119,7 +121,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||||
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?;
|
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
||||||
&[
|
&[
|
||||||
@ -127,7 +129,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
-196472.0, 63012.0, 324585.0, 587902.0
|
-196472.0, 63012.0, 324585.0, 587902.0
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
let mm = lhs.matmul(&tensor_rhs)?;
|
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
to_vec2_round(&mm, 0)?,
|
to_vec2_round(&mm, 0)?,
|
||||||
&[
|
&[
|
||||||
@ -139,7 +141,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
|
|
||||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
let res = matmul.forward(&lhs)?;
|
let res = matmul.forward(&tensor_lhs)?;
|
||||||
match device {
|
match device {
|
||||||
Device::Metal(_) => assert_eq!(
|
Device::Metal(_) => assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
@ -149,15 +151,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
[-196102.0, 63022.0, 324233.0, 587191.0]
|
[-196102.0, 63022.0, 324233.0, 587191.0]
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
Device::Cuda(_) => assert_eq!(
|
_ => assert_eq!(
|
||||||
to_vec2_round(&res, 0)?,
|
|
||||||
&[
|
|
||||||
[243740.0, -19762.0, -285476.0, -550498.0],
|
|
||||||
[23774.0, 21645.0, 19395.0, 18364.0],
|
|
||||||
[-196045.0, 63030.0, 324120.0, 587079.0]
|
|
||||||
]
|
|
||||||
),
|
|
||||||
Device::Cpu => assert_eq!(
|
|
||||||
to_vec2_round(&res, 0)?,
|
to_vec2_round(&res, 0)?,
|
||||||
&[
|
&[
|
||||||
[243524.0, -19596.0, -285051.0, -549815.0],
|
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||||
@ -166,60 +160,28 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
let lhs2 = Tensor::stack(&[&lhs, &lhs], 0)?;
|
|
||||||
let res2 = matmul.forward(&lhs2)?;
|
|
||||||
let res2 = res2.i(1)?;
|
|
||||||
let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
|
||||||
if device.is_cuda() {
|
|
||||||
assert!(diff < 0.1);
|
|
||||||
} else {
|
|
||||||
assert_eq!(diff, 0.);
|
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn qmm_batch(dev: &Device) -> Result<()> {
|
test_device!(
|
||||||
let (lhs, rhs, _mm) = get_random_tensors(2, 256, 6, dev)?;
|
quantized_matmul,
|
||||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
|
quantized_matmul_cpu,
|
||||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
quantized_matmul_cuda,
|
||||||
let mm = rhs.forward(&lhs)?;
|
quantized_matmul_metal
|
||||||
assert_eq!(mm.shape().dims(), [2, 6]);
|
);
|
||||||
let lhs2 = Tensor::cat(&[&lhs, &lhs], 0)?;
|
test_device!(
|
||||||
let mm2 = rhs.forward(&lhs2)?;
|
quantized_matmul_neg,
|
||||||
assert_eq!(mm2.shape().dims(), [4, 6]);
|
quantized_matmul_neg_cpu,
|
||||||
let diff2 = (mm2.i(2..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
quantized_matmul_neg_cuda,
|
||||||
assert_eq!(diff2, 0.0);
|
quantized_matmul_neg_metal
|
||||||
let lhs3 = Tensor::cat(&[&lhs2, &lhs], 0)?;
|
);
|
||||||
let mm3 = rhs.forward(&lhs3)?;
|
|
||||||
assert_eq!(mm3.shape().dims(), [6, 6]);
|
|
||||||
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
|
||||||
assert_eq!(diff3, 0.0);
|
|
||||||
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
|
||||||
assert_eq!(diff3, 0.0);
|
|
||||||
let lhs4 = Tensor::cat(&[&lhs3, &lhs3], 0)?;
|
|
||||||
let mm4 = rhs.forward(&lhs4)?;
|
|
||||||
assert_eq!(mm4.shape().dims(), [12, 6]);
|
|
||||||
let diff4 = (mm4.i(..6)? - &mm3)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
|
||||||
if dev.is_cuda() {
|
|
||||||
// We use a different kernel for sizes from 1 to 8 on cuda which explains
|
|
||||||
// the difference here.
|
|
||||||
assert!(0. < diff4 && diff4 < 1e-4)
|
|
||||||
} else {
|
|
||||||
assert_eq!(diff4, 0.0)
|
|
||||||
};
|
|
||||||
let diff4 = (mm4.i(6..)? - &mm4.i(..6)?)?
|
|
||||||
.abs()?
|
|
||||||
.sum_all()?
|
|
||||||
.to_vec0::<f32>()?;
|
|
||||||
assert_eq!(diff4, 0.0);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal);
|
|
||||||
test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal);
|
|
||||||
test_device!(qmm_batch, qmm_b_cpu, qmm_b_cuda, qmm_b_metal);
|
|
||||||
|
|
||||||
fn quantize_q4_0(device: &Device) -> Result<()> {
|
fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||||
|
// TODO Enable this later when we enable cuda.
|
||||||
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
|
|
||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
@ -247,6 +209,10 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q4_1(device: &Device) -> Result<()> {
|
fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||||
|
// TODO Enable this later when we enable cuda.
|
||||||
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
||||||
@ -273,6 +239,10 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q5_0(device: &Device) -> Result<()> {
|
fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||||
|
// TODO Enable this later when we enable cuda.
|
||||||
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
||||||
@ -299,6 +269,10 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q5_1(device: &Device) -> Result<()> {
|
fn quantize_q5_1(device: &Device) -> Result<()> {
|
||||||
|
// TODO Enable this later when we enable cuda.
|
||||||
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
||||||
@ -399,6 +373,10 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q2k(device: &Device) -> Result<()> {
|
fn quantize_q2k(device: &Device) -> Result<()> {
|
||||||
|
// TODO Enable this later when we enable cuda.
|
||||||
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let dtype = GgmlDType::Q2K;
|
let dtype = GgmlDType::Q2K;
|
||||||
|
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
@ -433,6 +411,10 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q3k(device: &Device) -> Result<()> {
|
fn quantize_q3k(device: &Device) -> Result<()> {
|
||||||
|
// TODO Enable this later when we enable cuda.
|
||||||
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let dtype = GgmlDType::Q3K;
|
let dtype = GgmlDType::Q3K;
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
@ -466,6 +448,10 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q4k(device: &Device) -> Result<()> {
|
fn quantize_q4k(device: &Device) -> Result<()> {
|
||||||
|
// TODO Enable this later when we enable cuda.
|
||||||
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let dtype = GgmlDType::Q4K;
|
let dtype = GgmlDType::Q4K;
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
@ -499,6 +485,10 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q5k(device: &Device) -> Result<()> {
|
fn quantize_q5k(device: &Device) -> Result<()> {
|
||||||
|
// TODO Enable this later when we enable cuda.
|
||||||
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let dtype = GgmlDType::Q5K;
|
let dtype = GgmlDType::Q5K;
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
@ -532,6 +522,10 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q6k(device: &Device) -> Result<()> {
|
fn quantize_q6k(device: &Device) -> Result<()> {
|
||||||
|
// TODO Enable this later when we enable cuda.
|
||||||
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let dtype = GgmlDType::Q6K;
|
let dtype = GgmlDType::Q6K;
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
@ -565,6 +559,10 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn quantize_q8k(device: &Device) -> Result<()> {
|
fn quantize_q8k(device: &Device) -> Result<()> {
|
||||||
|
// TODO Enable this later when we enable cuda.
|
||||||
|
if device.is_cuda() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
let dtype = GgmlDType::Q8K;
|
let dtype = GgmlDType::Q8K;
|
||||||
let src = get_test_vector2(0.5, 1024, device)?;
|
let src = get_test_vector2(0.5, 1024, device)?;
|
||||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||||
@ -780,6 +778,10 @@ macro_rules! quantized_matmul {
|
|||||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||||
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
|
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
|
||||||
fn $fn_name(device: &Device) -> Result<()> {
|
fn $fn_name(device: &Device) -> Result<()> {
|
||||||
|
if device.is_cuda() {
|
||||||
|
// TODO Enable Cuda GGML sometime maybe.
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
test_matmul(device, (1, 3, 4, 256), $dtype)?;
|
test_matmul(device, (1, 3, 4, 256), $dtype)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -106,9 +106,6 @@ fn unary_op(device: &Device) -> Result<()> {
|
|||||||
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
|
|
||||||
let max_diff = (tensor.gelu()? - t_f16)?.flatten_all()?.max(0)?;
|
|
||||||
assert!(max_diff.to_vec0::<f32>()? < 5e-3);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
||||||
[
|
[
|
||||||
@ -123,13 +120,6 @@ fn unary_op(device: &Device) -> Result<()> {
|
|||||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec2_round(&tensor.silu()?, 4)?,
|
|
||||||
[
|
|
||||||
[-0.1423, 0.7311, 3.9281, -0.0475, 0.3112],
|
|
||||||
[2.53, -0.2553, -0.1205, 1.5447, 2.6395]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||||
@ -151,14 +141,6 @@ fn unary_op(device: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
||||||
[3000.0, 300.]
|
[3000.0, 300.]
|
||||||
);
|
);
|
||||||
let tensor = Tensor::new(
|
|
||||||
&[-1.01f32, -0.9, -0.1, 0.0, -0.0, 0.1, 0.9, 1.0, 1.1],
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
assert_eq!(
|
|
||||||
tensor.sign()?.to_vec1::<f32>()?,
|
|
||||||
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
|
|
||||||
);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -683,31 +665,6 @@ fn cat(device: &Device) -> Result<()> {
|
|||||||
[2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
|
[2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
// 3D
|
|
||||||
let t1 = Tensor::arange(0, 48i64, device)?.reshape((2, 6, 4))?;
|
|
||||||
let t2 = Tensor::arange(100, 124i64, device)?.reshape((2, 3, 4))?;
|
|
||||||
let t3 = Tensor::arange(10000, 10032i64, device)?.reshape((2, 4, 4))?;
|
|
||||||
|
|
||||||
let t_cat = Tensor::cat(&[&t1, &t2, &t3], 1)?;
|
|
||||||
|
|
||||||
let t1 = t1.t()?.contiguous()?.t()?;
|
|
||||||
let t2 = t2.t()?.contiguous()?.t()?;
|
|
||||||
let t3 = t3.t()?.contiguous()?.t()?;
|
|
||||||
let t_cat2 = Tensor::cat(&[&t1, &t2, &t3], 1)?;
|
|
||||||
|
|
||||||
let diff = t_cat.eq(&t_cat2)?.to_dtype(DType::F32)?.sum_all()?;
|
|
||||||
assert_eq!(diff.to_vec0::<f32>()?, 104.0);
|
|
||||||
assert_eq!(t_cat.i((0, 0, 0))?.to_vec0::<i64>()?, 0);
|
|
||||||
assert_eq!(t_cat.i((0, 4, 0))?.to_vec0::<i64>()?, 16);
|
|
||||||
assert_eq!(t_cat.i((0, 5, 0))?.to_vec0::<i64>()?, 20);
|
|
||||||
assert_eq!(t_cat.i((1, 5, 0))?.to_vec0::<i64>()?, 44);
|
|
||||||
assert_eq!(t_cat.i((0, 6, 0))?.to_vec0::<i64>()?, 100);
|
|
||||||
assert_eq!(t_cat.i((1, 6, 0))?.to_vec0::<i64>()?, 112);
|
|
||||||
assert_eq!(t_cat.i((0, 6, 1))?.to_vec0::<i64>()?, 101);
|
|
||||||
assert_eq!(t_cat.i((0, 7, 1))?.to_vec0::<i64>()?, 105);
|
|
||||||
assert_eq!(t_cat.i((0, 12, 1))?.to_vec0::<i64>()?, 10013);
|
|
||||||
assert_eq!(t_cat.i((1, 12, 3))?.to_vec0::<i64>()?, 10031);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -718,8 +675,6 @@ fn embeddings(device: &Device) -> Result<()> {
|
|||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
let hs = t.index_select(&ids, 0)?;
|
let hs = t.index_select(&ids, 0)?;
|
||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
|
|
||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -747,47 +702,44 @@ fn index_select(device: &Device) -> Result<()> {
|
|||||||
[9.0, 10.0, 11.0]
|
[9.0, 10.0, 11.0]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
for dtype in [DType::U8, DType::U32, DType::I64] {
|
let hs = t.index_select(&ids, 1)?;
|
||||||
let ids = ids.to_dtype(dtype)?;
|
assert_eq!(
|
||||||
let hs = t.index_select(&ids, 1)?;
|
hs.to_vec2::<f32>()?,
|
||||||
assert_eq!(
|
&[
|
||||||
hs.to_vec2::<f32>()?,
|
[0.0, 2.0, 1.0],
|
||||||
&[
|
[3.0, 5.0, 4.0],
|
||||||
[0.0, 2.0, 1.0],
|
[6.0, 8.0, 7.0],
|
||||||
[3.0, 5.0, 4.0],
|
[9.0, 11.0, 10.0]
|
||||||
[6.0, 8.0, 7.0],
|
]
|
||||||
[9.0, 11.0, 10.0]
|
);
|
||||||
]
|
let hs = t.index_select(&ids, 0)?;
|
||||||
);
|
assert_eq!(
|
||||||
let hs = t.index_select(&ids, 0)?;
|
hs.to_vec2::<f32>()?,
|
||||||
assert_eq!(
|
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||||
hs.to_vec2::<f32>()?,
|
);
|
||||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||||
);
|
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||||
// Prior to https://github.com/huggingface/candle/pull/1022
|
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||||
// There would be a bug where the last values in the result tensor would be set to 0.
|
let hs = t.index_select(&ids, 0)?;
|
||||||
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
assert_eq!(
|
||||||
let hs = t.index_select(&ids, 0)?;
|
hs.to_vec2::<f32>()?,
|
||||||
assert_eq!(
|
&[
|
||||||
hs.to_vec2::<f32>()?,
|
[0.0, 1.0, 2.0],
|
||||||
&[
|
[6.0, 7.0, 8.0],
|
||||||
[0.0, 1.0, 2.0],
|
[3.0, 4.0, 5.0],
|
||||||
[6.0, 7.0, 8.0],
|
[0.0, 1.0, 2.0],
|
||||||
[3.0, 4.0, 5.0],
|
[6.0, 7.0, 8.0],
|
||||||
[0.0, 1.0, 2.0],
|
[3.0, 4.0, 5.0],
|
||||||
[6.0, 7.0, 8.0],
|
]
|
||||||
[3.0, 4.0, 5.0],
|
);
|
||||||
]
|
|
||||||
);
|
|
||||||
|
|
||||||
// Test when selecting dim > 0 with ids size different from elem count of
|
// Test when selecting dim > 0 with ids size different from elem count of
|
||||||
// target dim in source/input.
|
// target dim in source/input.
|
||||||
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||||
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||||
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||||
let hs = t.index_select(&ids, 1)?;
|
let hs = t.index_select(&ids, 1)?;
|
||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -949,6 +901,74 @@ fn gather(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn matmul(device: &Device) -> Result<()> {
|
||||||
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
|
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
||||||
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
|
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
||||||
|
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||||
|
|
||||||
|
let data = vec![1.0f32, 2.0];
|
||||||
|
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
||||||
|
let data = vec![3.0f32, 4.0];
|
||||||
|
let b = Tensor::from_slice(&data, (1, 2), device)?;
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
||||||
|
|
||||||
|
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
||||||
|
let a = Tensor::from_slice(&data, (2, 3), device)?;
|
||||||
|
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
||||||
|
let b = Tensor::from_slice(&data, (3, 2), device)?;
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
||||||
|
|
||||||
|
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||||
|
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
|
||||||
|
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
||||||
|
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
|
||||||
|
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
|
||||||
|
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec3::<f32>()?, &expected);
|
||||||
|
|
||||||
|
// Also perform the matmul on contiguous transposed versions.
|
||||||
|
let a_tt = a.t()?.contiguous()?.t()?;
|
||||||
|
assert!(!a_tt.is_contiguous());
|
||||||
|
assert_eq!(a.dims(), a_tt.dims());
|
||||||
|
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
||||||
|
|
||||||
|
let b_tt = b.t()?.contiguous()?.t()?;
|
||||||
|
assert!(!b_tt.is_contiguous());
|
||||||
|
assert_eq!(b.dims(), b_tt.dims());
|
||||||
|
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
||||||
|
|
||||||
|
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||||
|
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||||
|
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||||
|
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
||||||
|
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
||||||
|
let out = lhs.broadcast_matmul(&rhs)?;
|
||||||
|
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
||||||
|
for idx1 in 0..3 {
|
||||||
|
for idx2 in 0..6 {
|
||||||
|
let out = out.i((idx1, idx2))?;
|
||||||
|
let lhs = lhs.i((idx1, 0))?;
|
||||||
|
let rhs = rhs.i(idx2)?;
|
||||||
|
let out2 = lhs.matmul(&rhs);
|
||||||
|
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
||||||
|
// With cuda, we see errors of up to ~1e-12.
|
||||||
|
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn broadcasting(device: &Device) -> Result<()> {
|
fn broadcasting(device: &Device) -> Result<()> {
|
||||||
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
|
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
|
||||||
let t2 = Tensor::new(&[100f32, 200f32], device)?;
|
let t2 = Tensor::new(&[100f32, 200f32], device)?;
|
||||||
@ -1053,54 +1073,8 @@ fn broadcasting(device: &Device) -> Result<()> {
|
|||||||
fn randn(device: &Device) -> Result<()> {
|
fn randn(device: &Device) -> Result<()> {
|
||||||
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
||||||
assert_eq!(tensor.dims(), [5, 3]);
|
assert_eq!(tensor.dims(), [5, 3]);
|
||||||
// Check that the seed gets updated by checking that
|
|
||||||
// a new series of numbers is generated each time
|
|
||||||
let tensor2 = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
|
||||||
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
|
|
||||||
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
||||||
assert_eq!(tensor.dims(), [5, 3]);
|
assert_eq!(tensor.dims(), [5, 3]);
|
||||||
// Check that the seed gets updated by checking that
|
|
||||||
// a new series of numbers is generated each time
|
|
||||||
let tensor2 = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
|
||||||
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
|
|
||||||
// We do not expect deterministic elements at any index.
|
|
||||||
// There once was a bug that had a deterministic zero element in evenly sized tensors.
|
|
||||||
const N: usize = 2;
|
|
||||||
let v = (0..100)
|
|
||||||
.map(|_| Tensor::randn(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
assert!(
|
|
||||||
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
|
|
||||||
"There are deterministic values in the randn tensors"
|
|
||||||
);
|
|
||||||
let v = (0..100)
|
|
||||||
.map(|_| Tensor::rand(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
assert!(
|
|
||||||
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
|
|
||||||
"There are deterministic values in the rand tensors"
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn zero_dim(device: &Device) -> Result<()> {
|
|
||||||
let t = Tensor::zeros((4, 0, 1), DType::F32, device)?;
|
|
||||||
assert_eq!(t.dims3()?, (4, 0, 1));
|
|
||||||
let t2 = Tensor::zeros((4, 3, 1), DType::F32, device)?;
|
|
||||||
let t_cat = Tensor::cat(&[&t, &t2], 1)?;
|
|
||||||
assert_eq!(t_cat.dims3()?, (4, 3, 1));
|
|
||||||
let t_cat = Tensor::cat(&[&t, &t], 1)?;
|
|
||||||
assert_eq!(t_cat.dims3()?, (4, 0, 1));
|
|
||||||
let t_unary = t.sqrt()?;
|
|
||||||
assert_eq!(t_unary.dims3()?, (4, 0, 1));
|
|
||||||
let t_plus = (&t + 1.)?;
|
|
||||||
assert_eq!(t_plus.dims3()?, (4, 0, 1));
|
|
||||||
let t_mm = t2.matmul(&t.t()?)?;
|
|
||||||
assert_eq!(t_mm.dims3()?, (4, 3, 0));
|
|
||||||
let t_mm = t.matmul(&t2.t()?)?;
|
|
||||||
assert_eq!(t_mm.dims3()?, (4, 0, 3));
|
|
||||||
let t_mm = t.t()?.matmul(&t)?;
|
|
||||||
assert_eq!(t_mm.dims3()?, (4, 1, 1));
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1123,6 +1097,13 @@ test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
|||||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
||||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
||||||
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
||||||
|
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||||
|
test_device!(
|
||||||
|
broadcast_matmul,
|
||||||
|
broadcast_matmul_cpu,
|
||||||
|
broadcast_matmul_gpu,
|
||||||
|
broadcast_matmul_metal
|
||||||
|
);
|
||||||
test_device!(
|
test_device!(
|
||||||
broadcasting,
|
broadcasting,
|
||||||
broadcasting_cpu,
|
broadcasting_cpu,
|
||||||
@ -1152,7 +1133,6 @@ test_device!(
|
|||||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||||
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
|
|
||||||
|
|
||||||
// There was originally a bug on the CPU implementation for randn
|
// There was originally a bug on the CPU implementation for randn
|
||||||
// https://github.com/huggingface/candle/issues/381
|
// https://github.com/huggingface/candle/issues/381
|
||||||
@ -1280,8 +1260,8 @@ fn pow() -> Result<()> {
|
|||||||
let rhs = (&lhs - 2.)?;
|
let rhs = (&lhs - 2.)?;
|
||||||
let res = lhs.pow(&rhs)?;
|
let res = lhs.pow(&rhs)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec2_round(&res, 3)?,
|
test_utils::to_vec2_round(&res, 4)?,
|
||||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0]]
|
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Binary file not shown.
Binary file not shown.
@ -12,7 +12,7 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { workspace = true }
|
candle = { workspace = true }
|
||||||
candle-datasets = { workspace = true, optional = true }
|
candle-datasets = { workspace = true }
|
||||||
candle-nn = { workspace = true }
|
candle-nn = { workspace = true }
|
||||||
candle-transformers = { workspace = true }
|
candle-transformers = { workspace = true }
|
||||||
candle-flash-attn = { workspace = true, optional = true }
|
candle-flash-attn = { workspace = true, optional = true }
|
||||||
@ -21,19 +21,16 @@ candle-onnx = { workspace = true, optional = true }
|
|||||||
csv = "1.3.0"
|
csv = "1.3.0"
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
half = { workspace = true, optional = true }
|
half = { workspace = true, optional = true }
|
||||||
hf-hub = { workspace = true, features = ["tokio"] }
|
hf-hub = { workspace = true, features=["tokio"]}
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
|
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
|
||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
rubato = { version = "0.15.0", optional = true }
|
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
|
||||||
tokenizers = { workspace = true, features = ["onig"] }
|
tokenizers = { workspace = true, features = ["onig"] }
|
||||||
cpal= { version = "0.15.2", optional = true }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@ -42,10 +39,11 @@ clap = { workspace = true }
|
|||||||
imageproc = { workspace = true }
|
imageproc = { workspace = true }
|
||||||
memmap2 = { workspace = true }
|
memmap2 = { workspace = true }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
ab_glyph = { workspace = true }
|
rusttype = { workspace = true }
|
||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
tracing-chrome = { workspace = true }
|
tracing-chrome = { workspace = true }
|
||||||
tracing-subscriber = { workspace = true }
|
tracing-subscriber = { workspace = true }
|
||||||
|
wav = { workspace = true }
|
||||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||||
tokio = "1.29.1"
|
tokio = "1.29.1"
|
||||||
|
|
||||||
@ -63,8 +61,6 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
|
|||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
onnx = ["candle-onnx"]
|
onnx = ["candle-onnx"]
|
||||||
metal = ["candle/metal", "candle-nn/metal"]
|
metal = ["candle/metal", "candle-nn/metal"]
|
||||||
microphone = ["cpal"]
|
|
||||||
encodec = ["cpal", "symphonia", "rubato"]
|
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
@ -81,23 +77,3 @@ required-features = ["onnx"]
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "onnx_basics"
|
name = "onnx_basics"
|
||||||
required-features = ["onnx"]
|
required-features = ["onnx"]
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "whisper"
|
|
||||||
required-features = ["symphonia"]
|
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "whisper-microphone"
|
|
||||||
required-features = ["microphone"]
|
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "mnist-training"
|
|
||||||
required-features = ["candle-datasets"]
|
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "llama2-c"
|
|
||||||
required-features = ["candle-datasets"]
|
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "encodec"
|
|
||||||
required-features = ["encodec"]
|
|
||||||
|
@ -1,237 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use clap::Parser;
|
|
||||||
|
|
||||||
use candle_transformers::models::chatglm::{Config, Model};
|
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
struct TextGeneration {
|
|
||||||
model: Model,
|
|
||||||
device: Device,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
logits_processor: LogitsProcessor,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
verbose_prompt: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextGeneration {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn new(
|
|
||||||
model: Model,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
temp: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
verbose_prompt: bool,
|
|
||||||
device: &Device,
|
|
||||||
) -> Self {
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
logits_processor,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n,
|
|
||||||
verbose_prompt,
|
|
||||||
device: device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
println!("starting the inference loop");
|
|
||||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
|
||||||
if tokens.is_empty() {
|
|
||||||
anyhow::bail!("Empty prompts are not supported in the chatglm model.")
|
|
||||||
}
|
|
||||||
if self.verbose_prompt {
|
|
||||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
|
||||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
|
||||||
println!("{id:7} -> '{token}'");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let mut tokens = tokens.get_ids().to_vec();
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_vocab(true).get("</s>") {
|
|
||||||
Some(token) => *token,
|
|
||||||
None => anyhow::bail!("cannot find the endoftext token"),
|
|
||||||
};
|
|
||||||
print!("{prompt}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for index in 0..sample_len {
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
||||||
let logits = self.model.forward(&input)?;
|
|
||||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
|
||||||
&logits,
|
|
||||||
self.repeat_penalty,
|
|
||||||
&tokens[start_at..],
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
|
||||||
print!("{token}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
let dt = start_gen.elapsed();
|
|
||||||
println!(
|
|
||||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
||||||
generated_tokens as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
/// Display the token for the specified prompt.
|
|
||||||
#[arg(long)]
|
|
||||||
verbose_prompt: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: String,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
temperature: Option<f64>,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
|
||||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
|
||||||
sample_len: usize,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
revision: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
weight_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer: Option<String>,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
||||||
#[arg(long, default_value_t = 1.1)]
|
|
||||||
repeat_penalty: f32,
|
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
|
||||||
#[arg(long, default_value_t = 64)]
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle::utils::with_avx(),
|
|
||||||
candle::utils::with_neon(),
|
|
||||||
candle::utils::with_simd128(),
|
|
||||||
candle::utils::with_f16c()
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
||||||
args.temperature.unwrap_or(0.),
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n
|
|
||||||
);
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let api = Api::new()?;
|
|
||||||
let model_id = match args.model_id {
|
|
||||||
Some(model_id) => model_id.to_string(),
|
|
||||||
None => "THUDM/chatglm3-6b".to_string(),
|
|
||||||
};
|
|
||||||
let revision = match args.revision {
|
|
||||||
Some(rev) => rev.to_string(),
|
|
||||||
None => "main".to_string(),
|
|
||||||
};
|
|
||||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
|
||||||
let tokenizer_filename = match args.tokenizer {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => api
|
|
||||||
.model("lmz/candle-chatglm".to_string())
|
|
||||||
.get("chatglm-tokenizer.json")?,
|
|
||||||
};
|
|
||||||
let filenames = match args.weight_file {
|
|
||||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
|
||||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
|
||||||
};
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let config = Config::glm3_6b();
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
|
||||||
let model = Model::new(&config, vb)?;
|
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
args.seed,
|
|
||||||
args.temperature,
|
|
||||||
args.top_p,
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n,
|
|
||||||
args.verbose_prompt,
|
|
||||||
&device,
|
|
||||||
);
|
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,46 +0,0 @@
|
|||||||
Contrastive Language-Image Pre-Training
|
|
||||||
|
|
||||||
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
|
||||||
pairs of images with related texts.
|
|
||||||
|
|
||||||
https://github.com/openai/CLIP
|
|
||||||
|
|
||||||
https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
|
||||||
|
|
||||||
## Running on an example on cpu
|
|
||||||
|
|
||||||
```
|
|
||||||
$ cargo run --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
|
||||||
|
|
||||||
|
|
||||||
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
|
||||||
|
|
||||||
INFO clip: Probability: 0.0000% Text: a cycling race
|
|
||||||
INFO clip: Probability: 0.0000% Text: a photo of two cats
|
|
||||||
INFO clip: Probability: 100.0000% Text: a robot holding a candle
|
|
||||||
|
|
||||||
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
|
||||||
|
|
||||||
INFO clip: Probability: 99.9999% Text: a cycling race
|
|
||||||
INFO clip: Probability: 0.0001% Text: a photo of two cats
|
|
||||||
INFO clip: Probability: 0.0000% Text: a robot holding a candle
|
|
||||||
```
|
|
||||||
|
|
||||||
## Running on an example with metal feature (mac)
|
|
||||||
|
|
||||||
```
|
|
||||||
$ cargo run --features metal --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
|
||||||
|
|
||||||
|
|
||||||
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
|
||||||
|
|
||||||
INFO clip: Probability: 0.0000% Text: a cycling race
|
|
||||||
INFO clip: Probability: 0.0000% Text: a photo of two cats
|
|
||||||
INFO clip: Probability: 100.0000% Text: a robot holding a candle
|
|
||||||
|
|
||||||
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
|
||||||
|
|
||||||
INFO clip: Probability: 99.9999% Text: a cycling race
|
|
||||||
INFO clip: Probability: 0.0001% Text: a photo of two cats
|
|
||||||
INFO clip: Probability: 0.0000% Text: a robot holding a candle
|
|
||||||
```
|
|
@ -1,202 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::Error as E;
|
|
||||||
use clap::Parser;
|
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use candle_nn::{ops::softmax, VarBuilder};
|
|
||||||
use candle_transformers::models::clip;
|
|
||||||
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
use tracing::info;
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
|
||||||
struct Args {
|
|
||||||
#[arg(long)]
|
|
||||||
model: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long, use_value_delimiter = true)]
|
|
||||||
images: Option<Vec<String>>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
#[arg(long, use_value_delimiter = true)]
|
|
||||||
sequences: Option<Vec<String>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
|
|
||||||
let img = image::io::Reader::open(path)?.decode()?;
|
|
||||||
let (height, width) = (image_size, image_size);
|
|
||||||
let img = img.resize_to_fill(
|
|
||||||
width as u32,
|
|
||||||
height as u32,
|
|
||||||
image::imageops::FilterType::Triangle,
|
|
||||||
);
|
|
||||||
|
|
||||||
let img = img.to_rgb8();
|
|
||||||
|
|
||||||
let img = img.into_raw();
|
|
||||||
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
|
|
||||||
.permute((2, 0, 1))?
|
|
||||||
.to_dtype(DType::F32)?
|
|
||||||
.affine(2. / 255., -1.)?;
|
|
||||||
// .unsqueeze(0)?;
|
|
||||||
Ok(img)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_images<T: AsRef<std::path::Path>>(
|
|
||||||
paths: &Vec<T>,
|
|
||||||
image_size: usize,
|
|
||||||
) -> anyhow::Result<Tensor> {
|
|
||||||
let mut images = vec![];
|
|
||||||
|
|
||||||
for path in paths {
|
|
||||||
let tensor = load_image(path, image_size)?;
|
|
||||||
images.push(tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
let images = Tensor::stack(&images, 0)?;
|
|
||||||
|
|
||||||
Ok(images)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
|
||||||
// std::env::set_var("RUST_BACKTRACE", "full");
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
|
|
||||||
tracing_subscriber::fmt::init();
|
|
||||||
|
|
||||||
let model_file = match args.model {
|
|
||||||
None => {
|
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
|
||||||
|
|
||||||
let api = api.repo(hf_hub::Repo::with_revision(
|
|
||||||
"openai/clip-vit-base-patch32".to_string(),
|
|
||||||
hf_hub::RepoType::Model,
|
|
||||||
"refs/pr/15".to_string(),
|
|
||||||
));
|
|
||||||
|
|
||||||
api.get("model.safetensors")?
|
|
||||||
}
|
|
||||||
Some(model) => model.into(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let tokenizer = get_tokenizer(args.tokenizer)?;
|
|
||||||
|
|
||||||
let config = clip::ClipConfig::vit_base_patch32();
|
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
|
|
||||||
let vec_imgs = match args.images {
|
|
||||||
Some(imgs) => imgs,
|
|
||||||
None => vec![
|
|
||||||
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
|
|
||||||
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
|
||||||
],
|
|
||||||
};
|
|
||||||
|
|
||||||
// let image = load_image(args.image, config.image_size)?.to_device(&device)?;
|
|
||||||
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
|
|
||||||
|
|
||||||
let vb =
|
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
|
|
||||||
|
|
||||||
let model = clip::ClipModel::new(vb, &config)?;
|
|
||||||
|
|
||||||
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
|
|
||||||
|
|
||||||
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
|
|
||||||
|
|
||||||
let softmax_image = softmax(&logits_per_image, 1)?;
|
|
||||||
|
|
||||||
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
info!("softmax_image_vec: {:?}", softmax_image_vec);
|
|
||||||
|
|
||||||
let probability_vec = softmax_image_vec
|
|
||||||
.iter()
|
|
||||||
.map(|v| v * 100.0)
|
|
||||||
.collect::<Vec<f32>>();
|
|
||||||
|
|
||||||
let probability_per_image = probability_vec.len() / vec_imgs.len();
|
|
||||||
|
|
||||||
for (i, img) in vec_imgs.iter().enumerate() {
|
|
||||||
let start = i * probability_per_image;
|
|
||||||
let end = start + probability_per_image;
|
|
||||||
let prob = &probability_vec[start..end];
|
|
||||||
info!("\n\nResults for image: {}\n", img);
|
|
||||||
|
|
||||||
for (i, p) in prob.iter().enumerate() {
|
|
||||||
info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
|
|
||||||
let tokenizer = match tokenizer {
|
|
||||||
None => {
|
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
|
||||||
let api = api.repo(hf_hub::Repo::with_revision(
|
|
||||||
"openai/clip-vit-base-patch32".to_string(),
|
|
||||||
hf_hub::RepoType::Model,
|
|
||||||
"refs/pr/15".to_string(),
|
|
||||||
));
|
|
||||||
api.get("tokenizer.json")?
|
|
||||||
}
|
|
||||||
Some(file) => file.into(),
|
|
||||||
};
|
|
||||||
|
|
||||||
Tokenizer::from_file(tokenizer).map_err(E::msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tokenize_sequences(
|
|
||||||
sequences: Option<Vec<String>>,
|
|
||||||
tokenizer: &Tokenizer,
|
|
||||||
device: &Device,
|
|
||||||
) -> anyhow::Result<(Tensor, Vec<String>)> {
|
|
||||||
let pad_id = *tokenizer
|
|
||||||
.get_vocab(true)
|
|
||||||
.get("<|endoftext|>")
|
|
||||||
.ok_or(E::msg("No pad token"))?;
|
|
||||||
|
|
||||||
let vec_seq = match sequences {
|
|
||||||
Some(seq) => seq,
|
|
||||||
None => vec![
|
|
||||||
"a cycling race".to_string(),
|
|
||||||
"a photo of two cats".to_string(),
|
|
||||||
"a robot holding a candle".to_string(),
|
|
||||||
],
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut tokens = vec![];
|
|
||||||
|
|
||||||
for seq in vec_seq.clone() {
|
|
||||||
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
|
|
||||||
tokens.push(encoding.get_ids().to_vec());
|
|
||||||
}
|
|
||||||
|
|
||||||
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
|
|
||||||
|
|
||||||
// Pad the sequences to have the same length
|
|
||||||
for token_vec in tokens.iter_mut() {
|
|
||||||
let len_diff = max_len - token_vec.len();
|
|
||||||
if len_diff > 0 {
|
|
||||||
token_vec.extend(vec![pad_id; len_diff]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let input_ids = Tensor::new(tokens, device)?;
|
|
||||||
|
|
||||||
Ok((input_ids, vec_seq))
|
|
||||||
}
|
|
@ -28,7 +28,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
let model_file = match args.model {
|
let model_file = match args.model {
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
# candle-convnext
|
|
||||||
|
|
||||||
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and
|
|
||||||
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808).
|
|
||||||
|
|
||||||
This candle implementation uses a pre-trained ConvNeXt network for inference. The
|
|
||||||
classification head has been trained on the ImageNet dataset and returns the
|
|
||||||
probabilities for the top-5 classes.
|
|
||||||
|
|
||||||
## Running an example
|
|
||||||
|
|
||||||
```
|
|
||||||
$ cargo run --example convnext --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which tiny
|
|
||||||
|
|
||||||
loaded image Tensor[dims 3, 224, 224; f32]
|
|
||||||
model built
|
|
||||||
mountain bike, all-terrain bike, off-roader: 84.09%
|
|
||||||
bicycle-built-for-two, tandem bicycle, tandem: 4.15%
|
|
||||||
maillot : 0.74%
|
|
||||||
crash helmet : 0.54%
|
|
||||||
unicycle, monocycle : 0.44%
|
|
||||||
|
|
||||||
```
|
|
@ -1,126 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
|
|
||||||
use candle::{DType, IndexOp, D};
|
|
||||||
use candle_nn::{Module, VarBuilder};
|
|
||||||
use candle_transformers::models::convnext;
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
Atto,
|
|
||||||
Femto,
|
|
||||||
Pico,
|
|
||||||
Nano,
|
|
||||||
Tiny,
|
|
||||||
Small,
|
|
||||||
Base,
|
|
||||||
Large,
|
|
||||||
AttoV2,
|
|
||||||
FemtoV2,
|
|
||||||
PicoV2,
|
|
||||||
NanoV2,
|
|
||||||
TinyV2,
|
|
||||||
BaseV2,
|
|
||||||
LargeV2,
|
|
||||||
XLarge,
|
|
||||||
Huge,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Which {
|
|
||||||
fn model_filename(&self) -> String {
|
|
||||||
let name = match self {
|
|
||||||
Self::Atto => "convnext_atto.d2_in1k",
|
|
||||||
Self::Femto => "convnext_femto.d1_in1k",
|
|
||||||
Self::Pico => "convnext_pico.d1_in1k",
|
|
||||||
Self::Nano => "convnext_nano.d1h_in1k",
|
|
||||||
Self::Tiny => "convnext_tiny.fb_in1k",
|
|
||||||
Self::Small => "convnext_small.fb_in1k",
|
|
||||||
Self::Base => "convnext_base.fb_in1k",
|
|
||||||
Self::Large => "convnext_large.fb_in1k",
|
|
||||||
Self::AttoV2 => "convnextv2_atto.fcmae_ft_in1k",
|
|
||||||
Self::FemtoV2 => "convnextv2_femto.fcmae_ft_in1k",
|
|
||||||
Self::PicoV2 => "convnextv2_pico.fcmae_ft_in1k",
|
|
||||||
Self::NanoV2 => "convnextv2_nano.fcmae_ft_in1k",
|
|
||||||
Self::TinyV2 => "convnextv2_tiny.fcmae_ft_in1k",
|
|
||||||
Self::BaseV2 => "convnextv2_base.fcmae_ft_in1k",
|
|
||||||
Self::LargeV2 => "convnextv2_large.fcmae_ft_in1k",
|
|
||||||
Self::XLarge => "convnext_xlarge.fb_in22k_ft_in1k",
|
|
||||||
Self::Huge => "convnextv2_huge.fcmae_ft_in1k",
|
|
||||||
};
|
|
||||||
|
|
||||||
format!("timm/{name}")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn config(&self) -> convnext::Config {
|
|
||||||
match self {
|
|
||||||
Self::Atto | Self::AttoV2 => convnext::Config::atto(),
|
|
||||||
Self::Femto | Self::FemtoV2 => convnext::Config::femto(),
|
|
||||||
Self::Pico | Self::PicoV2 => convnext::Config::pico(),
|
|
||||||
Self::Nano | Self::NanoV2 => convnext::Config::nano(),
|
|
||||||
Self::Tiny | Self::TinyV2 => convnext::Config::tiny(),
|
|
||||||
Self::Small => convnext::Config::small(),
|
|
||||||
Self::Base | Self::BaseV2 => convnext::Config::base(),
|
|
||||||
Self::Large | Self::LargeV2 => convnext::Config::large(),
|
|
||||||
Self::XLarge => convnext::Config::xlarge(),
|
|
||||||
Self::Huge => convnext::Config::huge(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
|
||||||
struct Args {
|
|
||||||
#[arg(long)]
|
|
||||||
model: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
image: String,
|
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
#[arg(value_enum, long, default_value_t=Which::Tiny)]
|
|
||||||
which: Which,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
|
||||||
let args = Args::parse();
|
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
|
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
|
||||||
println!("loaded image {image:?}");
|
|
||||||
|
|
||||||
let model_file = match args.model {
|
|
||||||
None => {
|
|
||||||
let model_name = args.which.model_filename();
|
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
|
||||||
let api = api.model(model_name);
|
|
||||||
api.get("model.safetensors")?
|
|
||||||
}
|
|
||||||
Some(model) => model.into(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
|
||||||
let model = convnext::convnext(&args.which.config(), 1000, vb)?;
|
|
||||||
println!("model built");
|
|
||||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
|
||||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
|
||||||
.i(0)?
|
|
||||||
.to_vec1::<f32>()?;
|
|
||||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
|
||||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
|
||||||
for &(category_idx, pr) in prs.iter().take(5) {
|
|
||||||
println!(
|
|
||||||
"{:24}: {:.2}%",
|
|
||||||
candle_examples::imagenet::CLASSES[category_idx],
|
|
||||||
100. * pr
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1 +0,0 @@
|
|||||||
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));
|
|
||||||
|
@ -31,7 +31,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
let model_file = match args.model {
|
let model_file = match args.model {
|
||||||
|
@ -47,7 +47,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
let model_file = match args.model {
|
let model_file = match args.model {
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
# candle-efficientvit
|
|
||||||
|
|
||||||
[EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention](https://arxiv.org/abs/2305.07027).
|
|
||||||
|
|
||||||
This candle implementation uses a pre-trained EfficientViT (from Microsoft Research Asia) network for inference.
|
|
||||||
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
|
|
||||||
|
|
||||||
## Running an example
|
|
||||||
|
|
||||||
```
|
|
||||||
$ cargo run --example efficientvit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which m1
|
|
||||||
|
|
||||||
loaded image Tensor[dims 3, 224, 224; f32]
|
|
||||||
model built
|
|
||||||
mountain bike, all-terrain bike, off-roader: 69.80%
|
|
||||||
unicycle, monocycle : 13.03%
|
|
||||||
bicycle-built-for-two, tandem bicycle, tandem: 9.28%
|
|
||||||
crash helmet : 2.25%
|
|
||||||
alp : 0.46%
|
|
||||||
```
|
|
@ -1,99 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
|
|
||||||
use candle::{DType, IndexOp, D};
|
|
||||||
use candle_nn::{Module, VarBuilder};
|
|
||||||
use candle_transformers::models::efficientvit;
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
M0,
|
|
||||||
M1,
|
|
||||||
M2,
|
|
||||||
M3,
|
|
||||||
M4,
|
|
||||||
M5,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Which {
|
|
||||||
fn model_filename(&self) -> String {
|
|
||||||
let name = match self {
|
|
||||||
Self::M0 => "m0",
|
|
||||||
Self::M1 => "m1",
|
|
||||||
Self::M2 => "m2",
|
|
||||||
Self::M3 => "m3",
|
|
||||||
Self::M4 => "m4",
|
|
||||||
Self::M5 => "m5",
|
|
||||||
};
|
|
||||||
format!("timm/efficientvit_{}.r224_in1k", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn config(&self) -> efficientvit::Config {
|
|
||||||
match self {
|
|
||||||
Self::M0 => efficientvit::Config::m0(),
|
|
||||||
Self::M1 => efficientvit::Config::m1(),
|
|
||||||
Self::M2 => efficientvit::Config::m2(),
|
|
||||||
Self::M3 => efficientvit::Config::m3(),
|
|
||||||
Self::M4 => efficientvit::Config::m4(),
|
|
||||||
Self::M5 => efficientvit::Config::m5(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
|
||||||
struct Args {
|
|
||||||
#[arg(long)]
|
|
||||||
model: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
image: String,
|
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
#[arg(value_enum, long, default_value_t=Which::M0)]
|
|
||||||
which: Which,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
|
||||||
let args = Args::parse();
|
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
|
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
|
||||||
println!("loaded image {image:?}");
|
|
||||||
|
|
||||||
let model_file = match args.model {
|
|
||||||
None => {
|
|
||||||
let model_name = args.which.model_filename();
|
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
|
||||||
let api = api.model(model_name);
|
|
||||||
api.get("model.safetensors")?
|
|
||||||
}
|
|
||||||
Some(model) => model.into(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
|
||||||
let model = efficientvit::efficientvit(&args.which.config(), 1000, vb)?;
|
|
||||||
println!("model built");
|
|
||||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
|
||||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
|
||||||
.i(0)?
|
|
||||||
.to_vec1::<f32>()?;
|
|
||||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
|
||||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
|
||||||
for &(category_idx, pr) in prs.iter().take(5) {
|
|
||||||
println!(
|
|
||||||
"{:24}: {:.2}%",
|
|
||||||
candle_examples::imagenet::CLASSES[category_idx],
|
|
||||||
100. * pr
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,25 +0,0 @@
|
|||||||
# candle-endocec
|
|
||||||
|
|
||||||
[EnCodec](https://huggingface.co/facebook/encodec_24khz) is a high-quality audio
|
|
||||||
compression model using an encoder/decoder architecture with residual vector
|
|
||||||
quantization.
|
|
||||||
|
|
||||||
## Running one example
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run --example encodec --features symphonia --release -- code-to-audio \
|
|
||||||
candle-examples/examples/encodec/jfk-codes.safetensors \
|
|
||||||
jfk.wav
|
|
||||||
```
|
|
||||||
|
|
||||||
This decodes the EnCodec tokens stored in `jfk-codes.safetensors` and generates
|
|
||||||
an output wav file containing the audio data.
|
|
||||||
|
|
||||||
Instead of `code-to-audio` one can use:
|
|
||||||
- `audio-to-audio in.mp3 out.wav`: encodes the input audio file then decodes it to a wav file.
|
|
||||||
- `audio-to-code in.mp3 out.safetensors`: generates a safetensors file
|
|
||||||
containing EnCodec tokens for the input audio file.
|
|
||||||
|
|
||||||
If the audio output file name is set to `-`, the audio content directly gets
|
|
||||||
played on default audio output device. If the audio input file is set to `-`, the audio
|
|
||||||
gets recorded from the default audio input.
|
|
@ -1,275 +0,0 @@
|
|||||||
#![allow(unused)]
|
|
||||||
use anyhow::{Context, Result};
|
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
|
||||||
pub const SAMPLE_RATE: usize = 24_000;
|
|
||||||
|
|
||||||
pub(crate) struct AudioOutputData_ {
|
|
||||||
resampled_data: std::collections::VecDeque<f32>,
|
|
||||||
resampler: rubato::FastFixedIn<f32>,
|
|
||||||
output_buffer: Vec<f32>,
|
|
||||||
input_buffer: Vec<f32>,
|
|
||||||
input_len: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AudioOutputData_ {
|
|
||||||
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
|
|
||||||
use rubato::Resampler;
|
|
||||||
|
|
||||||
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
|
|
||||||
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
|
|
||||||
let resampler = rubato::FastFixedIn::new(
|
|
||||||
resample_ratio,
|
|
||||||
f64::max(resample_ratio, 1.0),
|
|
||||||
rubato::PolynomialDegree::Septic,
|
|
||||||
1024,
|
|
||||||
1,
|
|
||||||
)?;
|
|
||||||
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
|
|
||||||
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
|
|
||||||
Ok(Self {
|
|
||||||
resampled_data,
|
|
||||||
resampler,
|
|
||||||
input_buffer,
|
|
||||||
output_buffer,
|
|
||||||
input_len: 0,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn reset(&mut self) {
|
|
||||||
use rubato::Resampler;
|
|
||||||
self.output_buffer.fill(0.);
|
|
||||||
self.input_buffer.fill(0.);
|
|
||||||
self.resampler.reset();
|
|
||||||
self.resampled_data.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn take_all(&mut self) -> Vec<f32> {
|
|
||||||
let mut data = Vec::with_capacity(self.resampled_data.len());
|
|
||||||
while let Some(elem) = self.resampled_data.pop_back() {
|
|
||||||
data.push(elem);
|
|
||||||
}
|
|
||||||
data
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn is_empty(&self) -> bool {
|
|
||||||
self.resampled_data.is_empty()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assumes that the input buffer is large enough.
|
|
||||||
fn push_input_buffer(&mut self, samples: &[f32]) {
|
|
||||||
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
|
|
||||||
self.input_len += samples.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
|
|
||||||
use rubato::Resampler;
|
|
||||||
|
|
||||||
let mut pos_in = 0;
|
|
||||||
loop {
|
|
||||||
let rem = self.input_buffer.len() - self.input_len;
|
|
||||||
let pos_end = usize::min(pos_in + rem, samples.len());
|
|
||||||
self.push_input_buffer(&samples[pos_in..pos_end]);
|
|
||||||
pos_in = pos_end;
|
|
||||||
if self.input_len < self.input_buffer.len() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let (_, out_len) = self.resampler.process_into_buffer(
|
|
||||||
&[&self.input_buffer],
|
|
||||||
&mut [&mut self.output_buffer],
|
|
||||||
None,
|
|
||||||
)?;
|
|
||||||
for &elem in self.output_buffer[..out_len].iter() {
|
|
||||||
self.resampled_data.push_front(elem)
|
|
||||||
}
|
|
||||||
self.input_len = 0;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
|
|
||||||
|
|
||||||
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
|
||||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
|
||||||
|
|
||||||
println!("Setup audio output stream!");
|
|
||||||
let host = cpal::default_host();
|
|
||||||
let device = host
|
|
||||||
.default_output_device()
|
|
||||||
.context("no output device available")?;
|
|
||||||
let mut supported_configs_range = device.supported_output_configs()?;
|
|
||||||
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
|
|
||||||
// On macOS, it's commonly the case that there are only stereo outputs.
|
|
||||||
None => device
|
|
||||||
.supported_output_configs()?
|
|
||||||
.next()
|
|
||||||
.context("no audio output available")?,
|
|
||||||
Some(config_range) => config_range,
|
|
||||||
};
|
|
||||||
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
|
||||||
config_range.min_sample_rate(),
|
|
||||||
config_range.max_sample_rate(),
|
|
||||||
);
|
|
||||||
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
|
||||||
let channels = config.channels as usize;
|
|
||||||
println!(
|
|
||||||
"cpal device: {} {} {config:?}",
|
|
||||||
device.name().unwrap_or_else(|_| "unk".to_string()),
|
|
||||||
config.sample_rate.0
|
|
||||||
);
|
|
||||||
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
|
||||||
SAMPLE_RATE,
|
|
||||||
config.sample_rate.0 as usize,
|
|
||||||
)?));
|
|
||||||
let ad = audio_data.clone();
|
|
||||||
let stream = device.build_output_stream(
|
|
||||||
&config,
|
|
||||||
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
|
|
||||||
data.fill(0.);
|
|
||||||
let mut ad = ad.lock().unwrap();
|
|
||||||
let mut last_elem = 0f32;
|
|
||||||
for (idx, elem) in data.iter_mut().enumerate() {
|
|
||||||
if idx % channels == 0 {
|
|
||||||
match ad.resampled_data.pop_back() {
|
|
||||||
None => break,
|
|
||||||
Some(v) => {
|
|
||||||
last_elem = v;
|
|
||||||
*elem = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
*elem = last_elem
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
move |err| eprintln!("cpal error: {err}"),
|
|
||||||
None, // None=blocking, Some(Duration)=timeout
|
|
||||||
)?;
|
|
||||||
stream.play()?;
|
|
||||||
Ok((stream, audio_data))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
|
||||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
|
||||||
|
|
||||||
println!("Setup audio input stream!");
|
|
||||||
let host = cpal::default_host();
|
|
||||||
let device = host
|
|
||||||
.default_input_device()
|
|
||||||
.context("no input device available")?;
|
|
||||||
let mut supported_configs_range = device.supported_input_configs()?;
|
|
||||||
let config_range = supported_configs_range
|
|
||||||
.find(|c| c.channels() == 1)
|
|
||||||
.context("no audio input available")?;
|
|
||||||
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
|
||||||
config_range.min_sample_rate(),
|
|
||||||
config_range.max_sample_rate(),
|
|
||||||
);
|
|
||||||
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
|
||||||
println!(
|
|
||||||
"cpal device: {} {} {config:?}",
|
|
||||||
device.name().unwrap_or_else(|_| "unk".to_string()),
|
|
||||||
config.sample_rate.0
|
|
||||||
);
|
|
||||||
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
|
||||||
config.sample_rate.0 as usize,
|
|
||||||
SAMPLE_RATE,
|
|
||||||
)?));
|
|
||||||
let ad = audio_data.clone();
|
|
||||||
let stream = device.build_input_stream(
|
|
||||||
&config,
|
|
||||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
|
||||||
let mut ad = ad.lock().unwrap();
|
|
||||||
if let Err(err) = ad.push_samples(data) {
|
|
||||||
eprintln!("error processing audio input {err:?}")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
move |err| eprintln!("cpal error: {err}"),
|
|
||||||
None, // None=blocking, Some(Duration)=timeout
|
|
||||||
)?;
|
|
||||||
stream.play()?;
|
|
||||||
Ok((stream, audio_data))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
|
||||||
where
|
|
||||||
T: symphonia::core::sample::Sample,
|
|
||||||
f32: symphonia::core::conv::FromSample<T>,
|
|
||||||
{
|
|
||||||
use symphonia::core::audio::Signal;
|
|
||||||
use symphonia::core::conv::FromSample;
|
|
||||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
|
|
||||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
|
||||||
|
|
||||||
let src = std::fs::File::open(path)?;
|
|
||||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
|
||||||
let hint = symphonia::core::probe::Hint::new();
|
|
||||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
|
||||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
|
||||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
|
||||||
let mut format = probed.format;
|
|
||||||
let track = format
|
|
||||||
.tracks()
|
|
||||||
.iter()
|
|
||||||
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
|
||||||
.expect("no supported audio tracks");
|
|
||||||
let mut decoder = symphonia::default::get_codecs()
|
|
||||||
.make(&track.codec_params, &Default::default())
|
|
||||||
.expect("unsupported codec");
|
|
||||||
let track_id = track.id;
|
|
||||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
|
||||||
let mut pcm_data = Vec::new();
|
|
||||||
while let Ok(packet) = format.next_packet() {
|
|
||||||
while !format.metadata().is_latest() {
|
|
||||||
format.metadata().pop();
|
|
||||||
}
|
|
||||||
if packet.track_id() != track_id {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
match decoder.decode(&packet)? {
|
|
||||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
|
||||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok((pcm_data, sample_rate))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result<Vec<f32>> {
|
|
||||||
use rubato::Resampler;
|
|
||||||
|
|
||||||
let mut pcm_out =
|
|
||||||
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
|
|
||||||
|
|
||||||
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in, sr_out, 1024, 1)?;
|
|
||||||
let mut output_buffer = resampler.output_buffer_allocate(true);
|
|
||||||
let mut pos_in = 0;
|
|
||||||
while pos_in + resampler.input_frames_next() < pcm_in.len() {
|
|
||||||
let (in_len, out_len) =
|
|
||||||
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
|
|
||||||
pos_in += in_len;
|
|
||||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if pos_in < pcm_in.len() {
|
|
||||||
let (_in_len, out_len) = resampler.process_partial_into_buffer(
|
|
||||||
Some(&[&pcm_in[pos_in..]]),
|
|
||||||
&mut output_buffer,
|
|
||||||
None,
|
|
||||||
)?;
|
|
||||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(pcm_out)
|
|
||||||
}
|
|
Binary file not shown.
@ -1,131 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::Result;
|
|
||||||
use candle::{DType, IndexOp, Tensor};
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::models::encodec::{Config, Model};
|
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
use hf_hub::api::sync::Api;
|
|
||||||
|
|
||||||
mod audio_io;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
|
||||||
enum Action {
|
|
||||||
AudioToAudio,
|
|
||||||
AudioToCode,
|
|
||||||
CodeToAudio,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// The action to be performed, specifies the format for the input and output data.
|
|
||||||
action: Action,
|
|
||||||
|
|
||||||
/// The input file, either an audio file or some encodec tokens stored as safetensors.
|
|
||||||
in_file: String,
|
|
||||||
|
|
||||||
/// The output file, either a wave audio file or some encodec tokens stored as safetensors.
|
|
||||||
out_file: String,
|
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// The model weight file, in safetensor format.
|
|
||||||
#[arg(long)]
|
|
||||||
model: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
let args = Args::parse();
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let model = match args.model {
|
|
||||||
Some(model) => std::path::PathBuf::from(model),
|
|
||||||
None => Api::new()?
|
|
||||||
.model("facebook/encodec_24khz".to_string())
|
|
||||||
.get("model.safetensors")?,
|
|
||||||
};
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
|
||||||
let config = Config::default();
|
|
||||||
let model = Model::new(&config, vb)?;
|
|
||||||
|
|
||||||
let codes = match args.action {
|
|
||||||
Action::CodeToAudio => {
|
|
||||||
let codes = candle::safetensors::load(args.in_file, &device)?;
|
|
||||||
codes.get("codes").expect("no codes in input file").clone()
|
|
||||||
}
|
|
||||||
Action::AudioToCode | Action::AudioToAudio => {
|
|
||||||
let pcm = if args.in_file == "-" {
|
|
||||||
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
|
|
||||||
let (stream, input_audio) = audio_io::setup_input_stream()?;
|
|
||||||
let mut pcms = vec![];
|
|
||||||
let stdin = std::thread::spawn(|| {
|
|
||||||
let mut s = String::new();
|
|
||||||
std::io::stdin().read_line(&mut s)
|
|
||||||
});
|
|
||||||
while !stdin.is_finished() {
|
|
||||||
let input = input_audio.lock().unwrap().take_all();
|
|
||||||
if input.is_empty() {
|
|
||||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
pcms.push(input)
|
|
||||||
}
|
|
||||||
drop(stream);
|
|
||||||
pcms.concat()
|
|
||||||
} else {
|
|
||||||
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
|
|
||||||
if sample_rate != 24_000 {
|
|
||||||
println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}, resampling...");
|
|
||||||
audio_io::resample(&pcm, sample_rate as usize, 24_000)?
|
|
||||||
} else {
|
|
||||||
pcm
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let pcm_len = pcm.len();
|
|
||||||
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
|
||||||
println!("input pcm shape: {:?}", pcm.shape());
|
|
||||||
model.encode(&pcm)?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
println!("codes shape: {:?}", codes.shape());
|
|
||||||
|
|
||||||
match args.action {
|
|
||||||
Action::AudioToCode => {
|
|
||||||
codes.save_safetensors("codes", &args.out_file)?;
|
|
||||||
}
|
|
||||||
Action::AudioToAudio | Action::CodeToAudio => {
|
|
||||||
let pcm = model.decode(&codes)?;
|
|
||||||
println!("output pcm shape: {:?}", pcm.shape());
|
|
||||||
let pcm = pcm.i(0)?.i(0)?;
|
|
||||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
|
||||||
let pcm = pcm.to_vec1::<f32>()?;
|
|
||||||
if args.out_file == "-" {
|
|
||||||
let (stream, ad) = audio_io::setup_output_stream()?;
|
|
||||||
{
|
|
||||||
let mut ad = ad.lock().unwrap();
|
|
||||||
ad.push_samples(&pcm)?;
|
|
||||||
}
|
|
||||||
loop {
|
|
||||||
let ad = ad.lock().unwrap();
|
|
||||||
if ad.is_empty() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
// That's very weird, calling thread::sleep here triggers the stream to stop
|
|
||||||
// playing (the callback doesn't seem to be called anymore).
|
|
||||||
// std::thread::sleep(std::time::Duration::from_millis(100));
|
|
||||||
}
|
|
||||||
drop(stream)
|
|
||||||
} else {
|
|
||||||
let mut output = std::fs::File::create(&args.out_file)?;
|
|
||||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,27 +0,0 @@
|
|||||||
# candle-gemma: 2b and 7b LLMs from Google DeepMind
|
|
||||||
|
|
||||||
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
|
|
||||||
models published by Google Deepmind with a 2b and a 7b variant.
|
|
||||||
|
|
||||||
In order to use the example below, you have to accept the license on the
|
|
||||||
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
|
|
||||||
your access token via the [HuggingFace cli login
|
|
||||||
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
|
|
||||||
|
|
||||||
## Running the example
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ cargo run --example gemma --release -- --prompt "fn count_primes(max_n: usize)"
|
|
||||||
fn count_primes(max_n: usize) -> usize {
|
|
||||||
let mut primes = vec![true; max_n];
|
|
||||||
for i in 2..=max_n {
|
|
||||||
if primes[i] {
|
|
||||||
for j in i * i..max_n {
|
|
||||||
primes[j] = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
primes.len()
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
@ -1,289 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use clap::Parser;
|
|
||||||
|
|
||||||
use candle_transformers::models::gemma::{Config, Model};
|
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
#[value(name = "2b")]
|
|
||||||
Base2B,
|
|
||||||
#[value(name = "7b")]
|
|
||||||
Base7B,
|
|
||||||
#[value(name = "2b-it")]
|
|
||||||
Instruct2B,
|
|
||||||
#[value(name = "7b-it")]
|
|
||||||
Instruct7B,
|
|
||||||
#[value(name = "1.1-2b-it")]
|
|
||||||
InstructV1_1_2B,
|
|
||||||
#[value(name = "1.1-7b-it")]
|
|
||||||
InstructV1_1_7B,
|
|
||||||
#[value(name = "code-2b")]
|
|
||||||
CodeBase2B,
|
|
||||||
#[value(name = "code-7b")]
|
|
||||||
CodeBase7B,
|
|
||||||
#[value(name = "code-2b-it")]
|
|
||||||
CodeInstruct2B,
|
|
||||||
#[value(name = "code-7b-it")]
|
|
||||||
CodeInstruct7B,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct TextGeneration {
|
|
||||||
model: Model,
|
|
||||||
device: Device,
|
|
||||||
tokenizer: TokenOutputStream,
|
|
||||||
logits_processor: LogitsProcessor,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextGeneration {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn new(
|
|
||||||
model: Model,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
temp: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
device: &Device,
|
|
||||||
) -> Self {
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
tokenizer: TokenOutputStream::new(tokenizer),
|
|
||||||
logits_processor,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n,
|
|
||||||
device: device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
self.tokenizer.clear();
|
|
||||||
let mut tokens = self
|
|
||||||
.tokenizer
|
|
||||||
.tokenizer()
|
|
||||||
.encode(prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
for &t in tokens.iter() {
|
|
||||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
|
||||||
print!("{t}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => anyhow::bail!("cannot find the <eos> token"),
|
|
||||||
};
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for index in 0..sample_len {
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
|
||||||
let ctxt = &tokens[start_pos..];
|
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
||||||
let logits = self.model.forward(&input, start_pos)?;
|
|
||||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
|
||||||
&logits,
|
|
||||||
self.repeat_penalty,
|
|
||||||
&tokens[start_at..],
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let dt = start_gen.elapsed();
|
|
||||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
print!("{rest}");
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
println!(
|
|
||||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
||||||
generated_tokens as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: String,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
temperature: Option<f64>,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
|
||||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
|
||||||
sample_len: usize,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "main")]
|
|
||||||
revision: String,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
config_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
weight_files: Option<String>,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
||||||
#[arg(long, default_value_t = 1.1)]
|
|
||||||
repeat_penalty: f32,
|
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
|
||||||
#[arg(long, default_value_t = 64)]
|
|
||||||
repeat_last_n: usize,
|
|
||||||
|
|
||||||
/// The model to use.
|
|
||||||
#[arg(long, default_value = "2b")]
|
|
||||||
which: Which,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle::utils::with_avx(),
|
|
||||||
candle::utils::with_neon(),
|
|
||||||
candle::utils::with_simd128(),
|
|
||||||
candle::utils::with_f16c()
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
||||||
args.temperature.unwrap_or(0.),
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n
|
|
||||||
);
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let api = Api::new()?;
|
|
||||||
let model_id = match &args.model_id {
|
|
||||||
Some(model_id) => model_id.to_string(),
|
|
||||||
None => match args.which {
|
|
||||||
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
|
|
||||||
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
|
|
||||||
Which::Base2B => "google/gemma-2b".to_string(),
|
|
||||||
Which::Base7B => "google/gemma-7b".to_string(),
|
|
||||||
Which::Instruct2B => "google/gemma-2b-it".to_string(),
|
|
||||||
Which::Instruct7B => "google/gemma-7b-it".to_string(),
|
|
||||||
Which::CodeBase2B => "google/codegemma-2b".to_string(),
|
|
||||||
Which::CodeBase7B => "google/codegemma-7b".to_string(),
|
|
||||||
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
|
||||||
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
let repo = api.repo(Repo::with_revision(
|
|
||||||
model_id,
|
|
||||||
RepoType::Model,
|
|
||||||
args.revision,
|
|
||||||
));
|
|
||||||
let tokenizer_filename = match args.tokenizer_file {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => repo.get("tokenizer.json")?,
|
|
||||||
};
|
|
||||||
let config_filename = match args.config_file {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => repo.get("config.json")?,
|
|
||||||
};
|
|
||||||
let filenames = match args.weight_files {
|
|
||||||
Some(files) => files
|
|
||||||
.split(',')
|
|
||||||
.map(std::path::PathBuf::from)
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
|
||||||
};
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let dtype = if device.is_cuda() {
|
|
||||||
DType::BF16
|
|
||||||
} else {
|
|
||||||
DType::F32
|
|
||||||
};
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
|
||||||
let model = Model::new(&config, vb)?;
|
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
args.seed,
|
|
||||||
args.temperature,
|
|
||||||
args.top_p,
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n,
|
|
||||||
&device,
|
|
||||||
);
|
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -31,8 +31,6 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
|||||||
enum Which {
|
enum Which {
|
||||||
V1,
|
V1,
|
||||||
V2,
|
V2,
|
||||||
V3,
|
|
||||||
V3Instruct,
|
|
||||||
#[value(name = "solar-10.7b")]
|
#[value(name = "solar-10.7b")]
|
||||||
Solar10_7B,
|
Solar10_7B,
|
||||||
#[value(name = "tiny-llama-1.1b-chat")]
|
#[value(name = "tiny-llama-1.1b-chat")]
|
||||||
@ -47,8 +45,8 @@ struct Args {
|
|||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long, default_value_t = 0.8)]
|
#[arg(long)]
|
||||||
temperature: f64,
|
temperature: Option<f64>,
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
/// Nucleus sampling probability cutoff.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -59,7 +57,7 @@ struct Args {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, default_value_t = 10000)]
|
#[arg(long, default_value_t = 100)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
/// Disable the key-value cache.
|
/// Disable the key-value cache.
|
||||||
@ -92,11 +90,11 @@ struct Args {
|
|||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
#[arg(long, default_value_t = 1.1)]
|
#[arg(long, default_value_t = 1.0)]
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
/// The context size to consider for the repeat penalty.
|
||||||
#[arg(long, default_value_t = 128)]
|
#[arg(long, default_value_t = 64)]
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,18 +118,13 @@ fn main() -> Result<()> {
|
|||||||
Some("bf16") => DType::BF16,
|
Some("bf16") => DType::BF16,
|
||||||
Some("f32") => DType::F32,
|
Some("f32") => DType::F32,
|
||||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||||
None => match args.which {
|
None => DType::F16,
|
||||||
Which::V3 | Which::V3Instruct => DType::BF16,
|
|
||||||
Which::V1 | Which::V2 | Which::Solar10_7B | Which::TinyLlama1_1BChat => DType::F16,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
let (llama, tokenizer_filename, mut cache, config) = {
|
let (llama, tokenizer_filename, cache) = {
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||||
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
|
|
||||||
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
|
|
||||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||||
});
|
});
|
||||||
@ -145,31 +138,29 @@ fn main() -> Result<()> {
|
|||||||
let config = config.into_config(args.use_flash_attn);
|
let config = config.into_config(args.use_flash_attn);
|
||||||
|
|
||||||
let filenames = match args.which {
|
let filenames = match args.which {
|
||||||
Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
|
Which::V1 | Which::V2 | Which::Solar10_7B => {
|
||||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||||
}
|
}
|
||||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||||
};
|
};
|
||||||
|
println!("building the model");
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
|
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||||
};
|
};
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let eos_token_id = config
|
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||||
.eos_token_id
|
|
||||||
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
|
|
||||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||||
let mut tokens = tokenizer
|
let mut tokens = tokenizer
|
||||||
.encode(prompt, true)
|
.encode(prompt, true)
|
||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
|
||||||
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
print!("{prompt}");
|
print!("{prompt}");
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), args.top_p);
|
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
let mut token_generated = 0;
|
let mut token_generated = 0;
|
||||||
@ -181,7 +172,7 @@ fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
let logits = llama.forward(&input, context_index, &mut cache)?;
|
let logits = llama.forward(&input, context_index)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
let logits = if args.repeat_penalty == 1. {
|
let logits = if args.repeat_penalty == 1. {
|
||||||
logits
|
logits
|
||||||
@ -199,16 +190,18 @@ fn main() -> Result<()> {
|
|||||||
token_generated += 1;
|
token_generated += 1;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
|
|
||||||
|
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||||
|
// heuristics as it seems to work well enough for this example. See the following for more
|
||||||
|
// details:
|
||||||
|
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||||
|
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||||
|
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||||
|
print!("{text}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
if Some(next_token) == eos_token_id {
|
if Some(next_token) == eos_token_id {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
print!("{rest}");
|
|
||||||
}
|
}
|
||||||
let dt = start_gen.elapsed();
|
let dt = start_gen.elapsed();
|
||||||
println!(
|
println!(
|
||||||
|
@ -19,7 +19,7 @@ use candle_transformers::generation::LogitsProcessor;
|
|||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
use model::{Cache, Config, Llama};
|
use model::{Config, Llama};
|
||||||
use qmodel::QLlama;
|
use qmodel::QLlama;
|
||||||
use weights::TransformerWeights;
|
use weights::TransformerWeights;
|
||||||
|
|
||||||
@ -160,10 +160,10 @@ enum Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result<Tensor> {
|
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::Llama(l) => Ok(l.forward(xs, pos, cache)?),
|
Self::Llama(l) => Ok(l.forward(xs, pos)?),
|
||||||
Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?),
|
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -188,8 +188,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
|||||||
let config = Config::from_reader(&mut file)?;
|
let config = Config::from_reader(&mut file)?;
|
||||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||||
let vb = weights.var_builder(&config, &device)?;
|
let vb = weights.var_builder(&config, &device)?;
|
||||||
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
|
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
||||||
let model = Llama::load(vb, config)?;
|
let model = Llama::load(vb, &cache, config)?;
|
||||||
|
|
||||||
let tokens = match &args.pretokenized_dir {
|
let tokens = match &args.pretokenized_dir {
|
||||||
None => {
|
None => {
|
||||||
@ -235,7 +235,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
|||||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||||
for inp_tgt in batch_iter {
|
for inp_tgt in batch_iter {
|
||||||
let (inp, tgt) = inp_tgt?;
|
let (inp, tgt) = inp_tgt?;
|
||||||
let logits = model.forward(&inp, 0, &mut cache)?;
|
let logits = model.forward(&inp, 0)?;
|
||||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||||
println!("{}", loss.to_vec0::<f32>()?);
|
println!("{}", loss.to_vec0::<f32>()?);
|
||||||
}
|
}
|
||||||
@ -261,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
let is_safetensors = config_path
|
let is_safetensors = config_path
|
||||||
.extension()
|
.extension()
|
||||||
.map_or(false, |v| v == "safetensors");
|
.map_or(false, |v| v == "safetensors");
|
||||||
let (model, config, mut cache) = if is_gguf {
|
let (model, config) = if is_gguf {
|
||||||
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
|
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
|
||||||
let (_vocab_size, dim) = vb
|
let (_vocab_size, dim) = vb
|
||||||
.get_no_shape("model.embed_tokens.weight")?
|
.get_no_shape("model.embed_tokens.weight")?
|
||||||
@ -298,15 +298,15 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
&device,
|
&device,
|
||||||
);
|
);
|
||||||
let cache = model::Cache::new(true, &config, fake_vb)?;
|
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||||
let model = Model::QLlama(QLlama::load(vb, config.clone())?);
|
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||||
(model, config, cache)
|
(model, config)
|
||||||
} else if is_safetensors {
|
} else if is_safetensors {
|
||||||
let config = Config::tiny_15m();
|
let config = Config::tiny_15m();
|
||||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||||
let model = Model::Llama(Llama::load(vb, config.clone())?);
|
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||||
(model, config, cache)
|
(model, config)
|
||||||
} else {
|
} else {
|
||||||
let mut file = std::fs::File::open(config_path)?;
|
let mut file = std::fs::File::open(config_path)?;
|
||||||
let config = Config::from_reader(&mut file)?;
|
let config = Config::from_reader(&mut file)?;
|
||||||
@ -314,8 +314,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||||
let vb = weights.var_builder(&config, &device)?;
|
let vb = weights.var_builder(&config, &device)?;
|
||||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||||
let model = Model::Llama(Llama::load(vb, config.clone())?);
|
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||||
(model, config, cache)
|
(model, config)
|
||||||
};
|
};
|
||||||
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
@ -328,7 +328,6 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
|
||||||
|
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0.. {
|
for index in 0.. {
|
||||||
@ -338,7 +337,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, index_pos, &mut cache)?;
|
let logits = model.forward(&input, index_pos)?;
|
||||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||||
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
||||||
logits
|
logits
|
||||||
@ -354,14 +353,16 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
|
|
||||||
let next_token = logits_processor.sample(&logits)?;
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||||
print!("{t}");
|
// heuristics as it seems to work well enough for this example. See the following for more
|
||||||
|
// details:
|
||||||
|
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||||
|
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||||
|
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||||
|
print!("{text}");
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
print!("{rest}");
|
|
||||||
}
|
|
||||||
let dt = start_gen.elapsed();
|
let dt = start_gen.elapsed();
|
||||||
println!(
|
println!(
|
||||||
"\n{} tokens generated ({:.2} token/s)\n",
|
"\n{} tokens generated ({:.2} token/s)\n",
|
||||||
|
@ -8,7 +8,6 @@ fn valid_loss(
|
|||||||
model: &Llama,
|
model: &Llama,
|
||||||
args: &crate::TrainingCmd,
|
args: &crate::TrainingCmd,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
cache: &mut Cache,
|
|
||||||
) -> Result<f64> {
|
) -> Result<f64> {
|
||||||
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
|
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
|
||||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||||
@ -16,7 +15,7 @@ fn valid_loss(
|
|||||||
let mut cnt = 0usize;
|
let mut cnt = 0usize;
|
||||||
for inp_tgt in batch_iter.take(50) {
|
for inp_tgt in batch_iter.take(50) {
|
||||||
let (inp, tgt) = inp_tgt?;
|
let (inp, tgt) = inp_tgt?;
|
||||||
let logits = model.forward(&inp, 0, cache)?;
|
let logits = model.forward(&inp, 0)?;
|
||||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||||
sum_ce += loss.to_vec0::<f32>()? as f64;
|
sum_ce += loss.to_vec0::<f32>()? as f64;
|
||||||
cnt += 1;
|
cnt += 1;
|
||||||
@ -38,8 +37,8 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
|||||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||||
|
|
||||||
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
|
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||||
let model = Llama::load(vb, config)?;
|
let model = Llama::load(vb, &cache, config)?;
|
||||||
let params = candle_nn::ParamsAdamW {
|
let params = candle_nn::ParamsAdamW {
|
||||||
lr: args.learning_rate,
|
lr: args.learning_rate,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@ -47,14 +46,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
|||||||
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
|
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
|
||||||
for (batch_index, batch) in batch_iter.enumerate() {
|
for (batch_index, batch) in batch_iter.enumerate() {
|
||||||
let (inp, tgt) = batch?;
|
let (inp, tgt) = batch?;
|
||||||
let logits = model.forward(&inp, 0, &mut cache)?;
|
let logits = model.forward(&inp, 0)?;
|
||||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||||
opt.backward_step(&loss)?;
|
opt.backward_step(&loss)?;
|
||||||
|
|
||||||
if batch_index > 0 && batch_index % 100 == 0 {
|
if batch_index > 0 && batch_index % 100 == 0 {
|
||||||
// TODO: Add a way to deactivate the backprop graph tracking when computing the
|
// TODO: Add a way to deactivate the backprop graph tracking when computing the
|
||||||
// validation loss.
|
// validation loss.
|
||||||
let loss = valid_loss(&dataset, &model, args, &device, &mut cache)?;
|
let loss = valid_loss(&dataset, &model, args, &device)?;
|
||||||
println!("{batch_index} {loss}");
|
println!("{batch_index} {loss}");
|
||||||
}
|
}
|
||||||
if batch_index > 0 && batch_index % 1000 == 0 {
|
if batch_index > 0 && batch_index % 1000 == 0 {
|
||||||
|
@ -2,9 +2,6 @@
|
|||||||
|
|
||||||
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
|
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
|
||||||
|
|
||||||
Compared to the mamba example, this version can handle training but is much
|
|
||||||
slower.
|
|
||||||
|
|
||||||
## Running the example
|
## Running the example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
# candle-mamba: Mamba implementation
|
|
||||||
|
|
||||||
Candle implementation of *Mamba* [1] inference only. Mamba is an alternative to
|
|
||||||
the transformer architecture. It leverages State Space Models (SSMs) with the
|
|
||||||
goal of being computationally efficient on long sequences. The implementation is
|
|
||||||
based on [mamba.rs](https://github.com/LaurentMazare/mamba.rs).
|
|
||||||
|
|
||||||
- [1]. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752).
|
|
||||||
|
|
||||||
Compared to the mamba-minimal example, this version is far more efficient but
|
|
||||||
would only work for inference.
|
|
||||||
## Running the example
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
|
|
||||||
```
|
|
||||||
|
|
@ -1,305 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
|
|
||||||
use candle_transformers::models::mamba::{Config, Model, State};
|
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
struct TextGeneration {
|
|
||||||
model: Model,
|
|
||||||
config: Config,
|
|
||||||
device: Device,
|
|
||||||
tokenizer: TokenOutputStream,
|
|
||||||
logits_processor: LogitsProcessor,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextGeneration {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn new(
|
|
||||||
model: Model,
|
|
||||||
config: Config,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
temp: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
device: &Device,
|
|
||||||
) -> Self {
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
config,
|
|
||||||
tokenizer: TokenOutputStream::new(tokenizer),
|
|
||||||
logits_processor,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n,
|
|
||||||
device: device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
self.tokenizer.clear();
|
|
||||||
let dtype = self.model.dtype();
|
|
||||||
let mut tokens = self
|
|
||||||
.tokenizer
|
|
||||||
.tokenizer()
|
|
||||||
.encode(prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => anyhow::bail!("cannot find the </s> token"),
|
|
||||||
};
|
|
||||||
let mut state = State::new(1, &self.config, dtype, &self.device)?;
|
|
||||||
let mut next_logits = None;
|
|
||||||
for &t in tokens.iter() {
|
|
||||||
let input = Tensor::new(&[t], &self.device)?;
|
|
||||||
let logits = self.model.forward(&input, &mut state)?;
|
|
||||||
next_logits = Some(logits);
|
|
||||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
|
||||||
print!("{t}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for _ in 0..sample_len {
|
|
||||||
let logits = match next_logits.as_ref() {
|
|
||||||
Some(logits) => logits,
|
|
||||||
None => anyhow::bail!("cannot work on an empty prompt"),
|
|
||||||
};
|
|
||||||
let logits = logits.squeeze(0)?.to_dtype(dtype)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
|
||||||
&logits,
|
|
||||||
self.repeat_penalty,
|
|
||||||
&tokens[start_at..],
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let input = Tensor::new(&[next_token], &self.device)?;
|
|
||||||
next_logits = Some(self.model.forward(&input, &mut state)?)
|
|
||||||
}
|
|
||||||
let dt = start_gen.elapsed();
|
|
||||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
print!("{rest}");
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
println!(
|
|
||||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
||||||
generated_tokens as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
|
|
||||||
enum Which {
|
|
||||||
Mamba130m,
|
|
||||||
Mamba370m,
|
|
||||||
Mamba790m,
|
|
||||||
Mamba1_4b,
|
|
||||||
Mamba2_8b,
|
|
||||||
Mamba2_8bSlimPj,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for Which {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "{:?}", self)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Which {
|
|
||||||
fn model_id(&self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
Self::Mamba130m => "state-spaces/mamba-130m",
|
|
||||||
Self::Mamba370m => "state-spaces/mamba-370m",
|
|
||||||
Self::Mamba790m => "state-spaces/mamba-790m",
|
|
||||||
Self::Mamba1_4b => "state-spaces/mamba-1.4b",
|
|
||||||
Self::Mamba2_8b => "state-spaces/mamba-2.8b",
|
|
||||||
Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn revision(&self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
Self::Mamba130m
|
|
||||||
| Self::Mamba370m
|
|
||||||
| Self::Mamba790m
|
|
||||||
| Self::Mamba1_4b
|
|
||||||
| Self::Mamba2_8bSlimPj => "refs/pr/1",
|
|
||||||
Self::Mamba2_8b => "refs/pr/4",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: String,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
temperature: Option<f64>,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
|
||||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
|
||||||
sample_len: usize,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "mamba130m")]
|
|
||||||
which: Which,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
revision: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
weight_files: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
config_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "f32")]
|
|
||||||
dtype: String,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
||||||
#[arg(long, default_value_t = 1.1)]
|
|
||||||
repeat_penalty: f32,
|
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
|
||||||
#[arg(long, default_value_t = 64)]
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use std::str::FromStr;
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle::utils::with_avx(),
|
|
||||||
candle::utils::with_neon(),
|
|
||||||
candle::utils::with_simd128(),
|
|
||||||
candle::utils::with_f16c()
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
||||||
args.temperature.unwrap_or(0.),
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n
|
|
||||||
);
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let api = Api::new()?;
|
|
||||||
let repo = api.repo(Repo::with_revision(
|
|
||||||
args.model_id
|
|
||||||
.unwrap_or_else(|| args.which.model_id().to_string()),
|
|
||||||
RepoType::Model,
|
|
||||||
args.revision
|
|
||||||
.unwrap_or_else(|| args.which.revision().to_string()),
|
|
||||||
));
|
|
||||||
let tokenizer_filename = match args.tokenizer_file {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => api
|
|
||||||
.model("EleutherAI/gpt-neox-20b".to_string())
|
|
||||||
.get("tokenizer.json")?,
|
|
||||||
};
|
|
||||||
let config_filename = match args.config_file {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => repo.get("config.json")?,
|
|
||||||
};
|
|
||||||
let filenames = match args.weight_files {
|
|
||||||
Some(files) => files
|
|
||||||
.split(',')
|
|
||||||
.map(std::path::PathBuf::from)
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
None => {
|
|
||||||
vec![repo.get("model.safetensors")?]
|
|
||||||
}
|
|
||||||
};
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let dtype = DType::from_str(&args.dtype)?;
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
|
||||||
let model = Model::new(&config, vb.pp("backbone"))?;
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
|
||||||
model,
|
|
||||||
config,
|
|
||||||
tokenizer,
|
|
||||||
args.seed,
|
|
||||||
args.temperature,
|
|
||||||
args.top_p,
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n,
|
|
||||||
&device,
|
|
||||||
);
|
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,18 +0,0 @@
|
|||||||
# candle-metavoice
|
|
||||||
|
|
||||||
MetaVoice-1B is a text-to-speech model trained on 100K hours of speech, more
|
|
||||||
details on the [model
|
|
||||||
card](https://huggingface.co/metavoiceio/metavoice-1B-v0.1).
|
|
||||||
|
|
||||||
Note that the current candle implementation suffers from some limitations as of
|
|
||||||
2024-03-02:
|
|
||||||
- The speaker embeddings are hardcoded.
|
|
||||||
- The generated audio file quality is weaker than the Python implementation,
|
|
||||||
probably because of some implementation discrepancies.
|
|
||||||
|
|
||||||
## Run an example
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run --example metavoice --release -- \\
|
|
||||||
--prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
|
|
||||||
```
|
|
@ -1,277 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::Result;
|
|
||||||
use clap::Parser;
|
|
||||||
use std::io::Write;
|
|
||||||
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use candle_transformers::models::encodec;
|
|
||||||
use candle_transformers::models::metavoice::{adapters, gpt, tokenizers, transformer};
|
|
||||||
use candle_transformers::models::quantized_metavoice::transformer as qtransformer;
|
|
||||||
|
|
||||||
use candle::{DType, IndexOp, Tensor};
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use hf_hub::api::sync::Api;
|
|
||||||
use rand::{distributions::Distribution, SeedableRng};
|
|
||||||
|
|
||||||
pub const ENCODEC_NTOKENS: u32 = 1024;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
|
||||||
enum ArgDType {
|
|
||||||
F32,
|
|
||||||
F16,
|
|
||||||
Bf16,
|
|
||||||
}
|
|
||||||
|
|
||||||
enum Transformer {
|
|
||||||
Normal(transformer::Model),
|
|
||||||
Quantized(qtransformer::Model),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: String,
|
|
||||||
|
|
||||||
/// Use the quantized version of the model.
|
|
||||||
#[arg(long)]
|
|
||||||
quantized: bool,
|
|
||||||
|
|
||||||
/// The guidance scale.
|
|
||||||
#[arg(long, default_value_t = 3.0)]
|
|
||||||
guidance_scale: f64,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long, default_value_t = 1.0)]
|
|
||||||
temperature: f64,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
/// The maximum number of tokens to generate for the first stage.
|
|
||||||
#[arg(long, default_value_t = 2000)]
|
|
||||||
max_tokens: u64,
|
|
||||||
|
|
||||||
/// The output file using the wav format.
|
|
||||||
#[arg(long, default_value = "out.wav")]
|
|
||||||
out_file: String,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
first_stage_meta: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
first_stage_weights: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
second_stage_weights: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
encodec_weights: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
spk_emb: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "f32")]
|
|
||||||
dtype: ArgDType,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle::utils::with_avx(),
|
|
||||||
candle::utils::with_neon(),
|
|
||||||
candle::utils::with_simd128(),
|
|
||||||
candle::utils::with_f16c()
|
|
||||||
);
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let api = Api::new()?;
|
|
||||||
let repo = api.model("lmz/candle-metavoice".to_string());
|
|
||||||
let first_stage_meta = match &args.first_stage_meta {
|
|
||||||
Some(w) => std::path::PathBuf::from(w),
|
|
||||||
None => repo.get("first_stage.meta.json")?,
|
|
||||||
};
|
|
||||||
let first_stage_meta: serde_json::Value =
|
|
||||||
serde_json::from_reader(&std::fs::File::open(first_stage_meta)?)?;
|
|
||||||
let first_stage_tokenizer = match first_stage_meta.as_object() {
|
|
||||||
None => anyhow::bail!("not a json object"),
|
|
||||||
Some(j) => match j.get("tokenizer") {
|
|
||||||
None => anyhow::bail!("no tokenizer key"),
|
|
||||||
Some(j) => j,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?;
|
|
||||||
|
|
||||||
let second_stage_weights = match &args.second_stage_weights {
|
|
||||||
Some(w) => std::path::PathBuf::from(w),
|
|
||||||
None => repo.get("second_stage.safetensors")?,
|
|
||||||
};
|
|
||||||
let encodec_weights = match args.encodec_weights {
|
|
||||||
Some(w) => std::path::PathBuf::from(w),
|
|
||||||
None => Api::new()?
|
|
||||||
.model("facebook/encodec_24khz".to_string())
|
|
||||||
.get("model.safetensors")?,
|
|
||||||
};
|
|
||||||
let dtype = match args.dtype {
|
|
||||||
ArgDType::F32 => DType::F32,
|
|
||||||
ArgDType::F16 => DType::F16,
|
|
||||||
ArgDType::Bf16 => DType::BF16,
|
|
||||||
};
|
|
||||||
|
|
||||||
let first_stage_config = transformer::Config::cfg1b_v0_1();
|
|
||||||
let mut first_stage_model = if args.quantized {
|
|
||||||
let filename = match &args.first_stage_weights {
|
|
||||||
Some(w) => std::path::PathBuf::from(w),
|
|
||||||
None => repo.get("first_stage_q4k.gguf")?,
|
|
||||||
};
|
|
||||||
let vb =
|
|
||||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
|
||||||
let first_stage_model = qtransformer::Model::new(&first_stage_config, vb)?;
|
|
||||||
Transformer::Quantized(first_stage_model)
|
|
||||||
} else {
|
|
||||||
let first_stage_weights = match &args.first_stage_weights {
|
|
||||||
Some(w) => std::path::PathBuf::from(w),
|
|
||||||
None => repo.get("first_stage.safetensors")?,
|
|
||||||
};
|
|
||||||
let first_stage_vb =
|
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
|
|
||||||
let first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
|
|
||||||
Transformer::Normal(first_stage_model)
|
|
||||||
};
|
|
||||||
|
|
||||||
let second_stage_vb =
|
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? };
|
|
||||||
let second_stage_config = gpt::Config::cfg1b_v0_1();
|
|
||||||
let second_stage_model = gpt::Model::new(second_stage_config.clone(), second_stage_vb)?;
|
|
||||||
|
|
||||||
let encodec_device = if device.is_metal() {
|
|
||||||
&candle::Device::Cpu
|
|
||||||
} else {
|
|
||||||
&device
|
|
||||||
};
|
|
||||||
let encodec_vb =
|
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[encodec_weights], dtype, encodec_device)? };
|
|
||||||
let encodec_config = encodec::Config::default();
|
|
||||||
let encodec_model = encodec::Model::new(&encodec_config, encodec_vb)?;
|
|
||||||
|
|
||||||
println!("prompt: '{}'", args.prompt);
|
|
||||||
let prompt_tokens = fs_tokenizer.encode(&args.prompt)?;
|
|
||||||
let mut tokens = prompt_tokens.clone();
|
|
||||||
println!("{tokens:?}");
|
|
||||||
let spk_emb_file = match &args.spk_emb {
|
|
||||||
Some(w) => std::path::PathBuf::from(w),
|
|
||||||
None => repo.get("spk_emb.safetensors")?,
|
|
||||||
};
|
|
||||||
let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?;
|
|
||||||
let spk_emb = match spk_emb.get("spk_emb") {
|
|
||||||
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
|
|
||||||
Some(spk_emb) => spk_emb.to_dtype(dtype)?,
|
|
||||||
};
|
|
||||||
let spk_emb = spk_emb.to_device(&device)?;
|
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95));
|
|
||||||
|
|
||||||
// First stage generation.
|
|
||||||
for index in 0..args.max_tokens {
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
|
||||||
let ctxt = &tokens[start_pos..];
|
|
||||||
let input = Tensor::new(ctxt, &device)?;
|
|
||||||
let input = Tensor::stack(&[&input, &input], 0)?;
|
|
||||||
let logits = match &mut first_stage_model {
|
|
||||||
Transformer::Normal(m) => m.forward(&input, &spk_emb, tokens.len() - context_size)?,
|
|
||||||
Transformer::Quantized(m) => {
|
|
||||||
m.forward(&input, &spk_emb, tokens.len() - context_size)?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let logits0 = logits.i((0, 0))?;
|
|
||||||
let logits1 = logits.i((1, 0))?;
|
|
||||||
let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?;
|
|
||||||
let logits = logits.to_dtype(DType::F32)?;
|
|
||||||
let next_token = logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
print!(".");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
if next_token == 2048 {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
println!();
|
|
||||||
let fie2c = adapters::FlattenedInterleavedEncodec2Codebook::new(ENCODEC_NTOKENS);
|
|
||||||
let (text_ids, ids1, ids2) = fie2c.decode(&tokens);
|
|
||||||
println!("text ids len: {}", text_ids.len());
|
|
||||||
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed + 1337);
|
|
||||||
// TODO: Use the config rather than hardcoding the offset here.
|
|
||||||
let encoded_text: Vec<_> = prompt_tokens.iter().map(|v| v - 1024).collect();
|
|
||||||
let mut hierarchies_in1 =
|
|
||||||
[encoded_text.as_slice(), ids1.as_slice(), &[ENCODEC_NTOKENS]].concat();
|
|
||||||
let mut hierarchies_in2 = [
|
|
||||||
vec![ENCODEC_NTOKENS; encoded_text.len()].as_slice(),
|
|
||||||
ids2.as_slice(),
|
|
||||||
&[ENCODEC_NTOKENS],
|
|
||||||
]
|
|
||||||
.concat();
|
|
||||||
hierarchies_in1.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
|
|
||||||
hierarchies_in2.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
|
|
||||||
let in_x1 = Tensor::new(hierarchies_in1, &device)?;
|
|
||||||
let in_x2 = Tensor::new(hierarchies_in2, &device)?;
|
|
||||||
let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?;
|
|
||||||
let logits = second_stage_model.forward(&in_x)?;
|
|
||||||
println!("sampling from logits...");
|
|
||||||
let mut codes = vec![];
|
|
||||||
for logits in logits.iter() {
|
|
||||||
let logits = logits.squeeze(0)?;
|
|
||||||
let (seq_len, _) = logits.dims2()?;
|
|
||||||
let mut codes_ = Vec::with_capacity(seq_len);
|
|
||||||
for step in 0..seq_len {
|
|
||||||
let logits = logits.i(step)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = &(&logits / 1.0)?;
|
|
||||||
let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::<f32>()?;
|
|
||||||
let distr = rand::distributions::WeightedIndex::new(prs.as_slice())?;
|
|
||||||
let sample = distr.sample(&mut rng) as u32;
|
|
||||||
codes_.push(sample)
|
|
||||||
}
|
|
||||||
codes.push(codes_)
|
|
||||||
}
|
|
||||||
|
|
||||||
let codes = Tensor::new(codes, &device)?.unsqueeze(0)?;
|
|
||||||
let codes = Tensor::cat(&[in_x, codes], 1)?;
|
|
||||||
println!("codes: {codes}");
|
|
||||||
let tilted_encodec = adapters::TiltedEncodec::new(ENCODEC_NTOKENS);
|
|
||||||
let codes = codes.i(0)?.to_vec2::<u32>()?;
|
|
||||||
let (text_ids, audio_ids) = tilted_encodec.decode(&codes);
|
|
||||||
println!("text_ids len: {:?}", text_ids.len());
|
|
||||||
let audio_ids = Tensor::new(audio_ids, encodec_device)?.unsqueeze(0)?;
|
|
||||||
println!("audio_ids shape: {:?}", audio_ids.shape());
|
|
||||||
let pcm = encodec_model.decode(&audio_ids)?;
|
|
||||||
println!("output pcm shape: {:?}", pcm.shape());
|
|
||||||
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
|
|
||||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
|
||||||
let pcm = pcm.to_vec1::<f32>()?;
|
|
||||||
let mut output = std::fs::File::create(&args.out_file)?;
|
|
||||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -13,7 +13,7 @@ use candle_transformers::models::quantized_mistral::Model as QMistral;
|
|||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
@ -39,26 +39,11 @@ impl TextGeneration {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
temp: Option<f64>,
|
temp: Option<f64>,
|
||||||
top_p: Option<f64>,
|
top_p: Option<f64>,
|
||||||
top_k: Option<usize>,
|
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let logits_processor = {
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
let temperature = temp.unwrap_or(0.);
|
|
||||||
let sampling = if temperature <= 0. {
|
|
||||||
Sampling::ArgMax
|
|
||||||
} else {
|
|
||||||
match (top_k, top_p) {
|
|
||||||
(None, None) => Sampling::All { temperature },
|
|
||||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
|
||||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
|
||||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
|
||||||
}
|
|
||||||
};
|
|
||||||
LogitsProcessor::from_sampling(seed, sampling)
|
|
||||||
};
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
tokenizer: TokenOutputStream::new(tokenizer),
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
@ -137,18 +122,6 @@ impl TextGeneration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
#[value(name = "7b-v0.1")]
|
|
||||||
Mistral7bV01,
|
|
||||||
#[value(name = "7b-v0.2")]
|
|
||||||
Mistral7bV02,
|
|
||||||
#[value(name = "7b-instruct-v0.1")]
|
|
||||||
Mistral7bInstructV01,
|
|
||||||
#[value(name = "7b-instruct-v0.2")]
|
|
||||||
Mistral7bInstructV02,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -174,22 +147,14 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
top_p: Option<f64>,
|
top_p: Option<f64>,
|
||||||
|
|
||||||
/// Only sample among the top K samples.
|
|
||||||
#[arg(long)]
|
|
||||||
top_k: Option<usize>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
#[arg(long, short = 'n', default_value_t = 100)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
/// The model size to use.
|
|
||||||
#[arg(long, default_value = "7b-v0.1")]
|
|
||||||
which: Which,
|
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
|
|
||||||
@ -199,9 +164,6 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tokenizer_file: Option<String>,
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
config_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
weight_files: Option<String>,
|
weight_files: Option<String>,
|
||||||
|
|
||||||
@ -215,10 +177,6 @@ struct Args {
|
|||||||
/// The context size to consider for the repeat penalty.
|
/// The context size to consider for the repeat penalty.
|
||||||
#[arg(long, default_value_t = 64)]
|
#[arg(long, default_value_t = 64)]
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
|
|
||||||
/// Use the slower dmmv cuda kernel.
|
|
||||||
#[arg(long)]
|
|
||||||
force_dmmv: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -226,9 +184,6 @@ fn main() -> Result<()> {
|
|||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
#[cfg(feature = "cuda")]
|
|
||||||
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
|
|
||||||
|
|
||||||
let _guard = if args.tracing {
|
let _guard = if args.tracing {
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
@ -256,17 +211,9 @@ fn main() -> Result<()> {
|
|||||||
Some(model_id) => model_id,
|
Some(model_id) => model_id,
|
||||||
None => {
|
None => {
|
||||||
if args.quantized {
|
if args.quantized {
|
||||||
if args.which != Which::Mistral7bV01 {
|
|
||||||
anyhow::bail!("only 7b-v0.1 is available as a quantized model for now")
|
|
||||||
}
|
|
||||||
"lmz/candle-mistral".to_string()
|
"lmz/candle-mistral".to_string()
|
||||||
} else {
|
} else {
|
||||||
match args.which {
|
"mistralai/Mistral-7B-v0.1".to_string()
|
||||||
Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1".to_string(),
|
|
||||||
Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2".to_string(),
|
|
||||||
Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1".to_string(),
|
|
||||||
Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2".to_string(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -296,17 +243,7 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config = match args.config_file {
|
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
||||||
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
|
|
||||||
None => {
|
|
||||||
if args.quantized {
|
|
||||||
Config::config_7b_v0_1(args.use_flash_attn)
|
|
||||||
} else {
|
|
||||||
let config_file = repo.get("config.json")?;
|
|
||||||
serde_json::from_slice(&std::fs::read(config_file)?)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let (model, device) = if args.quantized {
|
let (model, device) = if args.quantized {
|
||||||
let filename = &filenames[0];
|
let filename = &filenames[0];
|
||||||
@ -333,7 +270,6 @@ fn main() -> Result<()> {
|
|||||||
args.seed,
|
args.seed,
|
||||||
args.temperature,
|
args.temperature,
|
||||||
args.top_p,
|
args.top_p,
|
||||||
args.top_k,
|
|
||||||
args.repeat_penalty,
|
args.repeat_penalty,
|
||||||
args.repeat_last_n,
|
args.repeat_last_n,
|
||||||
&device,
|
&device,
|
||||||
|
@ -143,7 +143,7 @@ struct Args {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
#[arg(long, short = 'n', default_value_t = 100)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]
|
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]
|
||||||
|
@ -1,22 +0,0 @@
|
|||||||
# candle-mobileone
|
|
||||||
|
|
||||||
[MobileOne: An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040).
|
|
||||||
|
|
||||||
This candle implementation uses a pre-trained MobileOne network for inference. The
|
|
||||||
classification head has been trained on the ImageNet dataset and returns the
|
|
||||||
probabilities for the top-5 classes.
|
|
||||||
|
|
||||||
## Running an example
|
|
||||||
|
|
||||||
```
|
|
||||||
$ cargo run --example mobileone --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which s2
|
|
||||||
|
|
||||||
loaded image Tensor[dims 3, 224, 224; f32]
|
|
||||||
model built
|
|
||||||
mountain bike, all-terrain bike, off-roader: 79.33%
|
|
||||||
bicycle-built-for-two, tandem bicycle, tandem: 15.32%
|
|
||||||
crash helmet : 2.58%
|
|
||||||
unicycle, monocycle : 1.70%
|
|
||||||
alp : 0.21%
|
|
||||||
|
|
||||||
```
|
|
@ -1,96 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
|
|
||||||
use candle::{DType, IndexOp, D};
|
|
||||||
use candle_nn::{Module, VarBuilder};
|
|
||||||
use candle_transformers::models::mobileone;
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
S0,
|
|
||||||
S1,
|
|
||||||
S2,
|
|
||||||
S3,
|
|
||||||
S4,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Which {
|
|
||||||
fn model_filename(&self) -> String {
|
|
||||||
let name = match self {
|
|
||||||
Self::S0 => "s0",
|
|
||||||
Self::S1 => "s1",
|
|
||||||
Self::S2 => "s2",
|
|
||||||
Self::S3 => "s3",
|
|
||||||
Self::S4 => "s4",
|
|
||||||
};
|
|
||||||
format!("timm/mobileone_{}.apple_in1k", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn config(&self) -> mobileone::Config {
|
|
||||||
match self {
|
|
||||||
Self::S0 => mobileone::Config::s0(),
|
|
||||||
Self::S1 => mobileone::Config::s1(),
|
|
||||||
Self::S2 => mobileone::Config::s2(),
|
|
||||||
Self::S3 => mobileone::Config::s3(),
|
|
||||||
Self::S4 => mobileone::Config::s4(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
|
||||||
struct Args {
|
|
||||||
#[arg(long)]
|
|
||||||
model: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
image: String,
|
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
#[arg(value_enum, long, default_value_t=Which::S0)]
|
|
||||||
which: Which,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
|
||||||
let args = Args::parse();
|
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
|
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
|
||||||
println!("loaded image {image:?}");
|
|
||||||
|
|
||||||
let model_file = match args.model {
|
|
||||||
None => {
|
|
||||||
let model_name = args.which.model_filename();
|
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
|
||||||
let api = api.model(model_name);
|
|
||||||
api.get("model.safetensors")?
|
|
||||||
}
|
|
||||||
Some(model) => model.into(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
|
||||||
let model = mobileone::mobileone(&args.which.config(), 1000, vb)?;
|
|
||||||
println!("model built");
|
|
||||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
|
||||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
|
||||||
.i(0)?
|
|
||||||
.to_vec1::<f32>()?;
|
|
||||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
|
||||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
|
||||||
for &(category_idx, pr) in prs.iter().take(5) {
|
|
||||||
println!(
|
|
||||||
"{:24}: {:.2}%",
|
|
||||||
candle_examples::imagenet::CLASSES[category_idx],
|
|
||||||
100. * pr
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,26 +0,0 @@
|
|||||||
# candle-moondream
|
|
||||||
|
|
||||||
[Moondream](https://github.com/vikhyat/moondream) is a computer-vision model can answer real-world questions about images. It's tiny by today's models, with only 1.6B parameters. That enables it to run on a variety of devices, including mobile phones and edge devices.
|
|
||||||
|
|
||||||
## Running some examples
|
|
||||||
First download an example image
|
|
||||||
```bash
|
|
||||||
$ wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg
|
|
||||||
```
|
|
||||||
|
|
||||||
<img src="https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg" width="200">
|
|
||||||
|
|
||||||
Now you can run Moondream from the `candle-examples` crate:
|
|
||||||
```bash
|
|
||||||
$ cargo run --example moondream --release -- --prompt "What is the girl eating?" --image "./demo-1.jpg"
|
|
||||||
|
|
||||||
avavx: false, neon: true, simd128: false, f16c: false
|
|
||||||
temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64
|
|
||||||
retrieved the files in 3.395583ms
|
|
||||||
Running on CPU, to run on GPU(metal), build this example with `--features metal`
|
|
||||||
loaded the model in 5.485493792s
|
|
||||||
loaded and encoded the image Tensor[dims 3, 378, 378; f32] in 4.801396417s
|
|
||||||
starting the inference loop
|
|
||||||
The girl is eating a hamburger.<
|
|
||||||
9 tokens generated (0.68 token/s)
|
|
||||||
```
|
|
@ -1,343 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use clap::Parser;
|
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::{
|
|
||||||
generation::LogitsProcessor,
|
|
||||||
models::{moondream, quantized_moondream},
|
|
||||||
};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
enum Model {
|
|
||||||
Moondream(moondream::Model),
|
|
||||||
Quantized(quantized_moondream::Model),
|
|
||||||
}
|
|
||||||
|
|
||||||
struct TextGeneration {
|
|
||||||
model: Model,
|
|
||||||
device: Device,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
logits_processor: LogitsProcessor,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
verbose_prompt: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextGeneration {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn new(
|
|
||||||
model: Model,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
temp: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
verbose_prompt: bool,
|
|
||||||
device: &Device,
|
|
||||||
) -> Self {
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
logits_processor,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n,
|
|
||||||
verbose_prompt,
|
|
||||||
device: device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(&mut self, prompt: &str, image_embeds: &Tensor, sample_len: usize) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
println!("starting the inference loop");
|
|
||||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
|
||||||
if tokens.is_empty() {
|
|
||||||
anyhow::bail!("Empty prompts are not supported in the Moondream model.")
|
|
||||||
}
|
|
||||||
if self.verbose_prompt {
|
|
||||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
|
||||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
|
||||||
println!("{id:7} -> '{token}'");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut tokens = tokens.get_ids().to_vec();
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
|
|
||||||
// Moondream tokenizer bos_token and eos_token is "<|endoftext|>"
|
|
||||||
// https://huggingface.co/vikhyatk/moondream2/blob/main/special_tokens_map.json
|
|
||||||
let special_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
|
||||||
Some(token) => *token,
|
|
||||||
None => anyhow::bail!("cannot find the special token"),
|
|
||||||
};
|
|
||||||
let (bos_token, eos_token) = (special_token, special_token);
|
|
||||||
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
let mut load_t = std::time::Duration::from_secs_f64(0f64);
|
|
||||||
for index in 0..sample_len {
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
||||||
let logits = if index > 0 {
|
|
||||||
match self.model {
|
|
||||||
Model::Moondream(ref mut model) => model.text_model.forward(&input)?,
|
|
||||||
Model::Quantized(ref mut model) => model.text_model.forward(&input)?,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let bos_token = Tensor::new(&[bos_token], &self.device)?.unsqueeze(0)?;
|
|
||||||
let logits = match self.model {
|
|
||||||
Model::Moondream(ref mut model) => {
|
|
||||||
model
|
|
||||||
.text_model
|
|
||||||
.forward_with_img(&bos_token, &input, image_embeds)?
|
|
||||||
}
|
|
||||||
Model::Quantized(ref mut model) => {
|
|
||||||
model
|
|
||||||
.text_model
|
|
||||||
.forward_with_img(&bos_token, &input, image_embeds)?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
load_t = start_gen.elapsed();
|
|
||||||
println!("load_t: {:?}", load_t);
|
|
||||||
logits
|
|
||||||
};
|
|
||||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
|
||||||
&logits,
|
|
||||||
self.repeat_penalty,
|
|
||||||
&tokens[start_at..],
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token || tokens.ends_with(&[27, 10619, 29] /* <END> */) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
|
||||||
print!("{token}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let dt = start_gen.elapsed() - load_t;
|
|
||||||
println!(
|
|
||||||
"\ngenerated in {} seconds\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
||||||
dt.as_secs_f64(),
|
|
||||||
(generated_tokens - 1) as f64 / dt.as_secs_f64()
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
/// Display the token for the specified prompt.
|
|
||||||
#[arg(long)]
|
|
||||||
verbose_prompt: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: String,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
image: String,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
temperature: Option<f64>,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 0)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
#[arg(long, default_value_t = 5000)]
|
|
||||||
sample_len: usize,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
||||||
#[arg(long, default_value_t = 1.0)]
|
|
||||||
repeat_penalty: f32,
|
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
|
||||||
#[arg(long, default_value_t = 64)]
|
|
||||||
repeat_last_n: usize,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "main")]
|
|
||||||
revision: String,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
quantized: bool,
|
|
||||||
|
|
||||||
/// Use f16 precision for all the computations rather than f32.
|
|
||||||
#[arg(long)]
|
|
||||||
f16: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer_file: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
|
||||||
/// (3, 378, 378).
|
|
||||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> candle::Result<Tensor> {
|
|
||||||
let img = image::io::Reader::open(p)?
|
|
||||||
.decode()
|
|
||||||
.map_err(candle::Error::wrap)?
|
|
||||||
.resize_to_fill(378, 378, image::imageops::FilterType::Triangle); // Adjusted to 378x378
|
|
||||||
let img = img.to_rgb8();
|
|
||||||
let data = img.into_raw();
|
|
||||||
let data = Tensor::from_vec(data, (378, 378, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
|
||||||
let mean = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
|
|
||||||
let std = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
|
|
||||||
(data.to_dtype(candle::DType::F32)? / 255.)?
|
|
||||||
.broadcast_sub(&mean)?
|
|
||||||
.broadcast_div(&std)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() -> anyhow::Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle::utils::with_avx(),
|
|
||||||
candle::utils::with_neon(),
|
|
||||||
candle::utils::with_simd128(),
|
|
||||||
candle::utils::with_f16c()
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
||||||
args.temperature.unwrap_or(0.),
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n
|
|
||||||
);
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let api = hf_hub::api::tokio::Api::new()?;
|
|
||||||
let model_id = match args.model_id {
|
|
||||||
Some(model_id) => model_id.to_string(),
|
|
||||||
None => {
|
|
||||||
if args.quantized {
|
|
||||||
"santiagomed/candle-moondream".to_string()
|
|
||||||
} else {
|
|
||||||
"vikhyatk/moondream2".to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let repo = api.repo(hf_hub::Repo::with_revision(
|
|
||||||
model_id,
|
|
||||||
hf_hub::RepoType::Model,
|
|
||||||
args.revision,
|
|
||||||
));
|
|
||||||
let model_file = match args.model_file {
|
|
||||||
Some(m) => m.into(),
|
|
||||||
None => {
|
|
||||||
if args.quantized {
|
|
||||||
repo.get("model-q4_0.gguf").await?
|
|
||||||
} else {
|
|
||||||
repo.get("model.safetensors").await?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let tokenizer = match args.tokenizer_file {
|
|
||||||
Some(m) => m.into(),
|
|
||||||
None => repo.get("tokenizer.json").await?,
|
|
||||||
};
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let config = moondream::Config::v2();
|
|
||||||
let dtype = if args.quantized {
|
|
||||||
if args.f16 {
|
|
||||||
anyhow::bail!("Quantized model does not support f16");
|
|
||||||
}
|
|
||||||
DType::F32
|
|
||||||
} else if device.is_cuda() || args.f16 {
|
|
||||||
DType::F16
|
|
||||||
} else {
|
|
||||||
DType::F32
|
|
||||||
};
|
|
||||||
let model = if args.quantized {
|
|
||||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
|
||||||
&model_file,
|
|
||||||
&device,
|
|
||||||
)?;
|
|
||||||
let model = quantized_moondream::Model::new(&config, vb)?;
|
|
||||||
Model::Quantized(model)
|
|
||||||
} else {
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
|
||||||
let model = moondream::Model::new(&config, vb)?;
|
|
||||||
Model::Moondream(model)
|
|
||||||
};
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let image = load_image(args.image)?
|
|
||||||
.to_device(&device)?
|
|
||||||
.to_dtype(dtype)?;
|
|
||||||
let image_embeds = image.unsqueeze(0)?;
|
|
||||||
let image_embeds = match model {
|
|
||||||
Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?,
|
|
||||||
Model::Quantized(ref m) => image_embeds.apply(m.vision_encoder())?,
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"loaded and encoded the image {image:?} in {:?}",
|
|
||||||
start.elapsed()
|
|
||||||
);
|
|
||||||
|
|
||||||
let prompt = format!("\n\nQuestion: {0}\n\nAnswer:", args.prompt);
|
|
||||||
let mut pipeline = TextGeneration::new(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
args.seed,
|
|
||||||
args.temperature,
|
|
||||||
args.top_p,
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n,
|
|
||||||
args.verbose_prompt,
|
|
||||||
&device,
|
|
||||||
);
|
|
||||||
pipeline.run(&prompt, &image_embeds, args.sample_len)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
580
candle-examples/examples/musicgen/encodec_model.rs
Normal file
580
candle-examples/examples/musicgen/encodec_model.rs
Normal file
@ -0,0 +1,580 @@
|
|||||||
|
use crate::nn::conv1d_weight_norm;
|
||||||
|
use candle::{DType, IndexOp, Module, Result, Tensor};
|
||||||
|
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
||||||
|
|
||||||
|
// Encodec Model
|
||||||
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
enum NormType {
|
||||||
|
WeightNorm,
|
||||||
|
TimeGroupNorm,
|
||||||
|
None,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub struct Config {
|
||||||
|
target_bandwidths: Vec<f64>,
|
||||||
|
sampling_rate: usize,
|
||||||
|
audio_channels: usize,
|
||||||
|
normalize: bool,
|
||||||
|
chunk_length_s: Option<usize>,
|
||||||
|
overlap: Option<usize>,
|
||||||
|
hidden_size: usize,
|
||||||
|
num_filters: usize,
|
||||||
|
num_residual_layers: usize,
|
||||||
|
upsampling_ratios: Vec<usize>,
|
||||||
|
norm_type: NormType,
|
||||||
|
kernel_size: usize,
|
||||||
|
last_kernel_size: usize,
|
||||||
|
residual_kernel_size: usize,
|
||||||
|
dilation_growth_rate: usize,
|
||||||
|
use_causal_conv: bool,
|
||||||
|
pad_mode: &'static str,
|
||||||
|
compress: usize,
|
||||||
|
num_lstm_layers: usize,
|
||||||
|
trim_right_ratio: f64,
|
||||||
|
codebook_size: usize,
|
||||||
|
codebook_dim: Option<usize>,
|
||||||
|
use_conv_shortcut: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Config {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
|
||||||
|
sampling_rate: 24_000,
|
||||||
|
audio_channels: 1,
|
||||||
|
normalize: false,
|
||||||
|
chunk_length_s: None,
|
||||||
|
overlap: None,
|
||||||
|
hidden_size: 128,
|
||||||
|
num_filters: 32,
|
||||||
|
num_residual_layers: 1,
|
||||||
|
upsampling_ratios: vec![8, 5, 4, 2],
|
||||||
|
norm_type: NormType::WeightNorm,
|
||||||
|
kernel_size: 7,
|
||||||
|
last_kernel_size: 7,
|
||||||
|
residual_kernel_size: 3,
|
||||||
|
dilation_growth_rate: 2,
|
||||||
|
use_causal_conv: true,
|
||||||
|
pad_mode: "reflect",
|
||||||
|
compress: 2,
|
||||||
|
num_lstm_layers: 2,
|
||||||
|
trim_right_ratio: 1.0,
|
||||||
|
codebook_size: 1024,
|
||||||
|
codebook_dim: None,
|
||||||
|
use_conv_shortcut: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
|
||||||
|
pub fn musicgen_small() -> Self {
|
||||||
|
Self {
|
||||||
|
audio_channels: 1,
|
||||||
|
chunk_length_s: None,
|
||||||
|
codebook_dim: Some(128),
|
||||||
|
codebook_size: 2048,
|
||||||
|
compress: 2,
|
||||||
|
dilation_growth_rate: 2,
|
||||||
|
hidden_size: 128,
|
||||||
|
kernel_size: 7,
|
||||||
|
last_kernel_size: 7,
|
||||||
|
norm_type: NormType::WeightNorm,
|
||||||
|
normalize: false,
|
||||||
|
num_filters: 64,
|
||||||
|
num_lstm_layers: 2,
|
||||||
|
num_residual_layers: 1,
|
||||||
|
overlap: None,
|
||||||
|
pad_mode: "reflect",
|
||||||
|
residual_kernel_size: 3,
|
||||||
|
sampling_rate: 32_000,
|
||||||
|
target_bandwidths: vec![2.2],
|
||||||
|
trim_right_ratio: 1.0,
|
||||||
|
upsampling_ratios: vec![8, 5, 4, 4],
|
||||||
|
use_causal_conv: false,
|
||||||
|
use_conv_shortcut: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn codebook_dim(&self) -> usize {
|
||||||
|
self.codebook_dim.unwrap_or(self.codebook_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn frame_rate(&self) -> usize {
|
||||||
|
let hop_length: usize = self.upsampling_ratios.iter().product();
|
||||||
|
(self.sampling_rate + hop_length - 1) / hop_length
|
||||||
|
}
|
||||||
|
|
||||||
|
fn num_quantizers(&self) -> usize {
|
||||||
|
let num = 1000f64
|
||||||
|
* self
|
||||||
|
.target_bandwidths
|
||||||
|
.last()
|
||||||
|
.expect("empty target_bandwidths");
|
||||||
|
(num as usize) / (self.frame_rate() * 10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct EncodecEuclideanCodebook {
|
||||||
|
inited: Tensor,
|
||||||
|
cluster_size: Tensor,
|
||||||
|
embed: Tensor,
|
||||||
|
embed_avg: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncodecEuclideanCodebook {
|
||||||
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let inited = vb.get(1, "inited")?;
|
||||||
|
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
|
||||||
|
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
|
||||||
|
let embed = vb.get(e_shape, "embed")?;
|
||||||
|
let embed_avg = vb.get(e_shape, "embed_avg")?;
|
||||||
|
Ok(Self {
|
||||||
|
inited,
|
||||||
|
cluster_size,
|
||||||
|
embed,
|
||||||
|
embed_avg,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||||
|
let quantize = self.embed.embedding(embed_ind)?;
|
||||||
|
Ok(quantize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct EncodecVectorQuantization {
|
||||||
|
codebook: EncodecEuclideanCodebook,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncodecVectorQuantization {
|
||||||
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?;
|
||||||
|
Ok(Self { codebook })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||||
|
let quantize = self.codebook.decode(embed_ind)?;
|
||||||
|
let quantize = quantize.transpose(1, 2)?;
|
||||||
|
Ok(quantize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct EncodecResidualVectorQuantizer {
|
||||||
|
layers: Vec<EncodecVectorQuantization>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncodecResidualVectorQuantizer {
|
||||||
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let vb = &vb.pp("layers");
|
||||||
|
let layers = (0..cfg.num_quantizers())
|
||||||
|
.map(|i| EncodecVectorQuantization::load(vb.pp(&i.to_string()), cfg))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
Ok(Self { layers })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
|
||||||
|
if codes.dim(0)? != self.layers.len() {
|
||||||
|
candle::bail!(
|
||||||
|
"codes shape {:?} does not match the number of quantization layers {}",
|
||||||
|
codes.shape(),
|
||||||
|
self.layers.len()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
for (i, layer) in self.layers.iter().enumerate() {
|
||||||
|
let quantized = layer.decode(&codes.i(i)?)?;
|
||||||
|
quantized_out = quantized.broadcast_add(&quantized_out)?;
|
||||||
|
}
|
||||||
|
Ok(quantized_out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct EncodecLSTM {
|
||||||
|
layers: Vec<candle_nn::LSTM>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncodecLSTM {
|
||||||
|
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let vb = &vb.pp("lstm");
|
||||||
|
let mut layers = vec![];
|
||||||
|
for layer_idx in 0..cfg.num_lstm_layers {
|
||||||
|
let config = candle_nn::LSTMConfig {
|
||||||
|
layer_idx,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
|
||||||
|
layers.push(lstm)
|
||||||
|
}
|
||||||
|
Ok(Self { layers })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for EncodecLSTM {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
use candle_nn::RNN;
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
for layer in self.layers.iter() {
|
||||||
|
let states = layer.seq(&xs)?;
|
||||||
|
xs = layer.states_to_tensor(&states)?;
|
||||||
|
}
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct EncodecConvTranspose1d {
|
||||||
|
weight_g: Tensor,
|
||||||
|
weight_v: Tensor,
|
||||||
|
bias: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncodecConvTranspose1d {
|
||||||
|
fn load(
|
||||||
|
in_c: usize,
|
||||||
|
out_c: usize,
|
||||||
|
k: usize,
|
||||||
|
_stride: usize,
|
||||||
|
vb: VarBuilder,
|
||||||
|
_cfg: &Config,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let vb = &vb.pp("conv");
|
||||||
|
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
|
||||||
|
let weight_v = vb.get((in_c, out_c, k), "weight_v")?;
|
||||||
|
let bias = vb.get(out_c, "bias")?;
|
||||||
|
Ok(Self {
|
||||||
|
weight_g,
|
||||||
|
weight_v,
|
||||||
|
bias,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for EncodecConvTranspose1d {
|
||||||
|
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct EncodecConv1d {
|
||||||
|
causal: bool,
|
||||||
|
conv: Conv1d,
|
||||||
|
norm: Option<candle_nn::GroupNorm>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncodecConv1d {
|
||||||
|
fn load(
|
||||||
|
in_c: usize,
|
||||||
|
out_c: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
stride: usize,
|
||||||
|
vb: VarBuilder,
|
||||||
|
cfg: &Config,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let conv = match cfg.norm_type {
|
||||||
|
NormType::WeightNorm => conv1d_weight_norm(
|
||||||
|
in_c,
|
||||||
|
out_c,
|
||||||
|
kernel_size,
|
||||||
|
Conv1dConfig {
|
||||||
|
padding: 0,
|
||||||
|
stride,
|
||||||
|
groups: 1,
|
||||||
|
dilation: 1,
|
||||||
|
},
|
||||||
|
vb.pp("conv"),
|
||||||
|
)?,
|
||||||
|
NormType::None | NormType::TimeGroupNorm => conv1d(
|
||||||
|
in_c,
|
||||||
|
out_c,
|
||||||
|
kernel_size,
|
||||||
|
Conv1dConfig {
|
||||||
|
padding: 0,
|
||||||
|
stride,
|
||||||
|
groups: 1,
|
||||||
|
dilation: 1,
|
||||||
|
},
|
||||||
|
vb.pp("conv"),
|
||||||
|
)?,
|
||||||
|
};
|
||||||
|
let norm = match cfg.norm_type {
|
||||||
|
NormType::None | NormType::WeightNorm => None,
|
||||||
|
NormType::TimeGroupNorm => {
|
||||||
|
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||||
|
Some(gn)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
causal: cfg.use_causal_conv,
|
||||||
|
conv,
|
||||||
|
norm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for EncodecConv1d {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
// TODO: padding, depending on causal.
|
||||||
|
let xs = self.conv.forward(xs)?;
|
||||||
|
match &self.norm {
|
||||||
|
None => Ok(xs),
|
||||||
|
Some(norm) => xs.apply(norm),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct EncodecResnetBlock {
|
||||||
|
block_conv1: EncodecConv1d,
|
||||||
|
block_conv2: EncodecConv1d,
|
||||||
|
shortcut: Option<EncodecConv1d>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncodecResnetBlock {
|
||||||
|
fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let h = dim / cfg.compress;
|
||||||
|
let mut layer = Layer::new(vb.pp("block"));
|
||||||
|
if dilations.len() != 2 {
|
||||||
|
candle::bail!("expected dilations of size 2")
|
||||||
|
}
|
||||||
|
// TODO: Apply dilations!
|
||||||
|
layer.inc();
|
||||||
|
let block_conv1 =
|
||||||
|
EncodecConv1d::load(dim, h, cfg.residual_kernel_size, 1, layer.next(), cfg)?;
|
||||||
|
layer.inc();
|
||||||
|
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, layer.next(), cfg)?;
|
||||||
|
let shortcut = if cfg.use_conv_shortcut {
|
||||||
|
let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?;
|
||||||
|
Some(conv)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
block_conv1,
|
||||||
|
block_conv2,
|
||||||
|
shortcut,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for EncodecResnetBlock {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let residual = xs.clone();
|
||||||
|
let xs = xs.elu(1.)?;
|
||||||
|
let xs = self.block_conv1.forward(&xs)?;
|
||||||
|
let xs = xs.elu(1.)?;
|
||||||
|
let xs = self.block_conv2.forward(&xs)?;
|
||||||
|
let xs = match &self.shortcut {
|
||||||
|
None => (xs + residual)?,
|
||||||
|
Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
|
||||||
|
};
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Layer<'a> {
|
||||||
|
vb: VarBuilder<'a>,
|
||||||
|
cnt: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Layer<'a> {
|
||||||
|
fn new(vb: VarBuilder<'a>) -> Self {
|
||||||
|
Self { vb, cnt: 0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn inc(&mut self) {
|
||||||
|
self.cnt += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn next(&mut self) -> VarBuilder {
|
||||||
|
let vb = self.vb.pp(&self.cnt.to_string());
|
||||||
|
self.cnt += 1;
|
||||||
|
vb
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct EncodecEncoder {
|
||||||
|
init_conv: EncodecConv1d,
|
||||||
|
sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
|
||||||
|
final_lstm: EncodecLSTM,
|
||||||
|
final_conv: EncodecConv1d,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncodecEncoder {
|
||||||
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let mut layer = Layer::new(vb.pp("layers"));
|
||||||
|
let init_conv = EncodecConv1d::load(
|
||||||
|
cfg.audio_channels,
|
||||||
|
cfg.num_filters,
|
||||||
|
cfg.kernel_size,
|
||||||
|
1,
|
||||||
|
layer.next(),
|
||||||
|
cfg,
|
||||||
|
)?;
|
||||||
|
let mut sampling_layers = vec![];
|
||||||
|
let mut scaling = 1;
|
||||||
|
for &ratio in cfg.upsampling_ratios.iter().rev() {
|
||||||
|
let current_scale = scaling * cfg.num_filters;
|
||||||
|
let mut resnets = vec![];
|
||||||
|
for j in 0..(cfg.num_residual_layers as u32) {
|
||||||
|
let resnet = EncodecResnetBlock::load(
|
||||||
|
current_scale,
|
||||||
|
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||||
|
layer.next(),
|
||||||
|
cfg,
|
||||||
|
)?;
|
||||||
|
resnets.push(resnet)
|
||||||
|
}
|
||||||
|
layer.inc(); // ELU
|
||||||
|
let conv1d = EncodecConv1d::load(
|
||||||
|
current_scale,
|
||||||
|
current_scale * 2,
|
||||||
|
ratio * 2,
|
||||||
|
ratio,
|
||||||
|
layer.next(),
|
||||||
|
cfg,
|
||||||
|
)?;
|
||||||
|
sampling_layers.push((resnets, conv1d));
|
||||||
|
scaling *= 2;
|
||||||
|
}
|
||||||
|
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
|
||||||
|
layer.inc(); // ELU
|
||||||
|
let final_conv = EncodecConv1d::load(
|
||||||
|
cfg.num_filters * scaling,
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.last_kernel_size,
|
||||||
|
1,
|
||||||
|
layer.next(),
|
||||||
|
cfg,
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
init_conv,
|
||||||
|
sampling_layers,
|
||||||
|
final_conv,
|
||||||
|
final_lstm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut xs = xs.apply(&self.init_conv)?;
|
||||||
|
for (resnets, conv) in self.sampling_layers.iter() {
|
||||||
|
for resnet in resnets.iter() {
|
||||||
|
xs = xs.apply(resnet)?;
|
||||||
|
}
|
||||||
|
xs = xs.elu(1.0)?.apply(conv)?;
|
||||||
|
}
|
||||||
|
xs.apply(&self.final_lstm)?
|
||||||
|
.elu(1.0)?
|
||||||
|
.apply(&self.final_conv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct EncodecDecoder {
|
||||||
|
init_conv: EncodecConv1d,
|
||||||
|
init_lstm: EncodecLSTM,
|
||||||
|
sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
|
||||||
|
final_conv: EncodecConv1d,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncodecDecoder {
|
||||||
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let mut layer = Layer::new(vb.pp("layers"));
|
||||||
|
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
|
||||||
|
let init_conv = EncodecConv1d::load(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.num_filters * scaling,
|
||||||
|
cfg.last_kernel_size,
|
||||||
|
1,
|
||||||
|
layer.next(),
|
||||||
|
cfg,
|
||||||
|
)?;
|
||||||
|
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
|
||||||
|
let mut sampling_layers = vec![];
|
||||||
|
for &ratio in cfg.upsampling_ratios.iter() {
|
||||||
|
let current_scale = scaling * cfg.num_filters;
|
||||||
|
layer.inc(); // ELU
|
||||||
|
let conv1d = EncodecConvTranspose1d::load(
|
||||||
|
current_scale,
|
||||||
|
current_scale / 2,
|
||||||
|
ratio * 2,
|
||||||
|
ratio,
|
||||||
|
layer.next(),
|
||||||
|
cfg,
|
||||||
|
)?;
|
||||||
|
let mut resnets = vec![];
|
||||||
|
for j in 0..(cfg.num_residual_layers as u32) {
|
||||||
|
let resnet = EncodecResnetBlock::load(
|
||||||
|
current_scale / 2,
|
||||||
|
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||||
|
layer.next(),
|
||||||
|
cfg,
|
||||||
|
)?;
|
||||||
|
resnets.push(resnet)
|
||||||
|
}
|
||||||
|
sampling_layers.push((conv1d, resnets));
|
||||||
|
scaling /= 2;
|
||||||
|
}
|
||||||
|
layer.inc(); // ELU
|
||||||
|
let final_conv = EncodecConv1d::load(
|
||||||
|
cfg.num_filters,
|
||||||
|
cfg.audio_channels,
|
||||||
|
cfg.last_kernel_size,
|
||||||
|
1,
|
||||||
|
layer.next(),
|
||||||
|
cfg,
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
init_conv,
|
||||||
|
init_lstm,
|
||||||
|
sampling_layers,
|
||||||
|
final_conv,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
|
||||||
|
for (conv, resnets) in self.sampling_layers.iter() {
|
||||||
|
xs = xs.elu(1.)?.apply(conv)?;
|
||||||
|
for resnet in resnets.iter() {
|
||||||
|
xs = xs.apply(resnet)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
xs.elu(1.)?.apply(&self.final_conv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct EncodecModel {
|
||||||
|
encoder: EncodecEncoder,
|
||||||
|
decoder: EncodecDecoder,
|
||||||
|
quantizer: EncodecResidualVectorQuantizer,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncodecModel {
|
||||||
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let encoder = EncodecEncoder::load(vb.pp("encoder"), cfg)?;
|
||||||
|
let decoder = EncodecDecoder::load(vb.pp("decoder"), cfg)?;
|
||||||
|
let quantizer = EncodecResidualVectorQuantizer::load(vb.pp("quantizer"), cfg)?;
|
||||||
|
Ok(Self {
|
||||||
|
encoder,
|
||||||
|
decoder,
|
||||||
|
quantizer,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
@ -10,7 +10,9 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
mod encodec_model;
|
||||||
mod musicgen_model;
|
mod musicgen_model;
|
||||||
|
mod nn;
|
||||||
|
|
||||||
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||||
|
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
|
use crate::encodec_model;
|
||||||
use candle::{DType, Device, Result, Tensor, D};
|
use candle::{DType, Device, Result, Tensor, D};
|
||||||
use candle_nn::{
|
use candle_nn::{
|
||||||
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
|
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
|
||||||
VarBuilder,
|
VarBuilder,
|
||||||
};
|
};
|
||||||
use candle_transformers::models::{encodec, t5};
|
use candle_transformers::models::t5;
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
@ -371,7 +372,7 @@ impl MusicgenForCausalLM {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct MusicgenForConditionalGeneration {
|
pub struct MusicgenForConditionalGeneration {
|
||||||
pub text_encoder: t5::T5EncoderModel,
|
pub text_encoder: t5::T5EncoderModel,
|
||||||
pub audio_encoder: encodec::Model,
|
pub audio_encoder: crate::encodec_model::EncodecModel,
|
||||||
pub decoder: MusicgenForCausalLM,
|
pub decoder: MusicgenForCausalLM,
|
||||||
cfg: GenConfig,
|
cfg: GenConfig,
|
||||||
}
|
}
|
||||||
@ -380,42 +381,15 @@ pub struct MusicgenForConditionalGeneration {
|
|||||||
pub struct GenConfig {
|
pub struct GenConfig {
|
||||||
musicgen: Config,
|
musicgen: Config,
|
||||||
t5: t5::Config,
|
t5: t5::Config,
|
||||||
encodec: encodec::Config,
|
encodec: crate::encodec_model::Config,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GenConfig {
|
impl GenConfig {
|
||||||
pub fn small() -> Self {
|
pub fn small() -> Self {
|
||||||
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
|
|
||||||
let encodec = encodec::Config {
|
|
||||||
audio_channels: 1,
|
|
||||||
chunk_length_s: None,
|
|
||||||
codebook_dim: Some(128),
|
|
||||||
codebook_size: 2048,
|
|
||||||
compress: 2,
|
|
||||||
dilation_growth_rate: 2,
|
|
||||||
hidden_size: 128,
|
|
||||||
kernel_size: 7,
|
|
||||||
last_kernel_size: 7,
|
|
||||||
norm_type: encodec::NormType::WeightNorm,
|
|
||||||
normalize: false,
|
|
||||||
num_filters: 64,
|
|
||||||
num_lstm_layers: 2,
|
|
||||||
num_residual_layers: 1,
|
|
||||||
overlap: None,
|
|
||||||
// This should be Reflect and not Replicate but Reflect does not work yet.
|
|
||||||
pad_mode: encodec::PadMode::Replicate,
|
|
||||||
residual_kernel_size: 3,
|
|
||||||
sampling_rate: 32_000,
|
|
||||||
target_bandwidths: vec![2.2],
|
|
||||||
trim_right_ratio: 1.0,
|
|
||||||
upsampling_ratios: vec![8, 5, 4, 4],
|
|
||||||
use_causal_conv: false,
|
|
||||||
use_conv_shortcut: false,
|
|
||||||
};
|
|
||||||
Self {
|
Self {
|
||||||
musicgen: Config::musicgen_small(),
|
musicgen: Config::musicgen_small(),
|
||||||
t5: t5::Config::musicgen_small(),
|
t5: t5::Config::musicgen_small(),
|
||||||
encodec,
|
encodec: encodec_model::Config::musicgen_small(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -427,7 +401,8 @@ impl MusicgenForConditionalGeneration {
|
|||||||
|
|
||||||
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
|
||||||
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
|
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
|
||||||
let audio_encoder = encodec::Model::new(&cfg.encodec, vb.pp("audio_encoder"))?;
|
let audio_encoder =
|
||||||
|
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
|
||||||
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
|
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
text_encoder,
|
text_encoder,
|
||||||
|
20
candle-examples/examples/musicgen/nn.rs
Normal file
20
candle-examples/examples/musicgen/nn.rs
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
use candle::Result;
|
||||||
|
use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
|
||||||
|
|
||||||
|
// Applies weight norm for inference by recomputing the weight tensor. This
|
||||||
|
// does not apply to training.
|
||||||
|
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
|
||||||
|
pub fn conv1d_weight_norm(
|
||||||
|
in_c: usize,
|
||||||
|
out_c: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
config: Conv1dConfig,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Conv1d> {
|
||||||
|
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||||
|
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||||
|
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||||
|
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||||
|
let bias = vb.get(out_c, "bias")?;
|
||||||
|
Ok(Conv1d::new(weight, Some(bias), config))
|
||||||
|
}
|
@ -1,39 +1,10 @@
|
|||||||
## Using ONNX models in Candle
|
## Using ONNX models in Candle
|
||||||
|
|
||||||
This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based models in Candle.
|
This example demonstrates how to run ONNX based models in Candle, the model
|
||||||
|
being used here is a small sequeezenet variant.
|
||||||
|
|
||||||
It contains small variants of two models, [SqueezeNet](https://arxiv.org/pdf/1602.07360.pdf) (default) and [EfficientNet](https://arxiv.org/pdf/1905.11946.pdf).
|
You can run the example with the following command:
|
||||||
|
|
||||||
You can run the examples with following commands:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --example onnx --features=onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
```
|
|
||||||
|
|
||||||
Use the `--which` flag to specify explicitly which network to use, i.e.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ cargo run --example onnx --features=onnx --release -- --which squeeze-net --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
|
||||||
|
|
||||||
Finished release [optimized] target(s) in 0.21s
|
|
||||||
Running `target/release/examples/onnx --which squeeze-net --image candle-examples/examples/yolo-v8/assets/bike.jpg`
|
|
||||||
loaded image Tensor[dims 3, 224, 224; f32]
|
|
||||||
unicycle, monocycle : 83.23%
|
|
||||||
ballplayer, baseball player : 3.68%
|
|
||||||
bearskin, busby, shako : 1.54%
|
|
||||||
military uniform : 0.78%
|
|
||||||
cowboy hat, ten-gallon hat : 0.76%
|
|
||||||
```
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ cargo run --example onnx --features=onnx --release -- --which efficient-net --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
|
||||||
|
|
||||||
Finished release [optimized] target(s) in 0.20s
|
|
||||||
Running `target/release/examples/onnx --which efficient-net --image candle-examples/examples/yolo-v8/assets/bike.jpg`
|
|
||||||
loaded image Tensor[dims 224, 224, 3; f32]
|
|
||||||
bicycle-built-for-two, tandem bicycle, tandem : 99.16%
|
|
||||||
mountain bike, all-terrain bike, off-roader : 0.60%
|
|
||||||
unicycle, monocycle : 0.17%
|
|
||||||
crash helmet : 0.02%
|
|
||||||
alp : 0.02%
|
|
||||||
```
|
```
|
||||||
|
@ -17,7 +17,7 @@ generate quantized weight files from the original safetensors file by using the
|
|||||||
`tensor-tools` command line utility via:
|
`tensor-tools` command line utility via:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --bin tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||||
```
|
```
|
||||||
|
|
||||||
## Using custom models
|
## Using custom models
|
||||||
|
@ -10,7 +10,7 @@ use tokenizers::Tokenizer;
|
|||||||
|
|
||||||
use candle::quantized::{ggml_file, gguf_file};
|
use candle::quantized::{ggml_file, gguf_file};
|
||||||
use candle::Tensor;
|
use candle::Tensor;
|
||||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
use candle_transformers::models::quantized_llama as model;
|
use candle_transformers::models::quantized_llama as model;
|
||||||
@ -67,8 +67,6 @@ enum Which {
|
|||||||
Mixtral,
|
Mixtral,
|
||||||
#[value(name = "mixtral-instruct")]
|
#[value(name = "mixtral-instruct")]
|
||||||
MixtralInstruct,
|
MixtralInstruct,
|
||||||
#[value(name = "llama3-8b")]
|
|
||||||
L8b,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Which {
|
impl Which {
|
||||||
@ -84,8 +82,7 @@ impl Which {
|
|||||||
| Self::L13bCode
|
| Self::L13bCode
|
||||||
| Self::L34bCode
|
| Self::L34bCode
|
||||||
| Self::Leo7b
|
| Self::Leo7b
|
||||||
| Self::Leo13b
|
| Self::Leo13b => false,
|
||||||
| Self::L8b => false,
|
|
||||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||||
// same way. Starling is a fine tuned version of OpenChat.
|
// same way. Starling is a fine tuned version of OpenChat.
|
||||||
Self::OpenChat35
|
Self::OpenChat35
|
||||||
@ -119,8 +116,7 @@ impl Which {
|
|||||||
| Self::Mistral7bInstruct
|
| Self::Mistral7bInstruct
|
||||||
| Self::Mistral7bInstructV02
|
| Self::Mistral7bInstructV02
|
||||||
| Self::OpenChat35
|
| Self::OpenChat35
|
||||||
| Self::Starling7bAlpha
|
| Self::Starling7bAlpha => false,
|
||||||
| Self::L8b => false,
|
|
||||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -144,8 +140,7 @@ impl Which {
|
|||||||
| Self::Mistral7bInstruct
|
| Self::Mistral7bInstruct
|
||||||
| Self::Mistral7bInstructV02
|
| Self::Mistral7bInstructV02
|
||||||
| Self::Zephyr7bAlpha
|
| Self::Zephyr7bAlpha
|
||||||
| Self::Zephyr7bBeta
|
| Self::Zephyr7bBeta => false,
|
||||||
| Self::L8b => false,
|
|
||||||
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -172,7 +167,6 @@ impl Which {
|
|||||||
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
||||||
Which::OpenChat35 => "openchat/openchat_3.5",
|
Which::OpenChat35 => "openchat/openchat_3.5",
|
||||||
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
||||||
Self::L8b => "meta-llama/Meta-Llama-3-8B",
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -206,10 +200,6 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
top_p: Option<f64>,
|
top_p: Option<f64>,
|
||||||
|
|
||||||
/// Only sample among the top K samples.
|
|
||||||
#[arg(long)]
|
|
||||||
top_k: Option<usize>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
@ -222,14 +212,6 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
verbose_prompt: bool,
|
verbose_prompt: bool,
|
||||||
|
|
||||||
/// Process prompt elements separately.
|
|
||||||
#[arg(long)]
|
|
||||||
split_prompt: bool,
|
|
||||||
|
|
||||||
/// Run on CPU rather than GPU even if a GPU is available.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
#[arg(long, default_value_t = 1.1)]
|
#[arg(long, default_value_t = 1.1)]
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
@ -245,10 +227,6 @@ struct Args {
|
|||||||
/// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
|
/// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
gqa: Option<usize>,
|
gqa: Option<usize>,
|
||||||
|
|
||||||
/// Use the slower dmmv cuda kernel.
|
|
||||||
#[arg(long)]
|
|
||||||
force_dmmv: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
@ -328,11 +306,6 @@ impl Args {
|
|||||||
"TheBloke/Starling-LM-7B-alpha-GGUF",
|
"TheBloke/Starling-LM-7B-alpha-GGUF",
|
||||||
"starling-lm-7b-alpha.Q4_K_M.gguf",
|
"starling-lm-7b-alpha.Q4_K_M.gguf",
|
||||||
),
|
),
|
||||||
// TODO: swap to TheBloke model when available
|
|
||||||
Which::L8b => (
|
|
||||||
"QuantFactory/Meta-Llama-3-8B-GGUF",
|
|
||||||
"Meta-Llama-3-8B.Q4_K_S.gguf",
|
|
||||||
),
|
|
||||||
};
|
};
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model(repo.to_string());
|
let api = api.model(repo.to_string());
|
||||||
@ -360,10 +333,11 @@ fn main() -> anyhow::Result<()> {
|
|||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
let temperature = if args.temperature == 0. {
|
||||||
#[cfg(feature = "cuda")]
|
None
|
||||||
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
|
} else {
|
||||||
|
Some(args.temperature)
|
||||||
|
};
|
||||||
let _guard = if args.tracing {
|
let _guard = if args.tracing {
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
@ -387,7 +361,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let model_path = args.model()?;
|
let model_path = args.model()?;
|
||||||
let mut file = std::fs::File::open(&model_path)?;
|
let mut file = std::fs::File::open(&model_path)?;
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(false)?;
|
||||||
|
|
||||||
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
|
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
|
||||||
Some("gguf") => {
|
Some("gguf") => {
|
||||||
@ -431,8 +405,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
| Which::L13bCode
|
| Which::L13bCode
|
||||||
| Which::L34bCode
|
| Which::L34bCode
|
||||||
| Which::Leo7b
|
| Which::Leo7b
|
||||||
| Which::Leo13b
|
| Which::Leo13b => 1,
|
||||||
| Which::L8b => 1,
|
|
||||||
Which::Mixtral
|
Which::Mixtral
|
||||||
| Which::MixtralInstruct
|
| Which::MixtralInstruct
|
||||||
| Which::Mistral7b
|
| Which::Mistral7b
|
||||||
@ -511,36 +484,14 @@ fn main() -> anyhow::Result<()> {
|
|||||||
prompt_tokens
|
prompt_tokens
|
||||||
};
|
};
|
||||||
let mut all_tokens = vec![];
|
let mut all_tokens = vec![];
|
||||||
let mut logits_processor = {
|
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
||||||
let temperature = args.temperature;
|
|
||||||
let sampling = if temperature <= 0. {
|
|
||||||
Sampling::ArgMax
|
|
||||||
} else {
|
|
||||||
match (args.top_k, args.top_p) {
|
|
||||||
(None, None) => Sampling::All { temperature },
|
|
||||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
|
||||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
|
||||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
|
||||||
}
|
|
||||||
};
|
|
||||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
|
||||||
};
|
|
||||||
|
|
||||||
let start_prompt_processing = std::time::Instant::now();
|
let start_prompt_processing = std::time::Instant::now();
|
||||||
let mut next_token = if !args.split_prompt {
|
let mut next_token = {
|
||||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, 0)?;
|
let logits = model.forward(&input, 0)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
logits_processor.sample(&logits)?
|
logits_processor.sample(&logits)?
|
||||||
} else {
|
|
||||||
let mut next_token = 0;
|
|
||||||
for (pos, token) in prompt_tokens.iter().enumerate() {
|
|
||||||
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
|
||||||
let logits = model.forward(&input, pos)?;
|
|
||||||
let logits = logits.squeeze(0)?;
|
|
||||||
next_token = logits_processor.sample(&logits)?
|
|
||||||
}
|
|
||||||
next_token
|
|
||||||
};
|
};
|
||||||
let prompt_dt = start_prompt_processing.elapsed();
|
let prompt_dt = start_prompt_processing.elapsed();
|
||||||
all_tokens.push(next_token);
|
all_tokens.push(next_token);
|
||||||
@ -549,14 +500,11 @@ fn main() -> anyhow::Result<()> {
|
|||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let eos_token = match args.which {
|
let eos_token = if args.which.is_open_chat() {
|
||||||
Which::L8b => "<|end_of_text|>",
|
"<|end_of_turn|>"
|
||||||
_ => match args.which.is_open_chat() {
|
} else {
|
||||||
true => "<|end_of_turn|>",
|
"</s>"
|
||||||
false => "</s>",
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
||||||
let start_post_prompt = std::time::Instant::now();
|
let start_post_prompt = std::time::Instant::now();
|
||||||
let mut sampled = 0;
|
let mut sampled = 0;
|
||||||
|
@ -1,27 +0,0 @@
|
|||||||
# candle-qwen: large language model series from Alibaba Cloud
|
|
||||||
|
|
||||||
Qwen 1.5 is a series of large language models that provide strong performances
|
|
||||||
on English and Chinese.
|
|
||||||
|
|
||||||
- [Blog post](https://qwenlm.github.io/blog/qwen1.5/) introducing Qwen1.5.
|
|
||||||
- [Model card](https://huggingface.co/Qwen/Qwen1.5-0.5B) on the HuggingFace Hub.
|
|
||||||
- [Blog post](https://qwenlm.github.io/blog/qwen-moe/) for the
|
|
||||||
mixture-of-experts (MoE) variant.
|
|
||||||
|
|
||||||
## Running the example
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ cargo run --example qwen --release -- --prompt "Hello there "
|
|
||||||
```
|
|
||||||
|
|
||||||
Various model sizes are available via the `--model` argument, including the MoE
|
|
||||||
variant.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ cargo run --example qwen --release -- --model moe-a2.7b --prompt 'def print_prime(n: int): '
|
|
||||||
def print_prime(n: int): # n is the number of primes to be printed
|
|
||||||
for i in range(2, n + 1):
|
|
||||||
if all(i % j != 0 for j in range(2, i)):
|
|
||||||
print(i)
|
|
||||||
```
|
|
||||||
|
|
@ -1,311 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use clap::Parser;
|
|
||||||
|
|
||||||
use candle_transformers::models::qwen2::{Config as ConfigBase, Model as ModelBase};
|
|
||||||
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
|
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
enum Model {
|
|
||||||
Base(ModelBase),
|
|
||||||
Moe(ModelMoe),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Model {
|
|
||||||
fn forward(&mut self, xs: &Tensor, s: usize) -> candle::Result<Tensor> {
|
|
||||||
match self {
|
|
||||||
Self::Moe(ref mut m) => m.forward(xs, s),
|
|
||||||
Self::Base(ref mut m) => m.forward(xs, s),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct TextGeneration {
|
|
||||||
model: Model,
|
|
||||||
device: Device,
|
|
||||||
tokenizer: TokenOutputStream,
|
|
||||||
logits_processor: LogitsProcessor,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextGeneration {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn new(
|
|
||||||
model: Model,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
temp: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
device: &Device,
|
|
||||||
) -> Self {
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
tokenizer: TokenOutputStream::new(tokenizer),
|
|
||||||
logits_processor,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n,
|
|
||||||
device: device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
self.tokenizer.clear();
|
|
||||||
let mut tokens = self
|
|
||||||
.tokenizer
|
|
||||||
.tokenizer()
|
|
||||||
.encode(prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
for &t in tokens.iter() {
|
|
||||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
|
||||||
print!("{t}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
|
||||||
};
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for index in 0..sample_len {
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
|
||||||
let ctxt = &tokens[start_pos..];
|
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
||||||
let logits = self.model.forward(&input, start_pos)?;
|
|
||||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
|
||||||
&logits,
|
|
||||||
self.repeat_penalty,
|
|
||||||
&tokens[start_at..],
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let dt = start_gen.elapsed();
|
|
||||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
print!("{rest}");
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
println!(
|
|
||||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
||||||
generated_tokens as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
|
|
||||||
enum WhichModel {
|
|
||||||
#[value(name = "0.5b")]
|
|
||||||
W0_5b,
|
|
||||||
#[value(name = "1.8b")]
|
|
||||||
W1_8b,
|
|
||||||
#[value(name = "4b")]
|
|
||||||
W4b,
|
|
||||||
#[value(name = "7b")]
|
|
||||||
W7b,
|
|
||||||
#[value(name = "14b")]
|
|
||||||
W14b,
|
|
||||||
#[value(name = "72b")]
|
|
||||||
W72b,
|
|
||||||
#[value(name = "moe-a2.7b")]
|
|
||||||
MoeA27b,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
use_flash_attn: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: String,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
temperature: Option<f64>,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
|
||||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
|
||||||
sample_len: usize,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "main")]
|
|
||||||
revision: String,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
weight_files: Option<String>,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
||||||
#[arg(long, default_value_t = 1.1)]
|
|
||||||
repeat_penalty: f32,
|
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
|
||||||
#[arg(long, default_value_t = 64)]
|
|
||||||
repeat_last_n: usize,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "0.5b")]
|
|
||||||
model: WhichModel,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle::utils::with_avx(),
|
|
||||||
candle::utils::with_neon(),
|
|
||||||
candle::utils::with_simd128(),
|
|
||||||
candle::utils::with_f16c()
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
||||||
args.temperature.unwrap_or(0.),
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n
|
|
||||||
);
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let api = Api::new()?;
|
|
||||||
let model_id = match args.model_id {
|
|
||||||
Some(model_id) => model_id,
|
|
||||||
None => {
|
|
||||||
let size = match args.model {
|
|
||||||
WhichModel::W0_5b => "0.5B",
|
|
||||||
WhichModel::W1_8b => "1.8B",
|
|
||||||
WhichModel::W4b => "4B",
|
|
||||||
WhichModel::W7b => "7B",
|
|
||||||
WhichModel::W14b => "14B",
|
|
||||||
WhichModel::W72b => "72B",
|
|
||||||
WhichModel::MoeA27b => "MoE-A2.7B",
|
|
||||||
};
|
|
||||||
format!("Qwen/Qwen1.5-{size}")
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let repo = api.repo(Repo::with_revision(
|
|
||||||
model_id,
|
|
||||||
RepoType::Model,
|
|
||||||
args.revision,
|
|
||||||
));
|
|
||||||
let tokenizer_filename = match args.tokenizer_file {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => repo.get("tokenizer.json")?,
|
|
||||||
};
|
|
||||||
let filenames = match args.weight_files {
|
|
||||||
Some(files) => files
|
|
||||||
.split(',')
|
|
||||||
.map(std::path::PathBuf::from)
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
None => match args.model {
|
|
||||||
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
|
|
||||||
WhichModel::W4b
|
|
||||||
| WhichModel::W7b
|
|
||||||
| WhichModel::W14b
|
|
||||||
| WhichModel::W72b
|
|
||||||
| WhichModel::MoeA27b => {
|
|
||||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
|
||||||
}
|
|
||||||
},
|
|
||||||
};
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let config_file = repo.get("config.json")?;
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let dtype = if device.is_cuda() {
|
|
||||||
DType::BF16
|
|
||||||
} else {
|
|
||||||
DType::F32
|
|
||||||
};
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
|
||||||
let model = match args.model {
|
|
||||||
WhichModel::MoeA27b => {
|
|
||||||
let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
|
||||||
Model::Moe(ModelMoe::new(&config, vb)?)
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
|
||||||
Model::Base(ModelBase::new(&config, vb)?)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
args.seed,
|
|
||||||
args.temperature,
|
|
||||||
args.top_p,
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n,
|
|
||||||
&device,
|
|
||||||
);
|
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,9 +0,0 @@
|
|||||||
# candle-recurrent-gemma
|
|
||||||
|
|
||||||
This model card corresponds to the 2B base version of the RecurrentGemma model
|
|
||||||
[huggingface model card](https://huggingface.co/google/recurrentgemma-2b).
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run --features cuda -r --example recurrent-gemma -- \
|
|
||||||
--prompt "Write me a poem about Machine Learning."
|
|
||||||
```
|
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user