Compare commits

..

11 Commits

Author SHA1 Message Date
67d93b4f42 More happy tests. 2024-01-15 18:46:18 +01:00
c35d7d50db Making the CI happy. 2024-01-15 18:31:09 +01:00
9694671bbf Not implementing quantized. 2024-01-15 18:00:43 +01:00
3dbf65ef20 Rebase after phi2 merge + fix replit default to CPU. 2024-01-15 17:52:49 +01:00
b2db5adf82 Bad code removal. 2024-01-15 17:43:00 +01:00
9ef040338d After rebase. 2024-01-15 17:43:00 +01:00
3aefc709c7 Cleanup the fence. 2024-01-15 17:43:00 +01:00
c8c603ce96 Removing the fences speeds everything up and *is* correct this time... 2024-01-15 17:43:00 +01:00
61ad8d91cc Fix the rebase. 2024-01-15 17:43:00 +01:00
2cd1e59c9e Cleanup. 2024-01-15 17:43:00 +01:00
9c4b4f0da0 Metal quantized modifications proposal.
- Add a device param, wherever needed.
- Create new QMetal storage thing that implements QuantizedType.
- Update everywhere needed.

Fix Python.

Fixing examples.

Fix: fmt + clippy + stub.

Moving everything around.

Only missing the actual implems.

Fixing everything + adding dequantized kernels.

More work.

Fixing matmul.

Fmt + Clippy

Some clippy fixes.

Working state.

Q2K Metal -> Bugged (also present in GGML).
Q4K CPU -> Bugged (present previously, new test catch it).
Q5K CPU -> Bugged (present previously).
Q8_1 Both -> Never really implemented it seems
Q8K metal -> Never implemented in metal

Fixing Q2K bug (present in ggml).
2024-01-15 17:42:58 +01:00
268 changed files with 4702 additions and 40304 deletions

View File

@ -5,15 +5,49 @@ on:
pull_request: pull_request:
jobs: jobs:
start-runner:
name: Start self-hosted EC2 runner
runs-on: ubuntu-latest
# Don't run on forks, they won't have access to secrets anyway.
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
env:
AWS_REGION: us-east-1
EC2_AMI_ID: ami-03cfed9ea28f4b002
EC2_INSTANCE_TYPE: g5.xlarge
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
EC2_SECURITY_GROUP: sg-030175c435ac141d6
outputs:
label: ${{ steps.start-ec2-runner.outputs.label }}
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Start EC2 runner
id: start-ec2-runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: start
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
ec2-image-id: ${{ env.EC2_AMI_ID }}
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
subnet-id: ${{ env.EC2_SUBNET_ID }}
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
aws-resource-tags: > # optional, requires additional permissions
[
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
]
test-cuda: test-cuda:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
runs-on: [single-gpu, nvidia-gpu, t4, ci] needs: start-runner # required to start the main job when the runner is ready
container: runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
options: --gpus 0
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
permissions: permissions:
contents: write contents: write
packages: write packages: write
@ -24,10 +58,32 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Install dependencies
run: apt-get update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
- name: Install Rust Stable - name: Install Rust Stable
uses: actions-rust-lang/setup-rust-toolchain@v1 run: curl https://sh.rustup.rs -sSf | sh -s -- -y
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
- name: Test (cuda) - name: Test (cuda)
run: cargo test --features cuda run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
stop-runner:
name: Stop self-hosted EC2 runner
needs:
- start-runner
- test-cuda
runs-on: ubuntu-latest
env:
AWS_REGION: us-east-1
if: ${{ (success() || failure()) && github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }} # required to stop the runner even if the error happened in the previous jobs
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Stop EC2 runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: stop
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
label: ${{ needs.start-runner.outputs.label }}
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}

View File

@ -9,7 +9,6 @@ members = [
"candle-transformers", "candle-transformers",
"candle-wasm-examples/*", "candle-wasm-examples/*",
"candle-wasm-tests", "candle-wasm-tests",
"tensor-tools",
] ]
exclude = [ exclude = [
"candle-flash-attn", "candle-flash-attn",
@ -20,7 +19,7 @@ exclude = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "0.5.0" version = "0.3.3"
edition = "2021" edition = "2021"
description = "Minimalist ML framework." description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle" repository = "https://github.com/huggingface/candle"
@ -29,38 +28,37 @@ categories = ["science"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
[workspace.dependencies] [workspace.dependencies]
ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" } accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3" byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.5.0" } candle = { path = "./candle-core", package = "candle-core" }
candle-datasets = { path = "./candle-datasets", version = "0.5.0" } candle-datasets = { path = "./candle-datasets" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.0" } candle-flash-attn = { path = "./candle-flash-attn" }
candle-kernels = { path = "./candle-kernels", version = "0.5.0" } candle-kernels = { path = "./candle-kernels" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.0" } candle-metal-kernels = { path = "./candle-metal-kernels" }
candle-nn = { path = "./candle-nn", version = "0.5.0" } candle-nn = { path = "./candle-nn" }
candle-onnx = { path = "./candle-onnx", version = "0.5.0" } candle-onnx = { path = "./candle-onnx" }
candle-transformers = { path = "./candle-transformers", version = "0.5.0" } candle-transformers = { path = "./candle-transformers" }
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false } criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.10.0", features = ["f16"] } cudarc = { version = "0.10.0", features = ["f16"] }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0" hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] } image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.24.0", default-features = false } imageproc = { version = "0.23.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] } intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" } libc = { version = "0.2.147" }
log = "0.4" log = "0.4"
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] } memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
num_cpus = "1.15.0" num_cpus = "1.15.0"
num-traits = "0.2.15" num-traits = "0.2.15"
parquet = { version = "51.0.0" } parquet = { version = "45.0.0" }
rand = "0.8.5" rand = "0.8.5"
rand_distr = "0.4.3" rand_distr = "0.4.3"
rayon = "1.7.0" rayon = "1.7.0"
safetensors = "0.4.1" rusttype = { version = "0.9", default-features = false }
safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] } serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2" serde_plain = "1.0.2"
serde_json = "1.0.99" serde_json = "1.0.99"

View File

@ -63,25 +63,17 @@ We also provide a some command line based examples using state of the art models
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes - [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
the SOLAR-10.7B variant. the SOLAR-10.7B variant.
- [Falcon](./candle-examples/examples/falcon/): general LLM. - [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
Griffin based models from Google that mix attention with a RNN like state.
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b. - [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM - [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets. Also supports pre-trained on 1T tokens of English and code datasets.
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants. - [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal
- [Mamba](./candle-examples/examples/mamba/): an inference only
implementation of the Mamba state space model. implementation of the Mamba state space model.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with - [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
better performance than all publicly available 13b models as of 2023-09-28. better performance than all publicly available 13b models as of 2023-09-28.
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of - [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
experts 8x7b general LLM with better performance than a Llama 2 70B model with experts 8x7b general LLM with better performance than a Llama 2 70B model with
much faster inference. much faster inference.
- [StarCoder](./candle-examples/examples/bigcode/) and - [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
[StarCoder2](./candle-examples/examples/starcoder2/): LLM specialized to code generation.
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
- [RWKV v5 and v6](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
performance.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion. - [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual - [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
(English/Chinese) general LLMs with 6b and 34b parameters. (English/Chinese) general LLMs with 6b and 34b parameters.
@ -111,12 +103,7 @@ We also provide a some command line based examples using state of the art models
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200"> <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmantation model.
- [Whisper](./candle-examples/examples/whisper/): speech recognition model. - [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
model using residual vector quantization.
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
text-to-speech.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/), - [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings. [JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained - [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
@ -124,16 +111,11 @@ We also provide a some command line based examples using state of the art models
evaluation, segmentation). evaluation, segmentation).
- [VGG](./candle-examples/examples/vgg/), - [VGG](./candle-examples/examples/vgg/),
[RepVGG](./candle-examples/examples/repvgg): computer vision models. [RepVGG](./candle-examples/examples/repvgg): computer vision models.
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to - [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
generate captions for an image. generate captions for an image.
- [CLIP](./candle-examples/examples/clip/): multi-model vision and language
model.
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
dedicated submodels for hand-writing and printed recognition.
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation - [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
model, generates the translated text from the input text. model, generates the translated text from the input text.
- [Moondream](./candle-examples/examples/moondream/): tiny computer-vision model
that can answer real-world questions about images.
Run them using commands like: Run them using commands like:
``` ```
@ -177,11 +159,9 @@ And then head over to
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and - [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
serving local LLMs including an OpenAI compatible API server. serving local LLMs including an OpenAI compatible API server.
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle. - [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
- [`candle-coursera-ml`](https://github.com/vishpat/candle-coursera-ml): Implementation of ML algorithms from Coursera's [Machine Learning Specialization](https://www.coursera.org/specializations/machine-learning-introduction) course.
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more. - [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle. - [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem. - [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
If you have an addition to this list, please submit a pull request. If you have an addition to this list, please submit a pull request.
@ -202,18 +182,15 @@ If you have an addition to this list, please submit a pull request.
- Language Models. - Language Models.
- LLaMA v1 and v2 with variants such as SOLAR-10.7B. - LLaMA v1 and v2 with variants such as SOLAR-10.7B.
- Falcon. - Falcon.
- StarCoder, StarCoder2. - StarCoder.
- Phi 1, 1.5, and 2. - Phi 1, 1.5, and 2.
- Mamba, Minimal Mamba - Minimal Mamba
- Gemma 2b and 7b.
- Mistral 7b v0.1. - Mistral 7b v0.1.
- Mixtral 8x7b v0.1. - Mixtral 8x7b v0.1.
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B. - StableLM-3B-4E1T.
- Replit-code-v1.5-3B. - Replit-code-v1.5-3B.
- Bert. - Bert.
- Yi-6B and Yi-34B. - Yi-6B and Yi-34B.
- Qwen1.5, Qwen1.5 MoE.
- RWKV v5 and v6.
- Quantized LLMs. - Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants. - Llama 7b, 13b, 70b, as well as the chat and code variants.
- Mistral 7b, and 7b instruct. - Mistral 7b, and 7b instruct.
@ -223,22 +200,16 @@ If you have an addition to this list, please submit a pull request.
- Text to text. - Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction). - T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation). - Marian MT (Machine Translation).
- Whisper (multi-lingual support).
- Text to image. - Text to image.
- Stable Diffusion v1.5, v2.1, XL v1.0. - Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2. - Wurstchen v2.
- Image to text. - Image to text.
- BLIP. - BLIP.
- TrOCR.
- Audio.
- Whisper, multi-lingual speech-to-text.
- EnCodec, audio compression model.
- MetaVoice-1B, text-to-speech model.
- Computer Vision Models. - Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT, - DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG.
ConvNeXTv2, MobileOne, EfficientVit (MSRA).
- yolo-v3, yolo-v8. - yolo-v3, yolo-v8.
- Segment-Anything Model (SAM). - Segment-Anything Model (SAM).
- SegFormer.
- File formats: load models from safetensors, npz, ggml, or PyTorch files. - File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Serverless (on CPU), small and fast deployments. - Serverless (on CPU), small and fast deployments.
- Quantization support using the llama.cpp quantized types. - Quantization support using the llama.cpp quantized types.
@ -375,9 +346,9 @@ git submodule update --init
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ...: /usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ...:
``` ```
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable. This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
``` ```
env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ... env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
``` ```
#### Linking error on windows when running rustdoc or mdbook tests #### Linking error on windows when running rustdoc or mdbook tests

View File

@ -2,10 +2,7 @@ mod benchmarks;
use criterion::criterion_main; use criterion::criterion_main;
criterion_main!( criterion_main!(
benchmarks::affine::benches,
benchmarks::matmul::benches, benchmarks::matmul::benches,
benchmarks::random::benches, benchmarks::affine::benches,
benchmarks::where_cond::benches, benchmarks::where_cond::benches
benchmarks::conv_transpose2d::benches,
benchmarks::qmatmul::benches,
); );

View File

@ -1,59 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(
x: &Tensor,
k: &Tensor,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) {
x.conv_transpose2d(k, padding, output_padding, stride, dilation)
.unwrap();
}
fn run_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let t = Tensor::arange(0.0f32, 10000.0, device)
.unwrap()
.reshape((1, 4, 50, 50))
.unwrap()
.to_dtype(dtype)
.unwrap();
let kernel = Tensor::arange(0.0f32, 100.0, device)
.unwrap()
.reshape((4, 1, 5, 5))
.unwrap()
.to_dtype(dtype)
.unwrap();
let flops = t.dims().iter().product::<usize>() * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&t), black_box(&kernel), 1, 0, 1, 2);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_benchmark(c, &device, DType::F32, "conv_transpose2d_f32");
run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16");
run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,8 +1,5 @@
pub(crate) mod affine; pub(crate) mod affine;
pub(crate) mod conv_transpose2d;
pub(crate) mod matmul; pub(crate) mod matmul;
pub(crate) mod qmatmul;
pub(crate) mod random;
pub(crate) mod where_cond; pub(crate) mod where_cond;
use candle_core::{Device, Result}; use candle_core::{Device, Result};

View File

@ -1,72 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{
quantized::{self, GgmlDType, QMatMul},
Device, Module, Tensor,
};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(matmul: &QMatMul, x: &Tensor) {
matmul.forward(&x).unwrap();
}
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
let b = 1;
let m = 1;
let n = 1024;
let k = 1024;
let lhs = (0..(m * k))
.map(|v| v as f32 / (m * k) as f32)
.collect::<Vec<_>>();
let rhs = (0..(k * n))
.map(|v| v as f32 / (n * k) as f32)
.collect::<Vec<_>>();
let lhs = Tensor::from_slice(&lhs, (m, k), device).unwrap();
let rhs = Tensor::from_slice(&rhs, (k, n), device).unwrap();
let qtensor = quantized::QTensor::quantize(&rhs.t().unwrap(), dtype).unwrap();
let matmul = quantized::QMatMul::from_qtensor(qtensor).unwrap();
let flops = b * m * n * k;
let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype)));
group.sample_size(200);
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&matmul), black_box(&lhs));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
for dtype in vec![
GgmlDType::F32,
GgmlDType::F16,
GgmlDType::Q4_0,
GgmlDType::Q4_1,
GgmlDType::Q5_0,
GgmlDType::Q5_1,
GgmlDType::Q8_0,
GgmlDType::Q2K,
GgmlDType::Q3K,
GgmlDType::Q4K,
GgmlDType::Q5K,
GgmlDType::Q6K,
] {
run_bench(c, &device, dtype);
}
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,63 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn rand_uniform(a: &Tensor) {
a.rand_like(-1.0, 123.0).unwrap();
}
fn rand_normal(a: &Tensor) {
a.randn_like(100.0, 15.0).unwrap();
}
fn run_random_bench(c: &mut Criterion, device: &Device) {
let b = 1;
let rows = 2048;
let cols = 2048;
let dtype = DType::F32;
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let flops = b * rows * cols * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_uniform(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let mut group = c.benchmark_group(device.bench_name("random_normal"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_normal(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_random_bench(c, &device);
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -5,32 +5,25 @@ extern crate accelerate_src;
extern crate intel_mkl_src; extern crate intel_mkl_src;
use anyhow::Result; use anyhow::Result;
use candle_core::{Device, Module, Tensor}; use candle_core::{Device, Tensor};
use candle_core::quantized::{QMatMul, QTensor};
fn main() -> Result<()> { fn main() -> Result<()> {
let device = Device::new_cuda(0)?; let device = Device::new_cuda(0)?;
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?; let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
let q_cpu = q.to_device(&Device::Cpu)?; let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?; let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
let q = QMatMul::from_qtensor(q)?; println!("{out_t}");
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?; let in_t = in_t.to_device(&Device::Cpu)?;
let res_q_cuda = q.forward(&x)?; let k_t = k_t.to_device(&Device::Cpu)?;
println!("{res_q_cuda}"); let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?; .sqr()?
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?; .sum_all()?;
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
let x_cpu = x.to_device(&Device::Cpu)?;
let res_q_cpu = q_cpu.forward(&x_cpu)?;
println!("{res_q_cpu}");
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
.abs()?
.flatten_all()?
.max(0)?;
println!("{diff}"); println!("{diff}");
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
let res = t.conv2d(&w, 1, 1, 1, 1)?;
println!("{res:?}");
Ok(()) Ok(())
} }

View File

@ -1,5 +1,5 @@
use candle::quantized::{gguf_file, GgmlDType, QTensor}; use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
use candle::{Device, Result}; use candle_core::{Device, Result};
use clap::{Parser, Subcommand, ValueEnum}; use clap::{Parser, Subcommand, ValueEnum};
use rayon::prelude::*; use rayon::prelude::*;
@ -117,24 +117,6 @@ enum Command {
verbose: bool, verbose: bool,
}, },
Print {
file: std::path::PathBuf,
names: Vec<String>,
/// The file format to use, if unspecified infer from the file extension.
#[arg(long, value_enum)]
format: Option<Format>,
/// Print the whole content of each tensor.
#[arg(long)]
full: bool,
/// Line width for printing the tensors.
#[arg(long)]
line_width: Option<usize>,
},
Quantize { Quantize {
/// The input file(s), in safetensors format. /// The input file(s), in safetensors format.
in_file: Vec<std::path::PathBuf>, in_file: Vec<std::path::PathBuf>,
@ -168,105 +150,6 @@ struct Args {
command: Command, command: Command,
} }
fn run_print(
file: &std::path::PathBuf,
names: Vec<String>,
format: Option<Format>,
full: bool,
line_width: Option<usize>,
device: &Device,
) -> Result<()> {
if full {
candle::display::set_print_options_full();
}
if let Some(line_width) = line_width {
candle::display::set_line_width(line_width)
}
let format = match format {
Some(format) => format,
None => match Format::infer(file) {
Some(format) => format,
None => {
println!(
"{file:?}: cannot infer format from file extension, use the --format flag"
);
return Ok(());
}
},
};
match format {
Format::Npz => {
let tensors = candle::npy::NpzTensors::new(file)?;
for name in names.iter() {
println!("==== {name} ====");
match tensors.get(name)? {
Some(tensor) => println!("{tensor}"),
None => println!("not found"),
}
}
}
Format::Safetensors => {
use candle::safetensors::Load;
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
for name in names.iter() {
println!("==== {name} ====");
match tensors.get(name) {
Some(tensor_view) => {
let tensor = tensor_view.load(device)?;
println!("{tensor}")
}
None => println!("not found"),
}
}
}
Format::Pth => {
let pth_file = candle::pickle::PthTensors::new(file, None)?;
for name in names.iter() {
println!("==== {name} ====");
match pth_file.get(name)? {
Some(tensor) => {
println!("{tensor}")
}
None => println!("not found"),
}
}
}
Format::Pickle => {
candle::bail!("pickle format is not supported for print")
}
Format::Ggml => {
let mut file = std::fs::File::open(file)?;
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
for name in names.iter() {
println!("==== {name} ====");
match content.tensors.get(name) {
Some(tensor) => {
let tensor = tensor.dequantize(device)?;
println!("{tensor}")
}
None => println!("not found"),
}
}
}
Format::Gguf => {
let mut file = std::fs::File::open(file)?;
let content = gguf_file::Content::read(&mut file)?;
for name in names.iter() {
println!("==== {name} ====");
match content.tensor(&mut file, name, device) {
Ok(tensor) => {
let tensor = tensor.dequantize(device)?;
println!("{tensor}")
}
Err(_) => println!("not found"),
}
}
}
}
Ok(())
}
fn run_ls( fn run_ls(
file: &std::path::PathBuf, file: &std::path::PathBuf,
format: Option<Format>, format: Option<Format>,
@ -287,7 +170,7 @@ fn run_ls(
}; };
match format { match format {
Format::Npz => { Format::Npz => {
let tensors = candle::npy::NpzTensors::new(file)?; let tensors = candle_core::npy::NpzTensors::new(file)?;
let mut names = tensors.names(); let mut names = tensors.names();
names.sort(); names.sort();
for name in names { for name in names {
@ -299,12 +182,12 @@ fn run_ls(
} }
} }
Format::Safetensors => { Format::Safetensors => {
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? }; let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
let mut tensors = tensors.tensors(); let mut tensors = tensors.tensors();
tensors.sort_by(|a, b| a.0.cmp(&b.0)); tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, view) in tensors.iter() { for (name, view) in tensors.iter() {
let dtype = view.dtype(); let dtype = view.dtype();
let dtype = match candle::DType::try_from(dtype) { let dtype = match candle_core::DType::try_from(dtype) {
Ok(dtype) => format!("{dtype:?}"), Ok(dtype) => format!("{dtype:?}"),
Err(_) => format!("{dtype:?}"), Err(_) => format!("{dtype:?}"),
}; };
@ -313,7 +196,7 @@ fn run_ls(
} }
} }
Format::Pth => { Format::Pth => {
let mut tensors = candle::pickle::read_pth_tensor_info(file, verbose, None)?; let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?;
tensors.sort_by(|a, b| a.name.cmp(&b.name)); tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() { for tensor_info in tensors.iter() {
println!( println!(
@ -330,7 +213,7 @@ fn run_ls(
Format::Pickle => { Format::Pickle => {
let file = std::fs::File::open(file)?; let file = std::fs::File::open(file)?;
let mut reader = std::io::BufReader::new(file); let mut reader = std::io::BufReader::new(file);
let mut stack = candle::pickle::Stack::empty(); let mut stack = candle_core::pickle::Stack::empty();
stack.read_loop(&mut reader)?; stack.read_loop(&mut reader)?;
for (i, obj) in stack.stack().iter().enumerate() { for (i, obj) in stack.stack().iter().enumerate() {
println!("{i} {obj:?}"); println!("{i} {obj:?}");
@ -338,7 +221,7 @@ fn run_ls(
} }
Format::Ggml => { Format::Ggml => {
let mut file = std::fs::File::open(file)?; let mut file = std::fs::File::open(file)?;
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?; let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>(); let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
tensors.sort_by(|a, b| a.0.cmp(&b.0)); tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, qtensor) in tensors.iter() { for (name, qtensor) in tensors.iter() {
@ -374,7 +257,7 @@ fn run_quantize_safetensors(
let mut out_file = std::fs::File::create(out_file)?; let mut out_file = std::fs::File::create(out_file)?;
let mut tensors = std::collections::HashMap::new(); let mut tensors = std::collections::HashMap::new();
for in_file in in_files.iter() { for in_file in in_files.iter() {
let in_tensors = candle::safetensors::load(in_file, &Device::Cpu)?; let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
tensors.extend(in_tensors) tensors.extend(in_tensors)
} }
println!("tensors: {}", tensors.len()); println!("tensors: {}", tensors.len());
@ -416,7 +299,7 @@ fn run_dequantize(
let tensor = tensor.dequantize(device)?; let tensor = tensor.dequantize(device)?;
tensors.insert(tensor_name.to_string(), tensor); tensors.insert(tensor_name.to_string(), tensor);
} }
candle::safetensors::save(&tensors, out_file)?; candle_core::safetensors::save(&tensors, out_file)?;
Ok(()) Ok(())
} }
@ -428,11 +311,11 @@ fn run_quantize(
device: &Device, device: &Device,
) -> Result<()> { ) -> Result<()> {
if in_files.is_empty() { if in_files.is_empty() {
candle::bail!("no specified input files") candle_core::bail!("no specified input files")
} }
if let Some(extension) = out_file.extension() { if let Some(extension) = out_file.extension() {
if extension == "safetensors" { if extension == "safetensors" {
candle::bail!("the generated file cannot use the safetensors extension") candle_core::bail!("the generated file cannot use the safetensors extension")
} }
} }
if let Some(extension) = in_files[0].extension() { if let Some(extension) = in_files[0].extension() {
@ -442,7 +325,7 @@ fn run_quantize(
} }
if in_files.len() != 1 { if in_files.len() != 1 {
candle::bail!("only a single in-file can be used when quantizing gguf files") candle_core::bail!("only a single in-file can be used when quantizing gguf files")
} }
// Open the out file early so as to fail directly on missing directories etc. // Open the out file early so as to fail directly on missing directories etc.
@ -494,13 +377,6 @@ fn main() -> anyhow::Result<()> {
run_ls(file, format.clone(), verbose, &device)? run_ls(file, format.clone(), verbose, &device)?
} }
} }
Command::Print {
file,
names,
format,
full,
line_width,
} => run_print(&file, names, format, full, line_width, &device)?,
Command::Quantize { Command::Quantize {
in_file, in_file,
out_file, out_file,

View File

@ -380,16 +380,6 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) } unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
} }
#[inline]
pub fn vs_exp_inplace(y: &mut [f32]) {
unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vd_exp_inplace(y: &mut [f64]) {
unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline] #[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) { pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) { for (&v, y) in vs.iter().zip(ys.iter_mut()) {
@ -412,28 +402,6 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
} }
} }
#[inline]
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vs_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
#[inline]
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vd_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
macro_rules! binary_op { macro_rules! binary_op {
($fn_name:ident, $ty:ty, $accelerate_name:ident) => { ($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
#[inline] #[inline]

View File

@ -98,19 +98,6 @@ pub trait BackendStorage: Sized {
) -> Result<Self>; ) -> Result<Self>;
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>; fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
#[allow(clippy::too_many_arguments)]
// Similar to cudaMemcpy2D, though values are in elements and not in bytes.
fn copy2d(
&self,
_: &mut Self,
_d1: usize,
_d2: usize,
_src_stride1: usize,
_dst_stride1: usize,
_src_offset: usize,
_dst_offset: usize,
) -> Result<()>;
} }
pub trait BackendDevice: Sized + std::fmt::Debug + Clone { pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
@ -127,22 +114,11 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>; fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
/// # Safety
/// This function is unsafe as it doesn't initialize the underlying data store.
/// The caller should ensure that the data is properly initialized as early as possible
/// after this call.
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>; fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>; fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>; fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn set_seed(&self, _: u64) -> Result<()>; fn set_seed(&self, _: u64) -> Result<()>;
/// Synchronize should block until all the operations on the device are completed.
fn synchronize(&self) -> Result<()>;
} }

View File

@ -1,4 +1,3 @@
/// Methods for backpropagation of gradients.
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp}; use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
use crate::{Error, Result, Tensor, TensorId}; use crate::{Error, Result, Tensor, TensorId};
use std::collections::HashMap; use std::collections::HashMap;
@ -112,10 +111,9 @@ impl Tensor {
} }
Op::Unary(_node, UnaryOp::Ceil) Op::Unary(_node, UnaryOp::Ceil)
| Op::Unary(_node, UnaryOp::Floor) | Op::Unary(_node, UnaryOp::Floor)
| Op::Unary(_node, UnaryOp::Round) | Op::Unary(_node, UnaryOp::Round) => nodes,
| Op::Unary(_node, UnaryOp::Sign) => nodes,
Op::Reshape(node) Op::Reshape(node)
| Op::UpsampleNearest1D { arg: node, .. } | Op::UpsampleNearest1D(node)
| Op::UpsampleNearest2D { arg: node, .. } | Op::UpsampleNearest2D { arg: node, .. }
| Op::AvgPool2D { arg: node, .. } | Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. } | Op::MaxPool2D { arg: node, .. }
@ -177,7 +175,7 @@ impl Tensor {
// the backprop graph of the backprop itself. This would be an issue for second order // the backprop graph of the backprop itself. This would be an issue for second order
// derivatives but these are out of scope at the moment. // derivatives but these are out of scope at the moment.
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b); let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
let grad = if do_not_detach { grad } else { grad.detach() }; let grad = if do_not_detach { grad } else { grad.detach()? };
if let Some(op) = node.op() { if let Some(op) = node.op() {
match op { match op {
Op::Binary(lhs, rhs, BinaryOp::Add) => { Op::Binary(lhs, rhs, BinaryOp::Add) => {
@ -252,7 +250,6 @@ impl Tensor {
out_padding, out_padding,
*stride, *stride,
*dilation, *dilation,
/* groups */ 1,
)?; )?;
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?; *sum_grad = sum_grad.add(&grad_arg)?;
@ -312,32 +309,9 @@ impl Tensor {
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported { Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose1d", op: "conv-transpose1d",
})?, })?,
Op::ConvTranspose2D { Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
arg, op: "conv-transpose2d",
kernel, })?,
padding,
stride,
dilation,
output_padding: _output_padding,
} => {
let grad_arg = grad.conv2d(kernel, *padding, *dilation, *stride, 1)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = grad
.transpose(0, 1)?
.conv2d(&arg.transpose(0, 1)?, *padding, *stride, *dilation, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0, k1) = kernel.dims4()?;
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::AvgPool2D { Op::AvgPool2D {
arg, arg,
kernel_size, kernel_size,
@ -373,18 +347,9 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?; *sum_grad = sum_grad.add(&grad_arg)?;
} }
Op::UpsampleNearest1D { arg, target_size } => { Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
let (_n, c, size) = arg.dims3()?; op: "upsample-nearest1d",
if target_size % size != 0 { })?,
crate::bail!("backward not supported for non integer upscaling factors")
}
let scale = target_size / size;
let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = conv_sum;
}
Op::UpsampleNearest2D { Op::UpsampleNearest2D {
arg, arg,
target_h, target_h,
@ -489,6 +454,7 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad)?; *sum_grad = sum_grad.add(&grad)?;
} }
Op::Cmp(_args, _) => {}
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => { Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
let node = broadcast_back(arg, node, reduced_dims)?; let node = broadcast_back(arg, node, reduced_dims)?;
let grad = broadcast_back(arg, &grad, reduced_dims)?; let grad = broadcast_back(arg, &grad, reduced_dims)?;
@ -578,18 +544,20 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)? *sum_grad = sum_grad.add(&arg_grad)?
} }
Op::Unary(_, UnaryOp::Floor) Op::Reduce(_, ReduceOp::ArgMin, _) => {}
| Op::Unary(_, UnaryOp::Round) Op::Reduce(_, ReduceOp::ArgMax, _) => {}
| Op::Reduce(_, ReduceOp::ArgMin, _)
| Op::Reduce(_, ReduceOp::ArgMax, _)
| Op::Unary(_, UnaryOp::Sign)
| Op::Cmp(_, _) => {}
Op::Reshape(arg) => { Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?; let arg_grad = grad.reshape(arg.dims())?;
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)? *sum_grad = sum_grad.add(&arg_grad)?
} }
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?, Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
Op::Unary(_, UnaryOp::Floor) => {
Err(Error::BackwardNotSupported { op: "floor" })?
}
Op::Unary(_, UnaryOp::Round) => {
Err(Error::BackwardNotSupported { op: "round" })?
}
Op::Unary(arg, UnaryOp::Gelu) => { Op::Unary(arg, UnaryOp::Gelu) => {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
let cube = arg.powf(3.)?; let cube = arg.powf(3.)?;
@ -621,13 +589,6 @@ impl Tensor {
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?; let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)? *sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
} }
Op::Unary(arg, UnaryOp::Silu) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
}
Op::Elu(arg, alpha) => { Op::Elu(arg, alpha) => {
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0 // d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
@ -712,38 +673,30 @@ impl Tensor {
} }
} }
/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation.
#[derive(Debug)] #[derive(Debug)]
pub struct GradStore(HashMap<TensorId, Tensor>); pub struct GradStore(HashMap<TensorId, Tensor>);
impl GradStore { impl GradStore {
/// Create a new gradient store
fn new() -> Self { fn new() -> Self {
GradStore(HashMap::new()) GradStore(HashMap::new())
} }
/// Get the gradient tensor corresponding to the given tensor id
pub fn get_id(&self, id: TensorId) -> Option<&Tensor> { pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
self.0.get(&id) self.0.get(&id)
} }
/// Get the gradient tensor associated with the given tensor
pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> { pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
self.0.get(&tensor.id()) self.0.get(&tensor.id())
} }
/// Remove the gradient tensor associated with the given tensor, returning it if it exists
pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> { pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
self.0.remove(&tensor.id()) self.0.remove(&tensor.id())
} }
/// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed
pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> { pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
self.0.insert(tensor.id(), grad) self.0.insert(tensor.id(), grad)
} }
/// Get the gradient tensor associated with the given tensor, or, if it does not exist,
/// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> { fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
use std::collections::hash_map::Entry; use std::collections::hash_map::Entry;
let grad = match self.0.entry(tensor.id()) { let grad = match self.0.entry(tensor.id()) {

View File

@ -187,16 +187,36 @@ impl Tensor {
} }
} }
fn conv_transpose1d_single_group( /// Applies a 1D transposed convolution over the input tensor.
pub fn conv_transpose1d(
&self, &self,
kernel: &Self, kernel: &Self,
params: &ParamsConvTranspose1D, padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) -> Result<Self> { ) -> Result<Self> {
let (b_size, c_in, l_in) = self.dims3()?;
let (c_in_k, c_out, k_size) = kernel.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = ParamsConvTranspose1D {
b_size,
l_in,
k_size,
c_out,
c_in,
padding,
output_padding,
stride,
dilation,
};
let storage = self.storage().conv_transpose1d( let storage = self.storage().conv_transpose1d(
self.layout(), self.layout(),
&kernel.storage(), &kernel.storage(),
kernel.layout(), kernel.layout(),
params, &params,
)?; )?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D { let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
arg, arg,
@ -210,49 +230,6 @@ impl Tensor {
Ok(crate::tensor::from_storage(storage, out_dims, op, false)) Ok(crate::tensor::from_storage(storage, out_dims, op, false))
} }
/// Applies a 1D transposed convolution over the input tensor.
pub fn conv_transpose1d(
&self,
kernel: &Self,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
let (c_in_k, c_out, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = self.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
if c_in % groups != 0 {
crate::bail!("in_channel {c_in} is not divisible by the number of groups")
}
let params = ParamsConvTranspose1D {
b_size,
l_in,
k_size,
c_out,
c_in: c_in / groups,
padding,
output_padding,
stride,
dilation,
};
if groups == 1 {
self.conv_transpose1d_single_group(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
let kernel = kernel.chunk(groups, 0)?;
let blocks = blocks
.iter()
.zip(&kernel)
.map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, &params))
.collect::<Result<Vec<_>>>()?;
Tensor::cat(&blocks, 1)
}
}
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> { fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
let storage = let storage =
self.storage() self.storage()

View File

@ -4,13 +4,7 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use half::{bf16, f16}; use half::{bf16, f16};
use rayon::prelude::*; use rayon::prelude::*;
mod utils;
pub use utils::{
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
};
const USE_IM2COL_CONV1D: bool = true; const USE_IM2COL_CONV1D: bool = true;
const USE_IM2COL_CONV1D_TR: bool = true;
const USE_IM2COL_CONV2D: bool = true; const USE_IM2COL_CONV2D: bool = true;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
@ -29,6 +23,102 @@ pub enum CpuStorage {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct CpuDevice; pub struct CpuDevice;
pub trait Map1 {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
match vs {
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)),
}
}
}
pub trait Map1Any {
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
&self,
vs: &[T],
layout: &Layout,
wrap: W,
) -> Result<CpuStorage>;
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
match vs {
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
}
}
}
type C = CpuStorage;
pub trait Map2 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
fn map(
&self,
v1: &CpuStorage,
l1: &Layout,
v2: &CpuStorage,
l2: &Layout,
) -> Result<CpuStorage> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt()),
}
}
}
pub trait Map2U8 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
fn map(
&self,
v1: &CpuStorage,
l1: &Layout,
v2: &CpuStorage,
l2: &Layout,
) -> Result<CpuStorage> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt()),
}
}
}
struct Cmp(CmpOp); struct Cmp(CmpOp);
impl Map2U8 for Cmp { impl Map2U8 for Cmp {
const OP: &'static str = "cmp"; const OP: &'static str = "cmp";
@ -275,6 +365,275 @@ impl<'a> Map1 for ReduceSum<'a> {
} }
} }
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
vs: &[T],
layout: &Layout,
mut f: F,
) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
[start_offset..start_offset + len]
.iter()
.map(|&v| f(v))
.collect(),
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let mut result = Vec::with_capacity(layout.shape().elem_count());
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
} else {
for index in block_start_index {
for offset in 0..block_len {
let v = unsafe { vs.get_unchecked(index + offset) };
result.push(f(*v))
}
}
}
result
}
}
}
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
vs: &[T],
layout: &Layout,
mut f: F,
mut f_vec: FV,
) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
let mut ys: Vec<U> = Vec::with_capacity(len);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(len) };
ys
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let el_count = layout.shape().elem_count();
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
let mut result = Vec::with_capacity(el_count);
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
result
} else {
let mut ys: Vec<U> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
let mut dst_index = 0;
for src_index in block_start_index {
let vs = &vs[src_index..src_index + block_len];
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
f_vec(vs, ys);
dst_index += block_len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
}
}
}
// This function maps over two strided index sequences.
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
rhs: &[T],
mut f: F,
) -> Vec<U> {
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
.iter()
.zip(rhs[o_r1..o_r2].iter())
.map(|(&l, &r)| f(l, r))
.collect(),
(Some((o_l1, o_l2)), None) => {
// TODO: Maybe we want to avoid going through the layout twice.
match rhs_l.offsets_b() {
Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
lhs[o_l1..o_l2]
.iter()
.map(|&l| {
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
i_right_broadcast += 1;
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
}
if i_in_block >= ob.len {
i_in_block = 0
}
f(l, *r)
})
.collect()
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
(None, Some((o_r1, o_r2))) => {
// TODO: Maybe we want to avoid going through the layout twice.
match lhs_l.offsets_b() {
Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
rhs[o_r1..o_r2]
.iter()
.map(|&r| {
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
i_right_broadcast += 1;
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
}
if i_in_block >= ob.len {
i_in_block = 0
}
f(*l, r)
})
.collect()
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
_ => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
// Similar to binary_map but with vectorized variants.
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
rhs: &[T],
mut f: F,
mut f_vec: FV,
) -> Vec<T> {
let el_count = lhs_l.shape().elem_count();
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let mut dst_i = 0;
for src_i in (o_l1..o_l2).step_by(ob.len) {
f_vec(
&lhs[src_i..src_i + ob.len],
rhs,
&mut ys_to_set[dst_i..dst_i + ob.len],
);
dst_i += ob.len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
Some(ob) => {
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys = lhs[o_l1..o_l2].to_vec();
for idx_l in 0..ob.left_broadcast {
let start = idx_l * ob.len * ob.right_broadcast;
for (i, &r) in rhs.iter().enumerate() {
let start = start + i * ob.right_broadcast;
for v in ys[start..start + ob.right_broadcast].iter_mut() {
*v = f(*v, r)
}
}
}
ys
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
},
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let mut dst_i = 0;
for src_i in (o_r1..o_r2).step_by(ob.len) {
f_vec(
lhs,
&rhs[src_i..src_i + ob.len],
&mut ys_to_set[dst_i..dst_i + ob.len],
);
dst_i += ob.len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
Some(ob) => {
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys = rhs[o_r1..o_r2].to_vec();
for idx_l in 0..ob.left_broadcast {
let start = idx_l * ob.len * ob.right_broadcast;
for (i, &l) in lhs.iter().enumerate() {
let start = start + i * ob.right_broadcast;
for v in ys[start..start + ob.right_broadcast].iter_mut() {
*v = f(l, *v)
}
}
}
ys
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
},
_ => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
struct Affine(f64, f64); struct Affine(f64, f64);
impl Map1 for Affine { impl Map1 for Affine {
@ -663,26 +1022,6 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
} }
} }
#[allow(clippy::too_many_arguments)]
fn copy2d_<T: Copy>(
src: &[T],
dst: &mut [T],
d1: usize,
d2: usize,
src_stride1: usize,
dst_stride1: usize,
src_offset: usize,
dst_offset: usize,
) {
for i1 in 0..d1 {
let dst_idx = i1 * dst_stride1 + dst_offset;
let src_idx = i1 * src_stride1 + src_offset;
let dst = &mut dst[dst_idx..dst_idx + d2];
let src = &src[src_idx..src_idx + d2];
dst.copy_from_slice(src)
}
}
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) { fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
match src_l.strided_blocks() { match src_l.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => { crate::StridedBlocks::SingleBlock { start_offset, len } => {
@ -917,34 +1256,6 @@ impl Map1 for Im2Col {
} }
} }
struct Col2Im1D {
stride: usize,
}
impl Map1 for Col2Im1D {
fn f<T: WithDType>(&self, col: &[T], l: &Layout) -> Result<Vec<T>> {
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
let stride = self.stride;
let l_out = (l_in - 1) * stride + k_size;
let mut im = vec![T::zero(); b_size * c_out * l_out];
let (dst_s0, dst_s1) = (c_out * l_out, l_out);
let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
for l_in_i in 0..l_in {
for k_i in 0..k_size {
let l_out_i = l_in_i * stride + k_i;
for b_i in 0..b_size {
for c_i in 0..c_out {
let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
im[dst_idx] += col[src_idx]
}
}
}
}
Ok(im)
}
}
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D); struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> { impl<'a> Map2 for ConvTranspose1D<'a> {
@ -952,7 +1263,6 @@ impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> { fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0; let p = self.0;
let inp = &inp[inp_l.start_offset()..]; let inp = &inp[inp_l.start_offset()..];
let k = &k[k_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?; let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?; let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out(); let l_out = p.l_out();
@ -1204,30 +1514,6 @@ impl MatMul {
})) }))
.bt() .bt()
} }
fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
let rank = lhs_stride.len();
let (_b, m, n, k) = self.0;
let a_skip: usize = match lhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[_, stride] if lhs_l.dims()[0] == 1 => stride,
[stride, _] if lhs_l.dims()[1] == 1 => stride,
[stride] => stride,
[] => m * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[_, stride] if rhs_l.dims()[0] == 1 => stride,
[stride, _] if rhs_l.dims()[1] == 1 => stride,
[stride] => stride,
[] => n * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
Ok((a_skip, b_skip))
}
} }
impl Map2 for MatMul { impl Map2 for MatMul {
@ -1261,7 +1547,18 @@ impl Map2 for MatMul {
let rhs_cs = rhs_stride[rank - 1]; let rhs_cs = rhs_stride[rank - 1];
let rhs_rs = rhs_stride[rank - 2]; let rhs_rs = rhs_stride[rank - 2];
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; let a_skip: usize = match lhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
let c_skip: usize = m * n; let c_skip: usize = m * n;
let dst_shape: Shape = (m, n).into(); let dst_shape: Shape = (m, n).into();
@ -1321,8 +1618,20 @@ impl Map2 for MatMul {
let lhs_stride = lhs_l.stride(); let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride(); let rhs_stride = rhs_l.stride();
let rank = lhs_stride.len();
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; let a_skip: usize = match lhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
let c_skip: usize = m * n; let c_skip: usize = m * n;
let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
@ -1330,7 +1639,7 @@ impl Map2 for MatMul {
let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
(n as i32, b'N') (n as i32, b'N')
} else if rhs_m1 == k && rhs_m2 == 1 { } else if rhs_m1 == k && rhs_m2 == 1 {
(k as i32, b'T') (k as i32, b'T')
@ -1338,7 +1647,7 @@ impl Map2 for MatMul {
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
}; };
// The b tensor has dims batching, m, k (lhs) // The b tensor has dims batching, m, k (lhs)
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
(k as i32, b'N') (k as i32, b'N')
} else if lhs_m1 == m && lhs_m2 == 1 { } else if lhs_m1 == m && lhs_m2 == 1 {
(m as i32, b'T') (m as i32, b'T')
@ -1412,8 +1721,20 @@ impl Map2 for MatMul {
let lhs_stride = lhs_l.stride(); let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride(); let rhs_stride = rhs_l.stride();
let rank = lhs_stride.len();
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; let a_skip: usize = match lhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
let c_skip: usize = m * n; let c_skip: usize = m * n;
let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
@ -1421,7 +1742,7 @@ impl Map2 for MatMul {
let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
(n as i32, b'N') (n as i32, b'N')
} else if rhs_m1 == k && rhs_m2 == 1 { } else if rhs_m1 == k && rhs_m2 == 1 {
(k as i32, b'T') (k as i32, b'T')
@ -1429,7 +1750,7 @@ impl Map2 for MatMul {
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
}; };
// The b tensor has dims batching, m, k (lhs) // The b tensor has dims batching, m, k (lhs)
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
(k as i32, b'N') (k as i32, b'N')
} else if lhs_m1 == m && lhs_m2 == 1 { } else if lhs_m1 == m && lhs_m2 == 1 {
(m as i32, b'T') (m as i32, b'T')
@ -2101,48 +2422,6 @@ impl BackendStorage for CpuStorage {
} }
} }
fn copy2d(
&self,
dst: &mut Self,
d1: usize,
d2: usize,
src_s: usize,
dst_s: usize,
src_o: usize,
dst_o: usize,
) -> Result<()> {
match (self, dst) {
(Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
(Self::U32(src), Self::U32(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::I64(src), Self::I64(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::BF16(src), Self::BF16(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::F16(src), Self::F16(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::F32(src), Self::F32(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::F64(src), Self::F64(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(_, dst) => {
return Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: dst.dtype(),
op: "copy2d",
}
.bt());
}
}
Ok(())
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
match (self, dst) { match (self, dst) {
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
@ -2211,10 +2490,7 @@ impl BackendStorage for CpuStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else { } else {
// Make the kernel contiguous if not already the case. // Make the kernel contiguous if not already the case.
let mut kernel_c = unsafe { let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
self.device()
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
};
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)? .transpose(1, 2)?
@ -2222,7 +2498,7 @@ impl BackendStorage for CpuStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
}; };
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?; let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?; res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t) Ok(res_t)
} }
@ -2234,52 +2510,7 @@ impl BackendStorage for CpuStorage {
kernel_l: &Layout, kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D, params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> { ) -> Result<Self> {
let can_use_col2im = kernel_l.is_contiguous() ConvTranspose1D(params).map(self, l, kernel, kernel_l)
&& params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
if USE_IM2COL_CONV1D_TR && can_use_col2im {
let (b_size, c_in, l_in) = l.shape().dims3()?;
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
if !kernel_l.is_contiguous() {
crate::bail!(
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
)
}
if c_in != c_in2 {
crate::bail!(
"convtr1d: shape mismatch on c_in {:?} {:?}",
l.shape(),
kernel_l.shape()
)
}
let col = {
// This merges the last two dimensions of the kernel together.
let kernel_l_mm = Layout::new(
(b_size, c_in, k_size * c_out).into(),
vec![0, k_size * c_out, 1],
kernel_l.start_offset(),
);
self.matmul(
kernel,
(
b_size,
/* m */ l_in,
/* n */ c_out * k_size,
/* k */ c_in,
),
&l.transpose(1, 2)?,
&kernel_l_mm,
)?
};
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
Col2Im1D {
stride: params.stride,
}
.map(&col, &col_l)
} else {
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
}
} }
fn conv2d( fn conv2d(
@ -2313,10 +2544,7 @@ impl BackendStorage for CpuStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else { } else {
// Make the kernel contiguous if not already the case. // Make the kernel contiguous if not already the case.
let mut kernel_c = unsafe { let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
self.device()
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
};
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)? .transpose(1, 2)?
@ -2326,7 +2554,7 @@ impl BackendStorage for CpuStorage {
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out)) let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
.transpose(1, 2)? .transpose(1, 2)?
.transpose(1, 3)?; .transpose(1, 3)?;
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?; res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t) Ok(res_t)
} }
@ -2346,7 +2574,7 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
} }
} }
@ -2355,7 +2583,7 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
} }
} }
@ -2372,7 +2600,7 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
} }
} }
@ -2449,10 +2677,6 @@ impl BackendDevice for CpuDevice {
Ok(s.clone()) Ok(s.clone())
} }
fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result<Self::Storage> {
Ok(s)
}
fn new(_: usize) -> Result<Self> { fn new(_: usize) -> Result<Self> {
Ok(Self) Ok(Self)
} }
@ -2554,53 +2778,6 @@ impl BackendDevice for CpuDevice {
} }
} }
#[allow(clippy::uninit_vec)]
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
let elem_count = shape.elem_count();
// The code below is highly unsafe but hopefully not directly unsound as we only consider
// types that are Copy, not Drop, and for which all bit patterns are proper values.
// It's still pretty risky, see the following for more details:
// https://github.com/rust-lang/rust-clippy/issues/4483
let storage = match dtype {
DType::U8 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::U8(v)
}
DType::U32 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::U32(v)
}
DType::I64 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::I64(v)
}
DType::BF16 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::BF16(v)
}
DType::F16 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::F16(v)
}
DType::F32 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::F32(v)
}
DType::F64 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::F64(v)
}
};
Ok(storage)
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> { fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
let elem_count = shape.elem_count(); let elem_count = shape.elem_count();
let storage = match dtype { let storage = match dtype {
@ -2628,10 +2805,6 @@ impl BackendDevice for CpuDevice {
}; };
Ok(storage) Ok(storage)
} }
fn synchronize(&self) -> Result<()> {
Ok(())
}
} }
#[macro_export] #[macro_export]

View File

@ -1,350 +0,0 @@
/// Helper functions to write CPU kernels.
use crate::backend::BackendStorage;
use crate::{Error, Layout, Result, WithDType};
type C = super::CpuStorage;
pub trait Map1 {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
match vs {
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
}
}
}
pub trait Map1Any {
fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
match vs {
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
}
}
}
pub trait Map2 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt()),
}
}
}
pub trait Map2U8 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt()),
}
}
}
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
rhs: &[T],
mut f: F,
) -> Vec<U> {
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
.iter()
.zip(rhs[o_r1..o_r2].iter())
.map(|(&l, &r)| f(l, r))
.collect(),
(Some((o_l1, o_l2)), None) => {
// TODO: Maybe we want to avoid going through the layout twice.
match rhs_l.offsets_b() {
Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
lhs[o_l1..o_l2]
.iter()
.map(|&l| {
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
i_right_broadcast += 1;
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
}
if i_in_block >= ob.len {
i_in_block = 0
}
f(l, *r)
})
.collect()
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
(None, Some((o_r1, o_r2))) => {
// TODO: Maybe we want to avoid going through the layout twice.
match lhs_l.offsets_b() {
Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
rhs[o_r1..o_r2]
.iter()
.map(|&r| {
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
i_right_broadcast += 1;
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
}
if i_in_block >= ob.len {
i_in_block = 0
}
f(*l, r)
})
.collect()
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
_ => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
// Similar to binary_map but with vectorized variants.
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
rhs: &[T],
mut f: F,
mut f_vec: FV,
) -> Vec<T> {
let el_count = lhs_l.shape().elem_count();
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let mut dst_i = 0;
for src_i in (o_l1..o_l2).step_by(ob.len) {
f_vec(
&lhs[src_i..src_i + ob.len],
rhs,
&mut ys_to_set[dst_i..dst_i + ob.len],
);
dst_i += ob.len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
Some(ob) => {
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys = lhs[o_l1..o_l2].to_vec();
for idx_l in 0..ob.left_broadcast {
let start = idx_l * ob.len * ob.right_broadcast;
for (i, &r) in rhs.iter().enumerate() {
let start = start + i * ob.right_broadcast;
for v in ys[start..start + ob.right_broadcast].iter_mut() {
*v = f(*v, r)
}
}
}
ys
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
},
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let mut dst_i = 0;
for src_i in (o_r1..o_r2).step_by(ob.len) {
f_vec(
lhs,
&rhs[src_i..src_i + ob.len],
&mut ys_to_set[dst_i..dst_i + ob.len],
);
dst_i += ob.len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
Some(ob) => {
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys = rhs[o_r1..o_r2].to_vec();
for idx_l in 0..ob.left_broadcast {
let start = idx_l * ob.len * ob.right_broadcast;
for (i, &l) in lhs.iter().enumerate() {
let start = start + i * ob.right_broadcast;
for v in ys[start..start + ob.right_broadcast].iter_mut() {
*v = f(l, *v)
}
}
}
ys
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
},
_ => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
vs: &[T],
layout: &Layout,
mut f: F,
) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
[start_offset..start_offset + len]
.iter()
.map(|&v| f(v))
.collect(),
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let mut result = Vec::with_capacity(layout.shape().elem_count());
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
} else {
for index in block_start_index {
for offset in 0..block_len {
let v = unsafe { vs.get_unchecked(index + offset) };
result.push(f(*v))
}
}
}
result
}
}
}
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
vs: &[T],
layout: &Layout,
mut f: F,
mut f_vec: FV,
) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
let mut ys: Vec<U> = Vec::with_capacity(len);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(len) };
ys
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let el_count = layout.shape().elem_count();
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
let mut result = Vec::with_capacity(el_count);
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
result
} else {
let mut ys: Vec<U> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
let mut dst_index = 0;
for src_index in block_start_index {
let vs = &vs[src_index..src_index + block_len];
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
f_vec(vs, ys);
dst_index += block_len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
}
}
}

View File

@ -5,41 +5,395 @@ pub use candle_kernels as kernels;
pub use cudarc; pub use cudarc;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{ use cudarc::driver::{
CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, CudaFunction, CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig,
ValidAsZeroBits,
}; };
use half::{bf16, f16}; use half::{bf16, f16};
use std::sync::{Arc, Mutex};
#[cfg(feature = "cudnn")] /// cudarc related errors
pub mod cudnn; #[derive(thiserror::Error, Debug)]
mod device; pub enum CudaError {
mod error; #[error(transparent)]
mod utils; Cuda(#[from] cudarc::driver::DriverError),
pub use device::{CudaDevice, DeviceId};
pub use error::{CudaError, WrapErr};
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
enum SlicePtrOrNull<T> { #[error(transparent)]
Ptr(CudaSlice<T>), Compiler(#[from] cudarc::nvrtc::CompileError),
Null,
#[error(transparent)]
Cublas(#[from] cudarc::cublas::result::CublasError),
#[error(transparent)]
Curand(#[from] cudarc::curand::result::CurandError),
#[error("missing kernel '{module_name}'")]
MissingKernel { module_name: String },
#[error("unsupported dtype {dtype:?} for {op}")]
UnsupportedDtype { dtype: DType, op: &'static str },
#[error("internal error '{0}'")]
InternalError(&'static str),
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous {
lhs_stride: Vec<usize>,
rhs_stride: Vec<usize>,
mnk: (usize, usize, usize),
},
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedDType {
msg: &'static str,
expected: DType,
got: DType,
},
#[error("{cuda} when loading {module_name}")]
Load {
cuda: cudarc::driver::DriverError,
module_name: String,
},
} }
unsafe impl<T: DeviceRepr> DeviceRepr for &SlicePtrOrNull<T> { impl From<CudaError> for crate::Error {
fn as_kernel_param(&self) -> *mut std::ffi::c_void { fn from(val: CudaError) -> Self {
match self { crate::Error::Cuda(Box::new(val)).bt()
SlicePtrOrNull::Ptr(slice) => slice.as_kernel_param(),
SlicePtrOrNull::Null => 0usize.as_kernel_param(),
}
} }
} }
impl SlicePtrOrNull<usize> { /// Unique identifier for cuda devices.
fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
let ds = if l.is_contiguous() { pub struct DeviceId(usize);
SlicePtrOrNull::Null
} else { impl DeviceId {
SlicePtrOrNull::Ptr(dev.htod_copy([l.dims(), l.stride()].concat()).w()?) fn new() -> Self {
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
struct CudaRng(cudarc::curand::CudaRng);
unsafe impl Send for CudaRng {}
#[derive(Clone)]
pub struct CudaDevice {
id: DeviceId,
device: Arc<cudarc::driver::CudaDevice>,
blas: Arc<cudarc::cublas::CudaBlas>,
curand: Arc<Mutex<CudaRng>>,
}
impl std::fmt::Debug for CudaDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "CudaDevice({:?})", self.id)
}
}
impl std::ops::Deref for CudaDevice {
type Target = Arc<cudarc::driver::CudaDevice>;
fn deref(&self) -> &Self::Target {
&self.device
}
}
pub trait WrapErr<O> {
fn w(self) -> std::result::Result<O, crate::Error>;
}
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
fn w(self) -> std::result::Result<O, crate::Error> {
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())))
}
}
impl CudaDevice {
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
self.device.clone()
}
pub fn id(&self) -> DeviceId {
self.id
}
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let slice = match dtype {
DType::U8 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
let params = (&data, v as u8, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
let params = (&data, v as u32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
let params = (&data, v as i64, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
let params = (&data, bf16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
let params = (&data, f16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
let params = (&data, v as f32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
let params = (&data, v, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F64(data)
}
}; };
Ok(ds) Ok(CudaStorage {
slice,
device: self.clone(),
})
}
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
if !self.has_func(module_name, module_name) {
// Leaking the string here is a bit sad but we need a &'static str and this is only
// done once per kernel name.
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
self.load_ptx(ptx.into(), module_name, &[static_module_name])
.map_err(|cuda| CudaError::Load {
cuda,
module_name: module_name.to_string(),
})
.w()?;
}
self.get_func(module_name, module_name)
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
// able to only build the error value if needed.
.ok_or(CudaError::MissingKernel {
module_name: module_name.to_string(),
})
.w()
}
}
impl BackendDevice for CudaDevice {
type Storage = CudaStorage;
fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
Ok(Self {
id: DeviceId::new(),
device,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
})
}
fn set_seed(&self, seed: u64) -> Result<()> {
// We do not call set_seed but instead create a new curand object. This ensures that the
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
Ok(())
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.device.ordinal(),
}
}
fn same_device(&self, rhs: &Self) -> bool {
self.id == rhs.id
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc_zeros::<u8>(elem_count).w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc_zeros::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc_zeros::<i64>(elem_count).w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc_zeros::<f16>(elem_count).w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc_zeros::<f32>(elem_count).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc_zeros::<f64>(elem_count).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
})
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F64(data)
}
};
let slice = if lo == 0. && up == 1.0 {
slice
} else {
let layout = Layout::contiguous(shape);
Affine(up - lo, lo).map(&slice, self, &layout)?
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
// curand can only generate an odd number of values.
// https://github.com/huggingface/candle/issues/734
let elem_count_round = if elem_count % 2 == 1 {
elem_count + 1
} else {
elem_count
};
let slice = match dtype {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
self.const_impl(1., shape, dtype)
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
} }
} }
@ -53,6 +407,133 @@ pub enum CudaStorageSlice {
F32(CudaSlice<f32>), F32(CudaSlice<f32>),
F64(CudaSlice<f64>), F64(CudaSlice<f64>),
} }
type S = CudaStorageSlice;
pub trait Map1 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>>;
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U8(s) => S::U8(self.f(s, d, l)?),
S::U32(s) => S::U32(self.f(s, d, l)?),
S::I64(s) => S::I64(self.f(s, d, l)?),
S::BF16(s) => S::BF16(self.f(s, d, l)?),
S::F16(s) => S::F16(self.f(s, d, l)?),
S::F32(s) => S::F32(self.f(s, d, l)?),
S::F64(s) => S::F64(self.f(s, d, l)?),
};
Ok(out)
}
}
pub trait Map2 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>>;
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
};
Ok(out)
}
}
pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
) -> Result<()>;
fn map(
&self,
dst: &mut S,
dst_s: &Shape,
src: &S,
src_l: &Layout,
d: &CudaDevice,
) -> Result<()> {
match (dst, src) {
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
}
}
}
pub trait Map1Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
wrap: W,
) -> Result<S>;
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U8(s) => self.f(s, d, l, S::U8)?,
S::U32(s) => self.f(s, d, l, S::U32)?,
S::I64(s) => self.f(s, d, l, S::I64)?,
S::BF16(s) => self.f(s, d, l, S::BF16)?,
S::F16(s) => self.f(s, d, l, S::F16)?,
S::F32(s) => self.f(s, d, l, S::F32)?,
S::F64(s) => self.f(s, d, l, S::F64)?,
};
Ok(out)
}
}
pub trait Map2Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
dev: &CudaDevice,
) -> Result<S>;
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
};
Ok(out)
}
}
struct Clone; struct Clone;
impl Map1 for Clone { impl Map1 for Clone {
@ -83,7 +564,7 @@ impl Map1 for Affine {
let dims = shape.dims(); let dims = shape.dims();
let el = shape.elem_count(); let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32); let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..); let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?; let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
// SAFETY: Set later by running the kernel. // SAFETY: Set later by running the kernel.
@ -115,7 +596,7 @@ impl Map1 for Elu {
let dims = shape.dims(); let dims = shape.dims();
let el = shape.elem_count(); let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32); let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..); let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), kernels::UNARY)?; let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), kernels::UNARY)?;
// SAFETY: Set later by running the kernel. // SAFETY: Set later by running the kernel.
@ -238,7 +719,7 @@ impl Map1 for Powf {
let dims = shape.dims(); let dims = shape.dims();
let el = shape.elem_count(); let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32); let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..); let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?; let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?;
// SAFETY: Set later by running the kernel. // SAFETY: Set later by running the kernel.
@ -371,7 +852,7 @@ impl<U: UnaryOpT> Map1 for U {
let dims = shape.dims(); let dims = shape.dims();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32); let cfg = LaunchConfig::for_num_elems(el_count as u32);
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..); let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::UNARY)?; let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::UNARY)?;
// SAFETY: Set later by running the kernel. // SAFETY: Set later by running the kernel.
@ -668,55 +1149,6 @@ impl<'a> Map2 for Conv2D<'a> {
} }
} }
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
inp_l: &Layout,
k: &CudaSlice<T>,
k_l: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
// Kernel shape: (c_in_k, c_out, l_k)
// Input shape: (b_size, c_in, l_in)
let p = &self.0;
let l_out = p.l_out();
let dst_el = p.c_out * l_out * p.b_size;
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(k_l.start_offset()..);
let shape = inp_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), kernels::CONV)?;
let ds = if dims.len() == 3 {
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
} else {
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
let params = (
el,
l_out,
p.stride,
p.padding,
p.output_padding,
p.dilation,
&ds,
inp,
k,
&out,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D); struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
impl<'a> Map2 for ConvTranspose2D<'a> { impl<'a> Map2 for ConvTranspose2D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@ -921,14 +1353,9 @@ impl<U: crate::op::BinaryOpT> Map2 for U {
let dims = shape.dims(); let dims = shape.dims();
let elem_count = shape.elem_count(); let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32); let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() { let dims_and_strides = dev
SlicePtrOrNull::Null .htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
} else { .w()?;
SlicePtrOrNull::Ptr(
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
.w()?,
)
};
let lhs = &lhs.slice(lhs_l.start_offset()..); let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::BINARY)?; let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::BINARY)?;
@ -955,14 +1382,9 @@ impl Map2Any for Cmp {
let dims = shape.dims(); let dims = shape.dims();
let elem_count = shape.elem_count(); let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32); let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() { let dims_and_strides = dev
SlicePtrOrNull::Null .htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
} else { .w()?;
SlicePtrOrNull::Ptr(
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
.w()?,
)
};
let lhs = &lhs.slice(lhs_l.start_offset()..); let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..);
let name = match self.0 { let name = match self.0 {
@ -1070,30 +1492,26 @@ fn gemm_config<T>(
let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// The a tensor has dims batching, k, n (rhs) // The a tensor has dims batching, k, n (rhs)
// We also allow for the case where the stride on the minor dimension is not as expected but let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
// there is a single element.
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
(n as i32, cublasOperation_t::CUBLAS_OP_N) (n as i32, cublasOperation_t::CUBLAS_OP_N)
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) { } else if rhs_m1 == k && rhs_m2 == 1 {
(k as i32, cublasOperation_t::CUBLAS_OP_T) (k as i32, cublasOperation_t::CUBLAS_OP_T)
} else { } else {
Err(CudaError::MatMulNonContiguous { Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_l.clone(), lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_l.clone(), rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k), mnk: (m, n, k),
})? })?
}; };
// The b tensor has dims batching, m, k (lhs) // The b tensor has dims batching, m, k (lhs)
// We also allow for the case where the stride on the minor dimension is not as expected but let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
// there is a single element.
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
(k as i32, cublasOperation_t::CUBLAS_OP_N) (k as i32, cublasOperation_t::CUBLAS_OP_N)
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) { } else if lhs_m1 == m && lhs_m2 == 1 {
(m as i32, cublasOperation_t::CUBLAS_OP_T) (m as i32, cublasOperation_t::CUBLAS_OP_T)
} else { } else {
Err(CudaError::MatMulNonContiguous { Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_l.clone(), lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_l.clone(), rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k), mnk: (m, n, k),
})? })?
}; };
@ -1114,25 +1532,21 @@ fn gemm_config<T>(
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] { let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[_, stride] if lhs_l.dims()[0] == 1 => stride,
[stride, _] if lhs_l.dims()[1] == 1 => stride,
[stride] => stride, [stride] => stride,
[] => m * k, [] => m * k,
_ => Err(CudaError::MatMulNonContiguous { _ => Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_l.clone(), lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_l.clone(), rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k), mnk: (m, n, k),
})?, })?,
}; };
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] { let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[_, stride] if rhs_l.dims()[0] == 1 => stride,
[stride, _] if rhs_l.dims()[1] == 1 => stride,
[stride] => stride, [stride] => stride,
[] => n * k, [] => n * k,
_ => Err(CudaError::MatMulNonContiguous { _ => Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_l.clone(), lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_l.clone(), rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k), mnk: (m, n, k),
})?, })?,
}; };
@ -1177,7 +1591,7 @@ impl BackendStorage for CudaStorage {
let el = shape.elem_count(); let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32); let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device(); let dev = self.device();
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let start_o = layout.start_offset(); let start_o = layout.start_offset();
// This returns an i64 rather than a &i64, this is useful to get around some temporary // This returns an i64 rather than a &i64, this is useful to get around some temporary
// lifetime issue and is safe as long as self.slice does not go out of scope before inp // lifetime issue and is safe as long as self.slice does not go out of scope before inp
@ -1381,10 +1795,7 @@ impl BackendStorage for CudaStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else { } else {
// Make the kernel contiguous if not already the case. // Make the kernel contiguous if not already the case.
let mut kernel_c = unsafe { let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
self.device()
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
};
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)? .transpose(1, 2)?
@ -1392,22 +1803,19 @@ impl BackendStorage for CudaStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
}; };
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?; res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t) Ok(res_t)
} }
fn conv_transpose1d( fn conv_transpose1d(
&self, &self,
l: &Layout, _: &Layout,
kernel: &Self, _: &Self,
kernel_l: &Layout, _: &Layout,
params: &crate::conv::ParamsConvTranspose1D, _: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> { ) -> Result<Self> {
let device = self.device().clone(); todo!()
let slice =
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
Ok(Self { slice, device })
} }
#[cfg(not(feature = "cudnn"))] #[cfg(not(feature = "cudnn"))]
@ -1449,10 +1857,7 @@ impl BackendStorage for CudaStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else { } else {
// Make the kernel contiguous if not already the case. // Make the kernel contiguous if not already the case.
let mut kernel_c = unsafe { let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
self.device()
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
};
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)? .transpose(1, 2)?
@ -1462,7 +1867,7 @@ impl BackendStorage for CudaStorage {
let res_l = Layout::contiguous((b, h_out, w_out, n)) let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)? .transpose(1, 2)?
.transpose(1, 3)?; .transpose(1, 3)?;
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?; res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t) Ok(res_t)
} }
@ -1599,7 +2004,7 @@ impl BackendStorage for CudaStorage {
dim: usize, dim: usize,
) -> Result<Self> { ) -> Result<Self> {
let device = self.device().clone(); let device = self.device().clone();
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? }; let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?; self.copy_strided_src(&mut acc, 0, l)?;
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
Ok(acc) Ok(acc)
@ -1614,7 +2019,7 @@ impl BackendStorage for CudaStorage {
dim: usize, dim: usize,
) -> Result<Self> { ) -> Result<Self> {
let device = self.device().clone(); let device = self.device().clone();
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? }; let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?; self.copy_strided_src(&mut acc, 0, l)?;
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
Ok(acc) Ok(acc)
@ -1688,72 +2093,6 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device }) Ok(Self { slice, device })
} }
fn copy2d(
&self,
dst: &mut Self,
d1: usize,
d2: usize,
src_s: usize,
dst_s: usize,
src_o: usize,
dst_o: usize,
) -> Result<()> {
let dev = &self.device;
let d1 = d1 as u32;
let d2 = d2 as u32;
// Nothing to copy so we exit early to avoid launching a kernel and some potential invalid
// argument with a null pointer.
if d1 == 0 || d2 == 0 {
return Ok(());
}
let dst_s = dst_s as u32;
let src_s = src_s as u32;
let (src, dst, kname) = match (&self.slice, &mut dst.slice) {
(S::U8(s), S::U8(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_u8",
),
(S::U32(s), S::U32(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_u32",
),
(S::I64(s), S::I64(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_i64",
),
(S::BF16(s), S::BF16(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_bf16",
),
(S::F16(s), S::F16(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_f16",
),
(S::F32(s), S::F32(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_f32",
),
(S::F64(s), S::F64(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_f64",
),
_ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?,
};
let func = dev.get_or_load_func(kname, kernels::FILL)?;
let cfg = LaunchConfig::for_num_elems(d1 * d2);
let params = (src, dst, d1, d2, src_s, dst_s);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(())
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let src_shape = src_l.shape(); let src_shape = src_l.shape();
let dims = src_shape.dims(); let dims = src_shape.dims();
@ -1763,7 +2102,7 @@ impl BackendStorage for CudaStorage {
} }
let cfg = LaunchConfig::for_num_elems(el_count as u32); let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device; let dev = &self.device;
let ds = SlicePtrOrNull::params_from_layout(dev, src_l)?; let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
match (&self.slice, &mut dst.slice) { match (&self.slice, &mut dst.slice) {
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => { (CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);

View File

@ -1,415 +0,0 @@
use crate::backend::BackendDevice;
use crate::{CpuStorage, DType, Layout, Result, Shape};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
use half::{bf16, f16};
use std::sync::{Arc, Mutex};
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct DeviceId(usize);
impl DeviceId {
fn new() -> Self {
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
struct CudaRng(cudarc::curand::CudaRng);
unsafe impl Send for CudaRng {}
#[derive(Clone)]
pub struct CudaDevice {
id: DeviceId,
device: Arc<cudarc::driver::CudaDevice>,
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
curand: Arc<Mutex<CudaRng>>,
}
impl std::fmt::Debug for CudaDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "CudaDevice({:?})", self.id)
}
}
impl std::ops::Deref for CudaDevice {
type Target = Arc<cudarc::driver::CudaDevice>;
fn deref(&self) -> &Self::Target {
&self.device
}
}
impl CudaDevice {
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
self.device.clone()
}
pub fn id(&self) -> DeviceId {
self.id
}
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let slice = match dtype {
DType::U8 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
let params = (&data, v as u8, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
let params = (&data, v as u32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
let params = (&data, v as i64, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
let params = (&data, bf16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
let params = (&data, f16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
let params = (&data, v as f32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
let params = (&data, v, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
if !self.has_func(module_name, module_name) {
// Leaking the string here is a bit sad but we need a &'static str and this is only
// done once per kernel name.
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
self.load_ptx(ptx.into(), module_name, &[static_module_name])
.map_err(|cuda| CudaError::Load {
cuda,
module_name: module_name.to_string(),
})
.w()?;
}
self.get_func(module_name, module_name)
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
// able to only build the error value if needed.
.ok_or(CudaError::MissingKernel {
module_name: module_name.to_string(),
})
.w()
}
}
impl BackendDevice for CudaDevice {
type Storage = CudaStorage;
fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
Ok(Self {
id: DeviceId::new(),
device,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
})
}
fn set_seed(&self, seed: u64) -> Result<()> {
// We do not call set_seed but instead create a new curand object. This ensures that the
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
Ok(())
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.device.ordinal(),
}
}
fn same_device(&self, rhs: &Self) -> bool {
self.id == rhs.id
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc_zeros::<u8>(elem_count).w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc_zeros::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc_zeros::<i64>(elem_count).w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc_zeros::<f16>(elem_count).w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc_zeros::<f32>(elem_count).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc_zeros::<f64>(elem_count).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
})
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F64(data)
}
};
let slice = if lo == 0. && up == 1.0 {
slice
} else {
use super::utils::Map1;
let layout = Layout::contiguous(shape);
super::Affine(up - lo, lo).map(&slice, self, &layout)?
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
// curand can only generate an odd number of values.
// https://github.com/huggingface/candle/issues/734
let elem_count_round = if elem_count % 2 == 1 {
elem_count + 1
} else {
elem_count
};
let slice = match dtype {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
self.const_impl(1., shape, dtype)
}
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc::<u8>(elem_count).w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc::<i64>(elem_count).w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc::<bf16>(elem_count).w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc::<f16>(elem_count).w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc::<f32>(elem_count).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc::<f64>(elem_count).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn synchronize(&self) -> Result<()> {
self.device.synchronize().map_err(crate::Error::wrap)?;
Ok(())
}
}

View File

@ -1,62 +0,0 @@
use crate::{DType, Layout};
/// cudarc related errors
#[derive(thiserror::Error, Debug)]
pub enum CudaError {
#[error(transparent)]
Cuda(#[from] cudarc::driver::DriverError),
#[error(transparent)]
Compiler(#[from] cudarc::nvrtc::CompileError),
#[error(transparent)]
Cublas(#[from] cudarc::cublas::result::CublasError),
#[error(transparent)]
Curand(#[from] cudarc::curand::result::CurandError),
#[error("missing kernel '{module_name}'")]
MissingKernel { module_name: String },
#[error("unsupported dtype {dtype:?} for {op}")]
UnsupportedDtype { dtype: DType, op: &'static str },
#[error("internal error '{0}'")]
InternalError(&'static str),
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous {
lhs_stride: Layout,
rhs_stride: Layout,
mnk: (usize, usize, usize),
},
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedDType {
msg: &'static str,
expected: DType,
got: DType,
},
#[error("{cuda} when loading {module_name}")]
Load {
cuda: cudarc::driver::DriverError,
module_name: String,
},
}
impl From<CudaError> for crate::Error {
fn from(val: CudaError) -> Self {
crate::Error::Cuda(Box::new(val)).bt()
}
}
pub trait WrapErr<O> {
fn w(self) -> std::result::Result<O, crate::Error>;
}
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
fn w(self) -> std::result::Result<O, crate::Error> {
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt())
}
}

View File

@ -1,134 +0,0 @@
/// Helper functions to plug cuda kernels in candle.
use crate::{Layout, Result, Shape, WithDType};
pub use cudarc;
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
use super::{CudaDevice, CudaError, WrapErr};
pub type S = super::CudaStorageSlice;
pub trait Map1 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>>;
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U8(s) => S::U8(self.f(s, d, l)?),
S::U32(s) => S::U32(self.f(s, d, l)?),
S::I64(s) => S::I64(self.f(s, d, l)?),
S::BF16(s) => S::BF16(self.f(s, d, l)?),
S::F16(s) => S::F16(self.f(s, d, l)?),
S::F32(s) => S::F32(self.f(s, d, l)?),
S::F64(s) => S::F64(self.f(s, d, l)?),
};
Ok(out)
}
}
pub trait Map2 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>>;
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
};
Ok(out)
}
}
pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
) -> Result<()>;
fn map(
&self,
dst: &mut S,
dst_s: &Shape,
src: &S,
src_l: &Layout,
d: &CudaDevice,
) -> Result<()> {
match (dst, src) {
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
}
}
}
pub trait Map1Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
wrap: W,
) -> Result<S>;
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U8(s) => self.f(s, d, l, S::U8)?,
S::U32(s) => self.f(s, d, l, S::U32)?,
S::I64(s) => self.f(s, d, l, S::I64)?,
S::BF16(s) => self.f(s, d, l, S::BF16)?,
S::F16(s) => self.f(s, d, l, S::F16)?,
S::F32(s) => self.f(s, d, l, S::F32)?,
S::F64(s) => self.f(s, d, l, S::F64)?,
};
Ok(out)
}
}
pub trait Map2Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
dev: &CudaDevice,
) -> Result<S>;
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
};
Ok(out)
}
}

View File

@ -1,377 +0,0 @@
use crate::op::{BackpropOp, Op};
use crate::tensor::from_storage;
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
use std::sync::Arc;
/// Unary ops that can be defined in user-land.
pub trait CustomOp1 {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp2 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp3 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_arg3: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
impl Tensor {
/// Applies a unary custom op without backward support
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a binary custom op without backward support
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
let (storage, shape) =
self.storage()
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a ternary custom op without backward support
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c,
)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a unary custom op.
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
let (storage, shape) = self
.storage()
.apply_op1(self.layout(), c.as_ref().as_ref())?;
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
self.apply_op1_arc(Arc::new(Box::new(c)))
}
/// Applies a binary custom op.
pub fn apply_op2_arc(
&self,
rhs: &Self,
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op2(
self.layout(),
&rhs.storage(),
rhs.layout(),
c.as_ref().as_ref(),
)?;
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
self.apply_op2_arc(r, Arc::new(Box::new(c)))
}
/// Applies a ternary custom op.
pub fn apply_op3_arc(
&self,
t2: &Self,
t3: &Self,
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c.as_ref().as_ref(),
)?;
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
Op::CustomOp3(t1, t2, t3, c.clone())
});
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
&self,
t2: &Self,
t3: &Self,
c: C,
) -> Result<Self> {
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
}
// In place ops.
/// Unary ops that can be defined in user-land.
/// These ops work in place and as such back-prop is unsupported.
pub trait InplaceOp1 {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
}
pub trait InplaceOp2 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout)
-> Result<()>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &mut MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<()> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
}
pub trait InplaceOp3 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &mut CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<()>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &mut CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<()> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &mut MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<()> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
}
impl Tensor {
/// Applies a unary custom op in place.
pub fn inplace_op1<C: InplaceOp1>(&self, c: &C) -> Result<()> {
self.storage_mut().inplace_op1(self.layout(), c)
}
/// Applies a unary custom op in place (for the first tensor).
pub fn inplace_op2<C: InplaceOp2>(&self, rhs: &Self, c: &C) -> Result<()> {
self.storage_mut()
.inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c)
}
/// Applies a ternary custom op in place (for the first tensor).
pub fn inplace_op3<C: InplaceOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> {
self.storage_mut().inplace_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c,
)
}
}

View File

@ -289,34 +289,17 @@ impl Device {
} }
} }
pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.alloc_uninit(shape, dtype)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.alloc_uninit(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.alloc_uninit(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> { pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
match self { match self {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())), Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
Device::Cuda(device) => { Device::Cuda(device) => {
let storage = array.to_cpu_storage(); let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage_owned(storage)?; let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage)) Ok(Storage::Cuda(storage))
} }
Device::Metal(device) => { Device::Metal(device) => {
let storage = array.to_cpu_storage(); let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage_owned(storage)?; let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage)) Ok(Storage::Metal(storage))
} }
} }
@ -327,22 +310,14 @@ impl Device {
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))), Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
Device::Cuda(device) => { Device::Cuda(device) => {
let storage = S::to_cpu_storage_owned(data); let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage_owned(storage)?; let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage)) Ok(Storage::Cuda(storage))
} }
Device::Metal(device) => { Device::Metal(device) => {
let storage = S::to_cpu_storage_owned(data); let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage_owned(storage)?; let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage)) Ok(Storage::Metal(storage))
} }
} }
} }
pub fn synchronize(&self) -> Result<()> {
match self {
Self::Cpu => Ok(()),
Self::Cuda(d) => d.synchronize(),
Self::Metal(d) => d.synchronize(),
}
}
} }

View File

@ -65,13 +65,12 @@ impl std::fmt::Debug for Tensor {
} }
/// Options for Tensor pretty printing /// Options for Tensor pretty printing
#[derive(Debug, Clone)]
pub struct PrinterOptions { pub struct PrinterOptions {
pub precision: usize, precision: usize,
pub threshold: usize, threshold: usize,
pub edge_items: usize, edge_items: usize,
pub line_width: usize, line_width: usize,
pub sci_mode: Option<bool>, sci_mode: Option<bool>,
} }
static PRINT_OPTS: std::sync::Mutex<PrinterOptions> = static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
@ -90,10 +89,6 @@ impl PrinterOptions {
} }
} }
pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
&PRINT_OPTS
}
pub fn set_print_options(options: PrinterOptions) { pub fn set_print_options(options: PrinterOptions) {
*PRINT_OPTS.lock().unwrap() = options *PRINT_OPTS.lock().unwrap() = options
} }
@ -122,26 +117,6 @@ pub fn set_print_options_full() {
} }
} }
pub fn set_line_width(line_width: usize) {
PRINT_OPTS.lock().unwrap().line_width = line_width
}
pub fn set_precision(precision: usize) {
PRINT_OPTS.lock().unwrap().precision = precision
}
pub fn set_edge_items(edge_items: usize) {
PRINT_OPTS.lock().unwrap().edge_items = edge_items
}
pub fn set_threshold(threshold: usize) {
PRINT_OPTS.lock().unwrap().threshold = threshold
}
pub fn set_sci_mode(sci_mode: Option<bool>) {
PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
}
struct FmtSize { struct FmtSize {
current_size: usize, current_size: usize,
} }

View File

@ -23,15 +23,7 @@ pub enum DType {
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct DTypeParseError(String); pub struct DTypeParseError;
impl std::fmt::Display for DTypeParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "cannot parse '{}' as a dtype", self.0)
}
}
impl std::error::Error for DTypeParseError {}
impl std::str::FromStr for DType { impl std::str::FromStr for DType {
type Err = DTypeParseError; type Err = DTypeParseError;
@ -44,7 +36,7 @@ impl std::str::FromStr for DType {
"f16" => Ok(Self::F16), "f16" => Ok(Self::F16),
"f32" => Ok(Self::F32), "f32" => Ok(Self::F32),
"f64" => Ok(Self::F64), "f64" => Ok(Self::F64),
_ => Err(DTypeParseError(s.to_string())), _ => Err(DTypeParseError),
} }
} }
} }

View File

@ -154,19 +154,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
fn copy2d(
&self,
_: &mut Self,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> { fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
@ -210,18 +197,10 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> { fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> { fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
@ -229,8 +208,4 @@ impl crate::backend::BackendDevice for CudaDevice {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> { fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
fn synchronize(&self) -> Result<()> {
Ok(())
}
} }

View File

@ -166,19 +166,6 @@ impl crate::backend::BackendStorage for MetalStorage {
Err(Error::NotCompiledWithMetalSupport) Err(Error::NotCompiledWithMetalSupport)
} }
fn copy2d(
&self,
_: &mut Self,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> { fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport) Err(Error::NotCompiledWithMetalSupport)
} }
@ -222,18 +209,10 @@ impl crate::backend::BackendDevice for MetalDevice {
Err(Error::NotCompiledWithMetalSupport) Err(Error::NotCompiledWithMetalSupport)
} }
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> { fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport) Err(Error::NotCompiledWithMetalSupport)
} }
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> { fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport) Err(Error::NotCompiledWithMetalSupport)
} }
@ -241,8 +220,4 @@ impl crate::backend::BackendDevice for MetalDevice {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> { fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport) Err(Error::NotCompiledWithMetalSupport)
} }
fn synchronize(&self) -> Result<()> {
Ok(())
}
} }

View File

@ -70,7 +70,7 @@ impl Layout {
self.shape.is_fortran_contiguous(&self.stride) self.shape.is_fortran_contiguous(&self.stride)
} }
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> { pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
let dims = self.shape().dims(); let dims = self.shape().dims();
if dim >= dims.len() { if dim >= dims.len() {
Err(Error::DimOutOfRange { Err(Error::DimOutOfRange {
@ -99,7 +99,7 @@ impl Layout {
}) })
} }
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> { pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
let rank = self.shape.rank(); let rank = self.shape.rank();
if rank <= dim1 || rank <= dim2 { if rank <= dim1 || rank <= dim2 {
Err(Error::UnexpectedNumberOfDims { Err(Error::UnexpectedNumberOfDims {
@ -120,7 +120,7 @@ impl Layout {
}) })
} }
pub fn permute(&self, idxs: &[usize]) -> Result<Self> { pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
let is_permutation = let is_permutation =
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i)); idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
if !is_permutation { if !is_permutation {

View File

@ -14,7 +14,7 @@
//! //!
//! ## Features //! ## Features
//! //!
//! - Simple syntax (looks and feels like PyTorch) //! - Simple syntax (looks and like PyTorch)
//! - CPU and Cuda backends (and M1 support) //! - CPU and Cuda backends (and M1 support)
//! - Enable serverless (CPU) small and fast deployments //! - Enable serverless (CPU) small and fast deployments
//! - Model training //! - Model training
@ -37,13 +37,14 @@
mod accelerate; mod accelerate;
pub mod backend; pub mod backend;
pub mod backprop; pub mod backprop;
pub mod conv; mod conv;
mod convert; mod convert;
pub mod cpu; pub mod cpu;
pub mod cpu_backend; pub mod cpu_backend;
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
pub mod cuda_backend; pub mod cuda_backend;
mod custom_op; #[cfg(feature = "cudnn")]
pub mod cudnn;
mod device; mod device;
pub mod display; pub mod display;
mod dtype; mod dtype;
@ -57,7 +58,7 @@ pub mod metal_backend;
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
mod mkl; mod mkl;
pub mod npy; pub mod npy;
pub mod op; mod op;
pub mod pickle; pub mod pickle;
pub mod quantized; pub mod quantized;
pub mod safetensors; pub mod safetensors;
@ -66,21 +67,17 @@ pub mod shape;
mod storage; mod storage;
mod strided_index; mod strided_index;
mod tensor; mod tensor;
mod tensor_cat;
pub mod test_utils; pub mod test_utils;
pub mod utils; pub mod utils;
mod variable; mod variable;
#[cfg(feature = "cudnn")]
pub use cuda_backend::cudnn;
pub use cpu_backend::CpuStorage; pub use cpu_backend::CpuStorage;
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
pub use device::{Device, DeviceLocation, NdArray}; pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType}; pub use dtype::{DType, FloatDType, IntDType, WithDType};
pub use error::{Error, Result}; pub use error::{Error, Result};
pub use indexer::IndexOp; pub use indexer::IndexOp;
pub use layout::Layout; pub use layout::Layout;
pub use op::{CustomOp1, CustomOp2, CustomOp3};
pub use shape::{Shape, D}; pub use shape::{Shape, D};
pub use storage::Storage; pub use storage::Storage;
pub use strided_index::{StridedBlocks, StridedIndex}; pub use strided_index::{StridedBlocks, StridedIndex};
@ -132,15 +129,6 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
} }
} }
impl<M: Module> Module for Option<&M> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
None => Ok(xs.clone()),
Some(m) => m.forward(xs),
}
}
}
// A trait defining a module with forward method using a single tensor argument and a flag to // A trait defining a module with forward method using a single tensor argument and a flag to
// separate the training and evaluation behaviors. // separate the training and evaluation behaviors.
pub trait ModuleT { pub trait ModuleT {

View File

@ -1,287 +0,0 @@
use crate::{DType, Result};
use candle_metal_kernels::Kernels;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};
use super::MetalError;
/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct DeviceId(usize);
impl DeviceId {
pub(crate) fn new() -> Self {
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
#[derive(Clone)]
pub struct MetalDevice {
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
/// the device itself.
pub(crate) id: DeviceId,
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
pub(crate) device: metal::Device,
/// Single command queue for the entire device.
pub(crate) command_queue: CommandQueue,
/// One command buffer at a time.
/// The scheduler works by allowing multiple
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
/// to start to work).
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
/// for their START time, but there's no guarantee that command buffer1 will finish before
/// command buffer2 starts (or there are metal bugs there)
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
/// Keeps track of the current amount of compute command encoders on the current
/// command buffer
/// Arc, RwLock because of the interior mutability.
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
pub(crate) compute_per_buffer: usize,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`]
pub(crate) kernels: Arc<Kernels>,
/// Simple allocator struct.
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
/// (could be linked to FFI communication overhead).
///
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
/// to be reused. However, in order for this to work, we need to guarantee the order of
/// operation, so that this buffer is not being used by another kernel at the same time.
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
///
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
/// (strong_count = 1).
pub(crate) buffers: AllocatedBuffers,
/// Seed for random number generation.
pub(crate) seed: Arc<Mutex<Buffer>>,
}
impl std::fmt::Debug for MetalDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetalDevice({:?})", self.id)
}
}
impl std::ops::Deref for MetalDevice {
type Target = metal::DeviceRef;
fn deref(&self) -> &Self::Target {
&self.device
}
}
impl MetalDevice {
pub fn id(&self) -> DeviceId {
self.id
}
pub fn metal_device(&self) -> &metal::Device {
&self.device
}
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
}
pub fn command_buffer(&self) -> Result<CommandBuffer> {
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
let mut command_buffer = command_buffer_lock.to_owned();
let mut index = self
.command_buffer_index
.try_write()
.map_err(MetalError::from)?;
if *index > self.compute_per_buffer {
command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned();
*command_buffer_lock = command_buffer.clone();
*index = 0;
self.drop_unused_buffers()?;
}
*index += 1;
Ok(command_buffer)
}
pub fn wait_until_completed(&self) -> Result<()> {
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
match command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => {
panic!("Already committed");
}
_ => {}
}
command_buffer.commit();
command_buffer.wait_until_completed();
*command_buffer = self.command_queue.new_command_buffer().to_owned();
Ok(())
}
pub fn kernels(&self) -> &Kernels {
&self.kernels
}
pub fn device(&self) -> &metal::Device {
&self.device
}
/// Creates a new buffer (not necessarily zeroed).
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
/// This means the buffer data cannot be read on the CPU directly.
///
/// [`name`] is only used to keep track of the resource origin in case of bugs
pub fn new_buffer(
&self,
element_count: usize,
dtype: DType,
name: &str,
) -> Result<Arc<Buffer>> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
}
/// Creates a new buffer (not necessarily zeroed).
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
/// This means the buffer can be read on the CPU but will require manual
/// synchronization when the CPU memory is modified
/// Used as a bridge to gather data back from the GPU
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
}
/// Creates a new buffer from data.
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
///
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
let size = core::mem::size_of_val(data) as NSUInteger;
let new_buffer = self.device.new_buffer_with_data(
data.as_ptr() as *const c_void,
size,
MTLResourceOptions::StorageModeManaged,
);
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
let subbuffers = buffers
.entry((size, MTLResourceOptions::StorageModeManaged))
.or_insert(vec![]);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
Ok(new_buffer)
}
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
let buffer = self.allocate_buffer(
size_in_bytes as NSUInteger,
MTLResourceOptions::StorageModePrivate,
"allocate_zeros",
)?;
let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder();
blit.fill_buffer(
&buffer,
metal::NSRange {
location: 0,
length: buffer.length(),
},
0,
);
blit.end_encoding();
Ok(buffer)
}
fn find_available_buffer(
&self,
size: NSUInteger,
option: MTLResourceOptions,
buffers: &RwLockWriteGuard<BufferMap>,
) -> Option<Arc<Buffer>> {
let mut best_buffer: Option<&Arc<Buffer>> = None;
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
for sub in subbuffers {
if Arc::strong_count(sub) == 1 {
best_buffer = Some(sub);
best_buffer_size = *buffer_size;
}
}
}
}
best_buffer.cloned()
}
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(*s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(())
}
/// The critical allocator algorithm
fn allocate_buffer(
&self,
size: NSUInteger,
option: MTLResourceOptions,
_name: &str,
) -> Result<Arc<Buffer>> {
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
// Cloning also ensures we increment the strong count
return Ok(b.clone());
}
let size = buf_size(size);
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
Ok(new_buffer)
}
/// Create a metal GPU capture trace on [`path`].
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let capture = metal::CaptureManager::shared();
let descriptor = metal::CaptureDescriptor::new();
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
descriptor.set_capture_device(self);
descriptor.set_output_url(path);
capture
.start_capture(&descriptor)
.map_err(MetalError::from)?;
Ok(())
}
}
fn buf_size(size: NSUInteger) -> NSUInteger {
size.saturating_sub(1).next_power_of_two() as NSUInteger
}

View File

@ -333,16 +333,6 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) } unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
} }
#[inline]
pub fn vs_exp_inplace(y: &mut [f32]) {
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_exp_inplace(y: &mut [f64]) {
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline] #[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) { pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) { for (&v, y) in vs.iter().zip(ys.iter_mut()) {
@ -365,28 +355,6 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
} }
} }
#[inline]
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vs_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
#[inline]
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vd_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
macro_rules! binary_op { macro_rules! binary_op {
($fn_name:ident, $ty:ty, $mkl_name:ident) => { ($fn_name:ident, $ty:ty, $mkl_name:ident) => {
#[inline] #[inline]

View File

@ -1,5 +1,5 @@
#![allow(clippy::redundant_closure_call)] #![allow(clippy::redundant_closure_call)]
use crate::Tensor; use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
use half::{bf16, f16}; use half::{bf16, f16};
use num_traits::float::Float; use num_traits::float::Float;
@ -61,12 +61,10 @@ pub enum UnaryOp {
GeluErf, GeluErf,
Erf, Erf,
Relu, Relu,
Silu,
Tanh, Tanh,
Floor, Floor,
Ceil, Ceil,
Round, Round,
Sign,
} }
#[derive(Clone)] #[derive(Clone)]
@ -133,10 +131,7 @@ pub enum Op {
stride: (usize, usize), stride: (usize, usize),
}, },
UpsampleNearest1D { UpsampleNearest1D(Tensor),
arg: Tensor,
target_size: usize,
},
UpsampleNearest2D { UpsampleNearest2D {
arg: Tensor, arg: Tensor,
target_h: usize, target_h: usize,
@ -162,23 +157,168 @@ pub enum Op {
Permute(Tensor, Vec<usize>), Permute(Tensor, Vec<usize>),
Elu(Tensor, f64), Elu(Tensor, f64),
Powf(Tensor, f64), Powf(Tensor, f64),
CustomOp1( CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
Tensor,
std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
),
CustomOp2( CustomOp2(
Tensor, Tensor,
Tensor, Tensor,
std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>, std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
), ),
CustomOp3( CustomOp3(
Tensor, Tensor,
Tensor, Tensor,
Tensor, Tensor,
std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>, std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
), ),
} }
/// Unary ops that can be defined in user-land.
pub trait CustomOp1 {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp2 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp3 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_arg3: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait UnaryOpT { pub trait UnaryOpT {
const NAME: &'static str; const NAME: &'static str;
const KERNEL: &'static str; const KERNEL: &'static str;
@ -250,12 +390,10 @@ pub(crate) struct Gelu;
pub(crate) struct GeluErf; pub(crate) struct GeluErf;
pub(crate) struct Erf; pub(crate) struct Erf;
pub(crate) struct Relu; pub(crate) struct Relu;
pub(crate) struct Silu;
pub(crate) struct Tanh; pub(crate) struct Tanh;
pub(crate) struct Floor; pub(crate) struct Floor;
pub(crate) struct Ceil; pub(crate) struct Ceil;
pub(crate) struct Round; pub(crate) struct Round;
pub(crate) struct Sign;
macro_rules! bin_op { macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => { ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
@ -459,13 +597,6 @@ unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr); unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt); unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
// Hardcode the value for sqrt(2/pi)
// https://github.com/huggingface/candle/issues/1982
#[allow(clippy::excessive_precision)]
const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
#[allow(clippy::excessive_precision)]
const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
/// Tanh based approximation of the `gelu` operation /// Tanh based approximation of the `gelu` operation
/// GeluErf is the more precise one. /// GeluErf is the more precise one.
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions> /// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
@ -478,7 +609,7 @@ impl UnaryOpT for Gelu {
* v * v
* (bf16::ONE * (bf16::ONE
+ bf16::tanh( + bf16::tanh(
bf16::from_f32_const(SQRT_TWO_OVER_PI_F32) (bf16::from_f32_const(2.0) / bf16::PI).sqrt()
* v * v
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v), * (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
)) ))
@ -489,18 +620,22 @@ impl UnaryOpT for Gelu {
* v * v
* (f16::ONE * (f16::ONE
+ f16::tanh( + f16::tanh(
f16::from_f32_const(SQRT_TWO_OVER_PI_F32) (f16::from_f32_const(2.0) / f16::PI).sqrt()
* v * v
* (f16::ONE + f16::from_f32_const(0.044715) * v * v), * (f16::ONE + f16::from_f32_const(0.044715) * v * v),
)) ))
} }
#[inline(always)] #[inline(always)]
fn f32(v: f32) -> f32 { fn f32(v: f32) -> f32 {
0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v))) 0.5 * v
* (1.0
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
} }
#[inline(always)] #[inline(always)]
fn f64(v: f64) -> f64 { fn f64(v: f64) -> f64 {
0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v))) 0.5 * v
* (1.0
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
} }
#[inline(always)] #[inline(always)]
fn u8(_: u8) -> u8 { fn u8(_: u8) -> u8 {
@ -589,77 +724,6 @@ impl UnaryOpT for Erf {
} }
} }
/// Silu operation
impl UnaryOpT for Silu {
const NAME: &'static str = "silu";
const V: Self = Silu;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v / (bf16::ONE + (-v).exp())
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v / (f16::ONE + (-v).exp())
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v / (1.0 + (-v).exp())
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v / (1.0 + (-v).exp())
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
const KERNEL: &'static str = "usilu";
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::vs_silu(xs, ys)
}
#[cfg(feature = "mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_silu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::accelerate::vs_silu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::accelerate::vd_silu(xs, ys)
}
}
impl UnaryOpT for Abs { impl UnaryOpT for Abs {
const NAME: &'static str = "abs"; const NAME: &'static str = "abs";
const KERNEL: &'static str = "uabs"; const KERNEL: &'static str = "uabs";
@ -927,37 +991,3 @@ impl std::ops::Deref for BackpropOp {
&self.0 &self.0
} }
} }
impl UnaryOpT for Sign {
const NAME: &'static str = "sign";
const KERNEL: &'static str = "usign";
const V: Self = Sign;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
}
#[inline(always)]
fn f32(v: f32) -> f32 {
f32::from(v > 0.) - f32::from(v < 0.)
}
#[inline(always)]
fn f64(v: f64) -> f64 {
f64::from(v > 0.) - f64::from(v < 0.)
}
#[inline(always)]
fn u8(v: u8) -> u8 {
u8::min(1, v)
}
#[inline(always)]
fn u32(v: u32) -> u32 {
u32::min(1, v)
}
#[inline(always)]
fn i64(v: i64) -> i64 {
(v > 0) as i64 - (v < 0) as i64
}
}

View File

@ -42,7 +42,7 @@ pub enum OpCode {
Stop = b'.', Stop = b'.',
NewObj = 0x81, NewObj = 0x81,
EmptyList = b']', EmptyList = b']',
BinFloat = b'G', BinFloat = b'g',
Append = b'a', Append = b'a',
Appends = b'e', Appends = b'e',
} }
@ -217,13 +217,6 @@ impl Object {
let args = args.remove(1); let args = args.remove(1);
(callable, args) (callable, args)
} }
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
let mut args = args.tuple()?;
args.remove(0).reduce()?
}
_ => (callable, args), _ => (callable, args),
}; };
match callable { match callable {
@ -234,11 +227,13 @@ impl Object {
_ => return Ok(None), _ => return Ok(None),
}; };
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?; let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
let mut path = dir_name.to_path_buf();
path.push(file_path);
Ok(Some(TensorInfo { Ok(Some(TensorInfo {
name, name,
dtype, dtype,
layout, layout,
path: format!("{}/{}", dir_name.to_string_lossy(), file_path), path: path.to_string_lossy().into_owned(),
storage_size, storage_size,
})) }))
} }
@ -350,10 +345,8 @@ impl Stack {
module_name, module_name,
class_name, class_name,
} => { } => {
if module_name == "collections" if module_name == "collections" && class_name == "OrderedDict" {
&& (class_name == "OrderedDict" || class_name == "defaultdict") // TODO: have a separate ordered dict.
{
// TODO: have a separate ordered dict and a separate default dict.
Some(Object::Dict(vec![])) Some(Object::Dict(vec![]))
} else { } else {
None None
@ -462,10 +455,7 @@ impl Stack {
self.push(Object::Int(arg)) self.push(Object::Int(arg))
} }
OpCode::BinFloat => { OpCode::BinFloat => {
// Somehow floats are encoded using BigEndian whereas int types use LittleEndian. let arg = r.read_f64::<LittleEndian>()?;
// https://github.com/python/cpython/blob/0c80da4c14d904a367968955544dd6ae58c8101c/Lib/pickletools.py#L855
// https://github.com/pytorch/pytorch/blob/372d078f361e726bb4ac0884ac334b04c58179ef/torch/_weights_only_unpickler.py#L243
let arg = r.read_f64::<byteorder::BigEndian>()?;
self.push(Object::Float(arg)) self.push(Object::Float(arg))
} }
OpCode::BinUnicode => { OpCode::BinUnicode => {
@ -637,16 +627,9 @@ pub struct TensorInfo {
pub storage_size: usize, pub storage_size: usize,
} }
/// Read the tensor info from a .pth file.
///
/// # Arguments
/// * `file` - The path to the .pth file.
/// * `verbose` - Whether to print debug information.
/// * `key` - Optional key to retrieve `state_dict` from the pth file.
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>( pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
file: P, file: P,
verbose: bool, verbose: bool,
key: Option<&str>,
) -> Result<Vec<TensorInfo>> { ) -> Result<Vec<TensorInfo>> {
let file = std::fs::File::open(file)?; let file = std::fs::File::open(file)?;
let zip_reader = std::io::BufReader::new(file); let zip_reader = std::io::BufReader::new(file);
@ -668,9 +651,8 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
stack.read_loop(&mut reader)?; stack.read_loop(&mut reader)?;
let obj = stack.finalize()?; let obj = stack.finalize()?;
if VERBOSE || verbose { if VERBOSE || verbose {
println!("{obj:#?}"); println!("{obj:?}");
} }
let obj = match obj { let obj = match obj {
Object::Build { callable, args } => match *callable { Object::Build { callable, args } => match *callable {
Object::Reduce { callable, args: _ } => match *callable { Object::Reduce { callable, args: _ } => match *callable {
@ -684,24 +666,6 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
}, },
obj => obj, obj => obj,
}; };
// If key is provided, then we need to extract the state_dict from the object.
let obj = if let Some(key) = key {
if let Object::Dict(key_values) = obj {
key_values
.into_iter()
.find(|(k, _)| *k == Object::Unicode(key.to_owned()))
.map(|(_, v)| v)
.ok_or_else(|| E::Msg(format!("key {key} not found")))?
} else {
obj
}
} else {
obj
};
// If the object is a dict, then we can extract the tensor info from it.
// NOTE: We are assuming that the `obj` is state_dict by this stage.
if let Object::Dict(key_values) = obj { if let Object::Dict(key_values) = obj {
for (name, value) in key_values.into_iter() { for (name, value) in key_values.into_iter() {
match value.into_tensor_info(name, &dir_name) { match value.into_tensor_info(name, &dir_name) {
@ -724,8 +688,8 @@ pub struct PthTensors {
} }
impl PthTensors { impl PthTensors {
pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> { pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?; let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
let tensor_infos = tensor_infos let tensor_infos = tensor_infos
.into_iter() .into_iter()
.map(|ti| (ti.name.to_string(), ti)) .map(|ti| (ti.name.to_string(), ti))
@ -748,12 +712,10 @@ impl PthTensors {
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?); let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?; let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_name(&tensor_info.path)?; let mut reader = zip.by_name(&tensor_info.path)?;
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
let rank = tensor_info.layout.shape().rank();
// Reading the data is a bit tricky as it can be strided, for now only support the basic // Reading the data is a bit tricky as it can be strided, for now only support the basic
// case and when the tensor is fortran contiguous. // case.
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous { if !tensor_info.layout.is_contiguous() {
crate::bail!( crate::bail!(
"cannot retrieve non-contiguous tensors {:?}", "cannot retrieve non-contiguous tensors {:?}",
tensor_info.layout tensor_info.layout
@ -771,33 +733,13 @@ impl PthTensors {
tensor_info.dtype, tensor_info.dtype,
&mut reader, &mut reader,
)?; )?;
Ok(Some(tensor))
if rank > 1 && is_fortran_contiguous {
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
let tensor = tensor.reshape(shape_reversed)?;
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
let tensor = tensor.permute(dim_indeces_reversed)?;
Ok(Some(tensor))
} else {
Ok(Some(tensor))
}
} }
} }
/// Read all the tensors from a PyTorch pth file with a given key. /// Read all the tensors from a PyTorch pth file.
/// pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
/// # Arguments let pth = PthTensors::new(path)?;
/// * `path` - Path to the pth file.
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
/// contains multiple objects and the state_dict is the one we are interested in.
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
path: P,
key: Option<&str>,
) -> Result<Vec<(String, Tensor)>> {
let pth = PthTensors::new(path, key)?;
let tensor_names = pth.tensor_infos.keys(); let tensor_names = pth.tensor_infos.keys();
let mut tensors = Vec::with_capacity(tensor_names.len()); let mut tensors = Vec::with_capacity(tensor_names.len());
for name in tensor_names { for name in tensor_names {
@ -807,11 +749,3 @@ pub fn read_all_with_key<P: AsRef<std::path::Path>>(
} }
Ok(tensors) Ok(tensors)
} }
/// Read all the tensors from a PyTorch pth file.
///
/// # Arguments
/// * `path` - Path to the pth file.
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
read_all_with_key(path, None)
}

View File

@ -1,618 +0,0 @@
use super::{GgmlDType, QStorage};
use crate::quantized::k_quants::GgmlType;
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{CudaDevice, CudaStorage, Result};
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
#[derive(Clone, Debug)]
pub struct QCudaStorage {
data: CudaSlice<u8>,
dtype: GgmlDType,
device: CudaDevice,
}
static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
pub fn set_force_dmmv(f: bool) {
FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
}
pub const WARP_SIZE: usize = 32;
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
pub const NWARPS_Q4_0_AMPERE: usize = 4;
pub const GGML_CUDA_MMV_X: usize = 32;
pub const GGML_CUDA_MMV_Y: usize = 1;
pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256;
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
pub const MATRIX_ROW_PADDING: usize = 512;
fn ceil_div(p: usize, q: usize) -> usize {
(p + q - 1) / q
}
fn pad(p: usize, q: usize) -> usize {
ceil_div(p, q) * q
}
fn quantize_q8_1(
src: &CudaView<f32>,
dst: &mut CudaSlice<u8>,
elem_count: usize,
ky: usize,
dev: &CudaDevice,
) -> Result<()> {
use cudarc::driver::LaunchAsync;
let kx = elem_count;
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, ky as u32, 1),
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
shared_mem_bytes: 0,
};
let params = (src, dst, kx as i32, kx_padded as i32);
unsafe { func.launch(cfg, params) }.w()?;
Ok(())
}
fn dequantize(
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
GgmlDType::Q5_0 => (
"dequantize_block_q5_0",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q5_1 => (
"dequantize_block_q5_1",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
block_dim: (block_dim as u32, 1, 1),
shared_mem_bytes: 0,
};
if is_k {
let params = (data, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (data, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_mul_mat_vec(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
if y.len() != ncols {
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
}
let kernel_name = match dtype {
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
GgmlDType::Q5_0 => "dequantize_mul_mat_vec_q5_0_cuda",
GgmlDType::Q5_1 => "dequantize_mul_mat_vec_q5_1_cuda",
GgmlDType::Q8_0 => "dequantize_mul_mat_vec_q8_0_cuda",
GgmlDType::Q2K => "dequantize_mul_mat_vec_q2_k",
GgmlDType::Q3K => "dequantize_mul_mat_vec_q3_k",
GgmlDType::Q4K => "dequantize_mul_mat_vec_q4_k",
GgmlDType::Q5K => "dequantize_mul_mat_vec_q5_k",
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (block_num_y as u32, 1, 1),
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
shared_mem_bytes: 0,
};
let params = (data, y, &dst, ncols as i32, nrows as i32);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn mul_mat_vec_via_q8_1(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
b_size: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
if y.len() != ncols * b_size {
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
}
if b_size == 0 || b_size > 8 {
crate::bail!("only bsize between 1 and 8 are supported, got {b_size}")
}
// Start by quantizing y
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
let y_size_in_bytes =
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
let kernel_name = match dtype {
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda",
GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda",
GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda",
GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda",
GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda",
GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda",
GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda",
GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda",
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let kernel_name = format!("{kernel_name}{b_size}");
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
let (nblocks, nwarps) = match b_size {
1 => (nrows as u32, 4),
2..=4 => ((nrows as u32 + 1) / 2, 4),
5..=8 => ((nrows as u32 + 1) / 2, 2),
_ => crate::bail!("unexpected bsize {b_size}"),
};
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (nblocks, 1, 1),
block_dim: (WARP_SIZE as u32, nwarps, 1),
shared_mem_bytes: 0,
};
let params = (
data,
&y_q8_1,
&dst,
/* ncols_x */ ncols as i32,
/* nrows_x */ nrows as i32,
/* nrows_y */ ncols_padded as i32,
/* nrows_dst */ nrows as i32,
);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
#[allow(clippy::too_many_arguments)]
fn mul_mat_via_q8_1(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
x_rows: usize,
x_cols: usize,
y_rows: usize,
y_cols: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < x_rows * x_cols {
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
}
if y.len() != y_rows * y_cols {
crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len())
}
if x_cols != y_rows {
crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}")
}
let k = x_cols;
// Start by quantizing y
let k_padded = pad(k, MATRIX_ROW_PADDING);
let y_size_in_bytes =
k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
let (kernel_name, mmq_x, mmq_y) = match dtype {
GgmlDType::Q4_0 => ("mul_mat_q4_0", 64, 128),
GgmlDType::Q4_1 => ("mul_mat_q4_1", 64, 128),
GgmlDType::Q5_0 => ("mul_mat_q5_0", 128, 64),
GgmlDType::Q5_1 => ("mul_mat_q5_1", 128, 64),
GgmlDType::Q8_0 => ("mul_mat_q8_0", 128, 64),
GgmlDType::Q2K => ("mul_mat_q2_K", 64, 128),
GgmlDType::Q3K => ("mul_mat_q3_K", 128, 128),
GgmlDType::Q4K => ("mul_mat_q4_K", 64, 128),
GgmlDType::Q5K => ("mul_mat_q5_K", 64, 128),
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (
ceil_div(x_rows, mmq_y) as u32,
ceil_div(y_cols, mmq_x) as u32,
1,
),
block_dim: (WARP_SIZE as u32, 4, 1),
shared_mem_bytes: 0,
};
let params = (
/* vx */ data,
/* vy */ &y_q8_1,
/* dst */ &dst,
/* ncols_x */ x_cols as i32,
/* nrows_x */ x_rows as i32,
/* ncols_y */ y_cols as i32,
/* nrows_y */ k_padded as i32,
/* nrows_dst */ x_rows as i32,
);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
Ok(QCudaStorage {
data,
device: device.clone(),
dtype,
})
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &CudaDevice {
&self.device
}
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
fn deq<T: GgmlType>(buffer: &[u8], n: usize, dst: &mut [f32]) -> Result<()> {
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
let vec = slice.to_vec();
T::to_float(&vec, dst)
}
let fast_kernel = matches!(
self.dtype,
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q2K
| GgmlDType::Q3K
| GgmlDType::Q4K
| GgmlDType::Q5K
| GgmlDType::Q6K
| GgmlDType::Q8K
);
if fast_kernel {
return dequantize(&self.data, self.dtype, elem_count, self.device());
}
// Run the dequantization on cpu.
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q5_1 => deq::<crate::quantized::BlockQ5_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q8_0 => deq::<crate::quantized::BlockQ8_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q8_1 => deq::<crate::quantized::BlockQ8_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q2K => deq::<crate::quantized::BlockQ2K>(&buffer, block_len, &mut out)?,
GgmlDType::Q3K => deq::<crate::quantized::BlockQ3K>(&buffer, block_len, &mut out)?,
GgmlDType::Q4K => deq::<crate::quantized::BlockQ4K>(&buffer, block_len, &mut out)?,
GgmlDType::Q5K => deq::<crate::quantized::BlockQ5K>(&buffer, block_len, &mut out)?,
GgmlDType::Q6K => deq::<crate::quantized::BlockQ6K>(&buffer, block_len, &mut out)?,
GgmlDType::Q8K => deq::<crate::quantized::BlockQ8K>(&buffer, block_len, &mut out)?,
}
self.device
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
}
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
// Run the quantization on cpu.
let src = match &src.slice {
crate::cuda_backend::CudaStorageSlice::F32(data) => {
self.device.dtoh_sync_copy(data).w()?
}
_ => crate::bail!("only f32 can be quantized"),
};
let src_len = src.len();
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
qcpu_storage.quantize(&src)?;
let data = qcpu_storage.data()?;
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
self.data = data;
Ok(())
}
pub fn storage_size_in_bytes(&self) -> usize {
self.data.len()
}
pub fn fwd(
&self,
self_shape: &crate::Shape,
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
1
} else {
8
};
let use_vec_kernel = match layout.shape().dims() {
[b, m, _k] => b * m <= max_bm,
[b, _k] => *b <= max_bm,
_ => false,
};
if use_vec_kernel {
self.dequantize_matmul_vec(self_shape, storage, layout)
} else {
self.dequantize_matmul(self_shape, storage, layout)
}
}
}
impl QCudaStorage {
fn dequantize_matmul_vec(
&self,
self_shape: &crate::Shape,
rhs: &CudaStorage,
rhs_l: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
let (nrows, ncols) = self_shape.dims2()?;
let rhs = rhs.as_cuda_slice::<f32>()?;
let rhs = match rhs_l.contiguous_offsets() {
Some((o1, o2)) => rhs.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
};
let (b_size, k) = match rhs_l.shape().dims() {
[b, m, k] => (b * m, *k),
[b, k] => (*b, *k),
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
};
if ncols != k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
}
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
} else {
mul_mat_vec_via_q8_1(
&self.data,
&rhs,
self.dtype,
ncols,
nrows,
b_size,
self.device(),
)?
};
let mut out_shape = rhs_l.shape().dims().to_vec();
out_shape.pop();
out_shape.push(nrows);
Ok((out, out_shape.into()))
}
fn dequantize_matmul(
&self,
self_shape: &crate::Shape,
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
use crate::backend::BackendStorage;
let (n, k) = self_shape.dims2()?;
let (b, m, k2) = match layout.shape().dims() {
&[b, m, k2] => (b, m, k2),
&[m, k2] => (1, m, k2),
s => crate::bail!("unexpected shape for input {s:?}"),
};
if k2 != k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
}
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
let data_f32 = self.dequantize(n * k)?;
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?
} else {
let storage = storage.as_cuda_slice::<f32>()?;
let storage = match layout.contiguous_offsets() {
Some((o1, o2)) => storage.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous {
op: "quantized-matmul",
}
.bt())?,
};
mul_mat_via_q8_1(
&self.data,
&storage,
self.dtype,
/* x_rows */ n,
/* x_cols */ k,
/* y_rows */ k,
/* y_cols */ b * m,
self.device(),
)?
};
let mut out_shape = layout.shape().dims().to_vec();
out_shape.pop();
out_shape.push(n);
Ok((out, out_shape.into()))
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
device: &CudaDevice,
data: &[T],
) -> Result<super::QStorage> {
let data = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
};
let data = device.htod_sync_copy(data).w()?;
Ok(QStorage::Cuda(QCudaStorage {
data,
device: device.clone(),
dtype: T::DTYPE,
}))
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn cuda_quantize_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let el = 256;
let el_padded = pad(el, MATRIX_ROW_PADDING);
let y_size_in_bytes =
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
Ok(())
}
#[test]
fn cuda_mmv_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_vec_via_q8_1(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
/* b_size */ 1,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
assert_eq!(vs.len(), 1);
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
// Q8 means 1/256 precision.
assert_eq!(vs[0], 5561664.5);
let cuda_storage = dequantize_mul_mat_vec(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
assert_eq!(vs.len(), 1);
assert_eq!(vs[0], 5561851.0);
Ok(())
}
#[test]
fn cuda_mm_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* x_rows */ 4,
/* x_cols */ ncols,
/* y_rows */ ncols,
/* y_cols */ 4,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
/*
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
x @ x.t() / 16
tensor([[ 347480.0000, 869720.0000, 1391960.0000, 1914200.0000],
[ 869720.0000, 2440536.0000, 4011352.0000, 5582166.5000],
[ 1391960.0000, 4011352.0000, 6630742.0000, 9250132.0000],
[ 1914200.0000, 5582166.5000, 9250132.0000, 12918099.0000]])
*/
assert_eq!(vs.len(), 16);
assert_eq!(vs[0], 347604.0);
assert_eq!(vs[1], 888153.06);
assert_eq!(vs[4], 869780.7);
assert_eq!(vs[5], 2483145.0);
assert_eq!(vs[11], 9407368.0);
assert_eq!(vs[14], 9470856.0);
assert_eq!(vs[15], 13138824.0);
Ok(())
}
}

View File

@ -1,50 +0,0 @@
#![allow(unused)]
use super::GgmlDType;
use crate::{CudaDevice, CudaStorage, Error, Result};
pub struct QCudaStorage {
dtype: GgmlDType,
device: CudaDevice,
}
impl QCudaStorage {
pub fn zeros(_: &CudaDevice, _: usize, _: GgmlDType) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &CudaDevice {
&self.device
}
pub fn dequantize(&self, _elem_count: usize) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn storage_size_in_bytes(&self) -> usize {
0
}
pub fn fwd(
&self,
_self_shape: &crate::Shape,
_storage: &CudaStorage,
_layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
Err(Error::NotCompiledWithCudaSupport)
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
_device: &CudaDevice,
_data: &[T],
) -> Result<super::QStorage> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -1,50 +0,0 @@
#![allow(unused)]
use super::GgmlDType;
use crate::{Error, MetalDevice, MetalStorage, Result};
pub struct QMetalStorage {
dtype: GgmlDType,
device: MetalDevice,
}
impl QMetalStorage {
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &MetalDevice {
&self.device
}
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn storage_size_in_bytes(&self) -> usize {
0
}
pub fn fwd(
&self,
_self_shape: &crate::Shape,
_storage: &MetalStorage,
_layout: &crate::Layout,
) -> Result<(MetalStorage, crate::Shape)> {
Err(Error::NotCompiledWithMetalSupport)
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
_device: &MetalDevice,
_data: &[T],
) -> Result<super::QStorage> {
Err(Error::NotCompiledWithMetalSupport)
}

View File

@ -1,5 +1,7 @@
//! Support for the GGML file format. //! Support for the GGML file format.
#[cfg(feature = "metal")]
use super::metal::load_quantized_metal;
use super::{k_quants, GgmlDType, QStorage}; use super::{k_quants, GgmlDType, QStorage};
use crate::{Device, Result}; use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt}; use byteorder::{LittleEndian, ReadBytesExt};
@ -128,8 +130,13 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
let data: QStorage = match device { let data: QStorage = match device {
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())), Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
Device::Metal(metal) => super::metal::load_quantized(metal, data)?, #[cfg(feature = "metal")]
Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?, Device::Metal(metal) => load_quantized_metal(metal, data)?,
#[cfg(not(feature = "metal"))]
Device::Metal(_metal) => {
crate::bail!("Metal backend requires `metal` feature")
}
device => unimplemented!("Implement quantized tensor for device {device:?}"),
}; };
super::QTensor::new(data, dims) super::QTensor::new(data, dims)
} }
@ -226,7 +233,6 @@ pub struct Content {
pub hparams: HParams, pub hparams: HParams,
pub vocab: Vocab, pub vocab: Vocab,
pub tensors: HashMap<String, super::QTensor>, pub tensors: HashMap<String, super::QTensor>,
pub device: Device,
} }
impl Content { impl Content {
@ -246,13 +252,11 @@ impl Content {
let (name, tensor) = read_one_tensor(reader, magic, device)?; let (name, tensor) = read_one_tensor(reader, magic, device)?;
tensors.insert(name, tensor); tensors.insert(name, tensor);
} }
let device = device.clone();
Ok(Self { Ok(Self {
magic, magic,
hparams, hparams,
vocab, vocab,
tensors, tensors,
device,
}) })
} }

View File

@ -1,6 +1,5 @@
use super::{GgmlDType, QStorage}; use super::{GgmlDType, QStorage};
use crate::backend::BackendStorage; use crate::{DType, MetalDevice, MetalStorage, Result};
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
use metal::Buffer; use metal::Buffer;
use std::sync::Arc; use std::sync::Arc;
@ -11,31 +10,23 @@ pub struct QMetalStorage {
} }
impl QMetalStorage { impl QMetalStorage {
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
let size = elem_count * dtype.type_size() / dtype.block_size();
let buffer = device.allocate_zeros(size)?;
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
pub fn dtype(&self) -> GgmlDType { pub fn dtype(&self) -> GgmlDType {
self.dtype self.dtype
} }
pub fn device(&self) -> &MetalDevice {
&self.device
}
pub fn buffer(&self) -> &Buffer { pub fn buffer(&self) -> &Buffer {
&self.buffer &self.buffer
} }
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> { pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: GgmlDType) -> Self {
use crate::quantized::k_quants::GgmlType; Self {
device,
buffer,
dtype,
}
}
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
let buffer = self.device.new_buffer_managed(self.buffer.length())?; let buffer = self.device.new_buffer_managed(self.buffer.length())?;
let command_buffer = self.device.command_buffer()?; let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu"); command_buffer.set_label("to_cpu");
@ -45,73 +36,87 @@ impl QMetalStorage {
blit.end_encoding(); blit.end_encoding();
self.device.wait_until_completed()?; self.device.wait_until_completed()?;
let mut out = vec![0.0; elem_count]; let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype { match self.dtype {
GgmlDType::F32 => { GgmlDType::F32 => {
let vec: Vec<f32> = read_to_vec(&buffer, block_len); let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
f32::to_float(&vec, &mut out)?; f32::to_float(&vec, &mut out)?;
} }
GgmlDType::F16 => { GgmlDType::F16 => {
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len); let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
half::f16::to_float(&vec, &mut out)?; half::f16::to_float(&vec, &mut out)?;
} }
GgmlDType::Q4_0 => { GgmlDType::Q4_0 => {
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?; crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
} }
GgmlDType::Q4_1 => { GgmlDType::Q4_1 => {
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?; crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
} }
GgmlDType::Q5_0 => { GgmlDType::Q5_0 => {
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?; crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
} }
GgmlDType::Q5_1 => { GgmlDType::Q5_1 => {
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?; crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
} }
GgmlDType::Q8_0 => { GgmlDType::Q8_0 => {
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?; crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
} }
GgmlDType::Q8_1 => { GgmlDType::Q8_1 => {
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?; crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
} }
GgmlDType::Q2K => { GgmlDType::Q2K => {
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ2K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?; crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
} }
GgmlDType::Q3K => { GgmlDType::Q3K => {
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ3K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?; crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
} }
GgmlDType::Q4K => { GgmlDType::Q4K => {
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ4K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?; crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
} }
GgmlDType::Q5K => { GgmlDType::Q5K => {
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ5K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?; crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
} }
GgmlDType::Q6K => { GgmlDType::Q6K => {
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ6K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?; crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
} }
GgmlDType::Q8K => { GgmlDType::Q8K => {
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ8K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
} }
} }
let buffer = self.device.new_buffer_with_data(&out)?; let buffer = self.device.new_buffer_with_data(&out)?;
Ok(MetalStorage::new( Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
buffer,
self.device.clone(),
elem_count,
DType::F32,
))
} }
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> { pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
@ -125,65 +130,9 @@ impl QMetalStorage {
self.buffer = buffer; self.buffer = buffer;
Ok(()) Ok(())
} }
pub fn storage_size_in_bytes(&self) -> usize {
self.buffer.length() as usize
}
pub fn fwd(
&self,
self_shape: &Shape,
storage: &MetalStorage,
layout: &crate::Layout,
) -> Result<(MetalStorage, Shape)> {
use crate::MetalError;
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let (n, k) = self_shape.dims2()?;
let mut dst_shape = src_shape.dims().to_vec();
// We always use a single batch dimension and stack all the tensors in the batch on the
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
// properly.
let (b, m) = match dst_shape.len() {
3 => (1, dst_shape[0] * dst_shape[1]),
2 => (1, dst_shape[0]),
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
let device = storage.device().clone();
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
let command_buffer = device.command_buffer()?;
candle_metal_kernels::call_quantized_matmul_t(
device.device(),
&command_buffer,
device.kernels(),
self.dtype.into(),
(b, m, n, k),
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
&self.buffer,
&dst,
)
.map_err(MetalError::from)?;
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
Ok((dst_storage, dst_shape))
}
} }
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>( pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
device: &MetalDevice, device: &MetalDevice,
data: &[T], data: &[T],
) -> Result<QStorage> { ) -> Result<QStorage> {
@ -202,24 +151,3 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let slice = unsafe { std::slice::from_raw_parts(ptr, n) }; let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
slice.to_vec() slice.to_vec()
} }
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
fn from(value: GgmlDType) -> Self {
match value {
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
}
}
}

View File

@ -1,27 +1,16 @@
#[cfg(feature = "metal")]
use crate::{backend::BackendStorage, DType};
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor}; use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
use k_quants::*; use k_quants::*;
use std::borrow::Cow; use std::borrow::Cow;
#[cfg(target_feature = "avx")] #[cfg(target_feature = "avx")]
pub mod avx; pub mod avx;
mod dummy_cuda;
mod dummy_metal;
pub mod ggml_file; pub mod ggml_file;
pub mod gguf_file; pub mod gguf_file;
pub mod k_quants; pub mod k_quants;
#[cfg(feature = "metal")] #[cfg(feature = "metal")]
pub mod metal; pub mod metal;
#[cfg(not(feature = "metal"))]
mod metal {
pub use super::dummy_metal::*;
}
#[cfg(feature = "cuda")]
pub mod cuda;
#[cfg(not(feature = "cuda"))]
mod cuda {
pub use super::dummy_cuda::*;
}
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
pub mod neon; pub mod neon;
#[cfg(target_feature = "simd128")] #[cfg(target_feature = "simd128")]
@ -43,13 +32,22 @@ impl Device {
let storage = dtype.cpu_zeros(elem_count); let storage = dtype.cpu_zeros(elem_count);
Ok(QStorage::Cpu(storage)) Ok(QStorage::Cpu(storage))
} }
#[cfg(feature = "metal")]
Device::Metal(metal) => { Device::Metal(metal) => {
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?; let size = elem_count * dtype.type_size() / dtype.block_size();
Ok(QStorage::Metal(storage)) let buffer = metal.allocate_zeros(size)?;
Ok(QStorage::Metal(metal::QMetalStorage::new(
buffer,
metal.clone(),
dtype,
)))
} }
Device::Cuda(cuda) => { #[cfg(not(feature = "metal"))]
let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?; Device::Metal(_metal) => {
Ok(QStorage::Cuda(storage)) crate::bail!("Metal feature not activated");
}
Device::Cuda(_cuda) => {
crate::bail!("Cuda ggml quantization not supported");
} }
} }
} }
@ -57,40 +55,32 @@ impl Device {
pub enum QStorage { pub enum QStorage {
Cpu(Box<dyn QuantizedType>), Cpu(Box<dyn QuantizedType>),
#[cfg(feature = "metal")]
Metal(metal::QMetalStorage), Metal(metal::QMetalStorage),
Cuda(cuda::QCudaStorage),
} }
impl QStorage { impl QStorage {
fn block_size(&self) -> usize { fn block_size(&self) -> usize {
match self { match self {
QStorage::Cpu(storage) => storage.block_size(), QStorage::Cpu(storage) => storage.block_size(),
#[cfg(feature = "metal")]
QStorage::Metal(storage) => storage.dtype().block_size(), QStorage::Metal(storage) => storage.dtype().block_size(),
QStorage::Cuda(storage) => storage.dtype().block_size(),
} }
} }
fn dtype(&self) -> GgmlDType { fn dtype(&self) -> GgmlDType {
match self { match self {
QStorage::Cpu(storage) => storage.dtype(), QStorage::Cpu(storage) => storage.dtype(),
#[cfg(feature = "metal")]
QStorage::Metal(storage) => storage.dtype(), QStorage::Metal(storage) => storage.dtype(),
QStorage::Cuda(storage) => storage.dtype(),
}
}
fn device(&self) -> Device {
match self {
QStorage::Cpu(_storage) => Device::Cpu,
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
} }
} }
fn size_in_bytes(&self) -> usize { fn size_in_bytes(&self) -> usize {
match self { match self {
QStorage::Cpu(storage) => storage.storage_size_in_bytes(), QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
QStorage::Metal(storage) => storage.storage_size_in_bytes(), #[cfg(feature = "metal")]
QStorage::Cuda(storage) => storage.storage_size_in_bytes(), QStorage::Metal(storage) => storage.buffer().length() as usize,
} }
} }
@ -99,8 +89,8 @@ impl QStorage {
(QStorage::Cpu(storage), Storage::Cpu(src)) => { (QStorage::Cpu(storage), Storage::Cpu(src)) => {
storage.from_float(src.as_slice::<f32>()?)?; storage.from_float(src.as_slice::<f32>()?)?;
} }
#[cfg(feature = "metal")]
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?, (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
(QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
_ => crate::bail!("Invalid dequantize storage locations do not match"), _ => crate::bail!("Invalid dequantize storage locations do not match"),
} }
Ok(()) Ok(())
@ -109,8 +99,8 @@ impl QStorage {
fn dequantize(&self, elem_count: usize) -> Result<Storage> { fn dequantize(&self, elem_count: usize) -> Result<Storage> {
match self { match self {
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)), QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
#[cfg(feature = "metal")]
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)), QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
} }
} }
@ -122,7 +112,8 @@ impl QStorage {
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
Ok(Cow::from(data)) Ok(Cow::from(data))
} }
QStorage::Metal(_) | QStorage::Cuda(_) => { #[cfg(feature = "metal")]
QStorage::Metal(_storage) => {
crate::bail!("not implemented"); crate::bail!("not implemented");
} }
} }
@ -345,10 +336,6 @@ impl QTensor {
self.storage.dtype() self.storage.dtype()
} }
pub fn device(&self) -> Device {
self.storage.device()
}
pub fn rank(&self) -> usize { pub fn rank(&self) -> usize {
self.shape.rank() self.shape.rank()
} }
@ -398,7 +385,7 @@ impl QMatMul {
_ => DEQUANTIZE_ALL.with(|b| *b), _ => DEQUANTIZE_ALL.with(|b| *b),
}; };
let t = if dequantize { let t = if dequantize {
let tensor = qtensor.dequantize(&qtensor.device())?; let tensor = qtensor.dequantize(&Device::Cpu)?;
Self::Tensor(tensor) Self::Tensor(tensor)
} else { } else {
Self::QTensor(qtensor) Self::QTensor(qtensor)
@ -440,7 +427,8 @@ impl crate::CustomOp1 for QTensor {
#[allow(clippy::infallible_destructuring_match)] #[allow(clippy::infallible_destructuring_match)]
let self_storage = match &self.storage { let self_storage = match &self.storage {
QStorage::Cpu(storage) => storage, QStorage::Cpu(storage) => storage,
QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"), #[cfg(feature = "metal")]
_ => crate::bail!("Invalid storage"),
}; };
let slice = storage.as_slice::<f32>()?; let slice = storage.as_slice::<f32>()?;
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
@ -449,28 +437,79 @@ impl crate::CustomOp1 for QTensor {
Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
} }
#[cfg(feature = "metal")]
fn metal_fwd( fn metal_fwd(
&self, &self,
storage: &crate::MetalStorage, storage: &crate::MetalStorage,
layout: &crate::Layout, layout: &crate::Layout,
) -> Result<(crate::MetalStorage, Shape)> { ) -> Result<(crate::MetalStorage, Shape)> {
let self_storage = match &self.storage { use crate::MetalError;
QStorage::Metal(metal) => metal,
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let (n, k) = self.shape.dims2()?;
let mut dst_shape = src_shape.dims().to_vec();
let (b, m) = match dst_shape.len() {
3 => (dst_shape[0], dst_shape[1]),
2 => (1, dst_shape[0]),
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
let device = storage.device().clone();
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
let (buffer, dtype) = match &self.storage {
QStorage::Metal(metal) => (metal.buffer(), metal.dtype()),
_ => unreachable!("Cannot call metal matmul on non metal QTensor"), _ => unreachable!("Cannot call metal matmul on non metal QTensor"),
}; };
self_storage.fwd(&self.shape, storage, layout) let command_buffer = device.command_buffer()?;
candle_metal_kernels::call_quantized_matmul_t(
device.device(),
&command_buffer,
device.kernels(),
dtype.into(),
(b, m, n, k),
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
buffer,
&dst,
)
.map_err(MetalError::from)?;
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
Ok((dst_storage, dst_shape))
} }
}
fn cuda_fwd( #[cfg(feature = "metal")]
&self, impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
storage: &crate::CudaStorage, fn from(value: GgmlDType) -> Self {
layout: &crate::Layout, match value {
) -> Result<(crate::CudaStorage, Shape)> { GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
let self_storage = match &self.storage { GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
QStorage::Cuda(cuda) => cuda, GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
_ => unreachable!("Cannot call cuda matmul on non cuda QTensor"), GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
}; GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
self_storage.fwd(&self.shape, storage, layout) GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
}
} }
} }

View File

@ -171,7 +171,7 @@ impl Shape {
} }
let mut acc = 1; let mut acc = 1;
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() { for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
if dim > 1 && stride != acc { if stride != acc {
return false; return false;
} }
acc *= dim; acc *= dim;
@ -186,7 +186,7 @@ impl Shape {
} }
let mut acc = 1; let mut acc = 1;
for (&stride, &dim) in stride.iter().zip(self.0.iter()) { for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
if dim > 1 && stride != acc { if stride != acc {
return false; return false;
} }
acc *= dim; acc *= dim;

View File

@ -1,7 +1,6 @@
use crate::backend::BackendStorage; use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, ReduceOp}; use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape}; use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
// We do not want to implement Clone on Storage as cloning may fail because of // We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used. // out of memory. Instead try_clone should be used.
@ -44,19 +43,9 @@ impl Storage {
} }
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> { pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
let lhs_device = self.device(); let lhs = self.device().location();
let rhs_device = rhs.device(); let rhs = rhs.device().location();
let lhs = lhs_device.location(); if lhs != rhs {
let rhs = rhs_device.location();
let same_device = if self.device().is_metal() {
// On metal, we require the device to be exactly the same rather than
// having the same location. In cuda this is not necessary as all CudaDevice on the
// same GPU will use the same cuda stream.
lhs_device.same_device(&rhs_device)
} else {
lhs == rhs
};
if !same_device {
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt()) Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
} else { } else {
Ok(()) Ok(())
@ -263,51 +252,6 @@ impl Storage {
} }
} }
pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> {
match self {
Self::Cpu(storage) => c.cpu_fwd(storage, l),
Self::Cuda(storage) => c.cuda_fwd(storage, l),
Self::Metal(storage) => c.metal_fwd(storage, l),
}
}
pub(crate) fn inplace_op2(
&mut self,
l1: &Layout,
t2: &Self,
l2: &Layout,
c: &dyn InplaceOp2,
) -> Result<()> {
self.same_device(t2, c.name())?;
match (self, t2) {
(Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2),
(Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2),
(Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2),
_ => unreachable!(),
}
}
pub(crate) fn inplace_op3(
&mut self,
l1: &Layout,
t2: &Self,
l2: &Layout,
t3: &Self,
l3: &Layout,
c: &dyn InplaceOp3,
) -> Result<()> {
self.same_device(t2, c.name())?;
self.same_device(t3, c.name())?;
match (self, t2, t3) {
(Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3),
(Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3),
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
c.metal_fwd(s1, l1, s2, l2, s3, l3)
}
_ => unreachable!(),
}
}
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> { pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
match self { match self {
Storage::Cpu(storage) => { Storage::Cpu(storage) => {
@ -408,10 +352,6 @@ impl Storage {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?; let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s)) Ok(Self::Cuda(s))
} }
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(), lhs: lhs.device().location(),
rhs: rhs.device().location(), rhs: rhs.device().location(),
@ -757,32 +697,4 @@ impl Storage {
.bt()), .bt()),
} }
} }
#[allow(clippy::too_many_arguments)]
pub(crate) fn copy2d(
&self,
dst: &mut Self,
d1: usize,
d2: usize,
src_s: usize,
dst_s: usize,
src_o: usize,
dst_o: usize,
) -> Result<()> {
match (self, dst) {
(Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
(Self::Cuda(src), Self::Cuda(dst)) => {
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
}
(Self::Metal(src), Self::Metal(dst)) => {
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "copy2d",
}
.bt()),
}
}
} }

View File

@ -1,7 +1,9 @@
//! Tensors are N-dimensional matrixes of elements using a single data type. //! Tensors are N-dimensional matrixes of elements using a single data type.
#![allow(clippy::redundant_closure_call)] #![allow(clippy::redundant_closure_call)]
use crate::backend::{BackendDevice, BackendStorage}; use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp}; use crate::op::{
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
};
use crate::scalar::TensorOrScalar; use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims}; use crate::shape::{Dim, Dims};
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape}; use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
@ -79,9 +81,6 @@ macro_rules! unary_op {
($fn_name:ident, $op_name:ident) => { ($fn_name:ident, $op_name:ident) => {
pub fn $fn_name(&self) -> Result<Self> { pub fn $fn_name(&self) -> Result<Self> {
let shape = self.shape(); let shape = self.shape();
if shape.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self let storage = self
.storage() .storage()
.unary_impl::<crate::op::$op_name>(self.layout())?; .unary_impl::<crate::op::$op_name>(self.layout())?;
@ -95,9 +94,6 @@ macro_rules! binary_op {
($fn_name:ident, $op_name:ident) => { ($fn_name:ident, $op_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> { pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?; let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
if shape.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().binary_impl::<crate::op::$op_name>( let storage = self.storage().binary_impl::<crate::op::$op_name>(
&*rhs.storage(), &*rhs.storage(),
self.layout(), self.layout(),
@ -120,9 +116,6 @@ macro_rules! binary_op_scalar {
.broadcast_as(self.shape())?, .broadcast_as(self.shape())?,
}; };
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?; let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
if self.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().binary_impl::<crate::op::$op_name>( let storage = self.storage().binary_impl::<crate::op::$op_name>(
&*rhs.storage(), &*rhs.storage(),
self.layout(), self.layout(),
@ -515,11 +508,9 @@ impl Tensor {
unary_op!(gelu_erf, GeluErf); unary_op!(gelu_erf, GeluErf);
unary_op!(erf, Erf); unary_op!(erf, Erf);
unary_op!(relu, Relu); unary_op!(relu, Relu);
unary_op!(silu, Silu);
unary_op!(ceil, Ceil); unary_op!(ceil, Ceil);
unary_op!(floor, Floor); unary_op!(floor, Floor);
unary_op!(round, Round); unary_op!(round, Round);
unary_op!(sign, Sign);
/// Round element of the input tensor to the nearest integer. /// Round element of the input tensor to the nearest integer.
/// ///
@ -655,9 +646,6 @@ impl Tensor {
/// # Ok::<(), candle_core::Error>(()) /// # Ok::<(), candle_core::Error>(())
/// ``` /// ```
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> { pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
if self.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().affine(self.layout(), mul, add)?; let storage = self.storage().affine(self.layout(), mul, add)?;
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add }); let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
Ok(from_storage(storage, self.shape(), op, false)) Ok(from_storage(storage, self.shape(), op, false))
@ -665,9 +653,6 @@ impl Tensor {
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor. /// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
pub fn elu(&self, alpha: f64) -> Result<Self> { pub fn elu(&self, alpha: f64) -> Result<Self> {
if self.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().elu(self.layout(), alpha)?; let storage = self.storage().elu(self.layout(), alpha)?;
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha)); let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
Ok(from_storage(storage, self.shape(), op, false)) Ok(from_storage(storage, self.shape(), op, false))
@ -675,15 +660,12 @@ impl Tensor {
/// Raise the tensor to some float exponent `e`. /// Raise the tensor to some float exponent `e`.
pub fn powf(&self, e: f64) -> Result<Self> { pub fn powf(&self, e: f64) -> Result<Self> {
if self.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().powf(self.layout(), e)?; let storage = self.storage().powf(self.layout(), e)?;
let op = BackpropOp::new1(self, |t| Op::Powf(t, e)); let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
Ok(from_storage(storage, self.shape(), op, false)) Ok(from_storage(storage, self.shape(), op, false))
} }
pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> { fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
if dim >= self.dims().len() { if dim >= self.dims().len() {
Err(Error::DimOutOfRange { Err(Error::DimOutOfRange {
shape: self.shape().clone(), shape: self.shape().clone(),
@ -822,35 +804,6 @@ impl Tensor {
} }
} }
/// Roll the tensor input along the given dimension.
/// Elements that are shifted beyond the last position are re-introduced at the first position.
///
/// ```rust
/// # use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.roll(1, 0)?;
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.roll(-1, 0)?;
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
where
D: Dim + Clone,
{
let dim = dim.to_index(self.shape(), "roll")?;
let dim_size = self.dim(dim)?;
let shift = shift.rem_euclid(dim_size as i32) as usize;
if shift == 0 {
Ok(self.clone())
} else {
let a = self.narrow(dim, 0, dim_size - shift)?;
let b = self.narrow(dim, dim_size - shift, shift)?;
Tensor::cat(&[&b, &a], dim)
}
}
/// Returns the sum of all elements in the input tensor. The sum is performed over all the /// Returns the sum of all elements in the input tensor. The sum is performed over all the
/// input dimensions. /// input dimensions.
/// ///
@ -1032,7 +985,7 @@ impl Tensor {
/// tensor also has three dimensions, `(batch, channels, target_size)`. /// tensor also has three dimensions, `(batch, channels, target_size)`.
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> { pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
let (n, c, _l) = self.dims3()?; let (n, c, _l) = self.dims3()?;
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size }); let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
let storage = self let storage = self
.storage() .storage()
.upsample_nearest1d(self.layout(), target_size)?; .upsample_nearest1d(self.layout(), target_size)?;
@ -1172,9 +1125,6 @@ impl Tensor {
let n = b_dims[dim - 1]; let n = b_dims[dim - 1];
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]); let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
if c_shape.elem_count() == 0 || k == 0 {
return Tensor::zeros(c_shape, self.dtype(), self.device());
}
let batching: usize = a_dims[..dim - 2].iter().product(); let batching: usize = a_dims[..dim - 2].iter().product();
let batching_b: usize = b_dims[..dim - 2].iter().product(); let batching_b: usize = b_dims[..dim - 2].iter().product();
if k != k2 || batching != batching_b { if k != k2 || batching != batching_b {
@ -1371,7 +1321,7 @@ impl Tensor {
} }
.bt())? .bt())?
} }
let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? }; let mut storage = self.device().zeros(self.shape(), self.dtype())?;
self.storage() self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?; .copy_strided_src(&mut storage, 0, self.layout())?;
let offset = start * src.dims()[1..].iter().product::<usize>(); let offset = start * src.dims()[1..].iter().product::<usize>();
@ -1903,9 +1853,9 @@ impl Tensor {
/// this new node. The storage of this tensor is shared with the initial tensor. /// this new node. The storage of this tensor is shared with the initial tensor.
/// ///
/// If the tensor is already detached from the computation graph, the same tensor is returned. /// If the tensor is already detached from the computation graph, the same tensor is returned.
pub fn detach(&self) -> Tensor { pub fn detach(&self) -> Result<Tensor> {
if self.op.is_none() && !self.is_variable { if self.op.is_none() && !self.is_variable {
self.clone() Ok(self.clone())
} else { } else {
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
@ -1916,7 +1866,7 @@ impl Tensor {
dtype: self.dtype, dtype: self.dtype,
device: self.device.clone(), device: self.device.clone(),
}; };
Tensor(Arc::new(tensor_)) Ok(Tensor(Arc::new(tensor_)))
} }
} }
@ -2021,7 +1971,7 @@ impl Tensor {
Ok(self.clone()) Ok(self.clone())
} else { } else {
let shape = self.shape(); let shape = self.shape();
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? }; let mut storage = self.device().zeros(shape, self.dtype())?;
self.storage() self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?; .copy_strided_src(&mut storage, 0, self.layout())?;
let op = BackpropOp::new1(self, Op::Copy); let op = BackpropOp::new1(self, Op::Copy);
@ -2029,21 +1979,11 @@ impl Tensor {
} }
} }
/// Returns a tensor that is in row major order. This always makes a copy.
pub fn force_contiguous(&self) -> Result<Tensor> {
let shape = self.shape();
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let op = BackpropOp::new1(self, Op::Copy);
Ok(from_storage(storage, shape.clone(), op, false))
}
/// Create a variable based on the values currently stored in a tensor. The storage is always /// Create a variable based on the values currently stored in a tensor. The storage is always
/// copied. /// copied.
pub(crate) fn make_var(&self) -> Result<Tensor> { pub(crate) fn make_var(&self) -> Result<Tensor> {
let shape = self.shape().clone(); let shape = self.shape().clone();
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? }; let mut storage = self.device().zeros(&shape, self.dtype())?;
self.storage() self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?; .copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(storage, shape, BackpropOp::none(), true)) Ok(from_storage(storage, shape, BackpropOp::none(), true))
@ -2096,7 +2036,7 @@ impl Tensor {
}; };
Ok(Tensor(Arc::new(tensor_))) Ok(Tensor(Arc::new(tensor_)))
} else { } else {
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? }; let mut storage = self.device().zeros(&shape, self.dtype())?;
self.storage() self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?; .copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(storage, shape, op, false)) Ok(from_storage(storage, shape, op, false))
@ -2123,19 +2063,8 @@ impl Tensor {
let dim = dim.to_index(self.shape(), "squeeze")?; let dim = dim.to_index(self.shape(), "squeeze")?;
if dims[dim] == 1 { if dims[dim] == 1 {
let mut dims = dims.to_vec(); let mut dims = dims.to_vec();
let mut strides = self.stride().to_vec();
dims.remove(dim); dims.remove(dim);
strides.remove(dim); self.reshape(dims)
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
op: BackpropOp::new1(self, Op::Reshape),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
} else { } else {
Ok(self.clone()) Ok(self.clone())
} }
@ -2156,24 +2085,10 @@ impl Tensor {
/// ``` /// ```
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> { pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
let mut dims = self.dims().to_vec(); let mut dims = self.dims().to_vec();
let mut strides = self.stride().to_vec();
let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?; let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
// Cannot panic because to_index_plus_one already checks dimensions // Cannot panic because to_index_plus_one already checks dimensions
dims.insert(dim, 1); dims.insert(dim, 1);
// Any stride would work here, but we pick one so as to maximize the probability to remain self.reshape(dims)
// C contiguous.
let stride = if dim < strides.len() { strides[dim] } else { 1 };
strides.insert(dim, stride);
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
op: BackpropOp::new1(self, Op::Reshape),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
} }
/// Stacks two or more tensors along a particular dimension. /// Stacks two or more tensors along a particular dimension.
@ -2204,6 +2119,152 @@ impl Tensor {
Self::cat(&args, dim) Self::cat(&args, dim)
} }
/// Concatenates two or more tensors along a particular dimension.
///
/// All tensors must of the same rank, and the output will have
/// the same rank
///
/// ```rust
/// # use candle_core::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
///
/// let c = Tensor::cat(&[&a, &b], 0)?;
/// assert_eq!(c.shape().dims(), &[4, 3]);
///
/// let c = Tensor::cat(&[&a, &b], 1)?;
/// assert_eq!(c.shape().dims(), &[2, 6]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() {
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
}
let arg0 = args[0].as_ref();
if args.len() == 1 {
return Ok(arg0.clone());
}
let dim = dim.to_index(arg0.shape(), "cat")?;
for arg in args {
arg.as_ref().check_dim(dim, "cat")?;
}
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg0.rank() != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: arg0.rank(),
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx != dim && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
}
if dim == 0 {
Self::cat0(args)
} else {
// TODO: Avoid these transpositions and have an implementation that works
// for dim != 0...
let args: Vec<Tensor> = args
.iter()
.map(|a| a.as_ref().transpose(0, dim))
.collect::<Result<Vec<_>>>()?;
let cat = Self::cat0(&args)?;
cat.transpose(0, dim)
}
}
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
if args.is_empty() {
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
}
let arg0 = args[0].as_ref();
if args.len() == 1 {
return Ok(arg0.clone());
}
let rank = arg0.rank();
let device = arg0.device();
let dtype = arg0.dtype();
let first_dims = arg0.shape().dims();
let mut cat_dims = first_dims.to_vec();
cat_dims[0] = 0;
let mut offsets = vec![0usize];
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg.dtype() != dtype {
Err(Error::DTypeMismatchBinaryOp {
lhs: dtype,
rhs: arg.dtype(),
op: "cat",
}
.bt())?
}
if arg.device().location() != device.location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: device.location(),
rhs: arg.device().location(),
op: "cat",
}
.bt())?
}
if rank != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: rank,
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx == 0 {
cat_dims[0] += v2;
}
if dim_idx != 0 && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
let next_offset = offsets.last().unwrap() + arg.elem_count();
offsets.push(next_offset);
}
let shape = Shape::from(cat_dims);
let op = BackpropOp::new(args, |args| Op::Cat(args, 0));
let mut storage = device.zeros(&shape, dtype)?;
for (arg, &offset) in args.iter().zip(offsets.iter()) {
let arg = arg.as_ref();
arg.storage()
.copy_strided_src(&mut storage, offset, arg.layout())?;
}
Ok(from_storage(storage, shape, op, false))
}
/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the /// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
/// input tensor values and `right` elements after. /// input tensor values and `right` elements after.
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> { pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
@ -2286,10 +2347,6 @@ impl Tensor {
self.storage.read().unwrap() self.storage.read().unwrap()
} }
pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
self.storage.write().unwrap()
}
// If we extend the visibility of this function to be usable outside of this crate, we should // If we extend the visibility of this function to be usable outside of this crate, we should
// make it unsafe. // make it unsafe.
pub(crate) fn storage_mut_and_layout( pub(crate) fn storage_mut_and_layout(
@ -2311,6 +2368,96 @@ impl Tensor {
std::ptr::eq(lhs, rhs) std::ptr::eq(lhs, rhs)
} }
/// Applies a unary custom op without backward support
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a binary custom op without backward support
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
let (storage, shape) =
self.storage()
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a ternary custom op without backward support
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c,
)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a unary custom op.
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
let (storage, shape) = self
.storage()
.apply_op1(self.layout(), c.as_ref().as_ref())?;
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
self.apply_op1_arc(Arc::new(Box::new(c)))
}
/// Applies a binary custom op.
pub fn apply_op2_arc(
&self,
rhs: &Self,
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op2(
self.layout(),
&rhs.storage(),
rhs.layout(),
c.as_ref().as_ref(),
)?;
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
self.apply_op2_arc(r, Arc::new(Box::new(c)))
}
/// Applies a ternary custom op.
pub fn apply_op3_arc(
&self,
t2: &Self,
t3: &Self,
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c.as_ref().as_ref(),
)?;
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
Op::CustomOp3(t1, t2, t3, c.clone())
});
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
&self,
t2: &Self,
t3: &Self,
c: C,
) -> Result<Self> {
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
/// Normalize a 'relative' axis value: positive values are kept, negative /// Normalize a 'relative' axis value: positive values are kept, negative
/// values means counting the dimensions from the back. /// values means counting the dimensions from the back.
pub fn normalize_axis(&self, axis: i64) -> Result<usize> { pub fn normalize_axis(&self, axis: i64) -> Result<usize> {

View File

@ -1,238 +0,0 @@
use crate::{shape::Dim, Error, Result, Shape, Tensor};
impl Tensor {
/// Concatenates two or more tensors along a particular dimension.
///
/// All tensors must of the same rank, and the output will have
/// the same rank
///
/// ```rust
/// # use candle_core::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
///
/// let c = Tensor::cat(&[&a, &b], 0)?;
/// assert_eq!(c.shape().dims(), &[4, 3]);
///
/// let c = Tensor::cat(&[&a, &b], 1)?;
/// assert_eq!(c.shape().dims(), &[2, 6]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() {
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
}
let arg0 = args[0].as_ref();
if args.len() == 1 {
return Ok(arg0.clone());
}
let dim = dim.to_index(arg0.shape(), "cat")?;
for arg in args {
arg.as_ref().check_dim(dim, "cat")?;
}
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg0.rank() != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: arg0.rank(),
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx != dim && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
}
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
if all_contiguous {
Self::cat_contiguous(args, dim)
} else if dim == 0 {
Self::cat0(args)
} else {
let args: Vec<Tensor> = args
.iter()
.map(|a| a.as_ref().transpose(0, dim))
.collect::<Result<Vec<_>>>()?;
let cat = Self::cat0(&args)?;
cat.transpose(0, dim)
}
}
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
if args.is_empty() {
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
}
let arg0 = args[0].as_ref();
if args.len() == 1 {
return Ok(arg0.clone());
}
let rank = arg0.rank();
let device = arg0.device();
let dtype = arg0.dtype();
let first_dims = arg0.shape().dims();
let mut cat_dims = first_dims.to_vec();
cat_dims[0] = 0;
let mut offsets = vec![0usize];
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg.dtype() != dtype {
Err(Error::DTypeMismatchBinaryOp {
lhs: dtype,
rhs: arg.dtype(),
op: "cat",
}
.bt())?
}
if arg.device().location() != device.location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: device.location(),
rhs: arg.device().location(),
op: "cat",
}
.bt())?
}
if rank != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: rank,
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx == 0 {
cat_dims[0] += v2;
}
if dim_idx != 0 && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
let next_offset = offsets.last().unwrap() + arg.elem_count();
offsets.push(next_offset);
}
let shape = Shape::from(cat_dims);
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
for (arg, &offset) in args.iter().zip(offsets.iter()) {
let arg = arg.as_ref();
arg.storage()
.copy_strided_src(&mut storage, offset, arg.layout())?;
}
Ok(crate::tensor::from_storage(storage, shape, op, false))
}
fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
if args.is_empty() {
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
}
let arg0 = args[0].as_ref();
if args.len() == 1 {
return Ok(arg0.clone());
}
let rank = arg0.rank();
let device = arg0.device();
let dtype = arg0.dtype();
let first_dims = arg0.shape().dims();
let mut cat_dims = first_dims.to_vec();
cat_dims[dim] = 0;
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg.dtype() != dtype {
Err(Error::DTypeMismatchBinaryOp {
lhs: dtype,
rhs: arg.dtype(),
op: "cat",
}
.bt())?
}
if arg.device().location() != device.location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: device.location(),
rhs: arg.device().location(),
op: "cat",
}
.bt())?
}
if rank != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: rank,
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx == dim {
cat_dims[dim] += v2;
}
if dim_idx != dim && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
}
let cat_target_dim_len = cat_dims[dim];
let block_size: usize = cat_dims.iter().skip(1 + dim).product();
let shape = Shape::from(cat_dims);
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
let mut dst_o = 0;
for arg in args.iter() {
let arg = arg.as_ref();
let arg_dims = arg.shape().dims();
let d1: usize = arg_dims.iter().take(dim).product();
let d2 = block_size * arg_dims[dim];
let dst_s = block_size * cat_target_dim_len;
let src_o = arg.layout().start_offset();
arg.storage().copy2d(
&mut storage,
d1,
d2,
/* src_s */ d2,
dst_s,
src_o,
dst_o,
)?;
dst_o += d2;
}
Ok(crate::tensor::from_storage(storage, shape, op, false))
}
}

View File

@ -107,10 +107,6 @@ impl Var {
Ok(Self(inner)) Ok(Self(inner))
} }
pub fn as_detached_tensor(&self) -> Tensor {
self.0.detach()
}
pub fn as_tensor(&self) -> &Tensor { pub fn as_tensor(&self) -> &Tensor {
&self.0 &self.0
} }

View File

@ -18,9 +18,6 @@ w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose1d(t, w_t) res = torch.nn.functional.conv_transpose1d(t, w_t)
print(res.shape) print(res.shape)
print(res) print(res)
res = torch.nn.functional.conv_transpose1d(t, w_t, groups=2)
print(res.shape)
print(res)
*/ */
fn conv1d(dev: &Device) -> Result<()> { fn conv1d(dev: &Device) -> Result<()> {
let t = Tensor::new( let t = Tensor::new(
@ -53,11 +50,8 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?, test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352] [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
); );
if dev.is_cpu() {
let w = w.transpose(0, 1)?; let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
// The CPU kernels applied in the contiguous and non contiguous cases are different.
for w in [w.clone(), w.contiguous()?] {
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7]); assert_eq!(res.dims(), [1, 2, 7]);
assert_eq!( assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?, test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@ -66,17 +60,6 @@ fn conv1d(dev: &Device) -> Result<()> {
4.7076, -5.9745, -0.8276, 1.621 4.7076, -5.9745, -0.8276, 1.621
], ],
); );
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 2)?;
assert_eq!(res.dims(), [1, 4, 7]);
assert_eq!(
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
[
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
]
);
} }
Ok(()) Ok(())
} }
@ -135,7 +118,7 @@ fn conv2d(dev: &Device) -> Result<()> {
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712, 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006, -0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
-0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085, -0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
], ],
dev, dev,
)?; )?;
@ -163,9 +146,7 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075 10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
] ]
); );
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7, 7]); assert_eq!(res.dims(), [1, 2, 7, 7]);
assert_eq!( assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?, test_utils::to_vec3_round(&res.i(0)?, 4)?,
@ -190,7 +171,6 @@ fn conv2d(dev: &Device) -> Result<()> {
] ]
] ]
); );
// Dilations. // Dilations.
let res = t.conv2d(&w, 0, 1, 2, 1)?; let res = t.conv2d(&w, 0, 1, 2, 1)?;
assert_eq!(res.dims(), [1, 2, 1, 1]); assert_eq!(res.dims(), [1, 2, 1, 1]);
@ -229,7 +209,6 @@ fn conv2d(dev: &Device) -> Result<()> {
] ]
] ]
); );
Ok(()) Ok(())
} }
@ -276,13 +255,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
assert_eq!( assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?, test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[ [
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1640, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
-0.0111, -0.1742, 0.0, 0.0, 0.0, 0.0, 2.6437, -2.0268, 1.1823, 0.0, 0.0, 0.0, 0.0, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000,
3.2855, -1.0324, 0.2539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855,
0.0, 0.0, 0.0, 0.0 -1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
] ]
); );
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]); assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!( assert_eq!(
@ -384,7 +363,6 @@ print(w.grad.shape)
print(w.grad[0]) print(w.grad[0])
*/ */
fn conv2d_grad(dev: &Device) -> Result<()> { fn conv2d_grad(dev: &Device) -> Result<()> {
// conv-transposes are not implemented for metal
use candle_core::Var; use candle_core::Var;
let t = Var::from_slice( let t = Var::from_slice(
&[ &[
@ -397,7 +375,7 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712, 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006, -0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
-0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085, -0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
], ],
(1, 4, 5, 5), (1, 4, 5, 5),
dev, dev,
@ -582,154 +560,6 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
] ]
); );
// Conv Transpose 2d Test
//tested against following python
// import torch
// torch.manual_seed(4242)
// padding = 4
// outpadding = 2
// dilation = 3
// stride = 3
// input = torch.randn((1, 4, 7, 5), requires_grad=True)
// kernel = torch.randn((4, 2, 3, 5), requires_grad=True)
// print("input", input.flatten())
// print("kernel", kernel.flatten())
// res = torch.nn.functional.conv_transpose2d(
// input,
// kernel,
// stride=stride,
// padding=padding,
// dilation=dilation,
// output_padding=outpadding,
// )
// res.retain_grad()
// print(res.shape)
// loss = (res**2).sum()
// print(loss)
// loss.backward()
// print(input.grad.shape)
// print("input grad", torch.round(input.grad, decimals=1))
// print(kernel.grad.shape)
// print("kernel grad", torch.round(kernel.grad.flatten(), decimals=1))
let padding = 4;
let outpadding = 2;
let dilation = 3;
let stride = 3;
let t = Var::from_slice(
&[
0.4056_f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997,
3.0616, 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843,
0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013,
-0.6836, 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130,
1.3123, 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071,
1.1586, 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090,
0.2049, 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323,
-1.3712, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742,
0.3790, -0.4431, -0.4720, -0.7890, 0.2620, 0.5411, -1.1715, -2.4997, 2.3249, -0.8912,
-0.4733, -0.5701, -2.8888, -1.4112, -0.5471, -0.9234, -1.1660, 0.4189, -0.7465,
-0.6473, 0.1402, 0.7875, 0.5377, -0.6779, -0.8088, -0.4864, -0.2312, 0.9279, 0.1264,
1.5480, 0.8265, -0.1025, 0.5138, -0.2512, 0.1576, 1.2705, 0.3641, -0.9325, 0.6451,
-0.8537, 0.2378, 0.1794, 0.2752, -0.3687, -1.1149, -0.1410, -0.5829, -0.0892, 1.4258,
-2.2789, 0.5270, 0.1825, 1.7007, -0.5263, -0.2954, 0.4440, 0.5537, 0.3492, 0.6186,
1.6475, 0.2219,
],
(1, 4, 7, 5),
dev,
)?;
#[rustfmt::skip]
let w = Var::from_slice(
&[
-1.1744_f32, 0.3266, 2.5893, 1.0142, 0.1763, 0.7752, 0.6604, 0.2029, -0.2145, 0.7234,
-0.3441, -1.5400, -0.6333, 0.6613, 0.2083, 0.6230, -1.7002, 0.3393, 0.4049, 1.0762,
0.2723, 1.4181, 0.0029, -0.2122, 1.7668, 1.4168, 0.3320, -0.2719, 0.7932, -0.7204,
0.4447, 0.1211, 0.5908, 1.0089, -0.1646, 1.8033, -0.6286, 0.2016, -0.3370, 1.2555,
0.8009, -0.6488, -0.4652, -1.5685, 1.5860, 0.5583, 0.4623, 0.6026, 0.8828, 2.4990,
0.6811, -0.3369, 1.3320, 1.7669, -1.1067, 1.2958, -0.9415, -0.9655, -0.4462, 0.7181,
0.5181, -1.1658, -1.8467, -0.7763, 1.2769, 0.8651, 0.9890, 1.5092, 0.7207, -0.8481,
0.7417, 0.3375, -1.2685, 1.4572, 1.0915, 0.1093, -0.8550, -0.5831, -0.6309, -0.2509,
0.5220, -0.0914, 0.7900, 0.1096, 0.3258, 0.2723, -1.0942, -0.3393, -0.1653, 0.5732,
-0.8014, 1.8194, -1.9023, 0.2127, 1.8636, -0.8979, 0.1927, -0.2778, 0.3105, 0.0071,
-1.1823, 0.2476, -0.7178, -1.3821, 1.0769, -0.4376, -0.9967, -0.1227, 1.6197, -1.0604,
0.1372, 0.8141, -0.6163, 0.7304, -0.8285, 2.0636, -0.7176, 0.2495, -0.2581, -0.4478,
],
(4, 2, 3, 5),
dev,
)?;
let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 2904.0);
let grads = loss.backward()?;
let grad_t = grads.get(&t).unwrap();
let grad_w = grads.get(&w).unwrap();
assert_eq!(grad_t.dims(), [1, 4, 7, 5]);
assert_eq!(grad_w.dims(), [4, 2, 3, 5]);
assert_eq!(
test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?,
[
// torch gets 89.1
-89.0, -135.3, 136.7, 102.0, -53.4, 117.9, 118.6, -43.9, -218.0, -58.5, -114.3, -150.0,
-15.6, 172.1, 66.3, -64.3, -27.9, -19.8, 31.7, 62.1, 5.5, 92.6, 28.2, -29.6, 55.9,
52.7, -72.7, -119.8, 53.8, -25.5, 128.8, 19.3, 68.0, 190.9, -64.1, -86.2, -111.2,
106.6, -67.7, 37.8, 115.9, 50.4, -77.7, -54.9, 22.3, -4.6, 89.8, 61.7, 122.4, 192.6,
-27.8, -104.6, 57.0, 166.4, 27.1, 6.1, 18.7, -93.2, 31.5, 168.2, -3.7, -99.5, -55.5,
-10.8, 17.5, 20.8, 16.9, 43.8, 42.0, -89.2, 18.8, -9.6, -84.1, 212.6, 19.7, -50.0,
-52.0, -40.0, -166.6, -73.2, -10.8, -73.3, 31.5, -23.4, -79.3, -27.0, -84.4, -42.9,
-20.3, 51.8, -16.7, 76.3, -120.5, -65.8, 96.5, -10.7, -45.9, -88.1, 65.4, -7.0, -1.5,
92.8, -25.1, -114.2, -5.8, -14.8, -51.2, -20.7, 54.2, -79.8, 47.7, -29.2, -8.8, 53.5,
-28.4, 85.0, -18.3, 107.0, 28.3, -71.8
]
);
assert_eq!(
test_utils::to_vec3_round(&grad_t.i(0)?, 1)?,
[
[
[32.3, -41.6, -24.0, 14.1, 17.6],
[-11.8, 72.5, 87.6, 46.4, 61.5],
[115.0, 108.5, -48.6, -63.4, -50.0],
[51.3, 5.4, 31.3, 91.1, -30.9],
[52.7, 92.8, -68.0, -47.0, 83.0],
// pytorch gets -107.1
[-10.2, -107.0, -5.4, 213.1, -31.4],
[-2.4, 65.1, 9.2, -146.2, -24.2]
],
[
[-72.6, -63.9, -61.9, 45.3, 33.0],
[79.3, -0.5, -26.2, 78.2, 42.7],
[90.9, 141.6, 40.1, -62.7, 37.0],
[32.8, 198.2, -0.8, -31.1, 27.3],
// torch gets 48.0
[34.5, 34.9, -47.9, 127.6, -12.3],
[-61.4, -3.2, -2.9, -10.9, -16.6],
[74.6, 60.1, -68.9, 34.5, -50.4]
],
[
[37.5, -56.9, -43.6, -13.5, -9.9],
[40.0, 97.3, 28.6, 14.2, -30.1],
[-22.3, -126.3, -68.8, -8.2, 26.1],
[-32.9, 37.3, 108.5, -54.8, 29.6],
[34.9, -176.9, -125.0, -28.3, -13.9],
[-54.9, 142.6, 62.1, -80.4, -65.6],
[7.4, -91.1, -67.6, 35.0, 39.7]
],
[
[-57.2, -40.9, -10.1, 32.6, 29.4],
[18.7, -18.0, 29.5, -1.2, 59.2],
[-14.0, -74.4, 19.8, -117.0, 58.2],
[-21.8, 163.5, -71.1, -99.0, 80.9],
[-58.9, -10.9, 93.8, -139.6, 98.0],
// torch gets 54.5
[-54.4, 135.3, 6.0, -79.1, 134.6],
[27.5, -76.0, 43.4, -2.8, -7.8]
]
]
);
Ok(()) Ok(())
} }

View File

@ -112,34 +112,3 @@ fn custom_op1_with_backward() -> Result<()> {
Ok(()) Ok(())
} }
impl candle_core::InplaceOp1 for Elu {
fn name(&self) -> &'static str {
"elu"
}
fn cpu_fwd(&self, s: &mut CpuStorage, _l: &Layout) -> Result<()> {
let alpha = self.alpha;
match s {
CpuStorage::BF16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
CpuStorage::F16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
CpuStorage::F32(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
CpuStorage::F64(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
_ => candle_core::bail!("unsupported dtype for inplace elu"),
}
Ok(())
}
}
#[test]
fn inplace_op1() -> Result<()> {
let cpu = &Device::Cpu;
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
let t = (t - 5.)?;
t.inplace_op1(&Elu { alpha: 1. })?;
assert_eq!(
to_vec1_round(&t, 4)?,
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
);
Ok(())
}

View File

@ -1,4 +1,3 @@
#![allow(clippy::approx_constant)]
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var}; use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
@ -97,24 +96,24 @@ fn unary_grad(device: &Device) -> Result<()> {
let grads = y.backward()?; let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?; let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!( assert_eq!(
test_utils::to_vec1_round(&y, 4)?, y.to_vec1::<f32>()?,
[20.0855, 2.7183, 54.5982, 1.1618] [20.085537, 2.7182817, 54.59815, 1.1618342]
); );
assert_eq!( assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?, grad_x.to_vec1::<f32>()?,
[20.0855, 2.7183, 54.5982, 1.1618] [20.085537, 2.7182817, 54.59815, 1.1618342]
); );
let y = x.exp()?.sqr()?; let y = x.exp()?.sqr()?;
let grads = y.backward()?; let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?; let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!( assert_eq!(
test_utils::to_vec1_round(&y, 3)?, y.to_vec1::<f32>()?,
[403.429, 7.389, 2980.958, 1.35] [403.4288, 7.3890557, 2980.9578, 1.3498588]
); );
// exp(x)^2 = exp(2*x) // exp(x)^2 = exp(2*x)
assert_eq!( assert_eq!(
test_utils::to_vec1_round(grad_x, 2)?, grad_x.to_vec1::<f32>()?,
[806.86, 14.78, 5961.92, 2.7] [806.8576, 14.778111, 5961.9155, 2.6997175]
); );
let y = x.sin()?; let y = x.sin()?;
let grads = y.backward()?; let grads = y.backward()?;
@ -262,7 +261,6 @@ fn unary_grad(device: &Device) -> Result<()> {
let y = elu_x.elu(2.)?; let y = elu_x.elu(2.)?;
let grads = y.backward()?; let grads = y.backward()?;
let grad_x = grads.get(&elu_x).context("no grad for x")?; let grad_x = grads.get(&elu_x).context("no grad for x")?;
assert_eq!( assert_eq!(
test_utils::to_vec1_round(&y, 4)?, test_utils::to_vec1_round(&y, 4)?,
[-1.2642, 0.0000, -1.7293, 3.0000] [-1.2642, 0.0000, -1.7293, 3.0000]
@ -272,51 +270,19 @@ fn unary_grad(device: &Device) -> Result<()> {
[0.7358, 2.0000, 0.2707, 1.0000] [0.7358, 2.0000, 0.2707, 1.0000]
); );
// testing compared to pytorch nn.Silu()
let y = x.silu()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.8577, 0.7311, 3.9281, 0.0806]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[1.0881, 0.9277, 1.0527, 0.5747],
);
if device.is_cpu() {
let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?;
let y = x.interpolate1d(12)?.reshape(36)?;
let z = Tensor::new(
&[
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16.,
17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.,
33., 34., 35., 36.,
],
device,
)?;
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec3_round(grad_x, 4)?,
[[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]]
);
}
// manually checked: see comments // manually checked: see comments
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?; let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?; let y = x.interpolate2d(6, 6)?.reshape(36)?;
#[rustfmt::skip]
let z = Tensor::new( let z = Tensor::new(
&[ &[
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17., 1_f32, 02., 03., 04., 05., 06.,
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 07., 08., 09., 10., 11., 12.,
35., 36., 13., 14., 15., 16., 17., 18.,
19., 20., 21., 22., 23., 24.,
25., 26., 27., 28., 29., 30.,
31., 32., 33., 34., 35., 36.,
], ],
device, device,
)?; )?;
@ -347,11 +313,15 @@ fn unary_grad(device: &Device) -> Result<()> {
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?; let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?; let y = x.interpolate2d(6, 6)?.reshape(36)?;
#[rustfmt::skip]
let z = Tensor::new( let z = Tensor::new(
&[ &[
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17., 1_f32, 02., 03., 04., 05., 06.,
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 07., 08., 09., 10., 11., 12.,
35., 36., 13., 14., 15., 16., 17., 18.,
19., 20., 21., 22., 23., 24.,
25., 26., 27., 28., 29., 30.,
31., 32., 33., 34., 35., 36.,
], ],
device, device,
)?; )?;

View File

@ -88,7 +88,7 @@ fn strided_blocks() -> Result<()> {
} }
}; };
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
let tensor = tensor.i((.., 1))?.contiguous()?; let tensor = tensor.i((.., 1))?;
match tensor.strided_blocks() { match tensor.strided_blocks() {
candle::StridedBlocks::SingleBlock { start_offset, len } => { candle::StridedBlocks::SingleBlock { start_offset, len } => {
assert_eq!(start_offset, 0); assert_eq!(start_offset, 0);
@ -100,20 +100,6 @@ fn strided_blocks() -> Result<()> {
} }
}; };
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
let tensor = tensor.i((.., 1))?;
match tensor.strided_blocks() {
candle::StridedBlocks::SingleBlock { .. } => {
panic!("unexpected block structure")
}
candle::StridedBlocks::MultipleBlocks {
block_len,
block_start_index,
} => {
assert_eq!(block_len, 4);
assert_eq!(block_start_index.collect::<Vec<_>>(), &[4, 16])
}
};
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
match tensor.t()?.strided_blocks() { match tensor.t()?.strided_blocks() {
candle::StridedBlocks::SingleBlock { .. } => { candle::StridedBlocks::SingleBlock { .. } => {
panic!("unexpected block structure") panic!("unexpected block structure")

View File

@ -1,106 +0,0 @@
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
fn matmul(device: &Device) -> Result<()> {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let a = Tensor::from_slice(&data, (2, 2), device)?;
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let b = Tensor::from_slice(&data, (2, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
let data = vec![1.0f32, 2.0];
let a = Tensor::from_slice(&data, (2, 1), device)?;
let data = vec![3.0f32, 4.0];
let b = Tensor::from_slice(&data, (1, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 3), device)?;
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (3, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
let c = a.matmul(&b)?;
assert_eq!(c.to_vec3::<f32>()?, &expected);
// Also perform the matmul on contiguous transposed versions.
let a_tt = a.t()?.contiguous()?.t()?;
assert!(!a_tt.is_contiguous());
assert_eq!(a.dims(), a_tt.dims());
assert_eq!(a_tt.stride(), &[6, 1, 2]);
let b_tt = b.t()?.contiguous()?.t()?;
assert!(!b_tt.is_contiguous());
assert_eq!(b.dims(), b_tt.dims());
assert_eq!(b_tt.stride(), &[6, 1, 3]);
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
Ok(())
}
fn broadcast_matmul(device: &Device) -> Result<()> {
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
let out = lhs.broadcast_matmul(&rhs)?;
assert_eq!(out.dims(), &[3, 6, 4, 2]);
for idx1 in 0..3 {
for idx2 in 0..6 {
let out = out.i((idx1, idx2))?;
let lhs = lhs.i((idx1, 0))?;
let rhs = rhs.i(idx2)?;
let out2 = lhs.matmul(&rhs);
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
// With cuda, we see errors of up to ~1e-12.
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
}
}
Ok(())
}
// https://github.com/huggingface/candle/issues/1948
fn squeeze_mm(device: &Device) -> Result<()> {
let seq_len = 8_usize;
let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
let x = a.i((.., seq_len - 1, ..))?;
let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
let x = x.matmul(&w)?;
assert_eq!(x.dims(), &[1, 32]);
Ok(())
}
// https://github.com/huggingface/candle/issues/1992
fn mm_layout(device: &Device) -> Result<()> {
let a = Tensor::arange(0f32, 16f32, device)?.reshape((1, 1, 4, 4))?;
let b = Tensor::arange(0f32, 8f32, device)?.reshape((1, 1, 4, 2))?;
let mm1 = a.matmul(&b)?;
// Forces the layout to be:
// shape: [1, 1, 4, 2], stride: [8, 2, 2, 1], start_offset: 0
// This is still a contiguous matrix but matmul checks are only the two last dimensions have
// non 1 sizes but matmul check may be reluctant to handle it.
let b = b.transpose(1, 2)?.force_contiguous()?.transpose(1, 2)?;
let mm2 = a.matmul(&b)?;
let diff = (mm1 - mm2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
Ok(())
}
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
test_device!(
broadcast_matmul,
broadcast_matmul_cpu,
broadcast_matmul_gpu,
broadcast_matmul_metal
);
test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal);
test_device!(mm_layout, mm_layout_cpu, mm_layout_gpu, mm_layout_metal);

View File

@ -43,9 +43,6 @@ res = torch.nn.functional.avg_pool2d(t, 2)
print(res) print(res)
*/ */
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> { fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
if dev.is_metal() {
return Ok(());
}
let t = Tensor::new( let t = Tensor::new(
&[ &[
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616, 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,

View File

@ -1,37 +0,0 @@
import torch
from collections import OrderedDict
# Write a trivial tensor to a pt file
a= torch.tensor([[1,2,3,4], [5,6,7,8]])
o = OrderedDict()
o["test"] = a
# Write a trivial tensor to a pt file
torch.save(o, "test.pt")
############################################################################################################
# Write a trivial tensor to a pt file with a key
torch.save({"model_state_dict": o}, "test_with_key.pt")
############################################################################################################
# Create a tensor with fortran contiguous memory layout
import numpy as np
# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers
# For example, creating a 2x3x4 array
array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4))
# Verify the memory order
print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True
print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False
# Step 2: Convert the NumPy array to a PyTorch tensor
tensor_fortran = torch.from_numpy(array_fortran)
# Verify the tensor layout
print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout
# Step 3: Save the PyTorch tensor to a .pth file
torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth')
print("3D Tensor saved with Fortran layout.")

View File

@ -1,31 +0,0 @@
/// Regression test for pth files not loading on Windows.
#[test]
fn test_pth() {
let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
tensors.get("test").unwrap().unwrap();
}
#[test]
fn test_pth_with_key() {
let tensors =
candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
.unwrap();
tensors.get("test").unwrap().unwrap();
}
#[test]
fn test_pth_fortran_congiguous() {
let tensors =
candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap();
let tensor = tensors.get("tensor_fortran").unwrap().unwrap();
assert_eq!(tensor.dims3().unwrap(), (2, 3, 4));
assert_eq!(
tensor.to_vec3::<i64>().unwrap(),
[
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
]
);
}

View File

@ -3,7 +3,7 @@ use candle_core::{
quantized::{self, GgmlDType}, quantized::{self, GgmlDType},
test_device, test_device,
test_utils::to_vec2_round, test_utils::to_vec2_round,
Device, IndexOp, Module, Result, Tensor, Device, Module, Result, Tensor,
}; };
use quantized::{k_quants, GgmlType}; use quantized::{k_quants, GgmlType};
use rand::prelude::*; use rand::prelude::*;
@ -47,14 +47,18 @@ fn test_matmul(
} }
fn quantized_matmul(device: &Device) -> Result<()> { fn quantized_matmul(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let (m, k, n) = (3, 64, 4); let (m, k, n) = (3, 64, 4);
let lhs_s = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>(); let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?; let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
let mut dst = vec![42.; 3 * 4]; let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>(); let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?; k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!( assert_eq!(
dst.iter().map(|x| x.round()).collect::<Vec<_>>(), dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
&[ &[
@ -63,7 +67,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
] ]
); );
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?; let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
let mm = lhs.matmul(&tensor_rhs)?; let mm = tensor_lhs.matmul(&tensor_rhs)?;
assert_eq!( assert_eq!(
mm.to_vec2::<f32>()?, mm.to_vec2::<f32>()?,
&[ &[
@ -75,7 +79,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?; let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&lhs)?; let res = matmul.forward(&tensor_lhs)?;
match device { match device {
Device::Metal(_) => assert_eq!( Device::Metal(_) => assert_eq!(
to_vec2_round(&res, 0)?, to_vec2_round(&res, 0)?,
@ -85,15 +89,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
[341970.0, 994574.0, 1656181.0, 2302182.0] [341970.0, 994574.0, 1656181.0, 2302182.0]
] ]
), ),
Device::Cuda(_) => assert_eq!( _ => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[84866.0, 214045.0, 344676.0, 473707.0],
[213425.0, 604313.0, 1000431.0, 1387960.0],
[342030.0, 994630.0, 1656248.0, 2302250.0]
]
),
Device::Cpu => assert_eq!(
to_vec2_round(&res, 0)?, to_vec2_round(&res, 0)?,
&[ &[
[85120.0, 214562.0, 345455.0, 474748.0], [85120.0, 214562.0, 345455.0, 474748.0],
@ -102,16 +98,22 @@ fn quantized_matmul(device: &Device) -> Result<()> {
] ]
), ),
} }
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?; test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
Ok(()) Ok(())
} }
fn quantized_matmul_neg(device: &Device) -> Result<()> { fn quantized_matmul_neg(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let (m, k, n) = (3, 64, 4); let (m, k, n) = (3, 64, 4);
let lhs_s = (0..(m * k)) let lhs = (0..(m * k))
.map(|v| v as f32 - (m * k) as f32 / 2.0) .map(|v| v as f32 - (m * k) as f32 / 2.0)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?; let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
let mut dst = vec![42.; 3 * 4]; let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..k * n) let rhs = (0..k * n)
@ -119,7 +121,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?; let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?; k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!( assert_eq!(
dst.iter().map(|x| x.round()).collect::<Vec<_>>(), dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
&[ &[
@ -127,7 +129,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
-196472.0, 63012.0, 324585.0, 587902.0 -196472.0, 63012.0, 324585.0, 587902.0
] ]
); );
let mm = lhs.matmul(&tensor_rhs)?; let mm = tensor_lhs.matmul(&tensor_rhs)?;
assert_eq!( assert_eq!(
to_vec2_round(&mm, 0)?, to_vec2_round(&mm, 0)?,
&[ &[
@ -139,7 +141,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?; let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&lhs)?; let res = matmul.forward(&tensor_lhs)?;
match device { match device {
Device::Metal(_) => assert_eq!( Device::Metal(_) => assert_eq!(
to_vec2_round(&res, 0)?, to_vec2_round(&res, 0)?,
@ -149,15 +151,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
[-196102.0, 63022.0, 324233.0, 587191.0] [-196102.0, 63022.0, 324233.0, 587191.0]
] ]
), ),
Device::Cuda(_) => assert_eq!( _ => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243740.0, -19762.0, -285476.0, -550498.0],
[23774.0, 21645.0, 19395.0, 18364.0],
[-196045.0, 63030.0, 324120.0, 587079.0]
]
),
Device::Cpu => assert_eq!(
to_vec2_round(&res, 0)?, to_vec2_round(&res, 0)?,
&[ &[
[243524.0, -19596.0, -285051.0, -549815.0], [243524.0, -19596.0, -285051.0, -549815.0],
@ -166,60 +160,28 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
] ]
), ),
} }
let lhs2 = Tensor::stack(&[&lhs, &lhs], 0)?;
let res2 = matmul.forward(&lhs2)?;
let res2 = res2.i(1)?;
let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if device.is_cuda() {
assert!(diff < 0.1);
} else {
assert_eq!(diff, 0.);
}
Ok(()) Ok(())
} }
fn qmm_batch(dev: &Device) -> Result<()> { test_device!(
let (lhs, rhs, _mm) = get_random_tensors(2, 256, 6, dev)?; quantized_matmul,
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?; quantized_matmul_cpu,
let rhs = quantized::QMatMul::from_qtensor(rhs)?; quantized_matmul_cuda,
let mm = rhs.forward(&lhs)?; quantized_matmul_metal
assert_eq!(mm.shape().dims(), [2, 6]); );
let lhs2 = Tensor::cat(&[&lhs, &lhs], 0)?; test_device!(
let mm2 = rhs.forward(&lhs2)?; quantized_matmul_neg,
assert_eq!(mm2.shape().dims(), [4, 6]); quantized_matmul_neg_cpu,
let diff2 = (mm2.i(2..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?; quantized_matmul_neg_cuda,
assert_eq!(diff2, 0.0); quantized_matmul_neg_metal
let lhs3 = Tensor::cat(&[&lhs2, &lhs], 0)?; );
let mm3 = rhs.forward(&lhs3)?;
assert_eq!(mm3.shape().dims(), [6, 6]);
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff3, 0.0);
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff3, 0.0);
let lhs4 = Tensor::cat(&[&lhs3, &lhs3], 0)?;
let mm4 = rhs.forward(&lhs4)?;
assert_eq!(mm4.shape().dims(), [12, 6]);
let diff4 = (mm4.i(..6)? - &mm3)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if dev.is_cuda() {
// We use a different kernel for sizes from 1 to 8 on cuda which explains
// the difference here.
assert!(0. < diff4 && diff4 < 1e-4)
} else {
assert_eq!(diff4, 0.0)
};
let diff4 = (mm4.i(6..)? - &mm4.i(..6)?)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff4, 0.0);
Ok(())
}
test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal);
test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal);
test_device!(qmm_batch, qmm_b_cpu, qmm_b_cuda, qmm_b_metal);
fn quantize_q4_0(device: &Device) -> Result<()> { fn quantize_q4_0(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let src = Tensor::from_slice(&src, (32 * 4,), device)?;
@ -247,6 +209,10 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
} }
fn quantize_q4_1(device: &Device) -> Result<()> { fn quantize_q4_1(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?; let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
@ -273,6 +239,10 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
} }
fn quantize_q5_0(device: &Device) -> Result<()> { fn quantize_q5_0(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?; let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
@ -299,6 +269,10 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
} }
fn quantize_q5_1(device: &Device) -> Result<()> { fn quantize_q5_1(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?; let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
@ -399,6 +373,10 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
} }
fn quantize_q2k(device: &Device) -> Result<()> { fn quantize_q2k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q2K; let dtype = GgmlDType::Q2K;
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
@ -433,6 +411,10 @@ fn quantize_q2k(device: &Device) -> Result<()> {
} }
fn quantize_q3k(device: &Device) -> Result<()> { fn quantize_q3k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q3K; let dtype = GgmlDType::Q3K;
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
@ -466,6 +448,10 @@ fn quantize_q3k(device: &Device) -> Result<()> {
} }
fn quantize_q4k(device: &Device) -> Result<()> { fn quantize_q4k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q4K; let dtype = GgmlDType::Q4K;
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
@ -499,6 +485,10 @@ fn quantize_q4k(device: &Device) -> Result<()> {
} }
fn quantize_q5k(device: &Device) -> Result<()> { fn quantize_q5k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q5K; let dtype = GgmlDType::Q5K;
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
@ -532,6 +522,10 @@ fn quantize_q5k(device: &Device) -> Result<()> {
} }
fn quantize_q6k(device: &Device) -> Result<()> { fn quantize_q6k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q6K; let dtype = GgmlDType::Q6K;
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
@ -565,6 +559,10 @@ fn quantize_q6k(device: &Device) -> Result<()> {
} }
fn quantize_q8k(device: &Device) -> Result<()> { fn quantize_q8k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q8K; let dtype = GgmlDType::Q8K;
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
@ -780,6 +778,10 @@ macro_rules! quantized_matmul {
// stable. https://github.com/rust-lang/rust/issues/29599 // stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => { ($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
fn $fn_name(device: &Device) -> Result<()> { fn $fn_name(device: &Device) -> Result<()> {
if device.is_cuda() {
// TODO Enable Cuda GGML sometime maybe.
return Ok(());
}
test_matmul(device, (1, 3, 4, 256), $dtype)?; test_matmul(device, (1, 3, 4, 256), $dtype)?;
Ok(()) Ok(())
} }

View File

@ -106,9 +106,6 @@ fn unary_op(device: &Device) -> Result<()> {
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933] [2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
] ]
); );
let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
let max_diff = (tensor.gelu()? - t_f16)?.flatten_all()?.max(0)?;
assert!(max_diff.to_vec0::<f32>()? < 5e-3);
assert_eq!( assert_eq!(
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?, test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
[ [
@ -123,13 +120,6 @@ fn unary_op(device: &Device) -> Result<()> {
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999] [0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
] ]
); );
assert_eq!(
test_utils::to_vec2_round(&tensor.silu()?, 4)?,
[
[-0.1423, 0.7311, 3.9281, -0.0475, 0.3112],
[2.53, -0.2553, -0.1205, 1.5447, 2.6395]
]
);
assert_eq!( assert_eq!(
test_utils::to_vec2_round(&tensor.ceil()?, 4)?, test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]] [[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
@ -151,14 +141,6 @@ fn unary_op(device: &Device) -> Result<()> {
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?, test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
[3000.0, 300.] [3000.0, 300.]
); );
let tensor = Tensor::new(
&[-1.01f32, -0.9, -0.1, 0.0, -0.0, 0.1, 0.9, 1.0, 1.1],
device,
)?;
assert_eq!(
tensor.sign()?.to_vec1::<f32>()?,
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
);
Ok(()) Ok(())
} }
@ -683,31 +665,6 @@ fn cat(device: &Device) -> Result<()> {
[2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0] [2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
] ]
); );
// 3D
let t1 = Tensor::arange(0, 48i64, device)?.reshape((2, 6, 4))?;
let t2 = Tensor::arange(100, 124i64, device)?.reshape((2, 3, 4))?;
let t3 = Tensor::arange(10000, 10032i64, device)?.reshape((2, 4, 4))?;
let t_cat = Tensor::cat(&[&t1, &t2, &t3], 1)?;
let t1 = t1.t()?.contiguous()?.t()?;
let t2 = t2.t()?.contiguous()?.t()?;
let t3 = t3.t()?.contiguous()?.t()?;
let t_cat2 = Tensor::cat(&[&t1, &t2, &t3], 1)?;
let diff = t_cat.eq(&t_cat2)?.to_dtype(DType::F32)?.sum_all()?;
assert_eq!(diff.to_vec0::<f32>()?, 104.0);
assert_eq!(t_cat.i((0, 0, 0))?.to_vec0::<i64>()?, 0);
assert_eq!(t_cat.i((0, 4, 0))?.to_vec0::<i64>()?, 16);
assert_eq!(t_cat.i((0, 5, 0))?.to_vec0::<i64>()?, 20);
assert_eq!(t_cat.i((1, 5, 0))?.to_vec0::<i64>()?, 44);
assert_eq!(t_cat.i((0, 6, 0))?.to_vec0::<i64>()?, 100);
assert_eq!(t_cat.i((1, 6, 0))?.to_vec0::<i64>()?, 112);
assert_eq!(t_cat.i((0, 6, 1))?.to_vec0::<i64>()?, 101);
assert_eq!(t_cat.i((0, 7, 1))?.to_vec0::<i64>()?, 105);
assert_eq!(t_cat.i((0, 12, 1))?.to_vec0::<i64>()?, 10013);
assert_eq!(t_cat.i((1, 12, 3))?.to_vec0::<i64>()?, 10031);
Ok(()) Ok(())
} }
@ -718,8 +675,6 @@ fn embeddings(device: &Device) -> Result<()> {
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids, 0)?; let hs = t.index_select(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
Ok(()) Ok(())
} }
@ -747,47 +702,44 @@ fn index_select(device: &Device) -> Result<()> {
[9.0, 10.0, 11.0] [9.0, 10.0, 11.0]
] ]
); );
for dtype in [DType::U8, DType::U32, DType::I64] { let hs = t.index_select(&ids, 1)?;
let ids = ids.to_dtype(dtype)?; assert_eq!(
let hs = t.index_select(&ids, 1)?; hs.to_vec2::<f32>()?,
assert_eq!( &[
hs.to_vec2::<f32>()?, [0.0, 2.0, 1.0],
&[ [3.0, 5.0, 4.0],
[0.0, 2.0, 1.0], [6.0, 8.0, 7.0],
[3.0, 5.0, 4.0], [9.0, 11.0, 10.0]
[6.0, 8.0, 7.0], ]
[9.0, 11.0, 10.0] );
] let hs = t.index_select(&ids, 0)?;
); assert_eq!(
let hs = t.index_select(&ids, 0)?; hs.to_vec2::<f32>()?,
assert_eq!( &[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
hs.to_vec2::<f32>()?, );
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]] // Prior to https://github.com/huggingface/candle/pull/1022
); // There would be a bug where the last values in the result tensor would be set to 0.
// Prior to https://github.com/huggingface/candle/pull/1022 let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
// There would be a bug where the last values in the result tensor would be set to 0. let hs = t.index_select(&ids, 0)?;
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?; assert_eq!(
let hs = t.index_select(&ids, 0)?; hs.to_vec2::<f32>()?,
assert_eq!( &[
hs.to_vec2::<f32>()?, [0.0, 1.0, 2.0],
&[ [6.0, 7.0, 8.0],
[0.0, 1.0, 2.0], [3.0, 4.0, 5.0],
[6.0, 7.0, 8.0], [0.0, 1.0, 2.0],
[3.0, 4.0, 5.0], [6.0, 7.0, 8.0],
[0.0, 1.0, 2.0], [3.0, 4.0, 5.0],
[6.0, 7.0, 8.0], ]
[3.0, 4.0, 5.0], );
]
);
// Test when selecting dim > 0 with ids size different from elem count of // Test when selecting dim > 0 with ids size different from elem count of
// target dim in source/input. // target dim in source/input.
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?; let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?; let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]); assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
let hs = t.index_select(&ids, 1)?; let hs = t.index_select(&ids, 1)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]); assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
}
Ok(()) Ok(())
} }
@ -949,6 +901,74 @@ fn gather(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
fn matmul(device: &Device) -> Result<()> {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let a = Tensor::from_slice(&data, (2, 2), device)?;
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let b = Tensor::from_slice(&data, (2, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
let data = vec![1.0f32, 2.0];
let a = Tensor::from_slice(&data, (2, 1), device)?;
let data = vec![3.0f32, 4.0];
let b = Tensor::from_slice(&data, (1, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 3), device)?;
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (3, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
let c = a.matmul(&b)?;
assert_eq!(c.to_vec3::<f32>()?, &expected);
// Also perform the matmul on contiguous transposed versions.
let a_tt = a.t()?.contiguous()?.t()?;
assert!(!a_tt.is_contiguous());
assert_eq!(a.dims(), a_tt.dims());
assert_eq!(a_tt.stride(), &[6, 1, 2]);
let b_tt = b.t()?.contiguous()?.t()?;
assert!(!b_tt.is_contiguous());
assert_eq!(b.dims(), b_tt.dims());
assert_eq!(b_tt.stride(), &[6, 1, 3]);
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
Ok(())
}
fn broadcast_matmul(device: &Device) -> Result<()> {
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
let out = lhs.broadcast_matmul(&rhs)?;
assert_eq!(out.dims(), &[3, 6, 4, 2]);
for idx1 in 0..3 {
for idx2 in 0..6 {
let out = out.i((idx1, idx2))?;
let lhs = lhs.i((idx1, 0))?;
let rhs = rhs.i(idx2)?;
let out2 = lhs.matmul(&rhs);
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
// With cuda, we see errors of up to ~1e-12.
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
}
}
Ok(())
}
fn broadcasting(device: &Device) -> Result<()> { fn broadcasting(device: &Device) -> Result<()> {
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?; let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
let t2 = Tensor::new(&[100f32, 200f32], device)?; let t2 = Tensor::new(&[100f32, 200f32], device)?;
@ -1053,54 +1073,8 @@ fn broadcasting(device: &Device) -> Result<()> {
fn randn(device: &Device) -> Result<()> { fn randn(device: &Device) -> Result<()> {
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?; let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
assert_eq!(tensor.dims(), [5, 3]); assert_eq!(tensor.dims(), [5, 3]);
// Check that the seed gets updated by checking that
// a new series of numbers is generated each time
let tensor2 = Tensor::randn(0f32, 1f32, (5, 3), device)?;
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?; let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
assert_eq!(tensor.dims(), [5, 3]); assert_eq!(tensor.dims(), [5, 3]);
// Check that the seed gets updated by checking that
// a new series of numbers is generated each time
let tensor2 = Tensor::rand(0f32, 1f32, (5, 3), device)?;
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
// We do not expect deterministic elements at any index.
// There once was a bug that had a deterministic zero element in evenly sized tensors.
const N: usize = 2;
let v = (0..100)
.map(|_| Tensor::randn(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
.collect::<Result<Vec<_>>>()?;
assert!(
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
"There are deterministic values in the randn tensors"
);
let v = (0..100)
.map(|_| Tensor::rand(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
.collect::<Result<Vec<_>>>()?;
assert!(
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
"There are deterministic values in the rand tensors"
);
Ok(())
}
fn zero_dim(device: &Device) -> Result<()> {
let t = Tensor::zeros((4, 0, 1), DType::F32, device)?;
assert_eq!(t.dims3()?, (4, 0, 1));
let t2 = Tensor::zeros((4, 3, 1), DType::F32, device)?;
let t_cat = Tensor::cat(&[&t, &t2], 1)?;
assert_eq!(t_cat.dims3()?, (4, 3, 1));
let t_cat = Tensor::cat(&[&t, &t], 1)?;
assert_eq!(t_cat.dims3()?, (4, 0, 1));
let t_unary = t.sqrt()?;
assert_eq!(t_unary.dims3()?, (4, 0, 1));
let t_plus = (&t + 1.)?;
assert_eq!(t_plus.dims3()?, (4, 0, 1));
let t_mm = t2.matmul(&t.t()?)?;
assert_eq!(t_mm.dims3()?, (4, 3, 0));
let t_mm = t.matmul(&t2.t()?)?;
assert_eq!(t_mm.dims3()?, (4, 0, 3));
let t_mm = t.t()?.matmul(&t)?;
assert_eq!(t_mm.dims3()?, (4, 1, 1));
Ok(()) Ok(())
} }
@ -1123,6 +1097,13 @@ test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal); test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal); test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal); test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
test_device!(
broadcast_matmul,
broadcast_matmul_cpu,
broadcast_matmul_gpu,
broadcast_matmul_metal
);
test_device!( test_device!(
broadcasting, broadcasting,
broadcasting_cpu, broadcasting_cpu,
@ -1152,7 +1133,6 @@ test_device!(
test_device!(randn, randn_cpu, randn_gpu, randn_metal); test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal); test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(var, var_cpu, var_gpu, var_metal); test_device!(var, var_cpu, var_gpu, var_metal);
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
// There was originally a bug on the CPU implementation for randn // There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381 // https://github.com/huggingface/candle/issues/381
@ -1280,8 +1260,8 @@ fn pow() -> Result<()> {
let rhs = (&lhs - 2.)?; let rhs = (&lhs - 2.)?;
let res = lhs.pow(&rhs)?; let res = lhs.pow(&rhs)?;
assert_eq!( assert_eq!(
test_utils::to_vec2_round(&res, 3)?, test_utils::to_vec2_round(&res, 4)?,
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0]] [[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
); );
Ok(()) Ok(())
} }

Binary file not shown.

Binary file not shown.

View File

@ -12,7 +12,7 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { workspace = true } candle = { workspace = true }
candle-datasets = { workspace = true, optional = true } candle-datasets = { workspace = true }
candle-nn = { workspace = true } candle-nn = { workspace = true }
candle-transformers = { workspace = true } candle-transformers = { workspace = true }
candle-flash-attn = { workspace = true, optional = true } candle-flash-attn = { workspace = true, optional = true }
@ -21,19 +21,16 @@ candle-onnx = { workspace = true, optional = true }
csv = "1.3.0" csv = "1.3.0"
cudarc = { workspace = true, optional = true } cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true } half = { workspace = true, optional = true }
hf-hub = { workspace = true, features = ["tokio"] } hf-hub = { workspace = true, features=["tokio"]}
image = { workspace = true } image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true } num-traits = { workspace = true }
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true } pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true } rayon = { workspace = true }
rubato = { version = "0.15.0", optional = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = { workspace = true, features = ["onig"] } tokenizers = { workspace = true, features = ["onig"] }
cpal= { version = "0.15.2", optional = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
@ -42,10 +39,11 @@ clap = { workspace = true }
imageproc = { workspace = true } imageproc = { workspace = true }
memmap2 = { workspace = true } memmap2 = { workspace = true }
rand = { workspace = true } rand = { workspace = true }
ab_glyph = { workspace = true } rusttype = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
tracing-chrome = { workspace = true } tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true } tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1 # Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.29.1" tokio = "1.29.1"
@ -63,8 +61,6 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
nccl = ["cuda", "cudarc/nccl", "dep:half"] nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"] onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"] metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"]
encodec = ["cpal", "symphonia", "rubato"]
[[example]] [[example]]
name = "llama_multiprocess" name = "llama_multiprocess"
@ -81,23 +77,3 @@ required-features = ["onnx"]
[[example]] [[example]]
name = "onnx_basics" name = "onnx_basics"
required-features = ["onnx"] required-features = ["onnx"]
[[example]]
name = "whisper"
required-features = ["symphonia"]
[[example]]
name = "whisper-microphone"
required-features = ["microphone"]
[[example]]
name = "mnist-training"
required-features = ["candle-datasets"]
[[example]]
name = "llama2-c"
required-features = ["candle-datasets"]
[[example]]
name = "encodec"
required-features = ["encodec"]

View File

@ -1,237 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::chatglm::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
if tokens.is_empty() {
anyhow::bail!("Empty prompts are not supported in the chatglm model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("</s>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the endoftext token"),
};
print!("{prompt}");
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
weight_file: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => "THUDM/chatglm3-6b".to_string(),
};
let revision = match args.revision {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("lmz/candle-chatglm".to_string())
.get("chatglm-tokenizer.json")?,
};
let filenames = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::glm3_6b();
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,46 +0,0 @@
Contrastive Language-Image Pre-Training
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
pairs of images with related texts.
https://github.com/openai/CLIP
https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
## Running on an example on cpu
```
$ cargo run --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
INFO clip: Probability: 0.0000% Text: a cycling race
INFO clip: Probability: 0.0000% Text: a photo of two cats
INFO clip: Probability: 100.0000% Text: a robot holding a candle
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
INFO clip: Probability: 99.9999% Text: a cycling race
INFO clip: Probability: 0.0001% Text: a photo of two cats
INFO clip: Probability: 0.0000% Text: a robot holding a candle
```
## Running on an example with metal feature (mac)
```
$ cargo run --features metal --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
INFO clip: Probability: 0.0000% Text: a cycling race
INFO clip: Probability: 0.0000% Text: a photo of two cats
INFO clip: Probability: 100.0000% Text: a robot holding a candle
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
INFO clip: Probability: 99.9999% Text: a cycling race
INFO clip: Probability: 0.0001% Text: a photo of two cats
INFO clip: Probability: 0.0000% Text: a robot holding a candle
```

View File

@ -1,202 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::clip;
use tokenizers::Tokenizer;
use tracing::info;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long, use_value_delimiter = true)]
images: Option<Vec<String>>,
#[arg(long)]
cpu: bool,
#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,
}
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
let img = image::io::Reader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
// .unsqueeze(0)?;
Ok(img)
}
fn load_images<T: AsRef<std::path::Path>>(
paths: &Vec<T>,
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = load_image(path, image_size)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
pub fn main() -> anyhow::Result<()> {
// std::env::set_var("RUST_BACKTRACE", "full");
let args = Args::parse();
tracing_subscriber::fmt::init();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.repo(hf_hub::Repo::with_revision(
"openai/clip-vit-base-patch32".to_string(),
hf_hub::RepoType::Model,
"refs/pr/15".to_string(),
));
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let tokenizer = get_tokenizer(args.tokenizer)?;
let config = clip::ClipConfig::vit_base_patch32();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
// let image = load_image(args.image, config.image_size)?.to_device(&device)?;
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let model = clip::ClipModel::new(vb, &config)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
info!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
info!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
}
}
Ok(())
}
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
let tokenizer = match tokenizer {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.repo(hf_hub::Repo::with_revision(
"openai/clip-vit-base-patch32".to_string(),
hf_hub::RepoType::Model,
"refs/pr/15".to_string(),
));
api.get("tokenizer.json")?
}
Some(file) => file.into(),
};
Tokenizer::from_file(tokenizer).map_err(E::msg)
}
pub fn tokenize_sequences(
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
device: &Device,
) -> anyhow::Result<(Tensor, Vec<String>)> {
let pad_id = *tokenizer
.get_vocab(true)
.get("<|endoftext|>")
.ok_or(E::msg("No pad token"))?;
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
"a cycling race".to_string(),
"a photo of two cats".to_string(),
"a robot holding a candle".to_string(),
],
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
if len_diff > 0 {
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -28,7 +28,7 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}"); println!("loaded image {image:?}");
let model_file = match args.model { let model_file = match args.model {

View File

@ -1,23 +0,0 @@
# candle-convnext
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808).
This candle implementation uses a pre-trained ConvNeXt network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example convnext --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which tiny
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 84.09%
bicycle-built-for-two, tandem bicycle, tandem: 4.15%
maillot : 0.74%
crash helmet : 0.54%
unicycle, monocycle : 0.44%
```

View File

@ -1,126 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::convnext;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
Atto,
Femto,
Pico,
Nano,
Tiny,
Small,
Base,
Large,
AttoV2,
FemtoV2,
PicoV2,
NanoV2,
TinyV2,
BaseV2,
LargeV2,
XLarge,
Huge,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::Atto => "convnext_atto.d2_in1k",
Self::Femto => "convnext_femto.d1_in1k",
Self::Pico => "convnext_pico.d1_in1k",
Self::Nano => "convnext_nano.d1h_in1k",
Self::Tiny => "convnext_tiny.fb_in1k",
Self::Small => "convnext_small.fb_in1k",
Self::Base => "convnext_base.fb_in1k",
Self::Large => "convnext_large.fb_in1k",
Self::AttoV2 => "convnextv2_atto.fcmae_ft_in1k",
Self::FemtoV2 => "convnextv2_femto.fcmae_ft_in1k",
Self::PicoV2 => "convnextv2_pico.fcmae_ft_in1k",
Self::NanoV2 => "convnextv2_nano.fcmae_ft_in1k",
Self::TinyV2 => "convnextv2_tiny.fcmae_ft_in1k",
Self::BaseV2 => "convnextv2_base.fcmae_ft_in1k",
Self::LargeV2 => "convnextv2_large.fcmae_ft_in1k",
Self::XLarge => "convnext_xlarge.fb_in22k_ft_in1k",
Self::Huge => "convnextv2_huge.fcmae_ft_in1k",
};
format!("timm/{name}")
}
fn config(&self) -> convnext::Config {
match self {
Self::Atto | Self::AttoV2 => convnext::Config::atto(),
Self::Femto | Self::FemtoV2 => convnext::Config::femto(),
Self::Pico | Self::PicoV2 => convnext::Config::pico(),
Self::Nano | Self::NanoV2 => convnext::Config::nano(),
Self::Tiny | Self::TinyV2 => convnext::Config::tiny(),
Self::Small => convnext::Config::small(),
Self::Base | Self::BaseV2 => convnext::Config::base(),
Self::Large | Self::LargeV2 => convnext::Config::large(),
Self::XLarge => convnext::Config::xlarge(),
Self::Huge => convnext::Config::huge(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::Tiny)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = convnext::convnext(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -1 +0,0 @@
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));

View File

@ -31,7 +31,7 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}"); println!("loaded image {image:?}");
let model_file = match args.model { let model_file = match args.model {

View File

@ -47,7 +47,7 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}"); println!("loaded image {image:?}");
let model_file = match args.model { let model_file = match args.model {

View File

@ -1,20 +0,0 @@
# candle-efficientvit
[EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention](https://arxiv.org/abs/2305.07027).
This candle implementation uses a pre-trained EfficientViT (from Microsoft Research Asia) network for inference.
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example efficientvit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which m1
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 69.80%
unicycle, monocycle : 13.03%
bicycle-built-for-two, tandem bicycle, tandem: 9.28%
crash helmet : 2.25%
alp : 0.46%
```

View File

@ -1,99 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::efficientvit;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
M0,
M1,
M2,
M3,
M4,
M5,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::M0 => "m0",
Self::M1 => "m1",
Self::M2 => "m2",
Self::M3 => "m3",
Self::M4 => "m4",
Self::M5 => "m5",
};
format!("timm/efficientvit_{}.r224_in1k", name)
}
fn config(&self) -> efficientvit::Config {
match self {
Self::M0 => efficientvit::Config::m0(),
Self::M1 => efficientvit::Config::m1(),
Self::M2 => efficientvit::Config::m2(),
Self::M3 => efficientvit::Config::m3(),
Self::M4 => efficientvit::Config::m4(),
Self::M5 => efficientvit::Config::m5(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::M0)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = efficientvit::efficientvit(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -1,25 +0,0 @@
# candle-endocec
[EnCodec](https://huggingface.co/facebook/encodec_24khz) is a high-quality audio
compression model using an encoder/decoder architecture with residual vector
quantization.
## Running one example
```bash
cargo run --example encodec --features symphonia --release -- code-to-audio \
candle-examples/examples/encodec/jfk-codes.safetensors \
jfk.wav
```
This decodes the EnCodec tokens stored in `jfk-codes.safetensors` and generates
an output wav file containing the audio data.
Instead of `code-to-audio` one can use:
- `audio-to-audio in.mp3 out.wav`: encodes the input audio file then decodes it to a wav file.
- `audio-to-code in.mp3 out.safetensors`: generates a safetensors file
containing EnCodec tokens for the input audio file.
If the audio output file name is set to `-`, the audio content directly gets
played on default audio output device. If the audio input file is set to `-`, the audio
gets recorded from the default audio input.

View File

@ -1,275 +0,0 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::sync::{Arc, Mutex};
pub const SAMPLE_RATE: usize = 24_000;
pub(crate) struct AudioOutputData_ {
resampled_data: std::collections::VecDeque<f32>,
resampler: rubato::FastFixedIn<f32>,
output_buffer: Vec<f32>,
input_buffer: Vec<f32>,
input_len: usize,
}
impl AudioOutputData_ {
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
use rubato::Resampler;
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
let resampler = rubato::FastFixedIn::new(
resample_ratio,
f64::max(resample_ratio, 1.0),
rubato::PolynomialDegree::Septic,
1024,
1,
)?;
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
Ok(Self {
resampled_data,
resampler,
input_buffer,
output_buffer,
input_len: 0,
})
}
pub fn reset(&mut self) {
use rubato::Resampler;
self.output_buffer.fill(0.);
self.input_buffer.fill(0.);
self.resampler.reset();
self.resampled_data.clear();
}
pub(crate) fn take_all(&mut self) -> Vec<f32> {
let mut data = Vec::with_capacity(self.resampled_data.len());
while let Some(elem) = self.resampled_data.pop_back() {
data.push(elem);
}
data
}
pub(crate) fn is_empty(&self) -> bool {
self.resampled_data.is_empty()
}
// Assumes that the input buffer is large enough.
fn push_input_buffer(&mut self, samples: &[f32]) {
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
self.input_len += samples.len()
}
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
use rubato::Resampler;
let mut pos_in = 0;
loop {
let rem = self.input_buffer.len() - self.input_len;
let pos_end = usize::min(pos_in + rem, samples.len());
self.push_input_buffer(&samples[pos_in..pos_end]);
pos_in = pos_end;
if self.input_len < self.input_buffer.len() {
break;
}
let (_, out_len) = self.resampler.process_into_buffer(
&[&self.input_buffer],
&mut [&mut self.output_buffer],
None,
)?;
for &elem in self.output_buffer[..out_len].iter() {
self.resampled_data.push_front(elem)
}
self.input_len = 0;
}
Ok(())
}
}
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio output stream!");
let host = cpal::default_host();
let device = host
.default_output_device()
.context("no output device available")?;
let mut supported_configs_range = device.supported_output_configs()?;
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
// On macOS, it's commonly the case that there are only stereo outputs.
None => device
.supported_output_configs()?
.next()
.context("no audio output available")?,
Some(config_range) => config_range,
};
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
let channels = config.channels as usize;
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
SAMPLE_RATE,
config.sample_rate.0 as usize,
)?));
let ad = audio_data.clone();
let stream = device.build_output_stream(
&config,
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
data.fill(0.);
let mut ad = ad.lock().unwrap();
let mut last_elem = 0f32;
for (idx, elem) in data.iter_mut().enumerate() {
if idx % channels == 0 {
match ad.resampled_data.pop_back() {
None => break,
Some(v) => {
last_elem = v;
*elem = v
}
}
} else {
*elem = last_elem
}
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio input stream!");
let host = cpal::default_host();
let device = host
.default_input_device()
.context("no input device available")?;
let mut supported_configs_range = device.supported_input_configs()?;
let config_range = supported_configs_range
.find(|c| c.channels() == 1)
.context("no audio input available")?;
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
config.sample_rate.0 as usize,
SAMPLE_RATE,
)?));
let ad = audio_data.clone();
let stream = device.build_input_stream(
&config,
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let mut ad = ad.lock().unwrap();
if let Err(err) = ad.push_samples(data) {
eprintln!("error processing audio input {err:?}")
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result<Vec<f32>> {
use rubato::Resampler;
let mut pcm_out =
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in, sr_out, 1024, 1)?;
let mut output_buffer = resampler.output_buffer_allocate(true);
let mut pos_in = 0;
while pos_in + resampler.input_frames_next() < pcm_in.len() {
let (in_len, out_len) =
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
pos_in += in_len;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
if pos_in < pcm_in.len() {
let (_in_len, out_len) = resampler.process_partial_into_buffer(
Some(&[&pcm_in[pos_in..]]),
&mut output_buffer,
None,
)?;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
Ok(pcm_out)
}

View File

@ -1,131 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::encodec::{Config, Model};
use clap::{Parser, ValueEnum};
use hf_hub::api::sync::Api;
mod audio_io;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Action {
AudioToAudio,
AudioToCode,
CodeToAudio,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The action to be performed, specifies the format for the input and output data.
action: Action,
/// The input file, either an audio file or some encodec tokens stored as safetensors.
in_file: String,
/// The output file, either a wave audio file or some encodec tokens stored as safetensors.
out_file: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The model weight file, in safetensor format.
#[arg(long)]
model: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.model("facebook/encodec_24khz".to_string())
.get("model.safetensors")?,
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let config = Config::default();
let model = Model::new(&config, vb)?;
let codes = match args.action {
Action::CodeToAudio => {
let codes = candle::safetensors::load(args.in_file, &device)?;
codes.get("codes").expect("no codes in input file").clone()
}
Action::AudioToCode | Action::AudioToAudio => {
let pcm = if args.in_file == "-" {
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
let (stream, input_audio) = audio_io::setup_input_stream()?;
let mut pcms = vec![];
let stdin = std::thread::spawn(|| {
let mut s = String::new();
std::io::stdin().read_line(&mut s)
});
while !stdin.is_finished() {
let input = input_audio.lock().unwrap().take_all();
if input.is_empty() {
std::thread::sleep(std::time::Duration::from_millis(100));
continue;
}
pcms.push(input)
}
drop(stream);
pcms.concat()
} else {
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
if sample_rate != 24_000 {
println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}, resampling...");
audio_io::resample(&pcm, sample_rate as usize, 24_000)?
} else {
pcm
}
};
let pcm_len = pcm.len();
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
println!("input pcm shape: {:?}", pcm.shape());
model.encode(&pcm)?
}
};
println!("codes shape: {:?}", codes.shape());
match args.action {
Action::AudioToCode => {
codes.save_safetensors("codes", &args.out_file)?;
}
Action::AudioToAudio | Action::CodeToAudio => {
let pcm = model.decode(&codes)?;
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
if args.out_file == "-" {
let (stream, ad) = audio_io::setup_output_stream()?;
{
let mut ad = ad.lock().unwrap();
ad.push_samples(&pcm)?;
}
loop {
let ad = ad.lock().unwrap();
if ad.is_empty() {
break;
}
// That's very weird, calling thread::sleep here triggers the stream to stop
// playing (the callback doesn't seem to be called anymore).
// std::thread::sleep(std::time::Duration::from_millis(100));
}
drop(stream)
} else {
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
}
}
}
Ok(())
}

View File

@ -1,27 +0,0 @@
# candle-gemma: 2b and 7b LLMs from Google DeepMind
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
models published by Google Deepmind with a 2b and a 7b variant.
In order to use the example below, you have to accept the license on the
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
your access token via the [HuggingFace cli login
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
## Running the example
```bash
$ cargo run --example gemma --release -- --prompt "fn count_primes(max_n: usize)"
fn count_primes(max_n: usize) -> usize {
let mut primes = vec![true; max_n];
for i in 2..=max_n {
if primes[i] {
for j in i * i..max_n {
primes[j] = false;
}
}
}
primes.len()
}
```

View File

@ -1,289 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::gemma::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "2b")]
Base2B,
#[value(name = "7b")]
Base7B,
#[value(name = "2b-it")]
Instruct2B,
#[value(name = "7b-it")]
Instruct7B,
#[value(name = "1.1-2b-it")]
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
#[value(name = "code-2b")]
CodeBase2B,
#[value(name = "code-7b")]
CodeBase7B,
#[value(name = "code-2b-it")]
CodeInstruct2B,
#[value(name = "code-7b-it")]
CodeInstruct7B,
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "2b")]
which: Which,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
Which::Base2B => "google/gemma-2b".to_string(),
Which::Base7B => "google/gemma-7b".to_string(),
Which::Instruct2B => "google/gemma-2b-it".to_string(),
Which::Instruct7B => "google/gemma-7b-it".to_string(),
Which::CodeBase2B => "google/codegemma-2b".to_string(),
Which::CodeBase7B => "google/codegemma-7b".to_string(),
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -31,8 +31,6 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
enum Which { enum Which {
V1, V1,
V2, V2,
V3,
V3Instruct,
#[value(name = "solar-10.7b")] #[value(name = "solar-10.7b")]
Solar10_7B, Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")] #[value(name = "tiny-llama-1.1b-chat")]
@ -47,8 +45,8 @@ struct Args {
cpu: bool, cpu: bool,
/// The temperature used to generate samples. /// The temperature used to generate samples.
#[arg(long, default_value_t = 0.8)] #[arg(long)]
temperature: f64, temperature: Option<f64>,
/// Nucleus sampling probability cutoff. /// Nucleus sampling probability cutoff.
#[arg(long)] #[arg(long)]
@ -59,7 +57,7 @@ struct Args {
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 10000)] #[arg(long, default_value_t = 100)]
sample_len: usize, sample_len: usize,
/// Disable the key-value cache. /// Disable the key-value cache.
@ -92,11 +90,11 @@ struct Args {
use_flash_attn: bool, use_flash_attn: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty. /// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)] #[arg(long, default_value_t = 1.0)]
repeat_penalty: f32, repeat_penalty: f32,
/// The context size to consider for the repeat penalty. /// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 128)] #[arg(long, default_value_t = 64)]
repeat_last_n: usize, repeat_last_n: usize,
} }
@ -120,18 +118,13 @@ fn main() -> Result<()> {
Some("bf16") => DType::BF16, Some("bf16") => DType::BF16,
Some("f32") => DType::F32, Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"), Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => match args.which { None => DType::F16,
Which::V3 | Which::V3Instruct => DType::BF16,
Which::V1 | Which::V2 | Which::Solar10_7B | Which::TinyLlama1_1BChat => DType::F16,
},
}; };
let (llama, tokenizer_filename, mut cache, config) = { let (llama, tokenizer_filename, cache) = {
let api = Api::new()?; let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| match args.which { let model_id = args.model_id.unwrap_or_else(|| match args.which {
Which::V1 => "Narsil/amall-7b".to_string(), Which::V1 => "Narsil/amall-7b".to_string(),
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(), Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(), Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(), Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
}); });
@ -145,31 +138,29 @@ fn main() -> Result<()> {
let config = config.into_config(args.use_flash_attn); let config = config.into_config(args.use_flash_attn);
let filenames = match args.which { let filenames = match args.which {
Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => { Which::V1 | Which::V2 | Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
} }
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?], Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
}; };
println!("building the model");
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &config)?, tokenizer_filename, cache, config) (Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
}; };
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = config let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
.eos_token_id
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str()); let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer let mut tokens = tokenizer
.encode(prompt, true) .encode(prompt, true)
.map_err(E::msg)? .map_err(E::msg)?
.get_ids() .get_ids()
.to_vec(); .to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
println!("starting the inference loop"); println!("starting the inference loop");
print!("{prompt}"); print!("{prompt}");
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), args.top_p); let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
let mut index_pos = 0; let mut index_pos = 0;
let mut token_generated = 0; let mut token_generated = 0;
@ -181,7 +172,7 @@ fn main() -> Result<()> {
}; };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, context_index, &mut cache)?; let logits = llama.forward(&input, context_index)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. { let logits = if args.repeat_penalty == 1. {
logits logits
@ -199,16 +190,18 @@ fn main() -> Result<()> {
token_generated += 1; token_generated += 1;
tokens.push(next_token); tokens.push(next_token);
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
if Some(next_token) == eos_token_id { if Some(next_token) == eos_token_id {
break; break;
} }
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
} }
let dt = start_gen.elapsed(); let dt = start_gen.elapsed();
println!( println!(

View File

@ -19,7 +19,7 @@ use candle_transformers::generation::LogitsProcessor;
use std::io::Write; use std::io::Write;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use model::{Cache, Config, Llama}; use model::{Config, Llama};
use qmodel::QLlama; use qmodel::QLlama;
use weights::TransformerWeights; use weights::TransformerWeights;
@ -160,10 +160,10 @@ enum Model {
} }
impl Model { impl Model {
fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result<Tensor> { fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
match self { match self {
Self::Llama(l) => Ok(l.forward(xs, pos, cache)?), Self::Llama(l) => Ok(l.forward(xs, pos)?),
Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?), Self::QLlama(l) => Ok(l.forward(xs, pos)?),
} }
} }
} }
@ -188,8 +188,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
let config = Config::from_reader(&mut file)?; let config = Config::from_reader(&mut file)?;
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?; let vb = weights.var_builder(&config, &device)?;
let mut cache = Cache::new(false, &config, vb.pp("rot"))?; let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, config)?; let model = Llama::load(vb, &cache, config)?;
let tokens = match &args.pretokenized_dir { let tokens = match &args.pretokenized_dir {
None => { None => {
@ -235,7 +235,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter { for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?; let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0, &mut cache)?; let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
println!("{}", loss.to_vec0::<f32>()?); println!("{}", loss.to_vec0::<f32>()?);
} }
@ -261,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let is_safetensors = config_path let is_safetensors = config_path
.extension() .extension()
.map_or(false, |v| v == "safetensors"); .map_or(false, |v| v == "safetensors");
let (model, config, mut cache) = if is_gguf { let (model, config) = if is_gguf {
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?; let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
let (_vocab_size, dim) = vb let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")? .get_no_shape("model.embed_tokens.weight")?
@ -298,15 +298,15 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
&device, &device,
); );
let cache = model::Cache::new(true, &config, fake_vb)?; let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, config.clone())?); let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
(model, config, cache) (model, config)
} else if is_safetensors { } else if is_safetensors {
let config = Config::tiny_15m(); let config = Config::tiny_15m();
let tensors = candle::safetensors::load(config_path, &device)?; let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device); let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
let cache = model::Cache::new(true, &config, vb.pp("rot"))?; let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, config.clone())?); let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config, cache) (model, config)
} else { } else {
let mut file = std::fs::File::open(config_path)?; let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?; let config = Config::from_reader(&mut file)?;
@ -314,8 +314,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?; let vb = weights.var_builder(&config, &device)?;
let cache = model::Cache::new(true, &config, vb.pp("rot"))?; let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, config.clone())?); let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config, cache) (model, config)
}; };
println!("starting the inference loop"); println!("starting the inference loop");
@ -328,7 +328,6 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.map_err(E::msg)? .map_err(E::msg)?
.get_ids() .get_ids()
.to_vec(); .to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
for index in 0.. { for index in 0.. {
@ -338,7 +337,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let context_size = if index > 0 { 1 } else { tokens.len() }; let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, index_pos, &mut cache)?; let logits = model.forward(&input, index_pos)?;
let logits = logits.i((0, logits.dim(1)? - 1))?; let logits = logits.i((0, logits.dim(1)? - 1))?;
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() { let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
logits logits
@ -354,14 +353,16 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let next_token = logits_processor.sample(&logits)?; let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token); tokens.push(next_token);
if let Some(t) = tokenizer.next_token(next_token)? { // Extracting the last token as a string is complicated, here we just apply some simple
print!("{t}"); // heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?; std::io::stdout().flush()?;
} }
} }
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
let dt = start_gen.elapsed(); let dt = start_gen.elapsed();
println!( println!(
"\n{} tokens generated ({:.2} token/s)\n", "\n{} tokens generated ({:.2} token/s)\n",

View File

@ -8,7 +8,6 @@ fn valid_loss(
model: &Llama, model: &Llama,
args: &crate::TrainingCmd, args: &crate::TrainingCmd,
device: &Device, device: &Device,
cache: &mut Cache,
) -> Result<f64> { ) -> Result<f64> {
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone()); let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
@ -16,7 +15,7 @@ fn valid_loss(
let mut cnt = 0usize; let mut cnt = 0usize;
for inp_tgt in batch_iter.take(50) { for inp_tgt in batch_iter.take(50) {
let (inp, tgt) = inp_tgt?; let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0, cache)?; let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
sum_ce += loss.to_vec0::<f32>()? as f64; sum_ce += loss.to_vec0::<f32>()? as f64;
cnt += 1; cnt += 1;
@ -38,8 +37,8 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone()); let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
let mut cache = Cache::new(false, &config, vb.pp("rot"))?; let cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, config)?; let model = Llama::load(vb, &cache, config)?;
let params = candle_nn::ParamsAdamW { let params = candle_nn::ParamsAdamW {
lr: args.learning_rate, lr: args.learning_rate,
..Default::default() ..Default::default()
@ -47,14 +46,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?; let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
for (batch_index, batch) in batch_iter.enumerate() { for (batch_index, batch) in batch_iter.enumerate() {
let (inp, tgt) = batch?; let (inp, tgt) = batch?;
let logits = model.forward(&inp, 0, &mut cache)?; let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
opt.backward_step(&loss)?; opt.backward_step(&loss)?;
if batch_index > 0 && batch_index % 100 == 0 { if batch_index > 0 && batch_index % 100 == 0 {
// TODO: Add a way to deactivate the backprop graph tracking when computing the // TODO: Add a way to deactivate the backprop graph tracking when computing the
// validation loss. // validation loss.
let loss = valid_loss(&dataset, &model, args, &device, &mut cache)?; let loss = valid_loss(&dataset, &model, args, &device)?;
println!("{batch_index} {loss}"); println!("{batch_index} {loss}");
} }
if batch_index > 0 && batch_index % 1000 == 0 { if batch_index > 0 && batch_index % 1000 == 0 {

View File

@ -2,9 +2,6 @@
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal). This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
Compared to the mamba example, this version can handle training but is much
slower.
## Running the example ## Running the example
```bash ```bash

View File

@ -1,17 +0,0 @@
# candle-mamba: Mamba implementation
Candle implementation of *Mamba* [1] inference only. Mamba is an alternative to
the transformer architecture. It leverages State Space Models (SSMs) with the
goal of being computationally efficient on long sequences. The implementation is
based on [mamba.rs](https://github.com/LaurentMazare/mamba.rs).
- [1]. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752).
Compared to the mamba-minimal example, this version is far more efficient but
would only work for inference.
## Running the example
```bash
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
```

View File

@ -1,305 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::mamba::{Config, Model, State};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
config: Config,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
config: Config,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
config,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let dtype = self.model.dtype();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let mut state = State::new(1, &self.config, dtype, &self.device)?;
let mut next_logits = None;
for &t in tokens.iter() {
let input = Tensor::new(&[t], &self.device)?;
let logits = self.model.forward(&input, &mut state)?;
next_logits = Some(logits);
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
for _ in 0..sample_len {
let logits = match next_logits.as_ref() {
Some(logits) => logits,
None => anyhow::bail!("cannot work on an empty prompt"),
};
let logits = logits.squeeze(0)?.to_dtype(dtype)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
let input = Tensor::new(&[next_token], &self.device)?;
next_logits = Some(self.model.forward(&input, &mut state)?)
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
enum Which {
Mamba130m,
Mamba370m,
Mamba790m,
Mamba1_4b,
Mamba2_8b,
Mamba2_8bSlimPj,
}
impl std::fmt::Display for Which {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl Which {
fn model_id(&self) -> &'static str {
match self {
Self::Mamba130m => "state-spaces/mamba-130m",
Self::Mamba370m => "state-spaces/mamba-370m",
Self::Mamba790m => "state-spaces/mamba-790m",
Self::Mamba1_4b => "state-spaces/mamba-1.4b",
Self::Mamba2_8b => "state-spaces/mamba-2.8b",
Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
}
}
fn revision(&self) -> &'static str {
match self {
Self::Mamba130m
| Self::Mamba370m
| Self::Mamba790m
| Self::Mamba1_4b
| Self::Mamba2_8bSlimPj => "refs/pr/1",
Self::Mamba2_8b => "refs/pr/4",
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long, default_value = "mamba130m")]
which: Which,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long, default_value = "f32")]
dtype: String,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use std::str::FromStr;
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id
.unwrap_or_else(|| args.which.model_id().to_string()),
RepoType::Model,
args.revision
.unwrap_or_else(|| args.which.revision().to_string()),
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("EleutherAI/gpt-neox-20b".to_string())
.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
vec![repo.get("model.safetensors")?]
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
let dtype = DType::from_str(&args.dtype)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb.pp("backbone"))?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
config,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,18 +0,0 @@
# candle-metavoice
MetaVoice-1B is a text-to-speech model trained on 100K hours of speech, more
details on the [model
card](https://huggingface.co/metavoiceio/metavoice-1B-v0.1).
Note that the current candle implementation suffers from some limitations as of
2024-03-02:
- The speaker embeddings are hardcoded.
- The generated audio file quality is weaker than the Python implementation,
probably because of some implementation discrepancies.
## Run an example
```bash
cargo run --example metavoice --release -- \\
--prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
```

View File

@ -1,277 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use clap::Parser;
use std::io::Write;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::encodec;
use candle_transformers::models::metavoice::{adapters, gpt, tokenizers, transformer};
use candle_transformers::models::quantized_metavoice::transformer as qtransformer;
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use hf_hub::api::sync::Api;
use rand::{distributions::Distribution, SeedableRng};
pub const ENCODEC_NTOKENS: u32 = 1024;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum ArgDType {
F32,
F16,
Bf16,
}
enum Transformer {
Normal(transformer::Model),
Quantized(qtransformer::Model),
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// Use the quantized version of the model.
#[arg(long)]
quantized: bool,
/// The guidance scale.
#[arg(long, default_value_t = 3.0)]
guidance_scale: f64,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 1.0)]
temperature: f64,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The maximum number of tokens to generate for the first stage.
#[arg(long, default_value_t = 2000)]
max_tokens: u64,
/// The output file using the wav format.
#[arg(long, default_value = "out.wav")]
out_file: String,
#[arg(long)]
first_stage_meta: Option<String>,
#[arg(long)]
first_stage_weights: Option<String>,
#[arg(long)]
second_stage_weights: Option<String>,
#[arg(long)]
encodec_weights: Option<String>,
#[arg(long)]
spk_emb: Option<String>,
#[arg(long, default_value = "f32")]
dtype: ArgDType,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let device = candle_examples::device(args.cpu)?;
let api = Api::new()?;
let repo = api.model("lmz/candle-metavoice".to_string());
let first_stage_meta = match &args.first_stage_meta {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("first_stage.meta.json")?,
};
let first_stage_meta: serde_json::Value =
serde_json::from_reader(&std::fs::File::open(first_stage_meta)?)?;
let first_stage_tokenizer = match first_stage_meta.as_object() {
None => anyhow::bail!("not a json object"),
Some(j) => match j.get("tokenizer") {
None => anyhow::bail!("no tokenizer key"),
Some(j) => j,
},
};
let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?;
let second_stage_weights = match &args.second_stage_weights {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("second_stage.safetensors")?,
};
let encodec_weights = match args.encodec_weights {
Some(w) => std::path::PathBuf::from(w),
None => Api::new()?
.model("facebook/encodec_24khz".to_string())
.get("model.safetensors")?,
};
let dtype = match args.dtype {
ArgDType::F32 => DType::F32,
ArgDType::F16 => DType::F16,
ArgDType::Bf16 => DType::BF16,
};
let first_stage_config = transformer::Config::cfg1b_v0_1();
let mut first_stage_model = if args.quantized {
let filename = match &args.first_stage_weights {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("first_stage_q4k.gguf")?,
};
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let first_stage_model = qtransformer::Model::new(&first_stage_config, vb)?;
Transformer::Quantized(first_stage_model)
} else {
let first_stage_weights = match &args.first_stage_weights {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("first_stage.safetensors")?,
};
let first_stage_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
let first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
Transformer::Normal(first_stage_model)
};
let second_stage_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? };
let second_stage_config = gpt::Config::cfg1b_v0_1();
let second_stage_model = gpt::Model::new(second_stage_config.clone(), second_stage_vb)?;
let encodec_device = if device.is_metal() {
&candle::Device::Cpu
} else {
&device
};
let encodec_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[encodec_weights], dtype, encodec_device)? };
let encodec_config = encodec::Config::default();
let encodec_model = encodec::Model::new(&encodec_config, encodec_vb)?;
println!("prompt: '{}'", args.prompt);
let prompt_tokens = fs_tokenizer.encode(&args.prompt)?;
let mut tokens = prompt_tokens.clone();
println!("{tokens:?}");
let spk_emb_file = match &args.spk_emb {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("spk_emb.safetensors")?,
};
let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?;
let spk_emb = match spk_emb.get("spk_emb") {
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
Some(spk_emb) => spk_emb.to_dtype(dtype)?,
};
let spk_emb = spk_emb.to_device(&device)?;
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95));
// First stage generation.
for index in 0..args.max_tokens {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &device)?;
let input = Tensor::stack(&[&input, &input], 0)?;
let logits = match &mut first_stage_model {
Transformer::Normal(m) => m.forward(&input, &spk_emb, tokens.len() - context_size)?,
Transformer::Quantized(m) => {
m.forward(&input, &spk_emb, tokens.len() - context_size)?
}
};
let logits0 = logits.i((0, 0))?;
let logits1 = logits.i((1, 0))?;
let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?;
let logits = logits.to_dtype(DType::F32)?;
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
print!(".");
std::io::stdout().flush()?;
if next_token == 2048 {
break;
}
}
println!();
let fie2c = adapters::FlattenedInterleavedEncodec2Codebook::new(ENCODEC_NTOKENS);
let (text_ids, ids1, ids2) = fie2c.decode(&tokens);
println!("text ids len: {}", text_ids.len());
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed + 1337);
// TODO: Use the config rather than hardcoding the offset here.
let encoded_text: Vec<_> = prompt_tokens.iter().map(|v| v - 1024).collect();
let mut hierarchies_in1 =
[encoded_text.as_slice(), ids1.as_slice(), &[ENCODEC_NTOKENS]].concat();
let mut hierarchies_in2 = [
vec![ENCODEC_NTOKENS; encoded_text.len()].as_slice(),
ids2.as_slice(),
&[ENCODEC_NTOKENS],
]
.concat();
hierarchies_in1.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
hierarchies_in2.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
let in_x1 = Tensor::new(hierarchies_in1, &device)?;
let in_x2 = Tensor::new(hierarchies_in2, &device)?;
let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?;
let logits = second_stage_model.forward(&in_x)?;
println!("sampling from logits...");
let mut codes = vec![];
for logits in logits.iter() {
let logits = logits.squeeze(0)?;
let (seq_len, _) = logits.dims2()?;
let mut codes_ = Vec::with_capacity(seq_len);
for step in 0..seq_len {
let logits = logits.i(step)?.to_dtype(DType::F32)?;
let logits = &(&logits / 1.0)?;
let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::<f32>()?;
let distr = rand::distributions::WeightedIndex::new(prs.as_slice())?;
let sample = distr.sample(&mut rng) as u32;
codes_.push(sample)
}
codes.push(codes_)
}
let codes = Tensor::new(codes, &device)?.unsqueeze(0)?;
let codes = Tensor::cat(&[in_x, codes], 1)?;
println!("codes: {codes}");
let tilted_encodec = adapters::TiltedEncodec::new(ENCODEC_NTOKENS);
let codes = codes.i(0)?.to_vec2::<u32>()?;
let (text_ids, audio_ids) = tilted_encodec.decode(&codes);
println!("text_ids len: {:?}", text_ids.len());
let audio_ids = Tensor::new(audio_ids, encodec_device)?.unsqueeze(0)?;
println!("audio_ids shape: {:?}", audio_ids.shape());
let pcm = encodec_model.decode(&audio_ids)?;
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
Ok(())
}

View File

@ -13,7 +13,7 @@ use candle_transformers::models::quantized_mistral::Model as QMistral;
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream; use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling}; use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
@ -39,26 +39,11 @@ impl TextGeneration {
seed: u64, seed: u64,
temp: Option<f64>, temp: Option<f64>,
top_p: Option<f64>, top_p: Option<f64>,
top_k: Option<usize>,
repeat_penalty: f32, repeat_penalty: f32,
repeat_last_n: usize, repeat_last_n: usize,
device: &Device, device: &Device,
) -> Self { ) -> Self {
let logits_processor = { let logits_processor = LogitsProcessor::new(seed, temp, top_p);
let temperature = temp.unwrap_or(0.);
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (top_k, top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
Self { Self {
model, model,
tokenizer: TokenOutputStream::new(tokenizer), tokenizer: TokenOutputStream::new(tokenizer),
@ -137,18 +122,6 @@ impl TextGeneration {
} }
} }
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "7b-v0.1")]
Mistral7bV01,
#[value(name = "7b-v0.2")]
Mistral7bV02,
#[value(name = "7b-instruct-v0.1")]
Mistral7bInstructV01,
#[value(name = "7b-instruct-v0.2")]
Mistral7bInstructV02,
}
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
struct Args { struct Args {
@ -174,22 +147,14 @@ struct Args {
#[arg(long)] #[arg(long)]
top_p: Option<f64>, top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)] #[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize, sample_len: usize,
/// The model size to use.
#[arg(long, default_value = "7b-v0.1")]
which: Which,
#[arg(long)] #[arg(long)]
model_id: Option<String>, model_id: Option<String>,
@ -199,9 +164,6 @@ struct Args {
#[arg(long)] #[arg(long)]
tokenizer_file: Option<String>, tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)] #[arg(long)]
weight_files: Option<String>, weight_files: Option<String>,
@ -215,10 +177,6 @@ struct Args {
/// The context size to consider for the repeat penalty. /// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)] #[arg(long, default_value_t = 64)]
repeat_last_n: usize, repeat_last_n: usize,
/// Use the slower dmmv cuda kernel.
#[arg(long)]
force_dmmv: bool,
} }
fn main() -> Result<()> { fn main() -> Result<()> {
@ -226,9 +184,6 @@ fn main() -> Result<()> {
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
let args = Args::parse(); let args = Args::parse();
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
let _guard = if args.tracing { let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init(); tracing_subscriber::registry().with(chrome_layer).init();
@ -256,17 +211,9 @@ fn main() -> Result<()> {
Some(model_id) => model_id, Some(model_id) => model_id,
None => { None => {
if args.quantized { if args.quantized {
if args.which != Which::Mistral7bV01 {
anyhow::bail!("only 7b-v0.1 is available as a quantized model for now")
}
"lmz/candle-mistral".to_string() "lmz/candle-mistral".to_string()
} else { } else {
match args.which { "mistralai/Mistral-7B-v0.1".to_string()
Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1".to_string(),
Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2".to_string(),
Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1".to_string(),
Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2".to_string(),
}
} }
} }
}; };
@ -296,17 +243,7 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config = match args.config_file { let config = Config::config_7b_v0_1(args.use_flash_attn);
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
if args.quantized {
Config::config_7b_v0_1(args.use_flash_attn)
} else {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
}
};
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized { let (model, device) = if args.quantized {
let filename = &filenames[0]; let filename = &filenames[0];
@ -333,7 +270,6 @@ fn main() -> Result<()> {
args.seed, args.seed,
args.temperature, args.temperature,
args.top_p, args.top_p,
args.top_k,
args.repeat_penalty, args.repeat_penalty,
args.repeat_last_n, args.repeat_last_n,
&device, &device,

View File

@ -143,7 +143,7 @@ struct Args {
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)] #[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize, sample_len: usize,
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")] #[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]

View File

@ -1,22 +0,0 @@
# candle-mobileone
[MobileOne: An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040).
This candle implementation uses a pre-trained MobileOne network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example mobileone --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which s2
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 79.33%
bicycle-built-for-two, tandem bicycle, tandem: 15.32%
crash helmet : 2.58%
unicycle, monocycle : 1.70%
alp : 0.21%
```

View File

@ -1,96 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::mobileone;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
S0,
S1,
S2,
S3,
S4,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::S0 => "s0",
Self::S1 => "s1",
Self::S2 => "s2",
Self::S3 => "s3",
Self::S4 => "s4",
};
format!("timm/mobileone_{}.apple_in1k", name)
}
fn config(&self) -> mobileone::Config {
match self {
Self::S0 => mobileone::Config::s0(),
Self::S1 => mobileone::Config::s1(),
Self::S2 => mobileone::Config::s2(),
Self::S3 => mobileone::Config::s3(),
Self::S4 => mobileone::Config::s4(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::S0)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = mobileone::mobileone(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -1,26 +0,0 @@
# candle-moondream
[Moondream](https://github.com/vikhyat/moondream) is a computer-vision model can answer real-world questions about images. It's tiny by today's models, with only 1.6B parameters. That enables it to run on a variety of devices, including mobile phones and edge devices.
## Running some examples
First download an example image
```bash
$ wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg
```
<img src="https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg" width="200">
Now you can run Moondream from the `candle-examples` crate:
```bash
$ cargo run --example moondream --release -- --prompt "What is the girl eating?" --image "./demo-1.jpg"
avavx: false, neon: true, simd128: false, f16c: false
temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64
retrieved the files in 3.395583ms
Running on CPU, to run on GPU(metal), build this example with `--features metal`
loaded the model in 5.485493792s
loaded and encoded the image Tensor[dims 3, 378, 378; f32] in 4.801396417s
starting the inference loop
The girl is eating a hamburger.<
9 tokens generated (0.68 token/s)
```

View File

@ -1,343 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
generation::LogitsProcessor,
models::{moondream, quantized_moondream},
};
use tokenizers::Tokenizer;
enum Model {
Moondream(moondream::Model),
Quantized(quantized_moondream::Model),
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, image_embeds: &Tensor, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
if tokens.is_empty() {
anyhow::bail!("Empty prompts are not supported in the Moondream model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
// Moondream tokenizer bos_token and eos_token is "<|endoftext|>"
// https://huggingface.co/vikhyatk/moondream2/blob/main/special_tokens_map.json
let special_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the special token"),
};
let (bos_token, eos_token) = (special_token, special_token);
let start_gen = std::time::Instant::now();
let mut load_t = std::time::Duration::from_secs_f64(0f64);
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = if index > 0 {
match self.model {
Model::Moondream(ref mut model) => model.text_model.forward(&input)?,
Model::Quantized(ref mut model) => model.text_model.forward(&input)?,
}
} else {
let bos_token = Tensor::new(&[bos_token], &self.device)?.unsqueeze(0)?;
let logits = match self.model {
Model::Moondream(ref mut model) => {
model
.text_model
.forward_with_img(&bos_token, &input, image_embeds)?
}
Model::Quantized(ref mut model) => {
model
.text_model
.forward_with_img(&bos_token, &input, image_embeds)?
}
};
load_t = start_gen.elapsed();
println!("load_t: {:?}", load_t);
logits
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || tokens.ends_with(&[27, 10619, 29] /* <END> */) {
break;
}
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed() - load_t;
println!(
"\ngenerated in {} seconds\n{generated_tokens} tokens generated ({:.2} token/s)",
dt.as_secs_f64(),
(generated_tokens - 1) as f64 / dt.as_secs_f64()
);
Ok(())
}
}
#[derive(Parser)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long)]
prompt: String,
#[arg(long)]
image: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 0)]
seed: u64,
#[arg(long, default_value_t = 5000)]
sample_len: usize,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
quantized: bool,
/// Use f16 precision for all the computations rather than f32.
#[arg(long)]
f16: bool,
#[arg(long)]
model_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
}
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 378, 378).
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> candle::Result<Tensor> {
let img = image::io::Reader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(378, 378, image::imageops::FilterType::Triangle); // Adjusted to 378x378
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (378, 378, 3), &Device::Cpu)?.permute((2, 0, 1))?;
let mean = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
(data.to_dtype(candle::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = hf_hub::api::tokio::Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => {
if args.quantized {
"santiagomed/candle-moondream".to_string()
} else {
"vikhyatk/moondream2".to_string()
}
}
};
let repo = api.repo(hf_hub::Repo::with_revision(
model_id,
hf_hub::RepoType::Model,
args.revision,
));
let model_file = match args.model_file {
Some(m) => m.into(),
None => {
if args.quantized {
repo.get("model-q4_0.gguf").await?
} else {
repo.get("model.safetensors").await?
}
}
};
let tokenizer = match args.tokenizer_file {
Some(m) => m.into(),
None => repo.get("tokenizer.json").await?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let config = moondream::Config::v2();
let dtype = if args.quantized {
if args.f16 {
anyhow::bail!("Quantized model does not support f16");
}
DType::F32
} else if device.is_cuda() || args.f16 {
DType::F16
} else {
DType::F32
};
let model = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&model_file,
&device,
)?;
let model = quantized_moondream::Model::new(&config, vb)?;
Model::Quantized(model)
} else {
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let model = moondream::Model::new(&config, vb)?;
Model::Moondream(model)
};
println!("loaded the model in {:?}", start.elapsed());
let start = std::time::Instant::now();
let image = load_image(args.image)?
.to_device(&device)?
.to_dtype(dtype)?;
let image_embeds = image.unsqueeze(0)?;
let image_embeds = match model {
Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?,
Model::Quantized(ref m) => image_embeds.apply(m.vision_encoder())?,
};
println!(
"loaded and encoded the image {image:?} in {:?}",
start.elapsed()
);
let prompt = format!("\n\nQuestion: {0}\n\nAnswer:", args.prompt);
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
);
pipeline.run(&prompt, &image_embeds, args.sample_len)?;
Ok(())
}

View File

@ -0,0 +1,580 @@
use crate::nn::conv1d_weight_norm;
use candle::{DType, IndexOp, Module, Result, Tensor};
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
// Encodec Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
#[derive(Debug, Clone, PartialEq)]
enum NormType {
WeightNorm,
TimeGroupNorm,
None,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Config {
target_bandwidths: Vec<f64>,
sampling_rate: usize,
audio_channels: usize,
normalize: bool,
chunk_length_s: Option<usize>,
overlap: Option<usize>,
hidden_size: usize,
num_filters: usize,
num_residual_layers: usize,
upsampling_ratios: Vec<usize>,
norm_type: NormType,
kernel_size: usize,
last_kernel_size: usize,
residual_kernel_size: usize,
dilation_growth_rate: usize,
use_causal_conv: bool,
pad_mode: &'static str,
compress: usize,
num_lstm_layers: usize,
trim_right_ratio: f64,
codebook_size: usize,
codebook_dim: Option<usize>,
use_conv_shortcut: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
sampling_rate: 24_000,
audio_channels: 1,
normalize: false,
chunk_length_s: None,
overlap: None,
hidden_size: 128,
num_filters: 32,
num_residual_layers: 1,
upsampling_ratios: vec![8, 5, 4, 2],
norm_type: NormType::WeightNorm,
kernel_size: 7,
last_kernel_size: 7,
residual_kernel_size: 3,
dilation_growth_rate: 2,
use_causal_conv: true,
pad_mode: "reflect",
compress: 2,
num_lstm_layers: 2,
trim_right_ratio: 1.0,
codebook_size: 1024,
codebook_dim: None,
use_conv_shortcut: true,
}
}
}
impl Config {
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
pub fn musicgen_small() -> Self {
Self {
audio_channels: 1,
chunk_length_s: None,
codebook_dim: Some(128),
codebook_size: 2048,
compress: 2,
dilation_growth_rate: 2,
hidden_size: 128,
kernel_size: 7,
last_kernel_size: 7,
norm_type: NormType::WeightNorm,
normalize: false,
num_filters: 64,
num_lstm_layers: 2,
num_residual_layers: 1,
overlap: None,
pad_mode: "reflect",
residual_kernel_size: 3,
sampling_rate: 32_000,
target_bandwidths: vec![2.2],
trim_right_ratio: 1.0,
upsampling_ratios: vec![8, 5, 4, 4],
use_causal_conv: false,
use_conv_shortcut: false,
}
}
fn codebook_dim(&self) -> usize {
self.codebook_dim.unwrap_or(self.codebook_size)
}
fn frame_rate(&self) -> usize {
let hop_length: usize = self.upsampling_ratios.iter().product();
(self.sampling_rate + hop_length - 1) / hop_length
}
fn num_quantizers(&self) -> usize {
let num = 1000f64
* self
.target_bandwidths
.last()
.expect("empty target_bandwidths");
(num as usize) / (self.frame_rate() * 10)
}
}
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340
#[derive(Debug)]
struct EncodecEuclideanCodebook {
inited: Tensor,
cluster_size: Tensor,
embed: Tensor,
embed_avg: Tensor,
}
impl EncodecEuclideanCodebook {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let inited = vb.get(1, "inited")?;
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
let embed = vb.get(e_shape, "embed")?;
let embed_avg = vb.get(e_shape, "embed_avg")?;
Ok(Self {
inited,
cluster_size,
embed,
embed_avg,
})
}
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = self.embed.embedding(embed_ind)?;
Ok(quantize)
}
}
#[derive(Debug)]
struct EncodecVectorQuantization {
codebook: EncodecEuclideanCodebook,
}
impl EncodecVectorQuantization {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?;
Ok(Self { codebook })
}
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = self.codebook.decode(embed_ind)?;
let quantize = quantize.transpose(1, 2)?;
Ok(quantize)
}
}
#[derive(Debug)]
struct EncodecResidualVectorQuantizer {
layers: Vec<EncodecVectorQuantization>,
}
impl EncodecResidualVectorQuantizer {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let vb = &vb.pp("layers");
let layers = (0..cfg.num_quantizers())
.map(|i| EncodecVectorQuantization::load(vb.pp(&i.to_string()), cfg))
.collect::<Result<Vec<_>>>()?;
Ok(Self { layers })
}
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
if codes.dim(0)? != self.layers.len() {
candle::bail!(
"codes shape {:?} does not match the number of quantization layers {}",
codes.shape(),
self.layers.len()
)
}
for (i, layer) in self.layers.iter().enumerate() {
let quantized = layer.decode(&codes.i(i)?)?;
quantized_out = quantized.broadcast_add(&quantized_out)?;
}
Ok(quantized_out)
}
}
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
#[derive(Debug)]
struct EncodecLSTM {
layers: Vec<candle_nn::LSTM>,
}
impl EncodecLSTM {
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let vb = &vb.pp("lstm");
let mut layers = vec![];
for layer_idx in 0..cfg.num_lstm_layers {
let config = candle_nn::LSTMConfig {
layer_idx,
..Default::default()
};
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
layers.push(lstm)
}
Ok(Self { layers })
}
}
impl Module for EncodecLSTM {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
use candle_nn::RNN;
let mut xs = xs.clone();
for layer in self.layers.iter() {
let states = layer.seq(&xs)?;
xs = layer.states_to_tensor(&states)?;
}
Ok(xs)
}
}
#[derive(Debug)]
struct EncodecConvTranspose1d {
weight_g: Tensor,
weight_v: Tensor,
bias: Tensor,
}
impl EncodecConvTranspose1d {
fn load(
in_c: usize,
out_c: usize,
k: usize,
_stride: usize,
vb: VarBuilder,
_cfg: &Config,
) -> Result<Self> {
let vb = &vb.pp("conv");
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
let weight_v = vb.get((in_c, out_c, k), "weight_v")?;
let bias = vb.get(out_c, "bias")?;
Ok(Self {
weight_g,
weight_v,
bias,
})
}
}
impl Module for EncodecConvTranspose1d {
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
}
#[derive(Debug)]
struct EncodecConv1d {
causal: bool,
conv: Conv1d,
norm: Option<candle_nn::GroupNorm>,
}
impl EncodecConv1d {
fn load(
in_c: usize,
out_c: usize,
kernel_size: usize,
stride: usize,
vb: VarBuilder,
cfg: &Config,
) -> Result<Self> {
let conv = match cfg.norm_type {
NormType::WeightNorm => conv1d_weight_norm(
in_c,
out_c,
kernel_size,
Conv1dConfig {
padding: 0,
stride,
groups: 1,
dilation: 1,
},
vb.pp("conv"),
)?,
NormType::None | NormType::TimeGroupNorm => conv1d(
in_c,
out_c,
kernel_size,
Conv1dConfig {
padding: 0,
stride,
groups: 1,
dilation: 1,
},
vb.pp("conv"),
)?,
};
let norm = match cfg.norm_type {
NormType::None | NormType::WeightNorm => None,
NormType::TimeGroupNorm => {
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(gn)
}
};
Ok(Self {
causal: cfg.use_causal_conv,
conv,
norm,
})
}
}
impl Module for EncodecConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: padding, depending on causal.
let xs = self.conv.forward(xs)?;
match &self.norm {
None => Ok(xs),
Some(norm) => xs.apply(norm),
}
}
}
#[derive(Debug)]
struct EncodecResnetBlock {
block_conv1: EncodecConv1d,
block_conv2: EncodecConv1d,
shortcut: Option<EncodecConv1d>,
}
impl EncodecResnetBlock {
fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h = dim / cfg.compress;
let mut layer = Layer::new(vb.pp("block"));
if dilations.len() != 2 {
candle::bail!("expected dilations of size 2")
}
// TODO: Apply dilations!
layer.inc();
let block_conv1 =
EncodecConv1d::load(dim, h, cfg.residual_kernel_size, 1, layer.next(), cfg)?;
layer.inc();
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, layer.next(), cfg)?;
let shortcut = if cfg.use_conv_shortcut {
let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?;
Some(conv)
} else {
None
};
Ok(Self {
block_conv1,
block_conv2,
shortcut,
})
}
}
impl Module for EncodecResnetBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs.clone();
let xs = xs.elu(1.)?;
let xs = self.block_conv1.forward(&xs)?;
let xs = xs.elu(1.)?;
let xs = self.block_conv2.forward(&xs)?;
let xs = match &self.shortcut {
None => (xs + residual)?,
Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
};
Ok(xs)
}
}
struct Layer<'a> {
vb: VarBuilder<'a>,
cnt: usize,
}
impl<'a> Layer<'a> {
fn new(vb: VarBuilder<'a>) -> Self {
Self { vb, cnt: 0 }
}
fn inc(&mut self) {
self.cnt += 1;
}
fn next(&mut self) -> VarBuilder {
let vb = self.vb.pp(&self.cnt.to_string());
self.cnt += 1;
vb
}
}
#[derive(Debug)]
struct EncodecEncoder {
init_conv: EncodecConv1d,
sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
final_lstm: EncodecLSTM,
final_conv: EncodecConv1d,
}
impl EncodecEncoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let mut layer = Layer::new(vb.pp("layers"));
let init_conv = EncodecConv1d::load(
cfg.audio_channels,
cfg.num_filters,
cfg.kernel_size,
1,
layer.next(),
cfg,
)?;
let mut sampling_layers = vec![];
let mut scaling = 1;
for &ratio in cfg.upsampling_ratios.iter().rev() {
let current_scale = scaling * cfg.num_filters;
let mut resnets = vec![];
for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::load(
current_scale,
&[cfg.dilation_growth_rate.pow(j), 1],
layer.next(),
cfg,
)?;
resnets.push(resnet)
}
layer.inc(); // ELU
let conv1d = EncodecConv1d::load(
current_scale,
current_scale * 2,
ratio * 2,
ratio,
layer.next(),
cfg,
)?;
sampling_layers.push((resnets, conv1d));
scaling *= 2;
}
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
layer.inc(); // ELU
let final_conv = EncodecConv1d::load(
cfg.num_filters * scaling,
cfg.hidden_size,
cfg.last_kernel_size,
1,
layer.next(),
cfg,
)?;
Ok(Self {
init_conv,
sampling_layers,
final_conv,
final_lstm,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.init_conv)?;
for (resnets, conv) in self.sampling_layers.iter() {
for resnet in resnets.iter() {
xs = xs.apply(resnet)?;
}
xs = xs.elu(1.0)?.apply(conv)?;
}
xs.apply(&self.final_lstm)?
.elu(1.0)?
.apply(&self.final_conv)
}
}
#[derive(Debug)]
struct EncodecDecoder {
init_conv: EncodecConv1d,
init_lstm: EncodecLSTM,
sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
final_conv: EncodecConv1d,
}
impl EncodecDecoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let mut layer = Layer::new(vb.pp("layers"));
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
let init_conv = EncodecConv1d::load(
cfg.hidden_size,
cfg.num_filters * scaling,
cfg.last_kernel_size,
1,
layer.next(),
cfg,
)?;
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
let mut sampling_layers = vec![];
for &ratio in cfg.upsampling_ratios.iter() {
let current_scale = scaling * cfg.num_filters;
layer.inc(); // ELU
let conv1d = EncodecConvTranspose1d::load(
current_scale,
current_scale / 2,
ratio * 2,
ratio,
layer.next(),
cfg,
)?;
let mut resnets = vec![];
for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::load(
current_scale / 2,
&[cfg.dilation_growth_rate.pow(j), 1],
layer.next(),
cfg,
)?;
resnets.push(resnet)
}
sampling_layers.push((conv1d, resnets));
scaling /= 2;
}
layer.inc(); // ELU
let final_conv = EncodecConv1d::load(
cfg.num_filters,
cfg.audio_channels,
cfg.last_kernel_size,
1,
layer.next(),
cfg,
)?;
Ok(Self {
init_conv,
init_lstm,
sampling_layers,
final_conv,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
for (conv, resnets) in self.sampling_layers.iter() {
xs = xs.elu(1.)?.apply(conv)?;
for resnet in resnets.iter() {
xs = xs.apply(resnet)?
}
}
xs.elu(1.)?.apply(&self.final_conv)
}
}
#[derive(Debug)]
pub struct EncodecModel {
encoder: EncodecEncoder,
decoder: EncodecDecoder,
quantizer: EncodecResidualVectorQuantizer,
}
impl EncodecModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let encoder = EncodecEncoder::load(vb.pp("encoder"), cfg)?;
let decoder = EncodecDecoder::load(vb.pp("decoder"), cfg)?;
let quantizer = EncodecResidualVectorQuantizer::load(vb.pp("quantizer"), cfg)?;
Ok(Self {
encoder,
decoder,
quantizer,
})
}
pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
}

View File

@ -10,7 +10,9 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")] #[cfg(feature = "accelerate")]
extern crate accelerate_src; extern crate accelerate_src;
mod encodec_model;
mod musicgen_model; mod musicgen_model;
mod nn;
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration}; use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};

View File

@ -1,9 +1,10 @@
use crate::encodec_model;
use candle::{DType, Device, Result, Tensor, D}; use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{ use candle_nn::{
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module, embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
VarBuilder, VarBuilder,
}; };
use candle_transformers::models::{encodec, t5}; use candle_transformers::models::t5;
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83 // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@ -371,7 +372,7 @@ impl MusicgenForCausalLM {
#[derive(Debug)] #[derive(Debug)]
pub struct MusicgenForConditionalGeneration { pub struct MusicgenForConditionalGeneration {
pub text_encoder: t5::T5EncoderModel, pub text_encoder: t5::T5EncoderModel,
pub audio_encoder: encodec::Model, pub audio_encoder: crate::encodec_model::EncodecModel,
pub decoder: MusicgenForCausalLM, pub decoder: MusicgenForCausalLM,
cfg: GenConfig, cfg: GenConfig,
} }
@ -380,42 +381,15 @@ pub struct MusicgenForConditionalGeneration {
pub struct GenConfig { pub struct GenConfig {
musicgen: Config, musicgen: Config,
t5: t5::Config, t5: t5::Config,
encodec: encodec::Config, encodec: crate::encodec_model::Config,
} }
impl GenConfig { impl GenConfig {
pub fn small() -> Self { pub fn small() -> Self {
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
let encodec = encodec::Config {
audio_channels: 1,
chunk_length_s: None,
codebook_dim: Some(128),
codebook_size: 2048,
compress: 2,
dilation_growth_rate: 2,
hidden_size: 128,
kernel_size: 7,
last_kernel_size: 7,
norm_type: encodec::NormType::WeightNorm,
normalize: false,
num_filters: 64,
num_lstm_layers: 2,
num_residual_layers: 1,
overlap: None,
// This should be Reflect and not Replicate but Reflect does not work yet.
pad_mode: encodec::PadMode::Replicate,
residual_kernel_size: 3,
sampling_rate: 32_000,
target_bandwidths: vec![2.2],
trim_right_ratio: 1.0,
upsampling_ratios: vec![8, 5, 4, 4],
use_causal_conv: false,
use_conv_shortcut: false,
};
Self { Self {
musicgen: Config::musicgen_small(), musicgen: Config::musicgen_small(),
t5: t5::Config::musicgen_small(), t5: t5::Config::musicgen_small(),
encodec, encodec: encodec_model::Config::musicgen_small(),
} }
} }
} }
@ -427,7 +401,8 @@ impl MusicgenForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> { pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?; let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
let audio_encoder = encodec::Model::new(&cfg.encodec, vb.pp("audio_encoder"))?; let audio_encoder =
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?; let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
Ok(Self { Ok(Self {
text_encoder, text_encoder,

View File

@ -0,0 +1,20 @@
use candle::Result;
use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
// Applies weight norm for inference by recomputing the weight tensor. This
// does not apply to training.
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
pub fn conv1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = vb.get(out_c, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
}

View File

@ -1,39 +1,10 @@
## Using ONNX models in Candle ## Using ONNX models in Candle
This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based models in Candle. This example demonstrates how to run ONNX based models in Candle, the model
being used here is a small sequeezenet variant.
It contains small variants of two models, [SqueezeNet](https://arxiv.org/pdf/1602.07360.pdf) (default) and [EfficientNet](https://arxiv.org/pdf/1905.11946.pdf). You can run the example with the following command:
You can run the examples with following commands:
```bash ```bash
cargo run --example onnx --features=onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
```
Use the `--which` flag to specify explicitly which network to use, i.e.
```bash
$ cargo run --example onnx --features=onnx --release -- --which squeeze-net --image candle-examples/examples/yolo-v8/assets/bike.jpg
Finished release [optimized] target(s) in 0.21s
Running `target/release/examples/onnx --which squeeze-net --image candle-examples/examples/yolo-v8/assets/bike.jpg`
loaded image Tensor[dims 3, 224, 224; f32]
unicycle, monocycle : 83.23%
ballplayer, baseball player : 3.68%
bearskin, busby, shako : 1.54%
military uniform : 0.78%
cowboy hat, ten-gallon hat : 0.76%
```
```bash
$ cargo run --example onnx --features=onnx --release -- --which efficient-net --image candle-examples/examples/yolo-v8/assets/bike.jpg
Finished release [optimized] target(s) in 0.20s
Running `target/release/examples/onnx --which efficient-net --image candle-examples/examples/yolo-v8/assets/bike.jpg`
loaded image Tensor[dims 224, 224, 3; f32]
bicycle-built-for-two, tandem bicycle, tandem : 99.16%
mountain bike, all-terrain bike, off-roader : 0.60%
unicycle, monocycle : 0.17%
crash helmet : 0.02%
alp : 0.02%
``` ```

View File

@ -17,7 +17,7 @@ generate quantized weight files from the original safetensors file by using the
`tensor-tools` command line utility via: `tensor-tools` command line utility via:
```bash ```bash
$ cargo run --bin tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf $ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
``` ```
## Using custom models ## Using custom models

View File

@ -10,7 +10,7 @@ use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file}; use candle::quantized::{ggml_file, gguf_file};
use candle::Tensor; use candle::Tensor;
use candle_transformers::generation::{LogitsProcessor, Sampling}; use candle_transformers::generation::LogitsProcessor;
use candle_examples::token_output_stream::TokenOutputStream; use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_llama as model; use candle_transformers::models::quantized_llama as model;
@ -67,8 +67,6 @@ enum Which {
Mixtral, Mixtral,
#[value(name = "mixtral-instruct")] #[value(name = "mixtral-instruct")]
MixtralInstruct, MixtralInstruct,
#[value(name = "llama3-8b")]
L8b,
} }
impl Which { impl Which {
@ -84,8 +82,7 @@ impl Which {
| Self::L13bCode | Self::L13bCode
| Self::L34bCode | Self::L34bCode
| Self::Leo7b | Self::Leo7b
| Self::Leo13b | Self::Leo13b => false,
| Self::L8b => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the // Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
// same way. Starling is a fine tuned version of OpenChat. // same way. Starling is a fine tuned version of OpenChat.
Self::OpenChat35 Self::OpenChat35
@ -119,8 +116,7 @@ impl Which {
| Self::Mistral7bInstruct | Self::Mistral7bInstruct
| Self::Mistral7bInstructV02 | Self::Mistral7bInstructV02
| Self::OpenChat35 | Self::OpenChat35
| Self::Starling7bAlpha | Self::Starling7bAlpha => false,
| Self::L8b => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true, Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
} }
} }
@ -144,8 +140,7 @@ impl Which {
| Self::Mistral7bInstruct | Self::Mistral7bInstruct
| Self::Mistral7bInstructV02 | Self::Mistral7bInstructV02
| Self::Zephyr7bAlpha | Self::Zephyr7bAlpha
| Self::Zephyr7bBeta | Self::Zephyr7bBeta => false,
| Self::L8b => false,
Self::OpenChat35 | Self::Starling7bAlpha => true, Self::OpenChat35 | Self::Starling7bAlpha => true,
} }
} }
@ -172,7 +167,6 @@ impl Which {
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1", | Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
Which::OpenChat35 => "openchat/openchat_3.5", Which::OpenChat35 => "openchat/openchat_3.5",
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha", Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
Self::L8b => "meta-llama/Meta-Llama-3-8B",
} }
} }
} }
@ -206,10 +200,6 @@ struct Args {
#[arg(long)] #[arg(long)]
top_p: Option<f64>, top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
@ -222,14 +212,6 @@ struct Args {
#[arg(long)] #[arg(long)]
verbose_prompt: bool, verbose_prompt: bool,
/// Process prompt elements separately.
#[arg(long)]
split_prompt: bool,
/// Run on CPU rather than GPU even if a GPU is available.
#[arg(long)]
cpu: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty. /// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)] #[arg(long, default_value_t = 1.1)]
repeat_penalty: f32, repeat_penalty: f32,
@ -245,10 +227,6 @@ struct Args {
/// Group-Query Attention, use 8 for the 70B version of LLaMAv2. /// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
#[arg(long)] #[arg(long)]
gqa: Option<usize>, gqa: Option<usize>,
/// Use the slower dmmv cuda kernel.
#[arg(long)]
force_dmmv: bool,
} }
impl Args { impl Args {
@ -328,11 +306,6 @@ impl Args {
"TheBloke/Starling-LM-7B-alpha-GGUF", "TheBloke/Starling-LM-7B-alpha-GGUF",
"starling-lm-7b-alpha.Q4_K_M.gguf", "starling-lm-7b-alpha.Q4_K_M.gguf",
), ),
// TODO: swap to TheBloke model when available
Which::L8b => (
"QuantFactory/Meta-Llama-3-8B-GGUF",
"Meta-Llama-3-8B.Q4_K_S.gguf",
),
}; };
let api = hf_hub::api::sync::Api::new()?; let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string()); let api = api.model(repo.to_string());
@ -360,10 +333,11 @@ fn main() -> anyhow::Result<()> {
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
let args = Args::parse(); let args = Args::parse();
let temperature = if args.temperature == 0. {
#[cfg(feature = "cuda")] None
candle::quantized::cuda::set_force_dmmv(args.force_dmmv); } else {
Some(args.temperature)
};
let _guard = if args.tracing { let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init(); tracing_subscriber::registry().with(chrome_layer).init();
@ -387,7 +361,7 @@ fn main() -> anyhow::Result<()> {
let model_path = args.model()?; let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?; let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(false)?;
let mut model = match model_path.extension().and_then(|v| v.to_str()) { let mut model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => { Some("gguf") => {
@ -431,8 +405,7 @@ fn main() -> anyhow::Result<()> {
| Which::L13bCode | Which::L13bCode
| Which::L34bCode | Which::L34bCode
| Which::Leo7b | Which::Leo7b
| Which::Leo13b | Which::Leo13b => 1,
| Which::L8b => 1,
Which::Mixtral Which::Mixtral
| Which::MixtralInstruct | Which::MixtralInstruct
| Which::Mistral7b | Which::Mistral7b
@ -511,36 +484,14 @@ fn main() -> anyhow::Result<()> {
prompt_tokens prompt_tokens
}; };
let mut all_tokens = vec![]; let mut all_tokens = vec![];
let mut logits_processor = { let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let start_prompt_processing = std::time::Instant::now(); let start_prompt_processing = std::time::Instant::now();
let mut next_token = if !args.split_prompt { let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?; let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)? logits_processor.sample(&logits)?
} else {
let mut next_token = 0;
for (pos, token) in prompt_tokens.iter().enumerate() {
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, pos)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?
}
next_token
}; };
let prompt_dt = start_prompt_processing.elapsed(); let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token); all_tokens.push(next_token);
@ -549,14 +500,11 @@ fn main() -> anyhow::Result<()> {
std::io::stdout().flush()?; std::io::stdout().flush()?;
} }
let eos_token = match args.which { let eos_token = if args.which.is_open_chat() {
Which::L8b => "<|end_of_text|>", "<|end_of_turn|>"
_ => match args.which.is_open_chat() { } else {
true => "<|end_of_turn|>", "</s>"
false => "</s>",
},
}; };
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap(); let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
let start_post_prompt = std::time::Instant::now(); let start_post_prompt = std::time::Instant::now();
let mut sampled = 0; let mut sampled = 0;

View File

@ -1,27 +0,0 @@
# candle-qwen: large language model series from Alibaba Cloud
Qwen 1.5 is a series of large language models that provide strong performances
on English and Chinese.
- [Blog post](https://qwenlm.github.io/blog/qwen1.5/) introducing Qwen1.5.
- [Model card](https://huggingface.co/Qwen/Qwen1.5-0.5B) on the HuggingFace Hub.
- [Blog post](https://qwenlm.github.io/blog/qwen-moe/) for the
mixture-of-experts (MoE) variant.
## Running the example
```bash
$ cargo run --example qwen --release -- --prompt "Hello there "
```
Various model sizes are available via the `--model` argument, including the MoE
variant.
```bash
$ cargo run --example qwen --release -- --model moe-a2.7b --prompt 'def print_prime(n: int): '
def print_prime(n: int): # n is the number of primes to be printed
for i in range(2, n + 1):
if all(i % j != 0 for j in range(2, i)):
print(i)
```

View File

@ -1,311 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::qwen2::{Config as ConfigBase, Model as ModelBase};
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
Base(ModelBase),
Moe(ModelMoe),
}
impl Model {
fn forward(&mut self, xs: &Tensor, s: usize) -> candle::Result<Tensor> {
match self {
Self::Moe(ref mut m) => m.forward(xs, s),
Self::Base(ref mut m) => m.forward(xs, s),
}
}
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
enum WhichModel {
#[value(name = "0.5b")]
W0_5b,
#[value(name = "1.8b")]
W1_8b,
#[value(name = "4b")]
W4b,
#[value(name = "7b")]
W7b,
#[value(name = "14b")]
W14b,
#[value(name = "72b")]
W72b,
#[value(name = "moe-a2.7b")]
MoeA27b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long, default_value = "0.5b")]
model: WhichModel,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => {
let size = match args.model {
WhichModel::W0_5b => "0.5B",
WhichModel::W1_8b => "1.8B",
WhichModel::W4b => "4B",
WhichModel::W7b => "7B",
WhichModel::W14b => "14B",
WhichModel::W72b => "72B",
WhichModel::MoeA27b => "MoE-A2.7B",
};
format!("Qwen/Qwen1.5-{size}")
}
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.model {
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
WhichModel::W4b
| WhichModel::W7b
| WhichModel::W14b
| WhichModel::W72b
| WhichModel::MoeA27b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config_file = repo.get("config.json")?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = match args.model {
WhichModel::MoeA27b => {
let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Moe(ModelMoe::new(&config, vb)?)
}
_ => {
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Base(ModelBase::new(&config, vb)?)
}
};
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,9 +0,0 @@
# candle-recurrent-gemma
This model card corresponds to the 2B base version of the RecurrentGemma model
[huggingface model card](https://huggingface.co/google/recurrentgemma-2b).
```bash
cargo run --features cuda -r --example recurrent-gemma -- \
--prompt "Write me a poem about Machine Learning."
```

Some files were not shown because too many files have changed in this diff Show More