Compare commits

..

4 Commits

Author SHA1 Message Date
c2261d0222 Merge. 2024-01-07 20:27:33 +01:00
06d186355b Change more consitently the test. 2024-01-06 15:20:55 +01:00
2bbd544832 Non random for better quantization quality 2024-01-06 15:16:01 +01:00
504d0b9ac7 Potential bug on q4k. 2024-01-05 14:15:47 +01:00
214 changed files with 2210 additions and 26304 deletions

View File

@ -1,7 +0,0 @@
version: 2
updates:
- package-ecosystem: "cargo"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 5

View File

@ -5,15 +5,49 @@ on:
pull_request: pull_request:
jobs: jobs:
start-runner:
name: Start self-hosted EC2 runner
runs-on: ubuntu-latest
# Don't run on forks, they won't have access to secrets anyway.
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
env:
AWS_REGION: us-east-1
EC2_AMI_ID: ami-03cfed9ea28f4b002
EC2_INSTANCE_TYPE: g5.xlarge
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
EC2_SECURITY_GROUP: sg-030175c435ac141d6
outputs:
label: ${{ steps.start-ec2-runner.outputs.label }}
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Start EC2 runner
id: start-ec2-runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: start
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
ec2-image-id: ${{ env.EC2_AMI_ID }}
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
subnet-id: ${{ env.EC2_SUBNET_ID }}
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
aws-resource-tags: > # optional, requires additional permissions
[
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
]
test-cuda: test-cuda:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
runs-on: [single-gpu, nvidia-gpu, t4, ci] needs: start-runner # required to start the main job when the runner is ready
container: runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
options: --gpus 0
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
permissions: permissions:
contents: write contents: write
packages: write packages: write
@ -24,10 +58,32 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Install dependencies
run: apt-get update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
- name: Install Rust Stable - name: Install Rust Stable
uses: actions-rust-lang/setup-rust-toolchain@v1 run: curl https://sh.rustup.rs -sSf | sh -s -- -y
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
- name: Test (cuda) - name: Test (cuda)
run: cargo test --features cuda run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
stop-runner:
name: Stop self-hosted EC2 runner
needs:
- start-runner
- test-cuda
runs-on: ubuntu-latest
env:
AWS_REGION: us-east-1
if: ${{ (success() || failure()) && github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }} # required to stop the runner even if the error happened in the previous jobs
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Stop EC2 runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: stop
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
label: ${{ needs.start-runner.outputs.label }}
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}

View File

@ -19,7 +19,7 @@ exclude = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "0.4.2" version = "0.3.3"
edition = "2021" edition = "2021"
description = "Minimalist ML framework." description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle" repository = "https://github.com/huggingface/candle"
@ -31,19 +31,18 @@ license = "MIT OR Apache-2.0"
accelerate-src = { version = "0.3.2" } accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3" byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.4.2" } candle = { path = "./candle-core", package = "candle-core" }
candle-datasets = { path = "./candle-datasets", version = "0.4.2" } candle-datasets = { path = "./candle-datasets" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.2" } candle-flash-attn = { path = "./candle-flash-attn" }
candle-kernels = { path = "./candle-kernels", version = "0.4.2" } candle-kernels = { path = "./candle-kernels" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.2" } candle-metal-kernels = { path = "./candle-metal-kernels" }
candle-nn = { path = "./candle-nn", version = "0.4.2" } candle-nn = { path = "./candle-nn" }
candle-onnx = { path = "./candle-onnx", version = "0.4.2" } candle-onnx = { path = "./candle-onnx" }
candle-transformers = { path = "./candle-transformers", version = "0.4.2" } candle-transformers = { path = "./candle-transformers" }
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false } criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.10.0", features = ["f16"] } cudarc = { version = "0.9.14", features = ["f16"] }
fancy-regex = "0.13.0" gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0" hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] } image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
@ -51,20 +50,20 @@ imageproc = { version = "0.23.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] } intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" } libc = { version = "0.2.147" }
log = "0.4" log = "0.4"
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] } memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
num_cpus = "1.15.0" num_cpus = "1.15.0"
num-traits = "0.2.15" num-traits = "0.2.15"
parquet = { version = "50.0.0" } parquet = { version = "45.0.0" }
rand = "0.8.5" rand = "0.8.5"
rand_distr = "0.4.3" rand_distr = "0.4.3"
rayon = "1.7.0" rayon = "1.7.0"
rusttype = { version = "0.9", default-features = false } rusttype = { version = "0.9", default-features = false }
safetensors = "0.4.1" safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] } serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2" serde_plain = "1.0.2"
serde_json = "1.0.99" serde_json = "1.0.99"
thiserror = "1" thiserror = "1"
tokenizers = { version = "0.15.0", default-features = false } tokenizers = { version = "0.13.4", default-features = false }
tracing = "0.1.37" tracing = "0.1.37"
tracing-chrome = "0.7.1" tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7" tracing-subscriber = "0.3.7"

View File

@ -63,24 +63,17 @@ We also provide a some command line based examples using state of the art models
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes - [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
the SOLAR-10.7B variant. the SOLAR-10.7B variant.
- [Falcon](./candle-examples/examples/falcon/): general LLM. - [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google
Deepmind.
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b. - [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM - [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets. Also supports pre-trained on 1T tokens of English and code datasets.
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants. - [Minimal Mamba](./candle-examples/examples/minimal-mamba/): a minimal
- [Mamba](./candle-examples/examples/mamba/): an inference only
implementation of the Mamba state space model. implementation of the Mamba state space model.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with - [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
better performance than all publicly available 13b models as of 2023-09-28. better performance than all publicly available 13b models as of 2023-09-28.
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of - [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
experts 8x7b general LLM with better performance than a Llama 2 70B model with experts 8x7b general LLM with better performance than a Llama 2 70B model with
much faster inference. much faster inference.
- [StarCoder](./candle-examples/examples/bigcode/) and - [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
[StarCoder2](./candle-examples/examples/starcoder2/): LLM specialized to code generation.
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
- [RWKV v5 and v6](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
performance.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion. - [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual - [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
(English/Chinese) general LLMs with 6b and 34b parameters. (English/Chinese) general LLMs with 6b and 34b parameters.
@ -110,23 +103,14 @@ We also provide a some command line based examples using state of the art models
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200"> <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmantation model.
- [Whisper](./candle-examples/examples/whisper/): speech recognition model. - [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
model using residual vector quantization.
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
text-to-speech.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/), - [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings. [JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained - [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation). evaluation, segmentation).
- [VGG](./candle-examples/examples/vgg/),
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to - [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
generate captions for an image. generate captions for an image.
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
dedicated submodels for hand-writing and printed recognition.
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation - [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
model, generates the translated text from the input text. model, generates the translated text from the input text.
@ -175,7 +159,6 @@ And then head over to
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more. - [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle. - [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem. - [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
If you have an addition to this list, please submit a pull request. If you have an addition to this list, please submit a pull request.
@ -196,18 +179,15 @@ If you have an addition to this list, please submit a pull request.
- Language Models. - Language Models.
- LLaMA v1 and v2 with variants such as SOLAR-10.7B. - LLaMA v1 and v2 with variants such as SOLAR-10.7B.
- Falcon. - Falcon.
- StarCoder, StarCoder2. - StarCoder.
- Phi 1, 1.5, and 2. - Phi 1, 1.5, and 2.
- Mamba, Minimal Mamba - Minimal Mamba
- Gemma 2b and 7b.
- Mistral 7b v0.1. - Mistral 7b v0.1.
- Mixtral 8x7b v0.1. - Mixtral 8x7b v0.1.
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B. - StableLM-3B-4E1T.
- Replit-code-v1.5-3B. - Replit-code-v1.5-3B.
- Bert. - Bert.
- Yi-6B and Yi-34B. - Yi-6B and Yi-34B.
- Qwen1.5.
- RWKV v5 and v6.
- Quantized LLMs. - Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants. - Llama 7b, 13b, 70b, as well as the chat and code variants.
- Mistral 7b, and 7b instruct. - Mistral 7b, and 7b instruct.
@ -217,22 +197,16 @@ If you have an addition to this list, please submit a pull request.
- Text to text. - Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction). - T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation). - Marian MT (Machine Translation).
- Whisper (multi-lingual support).
- Text to image. - Text to image.
- Stable Diffusion v1.5, v2.1, XL v1.0. - Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2. - Wurstchen v2.
- Image to text. - Image to text.
- BLIP. - BLIP.
- TrOCR.
- Audio.
- Whisper, multi-lingual speech-to-text.
- EnCodec, audio compression model.
- MetaVoice-1B, text-to-speech model.
- Computer Vision Models. - Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT, - DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
ConvNeXTv2, MobileOne, EfficientVit (MSRA).
- yolo-v3, yolo-v8. - yolo-v3, yolo-v8.
- Segment-Anything Model (SAM). - Segment-Anything Model (SAM).
- SegFormer.
- File formats: load models from safetensors, npz, ggml, or PyTorch files. - File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Serverless (on CPU), small and fast deployments. - Serverless (on CPU), small and fast deployments.
- Quantization support using the llama.cpp quantized types. - Quantization support using the llama.cpp quantized types.

View File

@ -46,5 +46,6 @@ accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"] metal = ["dep:metal", "dep:candle-metal-kernels"]
[[bench]] [[bench]]
name = "bench_main" name = "matmul"
harness = false harness = false

View File

@ -1,9 +0,0 @@
mod benchmarks;
use criterion::criterion_main;
criterion_main!(
benchmarks::affine::benches,
benchmarks::matmul::benches,
benchmarks::random::benches,
benchmarks::where_cond::benches
);

View File

@ -1,43 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor) {
a.affine(12.34, 56.78).unwrap();
}
fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let b = 1;
let m = 1024;
let k = 1024;
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
let flops = b * m * k * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
run_affine_benchmark(c, &device, DType::F16, "affine_f16");
run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,66 +0,0 @@
pub(crate) mod affine;
pub(crate) mod matmul;
pub(crate) mod random;
pub(crate) mod where_cond;
use candle_core::{Device, Result};
pub(crate) trait BenchDevice {
fn sync(&self) -> Result<()>;
fn bench_name<S: Into<String>>(&self, name: S) -> String;
}
impl BenchDevice for Device {
fn sync(&self) -> Result<()> {
match self {
Device::Cpu => Ok(()),
Device::Cuda(device) => {
#[cfg(feature = "cuda")]
return Ok(device.synchronize()?);
#[cfg(not(feature = "cuda"))]
panic!("Cuda device without cuda feature enabled: {:?}", device)
}
Device::Metal(device) => {
#[cfg(feature = "metal")]
return Ok(device.wait_until_completed()?);
#[cfg(not(feature = "metal"))]
panic!("Metal device without metal feature enabled: {:?}", device)
}
}
}
fn bench_name<S: Into<String>>(&self, name: S) -> String {
match self {
Device::Cpu => {
let cpu_type = if cfg!(feature = "accelerate") {
"accelerate"
} else if cfg!(feature = "mkl") {
"mkl"
} else {
"cpu"
};
format!("{}_{}", cpu_type, name.into())
}
Device::Cuda(_) => format!("cuda_{}", name.into()),
Device::Metal(_) => format!("metal_{}", name.into()),
}
}
}
struct BenchDeviceHandler {
devices: Vec<Device>,
}
impl BenchDeviceHandler {
pub fn new() -> Result<Self> {
let mut devices = Vec::new();
if cfg!(feature = "metal") {
devices.push(Device::new_metal(0)?);
} else if cfg!(feature = "cuda") {
devices.push(Device::new_cuda(0)?);
}
devices.push(Device::Cpu);
Ok(Self { devices })
}
}

View File

@ -1,63 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn rand_uniform(a: &Tensor) {
a.rand_like(-1.0, 123.0).unwrap();
}
fn rand_normal(a: &Tensor) {
a.randn_like(100.0, 15.0).unwrap();
}
fn run_random_bench(c: &mut Criterion, device: &Device) {
let b = 1;
let rows = 2048;
let cols = 2048;
let dtype = DType::F32;
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let flops = b * rows * cols * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_uniform(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let mut group = c.benchmark_group(device.bench_name("random_normal"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_normal(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_random_bench(c, &device);
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,64 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
a.where_cond(b, c).unwrap();
}
const fn create_cond_arr<const N: usize>() -> [u8; N] {
let mut arr = [0u8; N];
let mut i = 0;
while i < N {
arr[i] = (i % 2) as u8;
i += 1;
}
arr
}
const B: usize = 1;
const M: usize = 1024;
const K: usize = 1024;
const SIZE: usize = B * M * K;
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
let elements = B * M * K;
// E.g. 2 f32 tensors + 1 u8 tensor
let flops = (2 * elements * dtype.size_in_bytes()) + elements;
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(
black_box(&tensor),
black_box(&on_true),
black_box(&on_false),
);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let device = BenchDeviceHandler::new().unwrap();
for d in device.devices {
run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,25 +1,25 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor}; use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput}; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
use std::time::Instant; use std::time::Instant;
fn run(a: &Tensor, b: &Tensor) { fn run(a: &Tensor, b: &Tensor) {
a.matmul(&b.t().unwrap()).unwrap(); a.matmul(&b.t().unwrap()).unwrap();
} }
fn run_bench(c: &mut Criterion, device: &Device) { fn criterion_benchmark(c: &mut Criterion) {
let b = 1; let b = 1;
let m = 1; let m = 1;
let n = 2048; let n = 2048;
let k = 2048; let k = 2048;
let device = Device::new_metal(0).unwrap();
let dtype = DType::F32; let dtype = DType::F32;
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap(); let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap();
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap(); let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap();
let flops = b * m * n * k; let flops = b * m * n * k;
let mut group = c.benchmark_group(device.bench_name("matmul")); let mut group = c.benchmark_group("matmul_metal");
group.throughput(Throughput::Bytes(flops as u64)); group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| { group.bench_function("iter", move |b| {
b.iter_custom(|iters| { b.iter_custom(|iters| {
@ -27,18 +27,16 @@ fn run_bench(c: &mut Criterion, device: &Device) {
for _i in 0..iters { for _i in 0..iters {
run(black_box(&lhs), black_box(&rhs)); run(black_box(&lhs), black_box(&rhs));
} }
device.sync().unwrap(); if let Device::Metal(device) = &device {
device.wait_until_completed().unwrap();
} else {
panic!("Expected metal device");
}
start.elapsed() start.elapsed()
}) })
}); });
group.finish(); group.finish();
} }
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_bench(c, &device);
}
}
criterion_group!(benches, criterion_benchmark); criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

View File

@ -5,32 +5,25 @@ extern crate accelerate_src;
extern crate intel_mkl_src; extern crate intel_mkl_src;
use anyhow::Result; use anyhow::Result;
use candle_core::{Device, Module, Tensor}; use candle_core::{Device, Tensor};
use candle_core::quantized::{QMatMul, QTensor};
fn main() -> Result<()> { fn main() -> Result<()> {
let device = Device::new_cuda(0)?; let device = Device::new_cuda(0)?;
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?; let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
let q_cpu = q.to_device(&Device::Cpu)?; let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?; let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
let q = QMatMul::from_qtensor(q)?; println!("{out_t}");
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?; let in_t = in_t.to_device(&Device::Cpu)?;
let res_q_cuda = q.forward(&x)?; let k_t = k_t.to_device(&Device::Cpu)?;
println!("{res_q_cuda}"); let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?; .sqr()?
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?; .sum_all()?;
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
let x_cpu = x.to_device(&Device::Cpu)?;
let res_q_cpu = q_cpu.forward(&x_cpu)?;
println!("{res_q_cpu}");
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
.abs()?
.flatten_all()?
.max(0)?;
println!("{diff}"); println!("{diff}");
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
let res = t.conv2d(&w, 1, 1, 1, 1)?;
println!("{res:?}");
Ok(()) Ok(())
} }

View File

@ -1,5 +1,5 @@
use candle_core::quantized::{gguf_file, GgmlDType, QTensor}; use candle_core::quantized::{gguf_file, k_quants, QTensor};
use candle_core::{Device, Result}; use candle_core::{Device, Result, Tensor};
use clap::{Parser, Subcommand, ValueEnum}; use clap::{Parser, Subcommand, ValueEnum};
use rayon::prelude::*; use rayon::prelude::*;
@ -11,7 +11,12 @@ enum QuantizationMode {
} }
impl QuantizationMode { impl QuantizationMode {
fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result<QTensor> { fn quantize(
&self,
name: &str,
tensor: QTensor,
default: fn(&Tensor) -> Result<QTensor>,
) -> Result<QTensor> {
match self { match self {
Self::Llama => { Self::Llama => {
// Same behavior as the llama.cpp quantization. // Same behavior as the llama.cpp quantization.
@ -19,9 +24,9 @@ impl QuantizationMode {
if should_quantize { if should_quantize {
let tensor = tensor.dequantize(&Device::Cpu)?; let tensor = tensor.dequantize(&Device::Cpu)?;
if name == "output.weight" { if name == "output.weight" {
QTensor::quantize(&tensor, GgmlDType::Q6K) QTensor::quantize::<k_quants::BlockQ6K>(&tensor)
} else { } else {
QTensor::quantize(&tensor, dtype) default(&tensor)
} }
} else { } else {
Ok(tensor) Ok(tensor)
@ -55,27 +60,6 @@ enum Quantization {
F32, F32,
} }
impl Quantization {
fn dtype(&self) -> GgmlDType {
match self {
Quantization::Q4_0 => GgmlDType::Q4_0,
Quantization::Q4_1 => GgmlDType::Q4_1,
Quantization::Q5_0 => GgmlDType::Q5_0,
Quantization::Q5_1 => GgmlDType::Q5_1,
Quantization::Q8_0 => GgmlDType::Q8_0,
Quantization::Q8_1 => GgmlDType::Q8_1,
Quantization::Q2k => GgmlDType::Q2K,
Quantization::Q3k => GgmlDType::Q3K,
Quantization::Q4k => GgmlDType::Q4K,
Quantization::Q5k => GgmlDType::Q5K,
Quantization::Q6k => GgmlDType::Q6K,
Quantization::Q8k => GgmlDType::Q8K,
Quantization::F16 => GgmlDType::F16,
Quantization::F32 => GgmlDType::F32,
}
}
}
#[derive(ValueEnum, Debug, Clone)] #[derive(ValueEnum, Debug, Clone)]
enum Format { enum Format {
Safetensors, Safetensors,
@ -118,7 +102,7 @@ enum Command {
}, },
Quantize { Quantize {
/// The input file(s), in safetensors format. /// The input file, in gguf format.
in_file: Vec<std::path::PathBuf>, in_file: Vec<std::path::PathBuf>,
/// The output file, in gguf format. /// The output file, in gguf format.
@ -133,15 +117,6 @@ enum Command {
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)] #[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
mode: QuantizationMode, mode: QuantizationMode,
}, },
Dequantize {
/// The input file, in gguf format.
in_file: std::path::PathBuf,
/// The output file, in safetensors format.
#[arg(long)]
out_file: std::path::PathBuf,
},
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
@ -150,12 +125,7 @@ struct Args {
command: Command, command: Command,
} }
fn run_ls( fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
file: &std::path::PathBuf,
format: Option<Format>,
verbose: bool,
device: &Device,
) -> Result<()> {
let format = match format { let format = match format {
Some(format) => format, Some(format) => format,
None => match Format::infer(file) { None => match Format::infer(file) {
@ -196,7 +166,7 @@ fn run_ls(
} }
} }
Format::Pth => { Format::Pth => {
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?; let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?;
tensors.sort_by(|a, b| a.name.cmp(&b.name)); tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() { for tensor_info in tensors.iter() {
println!( println!(
@ -221,7 +191,7 @@ fn run_ls(
} }
Format::Ggml => { Format::Ggml => {
let mut file = std::fs::File::open(file)?; let mut file = std::fs::File::open(file)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?; let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>(); let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
tensors.sort_by(|a, b| a.0.cmp(&b.0)); tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, qtensor) in tensors.iter() { for (name, qtensor) in tensors.iter() {
@ -262,8 +232,37 @@ fn run_quantize_safetensors(
} }
println!("tensors: {}", tensors.len()); println!("tensors: {}", tensors.len());
let dtype = q.dtype(); let quantize_fn = match q {
let block_size = dtype.block_size(); Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
Quantization::F16 => QTensor::quantize::<half::f16>,
Quantization::F32 => QTensor::quantize::<f32>,
};
let block_size = match q {
Quantization::Q4_0 => k_quants::QK4_0,
Quantization::Q4_1 => k_quants::QK4_1,
Quantization::Q5_0 => k_quants::QK5_0,
Quantization::Q5_1 => k_quants::QK5_1,
Quantization::Q8_0 => k_quants::QK8_0,
Quantization::Q8_1 => k_quants::QK8_1,
Quantization::Q2k
| Quantization::Q3k
| Quantization::Q4k
| Quantization::Q5k
| Quantization::Q6k
| Quantization::Q8k => k_quants::QK_K,
Quantization::F16 | Quantization::F32 => 1,
};
let qtensors = tensors let qtensors = tensors
.into_par_iter() .into_par_iter()
@ -271,9 +270,9 @@ fn run_quantize_safetensors(
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0; let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
println!(" quantizing {name} {tensor:?} {should_quantize}"); println!(" quantizing {name} {tensor:?} {should_quantize}");
let tensor = if should_quantize { let tensor = if should_quantize {
QTensor::quantize(&tensor, dtype)? quantize_fn(&tensor)?
} else { } else {
QTensor::quantize(&tensor, GgmlDType::F32)? QTensor::quantize::<f32>(&tensor)?
}; };
Ok((name, tensor)) Ok((name, tensor))
}) })
@ -286,29 +285,11 @@ fn run_quantize_safetensors(
Ok(()) Ok(())
} }
fn run_dequantize(
in_file: std::path::PathBuf,
out_file: std::path::PathBuf,
device: &Device,
) -> Result<()> {
let mut in_file = std::fs::File::open(in_file)?;
let content = gguf_file::Content::read(&mut in_file)?;
let mut tensors = std::collections::HashMap::new();
for (tensor_name, _) in content.tensor_infos.iter() {
let tensor = content.tensor(&mut in_file, tensor_name, device)?;
let tensor = tensor.dequantize(device)?;
tensors.insert(tensor_name.to_string(), tensor);
}
candle_core::safetensors::save(&tensors, out_file)?;
Ok(())
}
fn run_quantize( fn run_quantize(
in_files: &[std::path::PathBuf], in_files: &[std::path::PathBuf],
out_file: std::path::PathBuf, out_file: std::path::PathBuf,
q: Quantization, q: Quantization,
qmode: QuantizationMode, qmode: QuantizationMode,
device: &Device,
) -> Result<()> { ) -> Result<()> {
if in_files.is_empty() { if in_files.is_empty() {
candle_core::bail!("no specified input files") candle_core::bail!("no specified input files")
@ -334,15 +315,31 @@ fn run_quantize(
let content = gguf_file::Content::read(&mut in_)?; let content = gguf_file::Content::read(&mut in_)?;
println!("tensors: {}", content.tensor_infos.len()); println!("tensors: {}", content.tensor_infos.len());
let dtype = q.dtype(); let quantize_fn = match q {
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
Quantization::F16 => QTensor::quantize::<half::f16>,
Quantization::F32 => QTensor::quantize::<f32>,
};
let qtensors = content let qtensors = content
.tensor_infos .tensor_infos
.par_iter() .par_iter()
.map(|(name, _)| { .map(|(name, _)| {
println!(" quantizing {name}"); println!(" quantizing {name}");
let mut in_file = std::fs::File::open(&in_files[0])?; let mut in_file = std::fs::File::open(&in_files[0])?;
let tensor = content.tensor(&mut in_file, name, device)?; let tensor = content.tensor(&mut in_file, name)?;
let tensor = qmode.quantize(name, tensor, dtype)?; let tensor = qmode.quantize(name, tensor, quantize_fn)?;
Ok((name, tensor)) Ok((name, tensor))
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
@ -362,7 +359,6 @@ fn run_quantize(
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let args = Args::parse(); let args = Args::parse();
let device = Device::Cpu;
match args.command { match args.command {
Command::Ls { Command::Ls {
files, files,
@ -374,7 +370,7 @@ fn main() -> anyhow::Result<()> {
if multiple_files { if multiple_files {
println!("--- {file:?} ---"); println!("--- {file:?} ---");
} }
run_ls(file, format.clone(), verbose, &device)? run_ls(file, format.clone(), verbose)?
} }
} }
Command::Quantize { Command::Quantize {
@ -382,8 +378,7 @@ fn main() -> anyhow::Result<()> {
out_file, out_file,
quantization, quantization,
mode, mode,
} => run_quantize(&in_file, out_file, quantization, mode, &device)?, } => run_quantize(&in_file, out_file, quantization, mode)?,
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?,
} }
Ok(()) Ok(())
} }

View File

@ -380,16 +380,6 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) } unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
} }
#[inline]
pub fn vs_exp_inplace(y: &mut [f32]) {
unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vd_exp_inplace(y: &mut [f64]) {
unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline] #[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) { pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) { for (&v, y) in vs.iter().zip(ys.iter_mut()) {
@ -412,28 +402,6 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
} }
} }
#[inline]
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vs_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
#[inline]
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vd_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
macro_rules! binary_op { macro_rules! binary_op {
($fn_name:ident, $ty:ty, $accelerate_name:ident) => { ($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
#[inline] #[inline]

View File

@ -98,19 +98,6 @@ pub trait BackendStorage: Sized {
) -> Result<Self>; ) -> Result<Self>;
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>; fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
#[allow(clippy::too_many_arguments)]
// Similar to cudaMemcpy2D, though values are in elements and not in bytes.
fn copy2d(
&self,
_: &mut Self,
_d1: usize,
_d2: usize,
_src_stride1: usize,
_dst_stride1: usize,
_src_offset: usize,
_dst_offset: usize,
) -> Result<()>;
} }
pub trait BackendDevice: Sized + std::fmt::Debug + Clone { pub trait BackendDevice: Sized + std::fmt::Debug + Clone {

View File

@ -113,7 +113,7 @@ impl Tensor {
| Op::Unary(_node, UnaryOp::Floor) | Op::Unary(_node, UnaryOp::Floor)
| Op::Unary(_node, UnaryOp::Round) => nodes, | Op::Unary(_node, UnaryOp::Round) => nodes,
Op::Reshape(node) Op::Reshape(node)
| Op::UpsampleNearest1D { arg: node, .. } | Op::UpsampleNearest1D(node)
| Op::UpsampleNearest2D { arg: node, .. } | Op::UpsampleNearest2D { arg: node, .. }
| Op::AvgPool2D { arg: node, .. } | Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. } | Op::MaxPool2D { arg: node, .. }
@ -175,7 +175,7 @@ impl Tensor {
// the backprop graph of the backprop itself. This would be an issue for second order // the backprop graph of the backprop itself. This would be an issue for second order
// derivatives but these are out of scope at the moment. // derivatives but these are out of scope at the moment.
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b); let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
let grad = if do_not_detach { grad } else { grad.detach() }; let grad = if do_not_detach { grad } else { grad.detach()? };
if let Some(op) = node.op() { if let Some(op) = node.op() {
match op { match op {
Op::Binary(lhs, rhs, BinaryOp::Add) => { Op::Binary(lhs, rhs, BinaryOp::Add) => {
@ -250,7 +250,6 @@ impl Tensor {
out_padding, out_padding,
*stride, *stride,
*dilation, *dilation,
/* groups */ 1,
)?; )?;
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?; *sum_grad = sum_grad.add(&grad_arg)?;
@ -348,18 +347,9 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?; *sum_grad = sum_grad.add(&grad_arg)?;
} }
Op::UpsampleNearest1D { arg, target_size } => { Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
let (_n, c, size) = arg.dims3()?; op: "upsample-nearest1d",
if target_size % size != 0 { })?,
crate::bail!("backward not supported for non integer upscaling factors")
}
let scale = target_size / size;
let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = conv_sum;
}
Op::UpsampleNearest2D { Op::UpsampleNearest2D {
arg, arg,
target_h, target_h,
@ -599,13 +589,6 @@ impl Tensor {
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?; let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)? *sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
} }
Op::Unary(arg, UnaryOp::Silu) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
let sigmoid_arg = (*node / arg)?;
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
}
Op::Elu(arg, alpha) => { Op::Elu(arg, alpha) => {
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0 // d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;

View File

@ -187,16 +187,36 @@ impl Tensor {
} }
} }
fn conv_transpose1d_single_group( /// Applies a 1D transposed convolution over the input tensor.
pub fn conv_transpose1d(
&self, &self,
kernel: &Self, kernel: &Self,
params: &ParamsConvTranspose1D, padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) -> Result<Self> { ) -> Result<Self> {
let (b_size, c_in, l_in) = self.dims3()?;
let (c_in_k, c_out, k_size) = kernel.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = ParamsConvTranspose1D {
b_size,
l_in,
k_size,
c_out,
c_in,
padding,
output_padding,
stride,
dilation,
};
let storage = self.storage().conv_transpose1d( let storage = self.storage().conv_transpose1d(
self.layout(), self.layout(),
&kernel.storage(), &kernel.storage(),
kernel.layout(), kernel.layout(),
params, &params,
)?; )?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D { let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
arg, arg,
@ -210,49 +230,6 @@ impl Tensor {
Ok(crate::tensor::from_storage(storage, out_dims, op, false)) Ok(crate::tensor::from_storage(storage, out_dims, op, false))
} }
/// Applies a 1D transposed convolution over the input tensor.
pub fn conv_transpose1d(
&self,
kernel: &Self,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
let (c_in_k, c_out, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = self.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
if c_in % groups != 0 {
crate::bail!("in_channel {c_in} is not divisible by the number of groups")
}
let params = ParamsConvTranspose1D {
b_size,
l_in,
k_size,
c_out,
c_in: c_in / groups,
padding,
output_padding,
stride,
dilation,
};
if groups == 1 {
self.conv_transpose1d_single_group(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
let kernel = kernel.chunk(groups, 0)?;
let blocks = blocks
.iter()
.zip(&kernel)
.map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, &params))
.collect::<Result<Vec<_>>>()?;
Tensor::cat(&blocks, 1)
}
}
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> { fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
let storage = let storage =
self.storage() self.storage()

View File

@ -5,7 +5,6 @@ use half::{bf16, f16};
use rayon::prelude::*; use rayon::prelude::*;
const USE_IM2COL_CONV1D: bool = true; const USE_IM2COL_CONV1D: bool = true;
const USE_IM2COL_CONV1D_TR: bool = true;
const USE_IM2COL_CONV2D: bool = true; const USE_IM2COL_CONV2D: bool = true;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
@ -1023,26 +1022,6 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
} }
} }
#[allow(clippy::too_many_arguments)]
fn copy2d_<T: Copy>(
src: &[T],
dst: &mut [T],
d1: usize,
d2: usize,
src_stride1: usize,
dst_stride1: usize,
src_offset: usize,
dst_offset: usize,
) {
for i1 in 0..d1 {
let dst_idx = i1 * dst_stride1 + dst_offset;
let src_idx = i1 * src_stride1 + src_offset;
let dst = &mut dst[dst_idx..dst_idx + d2];
let src = &src[src_idx..src_idx + d2];
dst.copy_from_slice(src)
}
}
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) { fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
match src_l.strided_blocks() { match src_l.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => { crate::StridedBlocks::SingleBlock { start_offset, len } => {
@ -1277,34 +1256,6 @@ impl Map1 for Im2Col {
} }
} }
struct Col2Im1D {
stride: usize,
}
impl Map1 for Col2Im1D {
fn f<T: WithDType>(&self, col: &[T], l: &Layout) -> Result<Vec<T>> {
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
let stride = self.stride;
let l_out = (l_in - 1) * stride + k_size;
let mut im = vec![T::zero(); b_size * c_out * l_out];
let (dst_s0, dst_s1) = (c_out * l_out, l_out);
let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
for l_in_i in 0..l_in {
for k_i in 0..k_size {
let l_out_i = l_in_i * stride + k_i;
for b_i in 0..b_size {
for c_i in 0..c_out {
let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
im[dst_idx] += col[src_idx]
}
}
}
}
Ok(im)
}
}
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D); struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> { impl<'a> Map2 for ConvTranspose1D<'a> {
@ -1312,7 +1263,6 @@ impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> { fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0; let p = self.0;
let inp = &inp[inp_l.start_offset()..]; let inp = &inp[inp_l.start_offset()..];
let k = &k[k_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?; let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?; let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out(); let l_out = p.l_out();
@ -2472,48 +2422,6 @@ impl BackendStorage for CpuStorage {
} }
} }
fn copy2d(
&self,
dst: &mut Self,
d1: usize,
d2: usize,
src_s: usize,
dst_s: usize,
src_o: usize,
dst_o: usize,
) -> Result<()> {
match (self, dst) {
(Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
(Self::U32(src), Self::U32(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::I64(src), Self::I64(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::BF16(src), Self::BF16(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::F16(src), Self::F16(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::F32(src), Self::F32(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::F64(src), Self::F64(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(_, dst) => {
return Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: dst.dtype(),
op: "copy2d",
}
.bt());
}
}
Ok(())
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
match (self, dst) { match (self, dst) {
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
@ -2602,52 +2510,7 @@ impl BackendStorage for CpuStorage {
kernel_l: &Layout, kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D, params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> { ) -> Result<Self> {
let can_use_col2im = kernel_l.is_contiguous() ConvTranspose1D(params).map(self, l, kernel, kernel_l)
&& params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
if USE_IM2COL_CONV1D_TR && can_use_col2im {
let (b_size, c_in, l_in) = l.shape().dims3()?;
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
if !kernel_l.is_contiguous() {
crate::bail!(
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
)
}
if c_in != c_in2 {
crate::bail!(
"convtr1d: shape mismatch on c_in {:?} {:?}",
l.shape(),
kernel_l.shape()
)
}
let col = {
// This merges the last two dimensions of the kernel together.
let kernel_l_mm = Layout::new(
(b_size, c_in, k_size * c_out).into(),
vec![0, k_size * c_out, 1],
kernel_l.start_offset(),
);
self.matmul(
kernel,
(
b_size,
/* m */ l_in,
/* n */ c_out * k_size,
/* k */ c_in,
),
&l.transpose(1, 2)?,
&kernel_l_mm,
)?
};
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
Col2Im1D {
stride: params.stride,
}
.map(&col, &col_l)
} else {
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
}
} }
fn conv2d( fn conv2d(
@ -2711,7 +2574,7 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
} }
} }
@ -2720,7 +2583,7 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
} }
} }
@ -2737,7 +2600,7 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
} }
} }

View File

@ -1149,55 +1149,6 @@ impl<'a> Map2 for Conv2D<'a> {
} }
} }
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
inp_l: &Layout,
k: &CudaSlice<T>,
k_l: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
// Kernel shape: (c_in_k, c_out, l_k)
// Input shape: (b_size, c_in, l_in)
let p = &self.0;
let l_out = p.l_out();
let dst_el = p.c_out * l_out * p.b_size;
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(k_l.start_offset()..);
let shape = inp_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), kernels::CONV)?;
let ds = if dims.len() == 3 {
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
} else {
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
let params = (
el,
l_out,
p.stride,
p.padding,
p.output_padding,
p.dilation,
&ds,
inp,
k,
&out,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D); struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
impl<'a> Map2 for ConvTranspose2D<'a> { impl<'a> Map2 for ConvTranspose2D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@ -1859,15 +1810,12 @@ impl BackendStorage for CudaStorage {
fn conv_transpose1d( fn conv_transpose1d(
&self, &self,
l: &Layout, _: &Layout,
kernel: &Self, _: &Self,
kernel_l: &Layout, _: &Layout,
params: &crate::conv::ParamsConvTranspose1D, _: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> { ) -> Result<Self> {
let device = self.device().clone(); todo!()
let slice =
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
Ok(Self { slice, device })
} }
#[cfg(not(feature = "cudnn"))] #[cfg(not(feature = "cudnn"))]
@ -2145,67 +2093,6 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device }) Ok(Self { slice, device })
} }
fn copy2d(
&self,
dst: &mut Self,
d1: usize,
d2: usize,
src_s: usize,
dst_s: usize,
src_o: usize,
dst_o: usize,
) -> Result<()> {
let dev = &self.device;
let d1 = d1 as u32;
let d2 = d2 as u32;
let dst_s = dst_s as u32;
let src_s = src_s as u32;
let (src, dst, kname) = match (&self.slice, &mut dst.slice) {
(S::U8(s), S::U8(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_u8",
),
(S::U32(s), S::U32(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_u32",
),
(S::I64(s), S::I64(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_i64",
),
(S::BF16(s), S::BF16(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_bf16",
),
(S::F16(s), S::F16(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_f16",
),
(S::F32(s), S::F32(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_f32",
),
(S::F64(s), S::F64(d)) => (
*s.slice(src_o..).device_ptr(),
*d.slice(dst_o..).device_ptr(),
"copy2d_f64",
),
_ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?,
};
let func = dev.get_or_load_func(kname, kernels::FILL)?;
let cfg = LaunchConfig::for_num_elems(d1 * d2);
let params = (src, dst, d1, d2, src_s, dst_s);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(())
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let src_shape = src_l.shape(); let src_shape = src_l.shape();
let dims = src_shape.dims(); let dims = src_shape.dims();

View File

@ -65,13 +65,12 @@ impl std::fmt::Debug for Tensor {
} }
/// Options for Tensor pretty printing /// Options for Tensor pretty printing
#[derive(Debug, Clone)]
pub struct PrinterOptions { pub struct PrinterOptions {
pub precision: usize, precision: usize,
pub threshold: usize, threshold: usize,
pub edge_items: usize, edge_items: usize,
pub line_width: usize, line_width: usize,
pub sci_mode: Option<bool>, sci_mode: Option<bool>,
} }
static PRINT_OPTS: std::sync::Mutex<PrinterOptions> = static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
@ -90,10 +89,6 @@ impl PrinterOptions {
} }
} }
pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
&PRINT_OPTS
}
pub fn set_print_options(options: PrinterOptions) { pub fn set_print_options(options: PrinterOptions) {
*PRINT_OPTS.lock().unwrap() = options *PRINT_OPTS.lock().unwrap() = options
} }
@ -122,26 +117,6 @@ pub fn set_print_options_full() {
} }
} }
pub fn set_line_width(line_width: usize) {
PRINT_OPTS.lock().unwrap().line_width = line_width
}
pub fn set_precision(precision: usize) {
PRINT_OPTS.lock().unwrap().precision = precision
}
pub fn set_edge_items(edge_items: usize) {
PRINT_OPTS.lock().unwrap().edge_items = edge_items
}
pub fn set_threshold(threshold: usize) {
PRINT_OPTS.lock().unwrap().threshold = threshold
}
pub fn set_sci_mode(sci_mode: Option<bool>) {
PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
}
struct FmtSize { struct FmtSize {
current_size: usize, current_size: usize,
} }

View File

@ -23,15 +23,7 @@ pub enum DType {
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct DTypeParseError(String); pub struct DTypeParseError;
impl std::fmt::Display for DTypeParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "cannot parse '{}' as a dtype", self.0)
}
}
impl std::error::Error for DTypeParseError {}
impl std::str::FromStr for DType { impl std::str::FromStr for DType {
type Err = DTypeParseError; type Err = DTypeParseError;
@ -44,7 +36,7 @@ impl std::str::FromStr for DType {
"f16" => Ok(Self::F16), "f16" => Ok(Self::F16),
"f32" => Ok(Self::F32), "f32" => Ok(Self::F32),
"f64" => Ok(Self::F64), "f64" => Ok(Self::F64),
_ => Err(DTypeParseError(s.to_string())), _ => Err(DTypeParseError),
} }
} }
} }

View File

@ -154,19 +154,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
fn copy2d(
&self,
_: &mut Self,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> { fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }

View File

@ -166,19 +166,6 @@ impl crate::backend::BackendStorage for MetalStorage {
Err(Error::NotCompiledWithMetalSupport) Err(Error::NotCompiledWithMetalSupport)
} }
fn copy2d(
&self,
_: &mut Self,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> { fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport) Err(Error::NotCompiledWithMetalSupport)
} }

View File

@ -70,7 +70,7 @@ impl Layout {
self.shape.is_fortran_contiguous(&self.stride) self.shape.is_fortran_contiguous(&self.stride)
} }
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> { pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
let dims = self.shape().dims(); let dims = self.shape().dims();
if dim >= dims.len() { if dim >= dims.len() {
Err(Error::DimOutOfRange { Err(Error::DimOutOfRange {
@ -99,7 +99,7 @@ impl Layout {
}) })
} }
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> { pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
let rank = self.shape.rank(); let rank = self.shape.rank();
if rank <= dim1 || rank <= dim2 { if rank <= dim1 || rank <= dim2 {
Err(Error::UnexpectedNumberOfDims { Err(Error::UnexpectedNumberOfDims {
@ -120,7 +120,7 @@ impl Layout {
}) })
} }
pub fn permute(&self, idxs: &[usize]) -> Result<Self> { pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
let is_permutation = let is_permutation =
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i)); idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
if !is_permutation { if !is_permutation {

View File

@ -67,13 +67,12 @@ pub mod shape;
mod storage; mod storage;
mod strided_index; mod strided_index;
mod tensor; mod tensor;
mod tensor_cat;
pub mod test_utils; pub mod test_utils;
pub mod utils; pub mod utils;
mod variable; mod variable;
pub use cpu_backend::CpuStorage; pub use cpu_backend::CpuStorage;
pub use device::{Device, DeviceLocation, NdArray}; pub use device::{Device, DeviceLocation};
pub use dtype::{DType, FloatDType, IntDType, WithDType}; pub use dtype::{DType, FloatDType, IntDType, WithDType};
pub use error::{Error, Result}; pub use error::{Error, Result};
pub use indexer::IndexOp; pub use indexer::IndexOp;
@ -130,15 +129,6 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
} }
} }
impl<M: Module> Module for Option<&M> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
None => Ok(xs.clone()),
Some(m) => m.forward(xs),
}
}
}
// A trait defining a module with forward method using a single tensor argument and a flag to // A trait defining a module with forward method using a single tensor argument and a flag to
// separate the training and evaluation behaviors. // separate the training and evaluation behaviors.
pub trait ModuleT { pub trait ModuleT {

View File

@ -7,9 +7,8 @@ use candle_metal_kernels::Kernels;
use metal; use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path; use std::path::Path;
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard, TryLockError}; use std::sync::{Arc, RwLock, TryLockError};
/// Simple way to catch lock error without /// Simple way to catch lock error without
/// depending on T /// depending on T
@ -60,8 +59,7 @@ impl From<String> for MetalError {
} }
} }
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>; type AllocatedBuffers = Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>;
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
#[derive(Clone)] #[derive(Clone)]
pub struct MetalDevice { pub struct MetalDevice {
@ -69,7 +67,7 @@ pub struct MetalDevice {
device: metal::Device, device: metal::Device,
/// Single command queue for the entire device. /// Single command queue for the entire device.
command_queue: CommandQueue, command_queue: metal::CommandQueue,
/// One command buffer at a time. /// One command buffer at a time.
/// The scheduler works by allowing multiple /// The scheduler works by allowing multiple
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) /// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
@ -79,16 +77,21 @@ pub struct MetalDevice {
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered /// Despite what the documentation says, command buffers are NOT ordered. They are ordered
/// for their START time, but there's no guarantee that command buffer1 will finish before /// for their START time, but there's no guarantee that command buffer1 will finish before
/// command buffer2 starts (or there are metal bugs there) /// command buffer2 starts (or there are metal bugs there)
command_buffer: Arc<RwLock<CommandBuffer>>, command_buffer: Arc<RwLock<metal::CommandBuffer>>,
/// Keeps track of the current amount of compute command encoders on the current /// Keeps track of the current amount of compute command encoders on the current
/// command buffer /// command buffer
/// Arc, RwLock because of the interior mutability. /// Arc, RwLock because of the interior mutability.
command_buffer_index: Arc<RwLock<usize>>, command_buffer_index: Arc<RwLock<usize>>,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
compute_per_buffer: usize, compute_per_buffer: usize,
/// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the
/// execution order to be linear.
/// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
/// compute graph.
fence: metal::Fence,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`] /// Heavily used by [`candle_metal_kernels`], both fences need to match
kernels: Arc<Kernels>, kernels: Arc<candle_metal_kernels::Kernels>,
/// Simple allocator struct. /// Simple allocator struct.
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over. /// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting /// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
@ -100,11 +103,9 @@ pub struct MetalDevice {
/// operation, so that this buffer is not being used by another kernel at the same time. /// operation, so that this buffer is not being used by another kernel at the same time.
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things. /// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
/// ///
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers /// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
/// (strong_count = 1). /// (strong_count = 1).
buffers: AllocatedBuffers, buffers: AllocatedBuffers,
/// Seed for random number generation.
seed: Arc<Mutex<Buffer>>,
} }
impl std::fmt::Debug for MetalDevice { impl std::fmt::Debug for MetalDevice {
@ -146,8 +147,6 @@ impl MetalDevice {
command_buffer = self.command_queue.new_command_buffer().to_owned(); command_buffer = self.command_queue.new_command_buffer().to_owned();
*command_buffer_lock = command_buffer.clone(); *command_buffer_lock = command_buffer.clone();
*index = 0; *index = 0;
self.drop_unused_buffers()?;
} }
*index += 1; *index += 1;
Ok(command_buffer) Ok(command_buffer)
@ -166,7 +165,6 @@ impl MetalDevice {
command_buffer.commit(); command_buffer.commit();
command_buffer.wait_until_completed(); command_buffer.wait_until_completed();
*command_buffer = self.command_queue.new_command_buffer().to_owned(); *command_buffer = self.command_queue.new_command_buffer().to_owned();
Ok(()) Ok(())
} }
@ -203,80 +201,41 @@ impl MetalDevice {
} }
/// Creates a new buffer from data. /// Creates a new buffer from data.
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
/// ///
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes) /// This method will block the computation because of the
/// allocates the buffer and copies over the existing data before returning the MTLBuffer. /// lack of lifetime management through the GPU.
/// Internal comment for technical details.
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> { pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
let size = core::mem::size_of_val(data) as NSUInteger; let size = core::mem::size_of_val(data) as NSUInteger;
let new_buffer = self.device.new_buffer_with_data( let tmp = self.device.new_buffer_with_data(
data.as_ptr() as *const c_void, data.as_ptr() as *const core::ffi::c_void,
size, size,
MTLResourceOptions::StorageModeManaged, metal::MTLResourceOptions::StorageModeManaged,
); );
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; let real = self.allocate_buffer(
let subbuffers = buffers size,
.entry((size, MTLResourceOptions::StorageModeManaged)) metal::MTLResourceOptions::StorageModePrivate,
.or_insert(vec![]); "with_data",
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
Ok(new_buffer)
}
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
let buffer = self.allocate_buffer(
size_in_bytes as NSUInteger,
MTLResourceOptions::StorageModePrivate,
"allocate_zeros",
)?; )?;
let command_buffer = self.command_buffer()?; let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros"); command_buffer.set_label("with_data");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
blit.fill_buffer( blit.wait_for_fence(&self.fence);
&buffer, blit.set_label("with_data_blit");
metal::NSRange { blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
location: 0, blit.update_fence(&self.fence);
length: buffer.length(),
},
0,
);
blit.end_encoding(); blit.end_encoding();
Ok(buffer)
}
fn find_available_buffer( // This is necessary, for mmaped safetensors
&self, // Because of the unsafe slice cast we're doing.
size: NSUInteger, // The slice might not live long enough for metal
option: MTLResourceOptions, // To actually fill the GPU buffer.
buffers: &RwLockWriteGuard<BufferMap>, // Putting this wait forces the GPU buffer to be filled
) -> Option<Arc<Buffer>> { // with the actual data allowing the CPU storage todo
let mut best_buffer: Option<&Arc<Buffer>> = None; // deallocate properly.
let mut best_buffer_size: NSUInteger = NSUInteger::MAX; self.wait_until_completed()?;
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() { Ok(real)
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
for sub in subbuffers {
if Arc::strong_count(sub) == 1 {
best_buffer = Some(sub);
best_buffer_size = *buffer_size;
}
}
}
}
return best_buffer.map(|b| b.clone());
}
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(*s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(())
} }
/// The critical allocator algorithm /// The critical allocator algorithm
@ -287,18 +246,24 @@ impl MetalDevice {
_name: &str, _name: &str,
) -> Result<Arc<Buffer>> { ) -> Result<Arc<Buffer>> {
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
// Cloning also ensures we increment the strong count
return Ok(b.clone());
}
let size = buf_size(size);
let subbuffers = buffers.entry((size, option)).or_insert(vec![]); let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
for sub in &mut *subbuffers {
if Arc::strong_count(sub) == 1 {
return Ok(sub.clone());
}
}
let new_buffer = self.device.new_buffer(size as NSUInteger, option); let new_buffer = self.device.new_buffer(size as NSUInteger, option);
let new_buffer = Arc::new(new_buffer); let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone()); subbuffers.push(new_buffer.clone());
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(new_buffer) Ok(new_buffer)
} }
@ -323,8 +288,6 @@ pub struct MetalStorage {
buffer: Arc<metal::Buffer>, buffer: Arc<metal::Buffer>,
/// a reference to the device owning this buffer /// a reference to the device owning this buffer
device: MetalDevice, device: MetalDevice,
/// The count of allocated elements in the buffer
count: usize,
/// The dtype is kept since buffers are untyped. /// The dtype is kept since buffers are untyped.
dtype: DType, dtype: DType,
} }
@ -345,14 +308,35 @@ impl BackendStorage for MetalStorage {
} }
fn to_cpu_storage(&self) -> Result<CpuStorage> { fn to_cpu_storage(&self) -> Result<CpuStorage> {
let length = self.buffer.length() as usize;
let size = self.dtype.size_in_bytes();
if length % size != 0 {
crate::bail!(
"The Metal buffer length is not aligned with dtype {:?}",
self.dtype
);
}
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
{
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.wait_for_fence(&self.device.fence);
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.update_fence(&self.device.fence);
blit.end_encoding();
}
self.device.wait_until_completed()?;
match self.dtype { match self.dtype {
DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)), DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))),
DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)), DType::U32 => Ok(CpuStorage::U32(read_to_vec(&buffer, length / size))),
DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)), DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))),
DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)), DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))),
DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)), DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))),
DType::F32 => Ok(CpuStorage::F32(self.to_cpu()?)), DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))),
DType::F64 => Ok(CpuStorage::F64(self.to_cpu()?)), DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))),
} }
} }
@ -369,7 +353,6 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "affine_f32", DType::F32 => "affine_f32",
DType::F16 => "affine_f16", DType::F16 => "affine_f16",
DType::BF16 => "affine_bf16",
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"), dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
}; };
candle_metal_kernels::call_affine( candle_metal_kernels::call_affine(
@ -388,7 +371,6 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "affine_f32_strided", DType::F32 => "affine_f32_strided",
DType::F16 => "affine_f16_strided", DType::F16 => "affine_f16_strided",
DType::BF16 => "affine_bf16_strided",
dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"), dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"),
}; };
candle_metal_kernels::call_affine_strided( candle_metal_kernels::call_affine_strided(
@ -406,7 +388,7 @@ impl BackendStorage for MetalStorage {
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} }
Ok(Self::new(buffer, device.clone(), el, dtype)) Ok(Self::new(buffer, device.clone(), dtype))
} }
fn powf(&self, layout: &Layout, pow: f64) -> Result<Self> { fn powf(&self, layout: &Layout, pow: f64) -> Result<Self> {
@ -422,7 +404,6 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "powf_f32", DType::F32 => "powf_f32",
DType::F16 => "powf_f16", DType::F16 => "powf_f16",
DType::BF16 => "powf_bf16",
dtype => crate::bail!("Metal contiguous powf {dtype:?} not implemented"), dtype => crate::bail!("Metal contiguous powf {dtype:?} not implemented"),
}; };
candle_metal_kernels::call_powf( candle_metal_kernels::call_powf(
@ -440,7 +421,6 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "powf_f32_strided", DType::F32 => "powf_f32_strided",
DType::F16 => "powf_f16_strided", DType::F16 => "powf_f16_strided",
DType::BF16 => "powf_bf16_strided",
dtype => crate::bail!("Metal strided powf {dtype:?} not implemented"), dtype => crate::bail!("Metal strided powf {dtype:?} not implemented"),
}; };
candle_metal_kernels::call_powf_strided( candle_metal_kernels::call_powf_strided(
@ -457,7 +437,7 @@ impl BackendStorage for MetalStorage {
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} }
Ok(Self::new(buffer, device.clone(), el, dtype)) Ok(Self::new(buffer, device.clone(), dtype))
} }
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> { fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
@ -473,7 +453,6 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "elu_f32", DType::F32 => "elu_f32",
DType::F16 => "elu_f16", DType::F16 => "elu_f16",
DType::BF16 => "elu_bf16",
dtype => crate::bail!("Metal contiguous elu {dtype:?} not implemented"), dtype => crate::bail!("Metal contiguous elu {dtype:?} not implemented"),
}; };
candle_metal_kernels::call_elu( candle_metal_kernels::call_elu(
@ -491,7 +470,6 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "elu_f32_strided", DType::F32 => "elu_f32_strided",
DType::F16 => "elu_f16_strided", DType::F16 => "elu_f16_strided",
DType::BF16 => "elu_bf16_strided",
dtype => crate::bail!("Metal strided elu {dtype:?} not implemented"), dtype => crate::bail!("Metal strided elu {dtype:?} not implemented"),
}; };
candle_metal_kernels::call_elu_strided( candle_metal_kernels::call_elu_strided(
@ -508,7 +486,7 @@ impl BackendStorage for MetalStorage {
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} }
Ok(Self::new(buffer, device.clone(), el, dtype)) Ok(Self::new(buffer, device.clone(), dtype))
} }
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> { fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
@ -586,7 +564,7 @@ impl BackendStorage for MetalStorage {
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
Ok(Self::new(buffer, device, dst_el, dtype)) Ok(Self::new(buffer, device, dtype))
} }
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> { fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
@ -612,26 +590,14 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::U8) => "cast_u32_u8", (DType::U32, DType::U8) => "cast_u32_u8",
(DType::U32, DType::I64) => "cast_u32_i64", (DType::U32, DType::I64) => "cast_u32_i64",
(DType::U32, DType::BF16) => "cast_u32_bf16",
(DType::U8, DType::U32) => "cast_u8_u32", (DType::U8, DType::U32) => "cast_u8_u32",
(DType::U8, DType::F32) => "cast_u8_f32", (DType::U8, DType::F32) => "cast_u8_f32",
(DType::U8, DType::I64) => "cast_u8_i64", (DType::U8, DType::I64) => "cast_u8_i64",
(DType::U8, DType::BF16) => "cast_u8_bf16",
(DType::F32, DType::F16) => "cast_f32_f16", (DType::F32, DType::F16) => "cast_f32_f16",
(DType::F32, DType::BF16) => "cast_f32_bf16",
(DType::I64, DType::F32) => "cast_i64_f32",
(DType::F16, DType::BF16) => "cast_f16_bf16",
(DType::F16, DType::F32) => "cast_f16_f32", (DType::F16, DType::F32) => "cast_f16_f32",
(DType::I64, DType::F32) => "cast_i64_f32",
(DType::BF16, DType::U8) => "cast_bf16_u8", (DType::F32, DType::BF16) => "cast_f32_bf16",
(DType::BF16, DType::U32) => "cast_bf16_u32",
(DType::BF16, DType::F16) => "cast_bf16_f16",
(DType::BF16, DType::F32) => "cast_bf16_f32", (DType::BF16, DType::F32) => "cast_bf16_f32",
(left, right) => { (left, right) => {
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented") crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
} }
@ -678,7 +644,7 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} }
command_buffer.set_label("to_dtype"); command_buffer.set_label("to_dtype");
Ok(Self::new(buffer, device.clone(), el_count, dtype)) Ok(Self::new(buffer, device.clone(), dtype))
} }
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> { fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
@ -703,14 +669,12 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F32) => contiguous::gelu::FLOAT, ("ugelu", DType::F32) => contiguous::gelu::FLOAT,
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
("uerf", DType::F32) => contiguous::erf::FLOAT, ("uerf", DType::F32) => contiguous::erf::FLOAT,
("usilu", DType::F32) => contiguous::silu::FLOAT,
("uabs", DType::F32) => contiguous::abs::FLOAT, ("uabs", DType::F32) => contiguous::abs::FLOAT,
("uceil", DType::F32) => contiguous::ceil::FLOAT, ("uceil", DType::F32) => contiguous::ceil::FLOAT,
("ufloor", DType::F32) => contiguous::floor::FLOAT, ("ufloor", DType::F32) => contiguous::floor::FLOAT,
("uround", DType::F32) => contiguous::round::FLOAT, ("uround", DType::F32) => contiguous::round::FLOAT,
("urecip", DType::F32) => contiguous::recip::FLOAT, ("urecip", DType::F32) => contiguous::recip::FLOAT,
("utanh", DType::F32) => contiguous::tanh::FLOAT, ("utanh", DType::F32) => contiguous::tanh::FLOAT,
("urelu", DType::F32) => contiguous::relu::FLOAT,
("ucos", DType::F16) => contiguous::cos::HALF, ("ucos", DType::F16) => contiguous::cos::HALF,
("usin", DType::F16) => contiguous::sin::HALF, ("usin", DType::F16) => contiguous::sin::HALF,
("usqr", DType::F16) => contiguous::sqr::HALF, ("usqr", DType::F16) => contiguous::sqr::HALF,
@ -721,14 +685,12 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F16) => contiguous::gelu::HALF, ("ugelu", DType::F16) => contiguous::gelu::HALF,
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
("uerf", DType::F16) => contiguous::erf::HALF, ("uerf", DType::F16) => contiguous::erf::HALF,
("usilu", DType::F16) => contiguous::silu::HALF,
("uabs", DType::F16) => contiguous::abs::HALF, ("uabs", DType::F16) => contiguous::abs::HALF,
("uceil", DType::F16) => contiguous::ceil::HALF, ("uceil", DType::F16) => contiguous::ceil::HALF,
("ufloor", DType::F16) => contiguous::floor::HALF, ("ufloor", DType::F16) => contiguous::floor::HALF,
("uround", DType::F16) => contiguous::round::HALF, ("uround", DType::F16) => contiguous::round::HALF,
("urecip", DType::F16) => contiguous::recip::HALF, ("urecip", DType::F16) => contiguous::recip::HALF,
("utanh", DType::F16) => contiguous::tanh::HALF, ("utanh", DType::F16) => contiguous::tanh::HALF,
("urelu", DType::F16) => contiguous::relu::HALF,
(name, dtype) => { (name, dtype) => {
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
} }
@ -756,13 +718,10 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F32) => strided::gelu::FLOAT, ("ugelu", DType::F32) => strided::gelu::FLOAT,
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
("uerf", DType::F32) => strided::erf::FLOAT, ("uerf", DType::F32) => strided::erf::FLOAT,
("usilu", DType::F32) => strided::silu::FLOAT,
("uabs", DType::F32) => strided::abs::FLOAT, ("uabs", DType::F32) => strided::abs::FLOAT,
("uceil", DType::F32) => strided::ceil::FLOAT, ("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT, ("ufloor", DType::F32) => strided::floor::FLOAT,
("urelu", DType::F32) => strided::relu::FLOAT,
("uround", DType::F32) => strided::round::FLOAT, ("uround", DType::F32) => strided::round::FLOAT,
("utanh", DType::F32) => strided::tanh::FLOAT,
("ucos", DType::F16) => strided::cos::HALF, ("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF, ("usin", DType::F16) => strided::sin::HALF,
("usqr", DType::F16) => strided::sqr::HALF, ("usqr", DType::F16) => strided::sqr::HALF,
@ -773,13 +732,10 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F16) => strided::gelu::HALF, ("ugelu", DType::F16) => strided::gelu::HALF,
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
("uerf", DType::F16) => strided::erf::HALF, ("uerf", DType::F16) => strided::erf::HALF,
("usilu", DType::F16) => strided::silu::HALF,
("uabs", DType::F16) => strided::abs::HALF, ("uabs", DType::F16) => strided::abs::HALF,
("uceil", DType::F16) => strided::ceil::HALF, ("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF, ("ufloor", DType::F16) => strided::floor::HALF,
("urelu", DType::F16) => strided::relu::HALF,
("uround", DType::F16) => strided::round::HALF, ("uround", DType::F16) => strided::round::HALF,
("utanh", DType::F16) => strided::tanh::HALF,
(name, dtype) => { (name, dtype) => {
crate::bail!("Metal strided unary {name} {dtype:?} not implemented") crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
} }
@ -798,7 +754,7 @@ impl BackendStorage for MetalStorage {
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} }
Ok(Self::new(buffer, device.clone(), el_count, dtype)) Ok(Self::new(buffer, device.clone(), dtype))
} }
fn binary_impl<B: BinaryOpT>( fn binary_impl<B: BinaryOpT>(
@ -834,7 +790,6 @@ impl BackendStorage for MetalStorage {
} }
let name = match (self.dtype, t.dtype()) { let name = match (self.dtype, t.dtype()) {
(DType::U8, DType::F32) => "where_u8_f32", (DType::U8, DType::F32) => "where_u8_f32",
(DType::U8, DType::BF16) => "where_u8_bf16",
(DType::U8, DType::F16) => "where_u8_f16", (DType::U8, DType::F16) => "where_u8_f16",
(DType::U8, DType::I64) => "where_u8_i64", (DType::U8, DType::I64) => "where_u8_i64",
(DType::U8, DType::U32) => "where_u8_u32", (DType::U8, DType::U32) => "where_u8_u32",
@ -853,13 +808,13 @@ impl BackendStorage for MetalStorage {
layout.start_offset() * self.dtype.size_in_bytes(), layout.start_offset() * self.dtype.size_in_bytes(),
), ),
&t.buffer, &t.buffer,
(t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), (&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
&f.buffer, &f.buffer,
(f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), (&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
&buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
Ok(Self::new(buffer, device, el, dtype)) Ok(Self::new(buffer, device, dtype))
} }
fn conv1d( fn conv1d(
@ -904,7 +859,6 @@ impl BackendStorage for MetalStorage {
let col = Self { let col = Self {
buffer: dst, buffer: dst,
device, device,
count: dst_el,
dtype: self.dtype, dtype: self.dtype,
}; };
let l_out = params.l_out(); let l_out = params.l_out();
@ -989,7 +943,6 @@ impl BackendStorage for MetalStorage {
let col = Self { let col = Self {
buffer: dst, buffer: dst,
device, device,
count: dst_el,
dtype: self.dtype, dtype: self.dtype,
}; };
let h_out = params.out_h(); let h_out = params.out_h();
@ -1075,7 +1028,7 @@ impl BackendStorage for MetalStorage {
&buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) Ok(Self::new(buffer, self.device.clone(), self.dtype))
} }
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> { fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
@ -1109,7 +1062,7 @@ impl BackendStorage for MetalStorage {
&buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
Ok(Self::new(buffer, device.clone(), dst_el, dtype)) Ok(Self::new(buffer, device.clone(), dtype))
} }
fn scatter_add( fn scatter_add(
@ -1132,15 +1085,7 @@ impl BackendStorage for MetalStorage {
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
}; };
let name = match (ids.dtype, self.dtype) { let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::F32) => "sa_u8_f32",
(DType::U8, DType::F16) => "sa_u8_f16",
(DType::U8, DType::BF16) => "sa_u8_bf16",
(DType::U32, DType::F32) => "sa_u32_f32", (DType::U32, DType::F32) => "sa_u32_f32",
(DType::U32, DType::F16) => "sa_u32_f16",
(DType::U32, DType::BF16) => "sa_u32_bf16",
(DType::I64, DType::F32) => "sa_i64_f32",
(DType::I64, DType::F16) => "sa_i64_f16",
(DType::I64, DType::BF16) => "sa_i64_bf16",
_ => Err(MetalError::UnexpectedDType { _ => Err(MetalError::UnexpectedDType {
msg: "scatter-add ids should be u8/u32/i64", msg: "scatter-add ids should be u8/u32/i64",
expected: DType::U32, expected: DType::U32,
@ -1182,12 +1127,8 @@ impl BackendStorage for MetalStorage {
let device = self.device(); let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?; let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) { let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::BF16) => "is_u8_bf16",
(DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16", (DType::U32, DType::F16) => "is_u32_f16",
(DType::U32, DType::BF16) => "is_u32_bf16",
(left, right) => { (left, right) => {
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented") crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
} }
@ -1206,7 +1147,7 @@ impl BackendStorage for MetalStorage {
&buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
Ok(Self::new(buffer, device.clone(), dst_el, dtype)) Ok(Self::new(buffer, device.clone(), dtype))
} }
fn index_add( fn index_add(
@ -1288,73 +1229,7 @@ impl BackendStorage for MetalStorage {
&buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
Ok(Self::new( Ok(Self::new(buffer, self.device.clone(), self.dtype()))
buffer,
self.device.clone(),
b * m * n,
self.dtype(),
))
}
fn copy2d(
&self,
dst: &mut Self,
d1: usize,
d2: usize,
src_s: usize,
dst_s: usize,
src_o: usize,
dst_o: usize,
) -> Result<()> {
if self.dtype() != dst.dtype() {
crate::bail!(
"copy2d with inconsistent dtypes {:?} {:?}",
self.dtype(),
dst.dtype()
)
}
let command_buffer = self.device.command_buffer()?;
if src_s == d2 && dst_s == d2 {
command_buffer.set_label("copy2d_contiguous");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("copy2d_contiguous");
let src_offset = (src_o * self.dtype.size_in_bytes()) as NSUInteger;
let length = (d1 * d2 * self.dtype.size_in_bytes()) as NSUInteger;
let dst_offset = (dst_o * dst.dtype().size_in_bytes()) as NSUInteger;
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
blit.end_encoding();
} else {
let el_count = d1 * d2;
if el_count == 0 {
return Ok(());
}
let kernel_name = match self.dtype {
DType::F32 => candle_metal_kernels::copy2d::FLOAT,
DType::F16 => candle_metal_kernels::copy2d::HALF,
DType::BF16 => candle_metal_kernels::copy2d::BFLOAT,
DType::I64 => candle_metal_kernels::copy2d::I64,
DType::U32 => candle_metal_kernels::copy2d::U32,
DType::U8 => candle_metal_kernels::copy2d::U8,
dtype => crate::bail!("Metal copy2d {dtype:?} not implemented"),
};
candle_metal_kernels::call_copy2d(
&self.device.device,
&command_buffer,
&self.device.kernels,
kernel_name,
&self.buffer,
&dst.buffer,
d1,
d2,
src_s,
dst_s,
src_o * self.dtype.size_in_bytes(),
dst_o * self.dtype.size_in_bytes(),
)
.map_err(MetalError::from)?;
command_buffer.set_label("copy2d");
}
Ok(())
} }
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
@ -1403,11 +1278,10 @@ impl BackendStorage for MetalStorage {
} }
impl MetalStorage { impl MetalStorage {
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, count: usize, dtype: DType) -> Self { pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self {
Self { Self {
buffer, buffer,
device, device,
count,
dtype, dtype,
} }
} }
@ -1444,7 +1318,6 @@ impl MetalStorage {
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8), ("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8), ("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8), ("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
("add", DType::F16) => (contiguous::add::HALF, self.dtype), ("add", DType::F16) => (contiguous::add::HALF, self.dtype),
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype), ("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype), ("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
@ -1455,18 +1328,6 @@ impl MetalStorage {
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8), ("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8), ("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8), ("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype),
("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype),
("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype),
("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype),
("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8),
("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8),
("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8),
("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8),
("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8),
("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8),
("add", DType::I64) => (contiguous::add::I64, self.dtype), ("add", DType::I64) => (contiguous::add::I64, self.dtype),
("sub", DType::I64) => (contiguous::sub::I64, self.dtype), ("sub", DType::I64) => (contiguous::sub::I64, self.dtype),
("mul", DType::I64) => (contiguous::mul::I64, self.dtype), ("mul", DType::I64) => (contiguous::mul::I64, self.dtype),
@ -1477,7 +1338,6 @@ impl MetalStorage {
("lt", DType::I64) => (contiguous::lt::I64, DType::U8), ("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
("ge", DType::I64) => (contiguous::ge::I64, DType::U8), ("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
("gt", DType::I64) => (contiguous::gt::I64, DType::U8), ("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
("add", DType::U32) => (contiguous::add::U32, self.dtype), ("add", DType::U32) => (contiguous::add::U32, self.dtype),
("sub", DType::U32) => (contiguous::sub::U32, self.dtype), ("sub", DType::U32) => (contiguous::sub::U32, self.dtype),
("mul", DType::U32) => (contiguous::mul::U32, self.dtype), ("mul", DType::U32) => (contiguous::mul::U32, self.dtype),
@ -1488,7 +1348,6 @@ impl MetalStorage {
("lt", DType::U32) => (contiguous::lt::U32, DType::U8), ("lt", DType::U32) => (contiguous::lt::U32, DType::U8),
("ge", DType::U32) => (contiguous::ge::U32, DType::U8), ("ge", DType::U32) => (contiguous::ge::U32, DType::U8),
("gt", DType::U32) => (contiguous::gt::U32, DType::U8), ("gt", DType::U32) => (contiguous::gt::U32, DType::U8),
("add", DType::U8) => (contiguous::add::U8, self.dtype), ("add", DType::U8) => (contiguous::add::U8, self.dtype),
("sub", DType::U8) => (contiguous::sub::U8, self.dtype), ("sub", DType::U8) => (contiguous::sub::U8, self.dtype),
("mul", DType::U8) => (contiguous::mul::U8, self.dtype), ("mul", DType::U8) => (contiguous::mul::U8, self.dtype),
@ -1499,7 +1358,6 @@ impl MetalStorage {
("lt", DType::U8) => (contiguous::lt::U8, DType::U8), ("lt", DType::U8) => (contiguous::lt::U8, DType::U8),
("ge", DType::U8) => (contiguous::ge::U8, DType::U8), ("ge", DType::U8) => (contiguous::ge::U8, DType::U8),
("gt", DType::U8) => (contiguous::gt::U8, DType::U8), ("gt", DType::U8) => (contiguous::gt::U8, DType::U8),
(name, dtype) => { (name, dtype) => {
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented") crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
} }
@ -1533,7 +1391,6 @@ impl MetalStorage {
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8), ("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8), ("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8), ("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
("badd", DType::F16) => (strided::add::HALF, self.dtype), ("badd", DType::F16) => (strided::add::HALF, self.dtype),
("bsub", DType::F16) => (strided::sub::HALF, self.dtype), ("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
("bmul", DType::F16) => (strided::mul::HALF, self.dtype), ("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
@ -1546,20 +1403,6 @@ impl MetalStorage {
("lt", DType::F16) => (strided::lt::HALF, DType::U8), ("lt", DType::F16) => (strided::lt::HALF, DType::U8),
("ge", DType::F16) => (strided::ge::HALF, DType::U8), ("ge", DType::F16) => (strided::ge::HALF, DType::U8),
("gt", DType::F16) => (strided::gt::HALF, DType::U8), ("gt", DType::F16) => (strided::gt::HALF, DType::U8),
("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype),
("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype),
("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype),
("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype),
("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype),
("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype),
("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8),
("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8),
("le", DType::BF16) => (strided::le::BFLOAT, DType::U8),
("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8),
("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8),
("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8),
("badd", DType::I64) => (strided::add::I64, self.dtype), ("badd", DType::I64) => (strided::add::I64, self.dtype),
("bsub", DType::I64) => (strided::sub::I64, self.dtype), ("bsub", DType::I64) => (strided::sub::I64, self.dtype),
("bmul", DType::I64) => (strided::mul::I64, self.dtype), ("bmul", DType::I64) => (strided::mul::I64, self.dtype),
@ -1572,7 +1415,6 @@ impl MetalStorage {
("lt", DType::I64) => (strided::lt::I64, DType::U8), ("lt", DType::I64) => (strided::lt::I64, DType::U8),
("ge", DType::I64) => (strided::ge::I64, DType::U8), ("ge", DType::I64) => (strided::ge::I64, DType::U8),
("gt", DType::I64) => (strided::gt::I64, DType::U8), ("gt", DType::I64) => (strided::gt::I64, DType::U8),
("badd", DType::U32) => (strided::add::U32, self.dtype), ("badd", DType::U32) => (strided::add::U32, self.dtype),
("bsub", DType::U32) => (strided::sub::U32, self.dtype), ("bsub", DType::U32) => (strided::sub::U32, self.dtype),
("bmul", DType::U32) => (strided::mul::U32, self.dtype), ("bmul", DType::U32) => (strided::mul::U32, self.dtype),
@ -1585,7 +1427,6 @@ impl MetalStorage {
("lt", DType::U32) => (strided::lt::U32, DType::U8), ("lt", DType::U32) => (strided::lt::U32, DType::U8),
("ge", DType::U32) => (strided::ge::U32, DType::U8), ("ge", DType::U32) => (strided::ge::U32, DType::U8),
("gt", DType::U32) => (strided::gt::U32, DType::U8), ("gt", DType::U32) => (strided::gt::U32, DType::U8),
("badd", DType::U8) => (strided::add::U8, self.dtype), ("badd", DType::U8) => (strided::add::U8, self.dtype),
("bsub", DType::U8) => (strided::sub::U8, self.dtype), ("bsub", DType::U8) => (strided::sub::U8, self.dtype),
("bmul", DType::U8) => (strided::mul::U8, self.dtype), ("bmul", DType::U8) => (strided::mul::U8, self.dtype),
@ -1598,7 +1439,6 @@ impl MetalStorage {
("lt", DType::U8) => (strided::lt::U8, DType::U8), ("lt", DType::U8) => (strided::lt::U8, DType::U8),
("ge", DType::U8) => (strided::ge::U8, DType::U8), ("ge", DType::U8) => (strided::ge::U8, DType::U8),
("gt", DType::U8) => (strided::gt::U8, DType::U8), ("gt", DType::U8) => (strided::gt::U8, DType::U8),
(name, dtype) => { (name, dtype) => {
crate::bail!("Metal strided binary {name} {dtype:?} not implemented") crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
} }
@ -1622,23 +1462,7 @@ impl MetalStorage {
(buffer, dtype) (buffer, dtype)
}; };
command_buffer.set_label("binary"); command_buffer.set_label("binary");
Ok(Self::new(buffer, device.clone(), el_count, dtype)) Ok(Self::new(buffer, device.clone(), dtype))
}
pub(crate) fn to_cpu<T: Clone>(&self) -> Result<Vec<T>> {
let size = (self.count * self.dtype.size_in_bytes()) as NSUInteger;
let buffer = self.device.new_buffer_managed(size)?;
{
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, size);
blit.end_encoding();
}
self.device.wait_until_completed()?;
Ok(read_to_vec(&buffer, self.count))
} }
} }
@ -1652,29 +1476,29 @@ impl BackendDevice for MetalDevice {
command_buffer.enqueue(); command_buffer.enqueue();
let command_buffer = Arc::new(RwLock::new(command_buffer)); let command_buffer = Arc::new(RwLock::new(command_buffer));
let command_buffer_index = Arc::new(RwLock::new(0)); let command_buffer_index = Arc::new(RwLock::new(0));
let kernels = Arc::new(Kernels::new()); let fence = device.new_fence();
let kernels = Arc::new(Kernels::new(fence.clone()));
let buffers = Arc::new(RwLock::new(HashMap::new())); let buffers = Arc::new(RwLock::new(HashMap::new()));
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?, Ok(val) => val.parse()?,
_ => 50, _ => 20,
}; };
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
4,
MTLResourceOptions::StorageModeManaged,
)));
Ok(Self { Ok(Self {
device, device,
fence,
command_queue, command_queue,
command_buffer, command_buffer,
command_buffer_index, command_buffer_index,
compute_per_buffer, compute_per_buffer,
buffers, buffers,
kernels, kernels,
seed,
}) })
} }
fn set_seed(&self, _seed: u64) -> Result<()> {
crate::bail!("Metal set_seed not implemented")
}
fn location(&self) -> crate::DeviceLocation { fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Metal { crate::DeviceLocation::Metal {
gpu_id: self.registry_id() as usize, gpu_id: self.registry_id() as usize,
@ -1686,14 +1510,22 @@ impl BackendDevice for MetalDevice {
} }
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> { fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let size = shape.elem_count() * dtype.size_in_bytes(); let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?;
let buffer = self.allocate_zeros(size)?; let command_buffer = self.command_buffer()?;
Ok(MetalStorage::new( command_buffer.set_label("zeros");
buffer, let blit = command_buffer.new_blit_command_encoder();
self.clone(), blit.wait_for_fence(&self.fence);
shape.elem_count(), blit.fill_buffer(
dtype, &buffer,
)) metal::NSRange {
location: 0,
length: buffer.length(),
},
0,
);
blit.update_fence(&self.fence);
blit.end_encoding();
Ok(MetalStorage::new(buffer, self.clone(), dtype))
} }
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> { fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
@ -1703,57 +1535,28 @@ impl BackendDevice for MetalDevice {
} }
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> { fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
let (count, buffer) = match storage { let buffer = match storage {
CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::U8(storage) => self.new_buffer_with_data(storage),
CpuStorage::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::U32(storage) => self.new_buffer_with_data(storage),
CpuStorage::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::I64(storage) => self.new_buffer_with_data(storage),
CpuStorage::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::BF16(storage) => self.new_buffer_with_data(storage),
CpuStorage::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
CpuStorage::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
CpuStorage::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
}; }?;
Ok(Self::Storage::new( Ok(Self::Storage::new(buffer, self.clone(), storage.dtype()))
buffer?,
self.clone(),
count,
storage.dtype(),
))
} }
fn rand_uniform( fn rand_uniform(
&self, &self,
shape: &Shape, shape: &Shape,
dtype: DType, dtype: DType,
min: f64, mean: f64,
max: f64, stddev: f64,
) -> Result<Self::Storage> { ) -> Result<Self::Storage> {
let name = match dtype { // TODO is there a better way ?
DType::F32 => "rand_uniform_f32", let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
DType::F16 => "rand_uniform_f16", self.storage_from_cpu_storage(&cpu_storage)
DType::BF16 => "rand_uniform_bf16",
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_uniform")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_random_uniform(
&self.device,
&command_buffer,
&self.kernels,
name,
min as f32,
max as f32,
shape.elem_count(),
&self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::Storage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
} }
fn rand_normal( fn rand_normal(
@ -1763,53 +1566,10 @@ impl BackendDevice for MetalDevice {
mean: f64, mean: f64,
stddev: f64, stddev: f64,
) -> Result<Self::Storage> { ) -> Result<Self::Storage> {
let name = match dtype { // TODO is there a better way ?
DType::F32 => "rand_normal_f32", let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
DType::F16 => "rand_normal_f16", self.storage_from_cpu_storage(&cpu_storage)
DType::BF16 => "rand_normal_bf16",
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_random_normal(
&self.device,
&command_buffer,
&self.kernels,
name,
mean as f32,
stddev as f32,
shape.elem_count(),
&self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::Storage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
} }
fn set_seed(&self, seed: u64) -> Result<()> {
let seed: u32 = seed.try_into().map_err(|_| {
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())
})?;
let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?;
let contents = seed_buffer.contents();
unsafe {
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 1);
}
seed_buffer.did_modify_range(metal::NSRange::new(0, 4));
Ok(())
}
}
fn buf_size(size: NSUInteger) -> NSUInteger {
(size - 1).next_power_of_two() as NSUInteger
} }
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {

View File

@ -333,16 +333,6 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) } unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
} }
#[inline]
pub fn vs_exp_inplace(y: &mut [f32]) {
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_exp_inplace(y: &mut [f64]) {
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline] #[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) { pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) { for (&v, y) in vs.iter().zip(ys.iter_mut()) {
@ -365,28 +355,6 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
} }
} }
#[inline]
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vs_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
#[inline]
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vd_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
macro_rules! binary_op { macro_rules! binary_op {
($fn_name:ident, $ty:ty, $mkl_name:ident) => { ($fn_name:ident, $ty:ty, $mkl_name:ident) => {
#[inline] #[inline]

View File

@ -61,7 +61,6 @@ pub enum UnaryOp {
GeluErf, GeluErf,
Erf, Erf,
Relu, Relu,
Silu,
Tanh, Tanh,
Floor, Floor,
Ceil, Ceil,
@ -132,10 +131,7 @@ pub enum Op {
stride: (usize, usize), stride: (usize, usize),
}, },
UpsampleNearest1D { UpsampleNearest1D(Tensor),
arg: Tensor,
target_size: usize,
},
UpsampleNearest2D { UpsampleNearest2D {
arg: Tensor, arg: Tensor,
target_h: usize, target_h: usize,
@ -394,7 +390,6 @@ pub(crate) struct Gelu;
pub(crate) struct GeluErf; pub(crate) struct GeluErf;
pub(crate) struct Erf; pub(crate) struct Erf;
pub(crate) struct Relu; pub(crate) struct Relu;
pub(crate) struct Silu;
pub(crate) struct Tanh; pub(crate) struct Tanh;
pub(crate) struct Floor; pub(crate) struct Floor;
pub(crate) struct Ceil; pub(crate) struct Ceil;
@ -729,77 +724,6 @@ impl UnaryOpT for Erf {
} }
} }
/// Silu operation
impl UnaryOpT for Silu {
const NAME: &'static str = "silu";
const V: Self = Silu;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v / (bf16::ONE + (-v).exp())
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v / (f16::ONE + (-v).exp())
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v / (1.0 + (-v).exp())
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v / (1.0 + (-v).exp())
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
const KERNEL: &'static str = "usilu";
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::vs_silu(xs, ys)
}
#[cfg(feature = "mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_silu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::accelerate::vs_silu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::accelerate::vd_silu(xs, ys)
}
}
impl UnaryOpT for Abs { impl UnaryOpT for Abs {
const NAME: &'static str = "abs"; const NAME: &'static str = "abs";
const KERNEL: &'static str = "uabs"; const KERNEL: &'static str = "uabs";

View File

@ -42,7 +42,7 @@ pub enum OpCode {
Stop = b'.', Stop = b'.',
NewObj = 0x81, NewObj = 0x81,
EmptyList = b']', EmptyList = b']',
BinFloat = b'G', BinFloat = b'g',
Append = b'a', Append = b'a',
Appends = b'e', Appends = b'e',
} }
@ -217,13 +217,6 @@ impl Object {
let args = args.remove(1); let args = args.remove(1);
(callable, args) (callable, args)
} }
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
let mut args = args.tuple()?;
args.remove(0).reduce()?
}
_ => (callable, args), _ => (callable, args),
}; };
match callable { match callable {
@ -234,11 +227,13 @@ impl Object {
_ => return Ok(None), _ => return Ok(None),
}; };
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?; let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
let mut path = dir_name.to_path_buf();
path.push(file_path);
Ok(Some(TensorInfo { Ok(Some(TensorInfo {
name, name,
dtype, dtype,
layout, layout,
path: format!("{}/{}", dir_name.to_string_lossy(), file_path), path: path.to_string_lossy().into_owned(),
storage_size, storage_size,
})) }))
} }
@ -350,10 +345,8 @@ impl Stack {
module_name, module_name,
class_name, class_name,
} => { } => {
if module_name == "collections" if module_name == "collections" && class_name == "OrderedDict" {
&& (class_name == "OrderedDict" || class_name == "defaultdict") // TODO: have a separate ordered dict.
{
// TODO: have a separate ordered dict and a separate default dict.
Some(Object::Dict(vec![])) Some(Object::Dict(vec![]))
} else { } else {
None None
@ -462,10 +455,7 @@ impl Stack {
self.push(Object::Int(arg)) self.push(Object::Int(arg))
} }
OpCode::BinFloat => { OpCode::BinFloat => {
// Somehow floats are encoded using BigEndian whereas int types use LittleEndian. let arg = r.read_f64::<LittleEndian>()?;
// https://github.com/python/cpython/blob/0c80da4c14d904a367968955544dd6ae58c8101c/Lib/pickletools.py#L855
// https://github.com/pytorch/pytorch/blob/372d078f361e726bb4ac0884ac334b04c58179ef/torch/_weights_only_unpickler.py#L243
let arg = r.read_f64::<byteorder::BigEndian>()?;
self.push(Object::Float(arg)) self.push(Object::Float(arg))
} }
OpCode::BinUnicode => { OpCode::BinUnicode => {
@ -637,16 +627,9 @@ pub struct TensorInfo {
pub storage_size: usize, pub storage_size: usize,
} }
/// Read the tensor info from a .pth file.
///
/// # Arguments
/// * `file` - The path to the .pth file.
/// * `verbose` - Whether to print debug information.
/// * `key` - Optional key to retrieve `state_dict` from the pth file.
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>( pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
file: P, file: P,
verbose: bool, verbose: bool,
key: Option<&str>,
) -> Result<Vec<TensorInfo>> { ) -> Result<Vec<TensorInfo>> {
let file = std::fs::File::open(file)?; let file = std::fs::File::open(file)?;
let zip_reader = std::io::BufReader::new(file); let zip_reader = std::io::BufReader::new(file);
@ -668,9 +651,8 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
stack.read_loop(&mut reader)?; stack.read_loop(&mut reader)?;
let obj = stack.finalize()?; let obj = stack.finalize()?;
if VERBOSE || verbose { if VERBOSE || verbose {
println!("{obj:#?}"); println!("{obj:?}");
} }
let obj = match obj { let obj = match obj {
Object::Build { callable, args } => match *callable { Object::Build { callable, args } => match *callable {
Object::Reduce { callable, args: _ } => match *callable { Object::Reduce { callable, args: _ } => match *callable {
@ -684,24 +666,6 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
}, },
obj => obj, obj => obj,
}; };
// If key is provided, then we need to extract the state_dict from the object.
let obj = if let Some(key) = key {
if let Object::Dict(key_values) = obj {
key_values
.into_iter()
.find(|(k, _)| *k == Object::Unicode(key.to_owned()))
.map(|(_, v)| v)
.ok_or_else(|| E::Msg(format!("key {key} not found")))?
} else {
obj
}
} else {
obj
};
// If the object is a dict, then we can extract the tensor info from it.
// NOTE: We are assuming that the `obj` is state_dict by this stage.
if let Object::Dict(key_values) = obj { if let Object::Dict(key_values) = obj {
for (name, value) in key_values.into_iter() { for (name, value) in key_values.into_iter() {
match value.into_tensor_info(name, &dir_name) { match value.into_tensor_info(name, &dir_name) {
@ -724,8 +688,8 @@ pub struct PthTensors {
} }
impl PthTensors { impl PthTensors {
pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> { pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?; let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
let tensor_infos = tensor_infos let tensor_infos = tensor_infos
.into_iter() .into_iter()
.map(|ti| (ti.name.to_string(), ti)) .map(|ti| (ti.name.to_string(), ti))
@ -739,7 +703,6 @@ impl PthTensors {
} }
pub fn get(&self, name: &str) -> Result<Option<Tensor>> { pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
use std::io::Read;
let tensor_info = match self.tensor_infos.get(name) { let tensor_info = match self.tensor_infos.get(name) {
None => return Ok(None), None => return Ok(None),
Some(tensor_info) => tensor_info, Some(tensor_info) => tensor_info,
@ -748,56 +711,27 @@ impl PthTensors {
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?); let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?; let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_name(&tensor_info.path)?; let mut reader = zip.by_name(&tensor_info.path)?;
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
let rank = tensor_info.layout.shape().rank();
// Reading the data is a bit tricky as it can be strided, for now only support the basic // Reading the data is a bit tricky as it can be strided, use an offset, etc.
// case and when the tensor is fortran contiguous. // For now only support the basic case.
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous { if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
crate::bail!( crate::bail!(
"cannot retrieve non-contiguous tensors {:?}", "cannot retrieve non-contiguous tensors {:?}",
tensor_info.layout tensor_info.layout
) )
} }
let start_offset = tensor_info.layout.start_offset();
if start_offset > 0 {
std::io::copy(
&mut reader.by_ref().take(start_offset as u64),
&mut std::io::sink(),
)?;
}
let tensor = Tensor::from_reader( let tensor = Tensor::from_reader(
tensor_info.layout.shape().clone(), tensor_info.layout.shape().clone(),
tensor_info.dtype, tensor_info.dtype,
&mut reader, &mut reader,
)?; )?;
Ok(Some(tensor))
if rank > 1 && is_fortran_contiguous {
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
let tensor = tensor.reshape(shape_reversed)?;
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
let tensor = tensor.permute(dim_indeces_reversed)?;
Ok(Some(tensor))
} else {
Ok(Some(tensor))
}
} }
} }
/// Read all the tensors from a PyTorch pth file with a given key. /// Read all the tensors from a PyTorch pth file.
/// pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
/// # Arguments let pth = PthTensors::new(path)?;
/// * `path` - Path to the pth file.
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
/// contains multiple objects and the state_dict is the one we are interested in.
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
path: P,
key: Option<&str>,
) -> Result<Vec<(String, Tensor)>> {
let pth = PthTensors::new(path, key)?;
let tensor_names = pth.tensor_infos.keys(); let tensor_names = pth.tensor_infos.keys();
let mut tensors = Vec::with_capacity(tensor_names.len()); let mut tensors = Vec::with_capacity(tensor_names.len());
for name in tensor_names { for name in tensor_names {
@ -807,11 +741,3 @@ pub fn read_all_with_key<P: AsRef<std::path::Path>>(
} }
Ok(tensors) Ok(tensors)
} }
/// Read all the tensors from a PyTorch pth file.
///
/// # Arguments
/// * `path` - Path to the pth file.
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
read_all_with_key(path, None)
}

View File

@ -1,343 +0,0 @@
use super::{GgmlDType, QStorage};
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{CudaDevice, CudaStorage, Result};
use cudarc::driver::{CudaSlice, DeviceSlice};
pub struct QCudaStorage {
data: CudaSlice<u8>,
dtype: GgmlDType,
device: CudaDevice,
}
pub const WARP_SIZE: usize = 32;
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
pub const NWARPS_Q4_0_AMPERE: usize = 4;
pub const GGML_CUDA_MMV_X: usize = 32;
pub const GGML_CUDA_MMV_Y: usize = 1;
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
fn dequantize(
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
GgmlDType::Q5_0 => {
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
(
"dequantize_block_q5_0",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
nb,
)
}
GgmlDType::Q5_1 => {
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
(
"dequantize_block_q5_1",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
nb,
)
}
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
block_dim: (block_dim as u32, 1, 1),
shared_mem_bytes: 0,
};
if is_k {
let params = (data, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (data, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_mut_mal_vec(
data: &CudaSlice<u8>,
y: &cudarc::driver::CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let kernel_name = match dtype {
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
GgmlDType::Q5_0 => "dequantize_mul_mat_vec_q5_0_cuda",
GgmlDType::Q5_1 => "dequantize_mul_mat_vec_q5_1_cuda",
GgmlDType::Q8_0 => "dequantize_mul_mat_vec_q8_0_cuda",
GgmlDType::Q2K => "dequantize_mul_mat_vec_q2_k",
GgmlDType::Q3K => "dequantize_mul_mat_vec_q3_k",
GgmlDType::Q4K => "dequantize_mul_mat_vec_q4_k",
GgmlDType::Q5K => "dequantize_mul_mat_vec_q5_k",
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = dev.alloc_zeros::<f32>(nrows).w()?;
let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (block_num_y as u32, 1, 1),
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
shared_mem_bytes: 0,
};
let params = (data, y, &dst, ncols as i32, nrows as i32);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = el_count * dtype.type_size() / dtype.block_size();
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
Ok(QCudaStorage {
data,
device: device.clone(),
dtype,
})
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &CudaDevice {
&self.device
}
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
let fast_kernel = matches!(
self.dtype,
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q2K
| GgmlDType::Q3K
| GgmlDType::Q4K
| GgmlDType::Q5K
| GgmlDType::Q6K
| GgmlDType::Q8K
);
if fast_kernel {
return dequantize(&self.data, self.dtype, elem_count, self.device());
}
// Run the dequantization on cpu.
use crate::quantized::k_quants::GgmlType;
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
GgmlDType::F32 => {
let slice =
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const f32, block_len) };
out.copy_from_slice(slice)
}
GgmlDType::F16 => {
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
half::f16::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_0 => {
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_1 => {
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_0 => {
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_1 => {
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_0 => {
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_1 => {
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q2K => {
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
}
GgmlDType::Q3K => {
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
}
GgmlDType::Q4K => {
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
}
GgmlDType::Q5K => {
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
}
GgmlDType::Q6K => {
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
}
GgmlDType::Q8K => {
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
}
}
self.device
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
}
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
// Run the quantization on cpu.
let src = match &src.slice {
crate::cuda_backend::CudaStorageSlice::F32(data) => {
self.device.dtoh_sync_copy(data).w()?
}
_ => crate::bail!("only f32 can be quantized"),
};
let src_len = src.len();
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
qcpu_storage.quantize(&src)?;
let data = qcpu_storage.data()?;
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
self.data = data;
Ok(())
}
pub fn storage_size_in_bytes(&self) -> usize {
self.data.len()
}
pub fn fwd(
&self,
self_shape: &crate::Shape,
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
self.dequantize_matmul_vec(self_shape, storage, layout)
} else {
self.dequantize_matmul(self_shape, storage, layout)
}
}
}
impl QCudaStorage {
fn dequantize_matmul_vec(
&self,
self_shape: &crate::Shape,
rhs: &CudaStorage,
rhs_l: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
let (nrows, ncols) = self_shape.dims2()?;
let rhs = rhs.as_cuda_slice::<f32>()?;
let rhs = match rhs_l.contiguous_offsets() {
Some((o1, o2)) => rhs.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
};
let (with_batch, k) = match rhs_l.shape().dims() {
[1, 1, k] => (true, k),
[1, k] => (false, k),
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
};
if ncols != *k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
}
let out =
dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?;
let out_shape = if with_batch {
vec![1, 1, nrows]
} else {
vec![1, nrows]
};
Ok((out, out_shape.into()))
}
fn dequantize_matmul(
&self,
self_shape: &crate::Shape,
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
use crate::backend::BackendStorage;
let (n, k) = self_shape.dims2()?;
let (b, m, k2) = match layout.shape().dims() {
&[b, m, k2] => (b, m, k2),
&[m, k2] => (1, m, k2),
s => crate::bail!("unexpected shape for input {s:?}"),
};
if k2 != k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
}
let data_f32 = self.dequantize(n * k)?;
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
let mut out_shape = layout.shape().dims().to_vec();
out_shape.pop();
out_shape.push(n);
Ok((out, out_shape.into()))
}
}
fn read_to_vec<T: Clone>(buffer: &[u8], n: usize) -> Vec<T> {
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
slice.to_vec()
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
device: &CudaDevice,
data: &[T],
) -> Result<super::QStorage> {
let data = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
};
let data = device.htod_sync_copy(data).w()?;
Ok(QStorage::Cuda(QCudaStorage {
data,
device: device.clone(),
dtype: T::DTYPE,
}))
}

View File

@ -1,50 +0,0 @@
#![allow(unused)]
use super::GgmlDType;
use crate::{CudaDevice, CudaStorage, Error, Result};
pub struct QCudaStorage {
dtype: GgmlDType,
device: CudaDevice,
}
impl QCudaStorage {
pub fn zeros(_: &CudaDevice, _: usize, _: GgmlDType) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &CudaDevice {
&self.device
}
pub fn dequantize(&self, _elem_count: usize) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn storage_size_in_bytes(&self) -> usize {
0
}
pub fn fwd(
&self,
_self_shape: &crate::Shape,
_storage: &CudaStorage,
_layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
Err(Error::NotCompiledWithCudaSupport)
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
_device: &CudaDevice,
_data: &[T],
) -> Result<super::QStorage> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -1,50 +0,0 @@
#![allow(unused)]
use super::GgmlDType;
use crate::{Error, MetalDevice, MetalStorage, Result};
pub struct QMetalStorage {
dtype: GgmlDType,
device: MetalDevice,
}
impl QMetalStorage {
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &MetalDevice {
&self.device
}
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn storage_size_in_bytes(&self) -> usize {
0
}
pub fn fwd(
&self,
_self_shape: &crate::Shape,
_storage: &MetalStorage,
_layout: &crate::Layout,
) -> Result<(MetalStorage, crate::Shape)> {
Err(Error::NotCompiledWithMetalSupport)
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
_device: &MetalDevice,
_data: &[T],
) -> Result<super::QStorage> {
Err(Error::NotCompiledWithMetalSupport)
}

View File

@ -1,7 +1,7 @@
//! Support for the GGML file format. //! Support for the GGML file format.
use super::{k_quants, GgmlDType, QStorage}; use super::{k_quants, GgmlDType};
use crate::{Device, Result}; use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt}; use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap; use std::collections::HashMap;
@ -121,17 +121,11 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
raw_data: &[u8], raw_data: &[u8],
size_in_bytes: usize, size_in_bytes: usize,
dims: Vec<usize>, dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> { ) -> Result<super::QTensor> {
let raw_data_ptr = raw_data.as_ptr(); let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>(); let n_blocks = size_in_bytes / std::mem::size_of::<T>();
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
let data: QStorage = match device { super::QTensor::new(data.to_vec(), dims)
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
Device::Metal(metal) => super::metal::load_quantized(metal, data)?,
Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?,
};
super::QTensor::new(data, dims)
} }
/// Creates a [Tensor] from a raw GGML tensor. /// Creates a [Tensor] from a raw GGML tensor.
@ -139,50 +133,29 @@ pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType, ggml_dtype: GgmlDType,
raw_data: &[u8], raw_data: &[u8],
dims: Vec<usize>, dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> { ) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>(); let tensor_elems = dims.iter().product::<usize>();
let block_size = ggml_dtype.block_size(); let blck_size = ggml_dtype.blck_size();
if tensor_elems % block_size != 0 { if tensor_elems % blck_size != 0 {
crate::bail!( crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {block_size}" "the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
) )
} }
let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size(); let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
match ggml_dtype { match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device), GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device), GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_0 => { GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device) GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
} GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_1 => { GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device) GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
} GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_0 => { GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device) GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
} GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_1 => { GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q8_0 => {
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q2K => {
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q3K => {
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4K => {
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5K => {
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q6K => {
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
}
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
} }
} }
@ -190,7 +163,6 @@ pub fn qtensor_from_ggml(
fn read_one_tensor<R: std::io::Seek + std::io::Read>( fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R, reader: &mut R,
magic: VersionedMagic, magic: VersionedMagic,
device: &Device,
) -> Result<(String, super::QTensor)> { ) -> Result<(String, super::QTensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?; let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?; let name_len = reader.read_u32::<LittleEndian>()?;
@ -211,11 +183,11 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
} }
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>(); let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
let tensor_elems = dims.iter().product::<usize>(); let tensor_elems = dims.iter().product::<usize>();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size(); let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
// TODO: Mmap version to avoid copying the data around? // TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes]; let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?; reader.read_exact(&mut raw_data)?;
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) { match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
Ok(tensor) => Ok((name, tensor)), Ok(tensor) => Ok((name, tensor)),
Err(e) => crate::bail!("Error creating tensor {name}: {e}"), Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
} }
@ -226,14 +198,10 @@ pub struct Content {
pub hparams: HParams, pub hparams: HParams,
pub vocab: Vocab, pub vocab: Vocab,
pub tensors: HashMap<String, super::QTensor>, pub tensors: HashMap<String, super::QTensor>,
pub device: Device,
} }
impl Content { impl Content {
pub fn read<R: std::io::Seek + std::io::Read>( pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
reader: &mut R,
device: &Device,
) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505 // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?; let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?; reader.seek(std::io::SeekFrom::Start(0))?;
@ -243,16 +211,14 @@ impl Content {
let mut tensors = HashMap::new(); let mut tensors = HashMap::new();
while reader.stream_position()? != last_position { while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic, device)?; let (name, tensor) = read_one_tensor(reader, magic)?;
tensors.insert(name, tensor); tensors.insert(name, tensor);
} }
let device = device.clone();
Ok(Self { Ok(Self {
magic, magic,
hparams, hparams,
vocab, vocab,
tensors, tensors,
device,
}) })
} }

View File

@ -3,7 +3,7 @@
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md //! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor}; use super::{GgmlDType, QTensor};
use crate::{Device, Result}; use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap; use std::collections::HashMap;
@ -59,25 +59,19 @@ impl TensorInfo {
&self, &self,
reader: &mut R, reader: &mut R,
tensor_data_offset: u64, tensor_data_offset: u64,
device: &Device,
) -> Result<QTensor> { ) -> Result<QTensor> {
let tensor_elems = self.shape.elem_count(); let tensor_elems = self.shape.elem_count();
let block_size = self.ggml_dtype.block_size(); let blck_size = self.ggml_dtype.blck_size();
if tensor_elems % block_size != 0 { if tensor_elems % blck_size != 0 {
crate::bail!( crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {block_size}" "the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
) )
} }
let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size(); let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
let mut raw_data = vec![0u8; size_in_bytes]; let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?; reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?; reader.read_exact(&mut raw_data)?;
super::ggml_file::qtensor_from_ggml( super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
self.ggml_dtype,
&raw_data,
self.shape.dims().to_vec(),
device,
)
} }
} }
@ -466,13 +460,12 @@ impl Content {
&self, &self,
reader: &mut R, reader: &mut R,
name: &str, name: &str,
device: &Device,
) -> Result<QTensor> { ) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) { let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info, Some(tensor_info) => tensor_info,
None => crate::bail!("cannot find tensor info for {name}"), None => crate::bail!("cannot find tensor info for {name}"),
}; };
tensor_info.read(reader, self.tensor_data_offset, device) tensor_info.read(reader, self.tensor_data_offset)
} }
} }
@ -524,9 +517,10 @@ pub fn write<W: std::io::Seek + std::io::Write>(
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}" "internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
) )
} }
let data = tensor.data()?; let data_ptr = tensor.as_ptr();
let size_in_bytes = data.len(); let size_in_bytes = tensor.storage_size_in_bytes();
w.write_all(&data)?; let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
w.write_all(data)?;
let padding = 31 - (31 + size_in_bytes) % 32; let padding = 31 - (31 + size_in_bytes) % 32;
w.write_all(&vec![0u8; padding])?; w.write_all(&vec![0u8; padding])?;
} }

View File

@ -1545,13 +1545,13 @@ impl GgmlType for BlockQ5K {
let d2 = d * sc as f32; let d2 = d * sc as f32;
let m2 = min * m as f32; let m2 = min * m as f32;
for (ql, qh) in ql.iter().zip(qh) { for (ql, qh) in ql.iter().zip(qh) {
let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 }; let to_add = if qh & u1 != 0 { 16 } else { 1 };
y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1; y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1;
ys_index += 1; ys_index += 1;
} }
for (ql, qh) in ql.iter().zip(qh) { for (ql, qh) in ql.iter().zip(qh) {
let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 }; let to_add = if qh & u2 != 0 { 16 } else { 1 };
y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2; y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2;
ys_index += 1; ys_index += 1;
} }
is += 2; is += 2;

View File

@ -1,222 +0,0 @@
use super::{GgmlDType, QStorage};
use crate::backend::BackendStorage;
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
use metal::Buffer;
use std::sync::Arc;
pub struct QMetalStorage {
dtype: GgmlDType,
device: MetalDevice,
buffer: Arc<Buffer>,
}
impl QMetalStorage {
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
let size = elem_count * dtype.type_size() / dtype.block_size();
let buffer = device.allocate_zeros(size)?;
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &MetalDevice {
&self.device
}
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
use crate::quantized::k_quants::GgmlType;
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding();
self.device.wait_until_completed()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
GgmlDType::F32 => {
let vec: Vec<f32> = read_to_vec(&buffer, block_len);
f32::to_float(&vec, &mut out)?;
}
GgmlDType::F16 => {
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
half::f16::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_0 => {
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_1 => {
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_0 => {
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_1 => {
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_0 => {
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_1 => {
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q2K => {
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
}
GgmlDType::Q3K => {
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
}
GgmlDType::Q4K => {
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
}
GgmlDType::Q5K => {
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
}
GgmlDType::Q6K => {
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
}
GgmlDType::Q8K => {
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
}
}
let buffer = self.device.new_buffer_with_data(&out)?;
Ok(MetalStorage::new(
buffer,
self.device.clone(),
elem_count,
DType::F32,
))
}
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
// Quantization only happens on CPU for now.
let src = src.to_cpu::<f32>()?;
let elem_count = src.len();
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;
qcpu_storage.quantize(&src)?;
let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
self.buffer = buffer;
Ok(())
}
pub fn storage_size_in_bytes(&self) -> usize {
self.buffer.length() as usize
}
pub fn fwd(
&self,
self_shape: &Shape,
storage: &MetalStorage,
layout: &crate::Layout,
) -> Result<(MetalStorage, Shape)> {
use crate::MetalError;
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let (n, k) = self_shape.dims2()?;
let mut dst_shape = src_shape.dims().to_vec();
let (b, m) = match dst_shape.len() {
3 => (dst_shape[0], dst_shape[1]),
2 => (1, dst_shape[0]),
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
let device = storage.device().clone();
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
let command_buffer = device.command_buffer()?;
candle_metal_kernels::call_quantized_matmul_t(
device.device(),
&command_buffer,
device.kernels(),
self.dtype.into(),
(b, m, n, k),
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
&self.buffer,
&dst,
)
.map_err(MetalError::from)?;
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
Ok((dst_storage, dst_shape))
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
device: &MetalDevice,
data: &[T],
) -> Result<QStorage> {
let buffer = device.new_buffer_with_data(data)?;
let device = device.clone();
Ok(QStorage::Metal(QMetalStorage {
dtype: T::DTYPE,
device,
buffer,
}))
}
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
assert!(!ptr.is_null());
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
slice.to_vec()
}
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
fn from(value: GgmlDType) -> Self {
match value {
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
}
}
}

View File

@ -1,134 +1,23 @@
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor}; use crate::{Device, Result, Shape, Tensor};
use k_quants::*;
use std::borrow::Cow;
#[cfg(target_feature = "avx")] #[cfg(target_feature = "avx")]
pub mod avx; pub mod avx;
mod dummy_cuda;
mod dummy_metal;
pub mod ggml_file; pub mod ggml_file;
pub mod gguf_file; pub mod gguf_file;
pub mod k_quants; pub mod k_quants;
#[cfg(feature = "metal")]
pub mod metal;
#[cfg(not(feature = "metal"))]
mod metal {
pub use super::dummy_metal::*;
}
#[cfg(feature = "cuda")]
pub mod cuda;
#[cfg(not(feature = "cuda"))]
mod cuda {
pub use super::dummy_cuda::*;
}
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
pub mod neon; pub mod neon;
#[cfg(target_feature = "simd128")] #[cfg(target_feature = "simd128")]
pub mod simd128; pub mod simd128;
pub mod utils; pub mod utils;
use half::f16;
pub use k_quants::GgmlType; pub use k_quants::GgmlType;
pub struct QTensor { pub struct QTensor {
storage: QStorage, data: Box<dyn QuantizedType>,
shape: Shape, shape: Shape,
} }
impl Device {
fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
match self {
Device::Cpu => {
let storage = dtype.cpu_zeros(elem_count);
Ok(QStorage::Cpu(storage))
}
Device::Metal(metal) => {
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
Ok(QStorage::Metal(storage))
}
Device::Cuda(cuda) => {
let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
Ok(QStorage::Cuda(storage))
}
}
}
}
pub enum QStorage {
Cpu(Box<dyn QuantizedType>),
Metal(metal::QMetalStorage),
Cuda(cuda::QCudaStorage),
}
impl QStorage {
fn block_size(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.block_size(),
QStorage::Metal(storage) => storage.dtype().block_size(),
QStorage::Cuda(storage) => storage.dtype().block_size(),
}
}
fn dtype(&self) -> GgmlDType {
match self {
QStorage::Cpu(storage) => storage.dtype(),
QStorage::Metal(storage) => storage.dtype(),
QStorage::Cuda(storage) => storage.dtype(),
}
}
fn device(&self) -> Device {
match self {
QStorage::Cpu(_storage) => Device::Cpu,
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
}
}
fn size_in_bytes(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
}
}
fn quantize(&mut self, src: &Storage) -> Result<()> {
match (self, src) {
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
storage.from_float(src.as_slice::<f32>()?)?;
}
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
(QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
_ => crate::bail!("Invalid dequantize storage locations do not match"),
}
Ok(())
}
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
match self {
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
}
}
fn data(&self) -> Result<Cow<[u8]>> {
match self {
QStorage::Cpu(storage) => {
let data_ptr = storage.as_ptr();
let size_in_bytes = storage.storage_size_in_bytes();
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
Ok(Cow::from(data))
}
QStorage::Metal(_) | QStorage::Cuda(_) => {
crate::bail!("not implemented");
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GgmlDType { pub enum GgmlDType {
F32, F32,
@ -188,25 +77,6 @@ impl GgmlDType {
} }
} }
/// The block dtype
pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
match self {
Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
}
}
/// The type size for blocks in bytes. /// The type size for blocks in bytes.
pub fn type_size(&self) -> usize { pub fn type_size(&self) -> usize {
use k_quants::*; use k_quants::*;
@ -230,7 +100,7 @@ impl GgmlDType {
} }
/// The block size, i.e. the number of elements stored in each block. /// The block size, i.e. the number of elements stored in each block.
pub fn block_size(&self) -> usize { pub fn blck_size(&self) -> usize {
match self { match self {
Self::F32 => 1, Self::F32 => 1,
Self::F16 => 1, Self::F16 => 1,
@ -249,13 +119,9 @@ impl GgmlDType {
pub trait QuantizedType: Send + Sync { pub trait QuantizedType: Send + Sync {
fn dtype(&self) -> GgmlDType; fn dtype(&self) -> GgmlDType;
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>; fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>; fn to_float(&self, ys: &mut [f32]) -> Result<()>;
fn storage_size_in_bytes(&self) -> usize; fn storage_size_in_bytes(&self) -> usize;
fn as_ptr(&self) -> *const u8; fn as_ptr(&self) -> *const u8;
fn block_size(&self) -> usize;
#[allow(clippy::wrong_self_convention)]
fn from_float(&mut self, xs: &[f32]) -> Result<()>;
fn size(&self) -> usize;
} }
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> { impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
@ -263,26 +129,12 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
k_quants::matmul(mkn, lhs, self.as_slice(), dst) k_quants::matmul(mkn, lhs, self.as_slice(), dst)
} }
fn size(&self) -> usize {
self.len() * core::mem::size_of::<T>()
}
fn from_float(&mut self, xs: &[f32]) -> Result<()> {
T::from_float(xs, self)
}
fn dtype(&self) -> GgmlDType { fn dtype(&self) -> GgmlDType {
T::DTYPE T::DTYPE
} }
fn block_size(&self) -> usize { fn to_float(&self, ys: &mut [f32]) -> Result<()> {
T::BLCK_SIZE T::to_float(self.as_slice(), ys)
}
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
let mut ys = vec![0.0f32; elem_count];
T::to_float(self.as_slice(), &mut ys)?;
Ok(CpuStorage::F32(ys))
} }
fn storage_size_in_bytes(&self) -> usize { fn storage_size_in_bytes(&self) -> usize {
@ -300,53 +152,56 @@ impl std::fmt::Debug for QTensor {
} }
} }
fn check_shape(shape: &Shape, block_size: usize) -> Result<()> { fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
let dims = shape.dims(); let dims = shape.dims();
if dims.is_empty() { if dims.is_empty() {
crate::bail!("scalar tensor cannot be quantized {shape:?}") crate::bail!("scalar tensor cannot be quantized {shape:?}")
} }
if dims[dims.len() - 1] % block_size != 0 { if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
crate::bail!( crate::bail!(
"quantized tensor must have their last dim divisible by block size {shape:?} {}", "quantized tensor must have their last dim divisible by block size {shape:?} {}",
block_size T::BLCK_SIZE
) )
} }
Ok(()) Ok(())
} }
impl QTensor { impl QTensor {
pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> { pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
data: Vec<T>,
shape: S,
) -> Result<Self> {
let shape = shape.into(); let shape = shape.into();
check_shape(&shape, storage.block_size())?; check_shape::<T>(&shape)?;
Ok(Self { storage, shape }) Ok(Self {
data: Box::new(data),
shape,
})
} }
pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> { pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
let shape = src.shape(); let shape = src.shape();
let block_size = dtype.block_size(); check_shape::<T>(shape)?;
check_shape(shape, block_size)?; let src = src
let src = src.to_dtype(crate::DType::F32)?.flatten_all()?; .to_dtype(crate::DType::F32)?
let elem_count = shape.elem_count(); .flatten_all()?
if elem_count % block_size != 0 { .to_vec1::<f32>()?;
if src.len() % T::BLCK_SIZE != 0 {
crate::bail!( crate::bail!(
"tensor size ({shape:?}) is not divisible by block size {}", "tensor size ({shape:?}) is not divisible by block size {}",
block_size T::BLCK_SIZE
) )
} }
let mut storage = src.device().qzeros(elem_count, dtype)?; let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
storage.quantize(&src.storage())?; T::from_float(&src, &mut data)?;
Ok(Self { Ok(Self {
storage, data: Box::new(data),
shape: shape.clone(), shape: shape.clone(),
}) })
} }
pub fn dtype(&self) -> GgmlDType { pub fn dtype(&self) -> GgmlDType {
self.storage.dtype() self.data.dtype()
}
pub fn device(&self) -> Device {
self.storage.device()
} }
pub fn rank(&self) -> usize { pub fn rank(&self) -> usize {
@ -358,19 +213,21 @@ impl QTensor {
} }
pub fn dequantize(&self, device: &Device) -> Result<Tensor> { pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
let storage = self.storage.dequantize(self.shape.elem_count())?; let mut f32_data = vec![0f32; self.shape.elem_count()];
let none = crate::op::BackpropOp::none(); self.data.to_float(&mut f32_data)?;
let is_variable = false; Tensor::from_vec(f32_data, &self.shape, device)
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable) }
.to_device(device)
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
self.data.matmul_t(mkn, lhs, dst)
} }
pub fn storage_size_in_bytes(&self) -> usize { pub fn storage_size_in_bytes(&self) -> usize {
self.storage.size_in_bytes() self.data.storage_size_in_bytes()
} }
pub fn data(&self) -> Result<Cow<'_, [u8]>> { pub fn as_ptr(&self) -> *const u8 {
self.storage.data() self.data.as_ptr()
} }
} }
@ -398,7 +255,7 @@ impl QMatMul {
_ => DEQUANTIZE_ALL.with(|b| *b), _ => DEQUANTIZE_ALL.with(|b| *b),
}; };
let t = if dequantize { let t = if dequantize {
let tensor = qtensor.dequantize(&qtensor.device())?; let tensor = qtensor.dequantize(&Device::Cpu)?;
Self::Tensor(tensor) Self::Tensor(tensor)
} else { } else {
Self::QTensor(qtensor) Self::QTensor(qtensor)
@ -437,41 +294,17 @@ impl crate::CustomOp1 for QTensor {
} }
dst_shape.push(n); dst_shape.push(n);
let dst_shape = Shape::from(dst_shape); let dst_shape = Shape::from(dst_shape);
#[allow(clippy::infallible_destructuring_match)] let storage = storage.as_slice::<f32>()?;
let self_storage = match &self.storage { let storage =
QStorage::Cpu(storage) => storage, &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
};
let slice = storage.as_slice::<f32>()?;
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let mut dst_storage = vec![0f32; dst_shape.elem_count()]; let mut dst_storage = vec![0f32; dst_shape.elem_count()];
self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?; self.matmul_t(
(dst_shape.elem_count() / n, k, n),
storage,
&mut dst_storage,
)?;
Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
} }
fn metal_fwd(
&self,
storage: &crate::MetalStorage,
layout: &crate::Layout,
) -> Result<(crate::MetalStorage, Shape)> {
let self_storage = match &self.storage {
QStorage::Metal(metal) => metal,
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
};
self_storage.fwd(&self.shape, storage, layout)
}
fn cuda_fwd(
&self,
storage: &crate::CudaStorage,
layout: &crate::Layout,
) -> Result<(crate::CudaStorage, Shape)> {
let self_storage = match &self.storage {
QStorage::Cuda(cuda) => cuda,
_ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
};
self_storage.fwd(&self.shape, storage, layout)
}
} }
impl crate::Module for QMatMul { impl crate::Module for QMatMul {

View File

@ -352,10 +352,6 @@ impl Storage {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?; let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s)) Ok(Self::Cuda(s))
} }
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(), lhs: lhs.device().location(),
rhs: rhs.device().location(), rhs: rhs.device().location(),
@ -701,32 +697,4 @@ impl Storage {
.bt()), .bt()),
} }
} }
#[allow(clippy::too_many_arguments)]
pub(crate) fn copy2d(
&self,
dst: &mut Self,
d1: usize,
d2: usize,
src_s: usize,
dst_s: usize,
src_o: usize,
dst_o: usize,
) -> Result<()> {
match (self, dst) {
(Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
(Self::Cuda(src), Self::Cuda(dst)) => {
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
}
(Self::Metal(src), Self::Metal(dst)) => {
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "copy2d",
}
.bt()),
}
}
} }

View File

@ -508,7 +508,6 @@ impl Tensor {
unary_op!(gelu_erf, GeluErf); unary_op!(gelu_erf, GeluErf);
unary_op!(erf, Erf); unary_op!(erf, Erf);
unary_op!(relu, Relu); unary_op!(relu, Relu);
unary_op!(silu, Silu);
unary_op!(ceil, Ceil); unary_op!(ceil, Ceil);
unary_op!(floor, Floor); unary_op!(floor, Floor);
unary_op!(round, Round); unary_op!(round, Round);
@ -666,7 +665,7 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false)) Ok(from_storage(storage, self.shape(), op, false))
} }
pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> { fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
if dim >= self.dims().len() { if dim >= self.dims().len() {
Err(Error::DimOutOfRange { Err(Error::DimOutOfRange {
shape: self.shape().clone(), shape: self.shape().clone(),
@ -805,35 +804,6 @@ impl Tensor {
} }
} }
/// Roll the tensor input along the given dimension.
/// Elements that are shifted beyond the last position are re-introduced at the first position.
///
/// ```rust
/// # use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.roll(1, 0)?;
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.roll(-1, 0)?;
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
where
D: Dim + Clone,
{
let dim = dim.to_index(self.shape(), "roll")?;
let dim_size = self.dim(dim)?;
let shift = shift.rem_euclid(dim_size as i32) as usize;
if shift == 0 {
Ok(self.clone())
} else {
let a = self.narrow(dim, 0, dim_size - shift)?;
let b = self.narrow(dim, dim_size - shift, shift)?;
Tensor::cat(&[&b, &a], dim)
}
}
/// Returns the sum of all elements in the input tensor. The sum is performed over all the /// Returns the sum of all elements in the input tensor. The sum is performed over all the
/// input dimensions. /// input dimensions.
/// ///
@ -1015,7 +985,7 @@ impl Tensor {
/// tensor also has three dimensions, `(batch, channels, target_size)`. /// tensor also has three dimensions, `(batch, channels, target_size)`.
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> { pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
let (n, c, _l) = self.dims3()?; let (n, c, _l) = self.dims3()?;
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size }); let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
let storage = self let storage = self
.storage() .storage()
.upsample_nearest1d(self.layout(), target_size)?; .upsample_nearest1d(self.layout(), target_size)?;
@ -1883,9 +1853,9 @@ impl Tensor {
/// this new node. The storage of this tensor is shared with the initial tensor. /// this new node. The storage of this tensor is shared with the initial tensor.
/// ///
/// If the tensor is already detached from the computation graph, the same tensor is returned. /// If the tensor is already detached from the computation graph, the same tensor is returned.
pub fn detach(&self) -> Tensor { pub fn detach(&self) -> Result<Tensor> {
if self.op.is_none() && !self.is_variable { if self.op.is_none() && !self.is_variable {
self.clone() Ok(self.clone())
} else { } else {
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
@ -1896,7 +1866,7 @@ impl Tensor {
dtype: self.dtype, dtype: self.dtype,
device: self.device.clone(), device: self.device.clone(),
}; };
Tensor(Arc::new(tensor_)) Ok(Tensor(Arc::new(tensor_)))
} }
} }
@ -2149,6 +2119,152 @@ impl Tensor {
Self::cat(&args, dim) Self::cat(&args, dim)
} }
/// Concatenates two or more tensors along a particular dimension.
///
/// All tensors must of the same rank, and the output will have
/// the same rank
///
/// ```rust
/// # use candle_core::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
///
/// let c = Tensor::cat(&[&a, &b], 0)?;
/// assert_eq!(c.shape().dims(), &[4, 3]);
///
/// let c = Tensor::cat(&[&a, &b], 1)?;
/// assert_eq!(c.shape().dims(), &[2, 6]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() {
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
}
let arg0 = args[0].as_ref();
if args.len() == 1 {
return Ok(arg0.clone());
}
let dim = dim.to_index(arg0.shape(), "cat")?;
for arg in args {
arg.as_ref().check_dim(dim, "cat")?;
}
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg0.rank() != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: arg0.rank(),
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx != dim && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
}
if dim == 0 {
Self::cat0(args)
} else {
// TODO: Avoid these transpositions and have an implementation that works
// for dim != 0...
let args: Vec<Tensor> = args
.iter()
.map(|a| a.as_ref().transpose(0, dim))
.collect::<Result<Vec<_>>>()?;
let cat = Self::cat0(&args)?;
cat.transpose(0, dim)
}
}
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
if args.is_empty() {
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
}
let arg0 = args[0].as_ref();
if args.len() == 1 {
return Ok(arg0.clone());
}
let rank = arg0.rank();
let device = arg0.device();
let dtype = arg0.dtype();
let first_dims = arg0.shape().dims();
let mut cat_dims = first_dims.to_vec();
cat_dims[0] = 0;
let mut offsets = vec![0usize];
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg.dtype() != dtype {
Err(Error::DTypeMismatchBinaryOp {
lhs: dtype,
rhs: arg.dtype(),
op: "cat",
}
.bt())?
}
if arg.device().location() != device.location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: device.location(),
rhs: arg.device().location(),
op: "cat",
}
.bt())?
}
if rank != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: rank,
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx == 0 {
cat_dims[0] += v2;
}
if dim_idx != 0 && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
let next_offset = offsets.last().unwrap() + arg.elem_count();
offsets.push(next_offset);
}
let shape = Shape::from(cat_dims);
let op = BackpropOp::new(args, |args| Op::Cat(args, 0));
let mut storage = device.zeros(&shape, dtype)?;
for (arg, &offset) in args.iter().zip(offsets.iter()) {
let arg = arg.as_ref();
arg.storage()
.copy_strided_src(&mut storage, offset, arg.layout())?;
}
Ok(from_storage(storage, shape, op, false))
}
/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the /// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
/// input tensor values and `right` elements after. /// input tensor values and `right` elements after.
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> { pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
@ -2462,21 +2578,11 @@ impl Tensor {
} }
/// Returns log(sum(exp(tensor), dim)). /// Returns log(sum(exp(tensor), dim)).
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> { pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
let exp = self.exp()?; let exp = self.exp()?;
let sum = exp.sum(sum_dims)?; let sum = exp.sum(sum_dims)?;
sum.log() sum.log()
} }
/// Pointwise pow operation.
pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.mul(&self.log()?)?.exp()
}
/// Broadcasting version of `pow`.
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.broadcast_mul(&self.log()?)?.exp()
}
} }
macro_rules! bin_trait { macro_rules! bin_trait {

View File

@ -1,240 +0,0 @@
use crate::{shape::Dim, Error, Result, Shape, Tensor};
impl Tensor {
/// Concatenates two or more tensors along a particular dimension.
///
/// All tensors must of the same rank, and the output will have
/// the same rank
///
/// ```rust
/// # use candle_core::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
///
/// let c = Tensor::cat(&[&a, &b], 0)?;
/// assert_eq!(c.shape().dims(), &[4, 3]);
///
/// let c = Tensor::cat(&[&a, &b], 1)?;
/// assert_eq!(c.shape().dims(), &[2, 6]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() {
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
}
let arg0 = args[0].as_ref();
if args.len() == 1 {
return Ok(arg0.clone());
}
let dim = dim.to_index(arg0.shape(), "cat")?;
for arg in args {
arg.as_ref().check_dim(dim, "cat")?;
}
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg0.rank() != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: arg0.rank(),
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx != dim && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
}
if dim == 0 {
Self::cat0(args)
} else {
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
if all_contiguous {
Self::cat_contiguous(args, dim)
} else {
let args: Vec<Tensor> = args
.iter()
.map(|a| a.as_ref().transpose(0, dim))
.collect::<Result<Vec<_>>>()?;
let cat = Self::cat0(&args)?;
cat.transpose(0, dim)
}
}
}
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
if args.is_empty() {
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
}
let arg0 = args[0].as_ref();
if args.len() == 1 {
return Ok(arg0.clone());
}
let rank = arg0.rank();
let device = arg0.device();
let dtype = arg0.dtype();
let first_dims = arg0.shape().dims();
let mut cat_dims = first_dims.to_vec();
cat_dims[0] = 0;
let mut offsets = vec![0usize];
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg.dtype() != dtype {
Err(Error::DTypeMismatchBinaryOp {
lhs: dtype,
rhs: arg.dtype(),
op: "cat",
}
.bt())?
}
if arg.device().location() != device.location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: device.location(),
rhs: arg.device().location(),
op: "cat",
}
.bt())?
}
if rank != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: rank,
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx == 0 {
cat_dims[0] += v2;
}
if dim_idx != 0 && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
let next_offset = offsets.last().unwrap() + arg.elem_count();
offsets.push(next_offset);
}
let shape = Shape::from(cat_dims);
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
let mut storage = device.zeros(&shape, dtype)?;
for (arg, &offset) in args.iter().zip(offsets.iter()) {
let arg = arg.as_ref();
arg.storage()
.copy_strided_src(&mut storage, offset, arg.layout())?;
}
Ok(crate::tensor::from_storage(storage, shape, op, false))
}
fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
if args.is_empty() {
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
}
let arg0 = args[0].as_ref();
if args.len() == 1 {
return Ok(arg0.clone());
}
let rank = arg0.rank();
let device = arg0.device();
let dtype = arg0.dtype();
let first_dims = arg0.shape().dims();
let mut cat_dims = first_dims.to_vec();
cat_dims[dim] = 0;
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg.dtype() != dtype {
Err(Error::DTypeMismatchBinaryOp {
lhs: dtype,
rhs: arg.dtype(),
op: "cat",
}
.bt())?
}
if arg.device().location() != device.location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: device.location(),
rhs: arg.device().location(),
op: "cat",
}
.bt())?
}
if rank != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: rank,
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx == dim {
cat_dims[dim] += v2;
}
if dim_idx != dim && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
}
let cat_target_dim_len = cat_dims[dim];
let block_size: usize = cat_dims.iter().skip(1 + dim).product();
let shape = Shape::from(cat_dims);
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
let mut storage = device.zeros(&shape, dtype)?;
let mut dst_o = 0;
for arg in args.iter() {
let arg = arg.as_ref();
let arg_dims = arg.shape().dims();
let d1: usize = arg_dims.iter().take(dim).product();
let d2 = block_size * arg_dims[dim];
let dst_s = block_size * cat_target_dim_len;
let src_o = arg.layout().start_offset();
arg.storage().copy2d(
&mut storage,
d1,
d2,
/* src_s */ d2,
dst_s,
src_o,
dst_o,
)?;
dst_o += d2;
}
Ok(crate::tensor::from_storage(storage, shape, op, false))
}
}

View File

@ -107,10 +107,6 @@ impl Var {
Ok(Self(inner)) Ok(Self(inner))
} }
pub fn as_detached_tensor(&self) -> Tensor {
self.0.detach()
}
pub fn as_tensor(&self) -> &Tensor { pub fn as_tensor(&self) -> &Tensor {
&self.0 &self.0
} }

View File

@ -18,9 +18,6 @@ w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose1d(t, w_t) res = torch.nn.functional.conv_transpose1d(t, w_t)
print(res.shape) print(res.shape)
print(res) print(res)
res = torch.nn.functional.conv_transpose1d(t, w_t, groups=2)
print(res.shape)
print(res)
*/ */
fn conv1d(dev: &Device) -> Result<()> { fn conv1d(dev: &Device) -> Result<()> {
let t = Tensor::new( let t = Tensor::new(
@ -53,16 +50,8 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?, test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352] [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
); );
if dev.is_cpu() {
// conv-transposes are not implemented for metal. let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
if dev.is_metal() {
return Ok(());
}
let w = w.transpose(0, 1)?;
// The CPU kernels applied in the contiguous and non contiguous cases are different.
for w in [w.clone(), w.contiguous()?] {
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7]); assert_eq!(res.dims(), [1, 2, 7]);
assert_eq!( assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?, test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@ -71,17 +60,6 @@ fn conv1d(dev: &Device) -> Result<()> {
4.7076, -5.9745, -0.8276, 1.621 4.7076, -5.9745, -0.8276, 1.621
], ],
); );
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 2)?;
assert_eq!(res.dims(), [1, 4, 7]);
assert_eq!(
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
[
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
]
);
} }
Ok(()) Ok(())
} }
@ -168,33 +146,31 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075 10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
] ]
); );
if !dev.is_metal() { let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; assert_eq!(res.dims(), [1, 2, 7, 7]);
assert_eq!(res.dims(), [1, 2, 7, 7]); assert_eq!(
assert_eq!( test_utils::to_vec3_round(&res.i(0)?, 4)?,
test_utils::to_vec3_round(&res.i(0)?, 4)?, [
[ [
[ [-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
[-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277], [1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
[1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375], [0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
[0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889], [0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
[0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632], [-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
[-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985], [2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
[2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114], [5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
[5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579] ],
], [
[ [1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
[1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211], [-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
[-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131], [1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
[1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621], [-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
[-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142], [7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
[7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059], [-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
[-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516], [-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
[-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
]
] ]
); ]
} );
// Dilations. // Dilations.
let res = t.conv2d(&w, 0, 1, 2, 1)?; let res = t.conv2d(&w, 0, 1, 2, 1)?;
assert_eq!(res.dims(), [1, 2, 1, 1]); assert_eq!(res.dims(), [1, 2, 1, 1]);
@ -203,44 +179,36 @@ fn conv2d(dev: &Device) -> Result<()> {
[2.45, -2.3504], [2.45, -2.3504],
); );
if !dev.is_metal() { // Transpose and dilations.
// Transpose and dilations. let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?; assert_eq!(res.dims(), [1, 2, 9, 9]);
assert_eq!(res.dims(), [1, 2, 9, 9]); assert_eq!(
assert_eq!( test_utils::to_vec3_round(&res.i(0)?, 4)?,
test_utils::to_vec3_round(&res.i(0)?, 4)?, [
[ [
[ [-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
[-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277], [2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
[2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499], [-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
[-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376], [-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141], [-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822], [0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03], [-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024],
[ [4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, [5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
-3.5024 ],
], [
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787], [1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579] [-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
], [1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
[ [1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
[1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211], [1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
[-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278], [3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
[1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861], [5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
[1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185], [-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
[1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642], [-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171]
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
[
-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827,
1.0171
]
]
] ]
); ]
} );
Ok(()) Ok(())
} }
@ -294,12 +262,6 @@ fn conv2d_small(dev: &Device) -> Result<()> {
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
] ]
); );
// conv-transposes are not implemented for metal
if dev.is_metal() {
return Ok(());
}
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]); assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!( assert_eq!(
@ -401,10 +363,6 @@ print(w.grad.shape)
print(w.grad[0]) print(w.grad[0])
*/ */
fn conv2d_grad(dev: &Device) -> Result<()> { fn conv2d_grad(dev: &Device) -> Result<()> {
// conv-transposes are not implemented for metal
if dev.is_metal() {
return Ok(());
}
use candle_core::Var; use candle_core::Var;
let t = Var::from_slice( let t = Var::from_slice(
&[ &[

View File

@ -1,4 +1,3 @@
#![allow(clippy::approx_constant)]
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var}; use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
@ -97,24 +96,24 @@ fn unary_grad(device: &Device) -> Result<()> {
let grads = y.backward()?; let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?; let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!( assert_eq!(
test_utils::to_vec1_round(&y, 4)?, y.to_vec1::<f32>()?,
[20.0855, 2.7183, 54.5982, 1.1618] [20.085537, 2.7182817, 54.59815, 1.1618342]
); );
assert_eq!( assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?, grad_x.to_vec1::<f32>()?,
[20.0855, 2.7183, 54.5982, 1.1618] [20.085537, 2.7182817, 54.59815, 1.1618342]
); );
let y = x.exp()?.sqr()?; let y = x.exp()?.sqr()?;
let grads = y.backward()?; let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?; let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!( assert_eq!(
test_utils::to_vec1_round(&y, 3)?, y.to_vec1::<f32>()?,
[403.429, 7.389, 2980.958, 1.35] [403.4288, 7.3890557, 2980.9578, 1.3498588]
); );
// exp(x)^2 = exp(2*x) // exp(x)^2 = exp(2*x)
assert_eq!( assert_eq!(
test_utils::to_vec1_round(grad_x, 2)?, grad_x.to_vec1::<f32>()?,
[806.86, 14.78, 5961.92, 2.7] [806.8576, 14.778111, 5961.9155, 2.6997175]
); );
let y = x.sin()?; let y = x.sin()?;
let grads = y.backward()?; let grads = y.backward()?;
@ -262,7 +261,6 @@ fn unary_grad(device: &Device) -> Result<()> {
let y = elu_x.elu(2.)?; let y = elu_x.elu(2.)?;
let grads = y.backward()?; let grads = y.backward()?;
let grad_x = grads.get(&elu_x).context("no grad for x")?; let grad_x = grads.get(&elu_x).context("no grad for x")?;
assert_eq!( assert_eq!(
test_utils::to_vec1_round(&y, 4)?, test_utils::to_vec1_round(&y, 4)?,
[-1.2642, 0.0000, -1.7293, 3.0000] [-1.2642, 0.0000, -1.7293, 3.0000]
@ -272,51 +270,19 @@ fn unary_grad(device: &Device) -> Result<()> {
[0.7358, 2.0000, 0.2707, 1.0000] [0.7358, 2.0000, 0.2707, 1.0000]
); );
// testing compared to pytorch nn.Silu()
let y = x.silu()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.8577, 0.7311, 3.9281, 0.0806]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[1.0881, 0.9277, 1.0527, 0.5747],
);
if device.is_cpu() {
let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?;
let y = x.interpolate1d(12)?.reshape(36)?;
let z = Tensor::new(
&[
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16.,
17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.,
33., 34., 35., 36.,
],
device,
)?;
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec3_round(grad_x, 4)?,
[[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]]
);
}
// manually checked: see comments // manually checked: see comments
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?; let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?; let y = x.interpolate2d(6, 6)?.reshape(36)?;
#[rustfmt::skip]
let z = Tensor::new( let z = Tensor::new(
&[ &[
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17., 1_f32, 02., 03., 04., 05., 06.,
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 07., 08., 09., 10., 11., 12.,
35., 36., 13., 14., 15., 16., 17., 18.,
19., 20., 21., 22., 23., 24.,
25., 26., 27., 28., 29., 30.,
31., 32., 33., 34., 35., 36.,
], ],
device, device,
)?; )?;
@ -347,11 +313,15 @@ fn unary_grad(device: &Device) -> Result<()> {
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?; let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?; let y = x.interpolate2d(6, 6)?.reshape(36)?;
#[rustfmt::skip]
let z = Tensor::new( let z = Tensor::new(
&[ &[
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17., 1_f32, 02., 03., 04., 05., 06.,
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 07., 08., 09., 10., 11., 12.,
35., 36., 13., 14., 15., 16., 17., 18.,
19., 20., 21., 22., 23., 24.,
25., 26., 27., 28., 29., 30.,
31., 32., 33., 34., 35., 36.,
], ],
device, device,
)?; )?;

View File

@ -2,9 +2,6 @@ use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor};
// https://github.com/huggingface/candle/issues/364 // https://github.com/huggingface/candle/issues/364
fn avg_pool2d(dev: &Device) -> Result<()> { fn avg_pool2d(dev: &Device) -> Result<()> {
if dev.is_metal() {
return Ok(());
}
let data: Vec<f32> = vec![ let data: Vec<f32> = vec![
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
]; ];
@ -22,9 +19,6 @@ fn avg_pool2d(dev: &Device) -> Result<()> {
} }
fn max_pool2d(dev: &Device) -> Result<()> { fn max_pool2d(dev: &Device) -> Result<()> {
if dev.is_metal() {
return Ok(());
}
let data: Vec<f32> = vec![ let data: Vec<f32> = vec![
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1., 1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
]; ];
@ -49,9 +43,6 @@ res = torch.nn.functional.avg_pool2d(t, 2)
print(res) print(res)
*/ */
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> { fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
if dev.is_metal() {
return Ok(());
}
let t = Tensor::new( let t = Tensor::new(
&[ &[
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616, 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,

View File

@ -1,37 +0,0 @@
import torch
from collections import OrderedDict
# Write a trivial tensor to a pt file
a= torch.tensor([[1,2,3,4], [5,6,7,8]])
o = OrderedDict()
o["test"] = a
# Write a trivial tensor to a pt file
torch.save(o, "test.pt")
############################################################################################################
# Write a trivial tensor to a pt file with a key
torch.save({"model_state_dict": o}, "test_with_key.pt")
############################################################################################################
# Create a tensor with fortran contiguous memory layout
import numpy as np
# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers
# For example, creating a 2x3x4 array
array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4))
# Verify the memory order
print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True
print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False
# Step 2: Convert the NumPy array to a PyTorch tensor
tensor_fortran = torch.from_numpy(array_fortran)
# Verify the tensor layout
print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout
# Step 3: Save the PyTorch tensor to a .pth file
torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth')
print("3D Tensor saved with Fortran layout.")

View File

@ -1,31 +0,0 @@
/// Regression test for pth files not loading on Windows.
#[test]
fn test_pth() {
let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
tensors.get("test").unwrap().unwrap();
}
#[test]
fn test_pth_with_key() {
let tensors =
candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
.unwrap();
tensors.get("test").unwrap().unwrap();
}
#[test]
fn test_pth_fortran_congiguous() {
let tensors =
candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap();
let tensor = tensors.get("tensor_fortran").unwrap().unwrap();
assert_eq!(tensor.dims3().unwrap(), (2, 3, 4));
assert_eq!(
tensor.to_vec3::<i64>().unwrap(),
[
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
]
);
}

View File

@ -1,7 +1,6 @@
use candle_core::{ use candle_core::{
bail, bail,
quantized::{self, GgmlDType}, quantized::{self, GgmlDType},
test_device,
test_utils::to_vec2_round, test_utils::to_vec2_round,
Device, Module, Result, Tensor, Device, Module, Result, Tensor,
}; };
@ -15,48 +14,16 @@ const GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS: f32 = 0.0075;
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040; const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040;
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02; const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
fn test_matmul( #[test]
device: &Device, fn quantized_matmul() -> Result<()> {
(b, m, n, k): (usize, usize, usize, usize), let cpu = &Device::Cpu;
dtype: GgmlDType,
) -> Result<()> {
let lhs = (0..(m * k))
.map(|v| v as f32 / (m * k) as f32)
.collect::<Vec<_>>();
let rhs = (0..(k * n))
.map(|v| v as f32 / (n * k) as f32)
.collect::<Vec<_>>();
let lhs = Tensor::from_slice(&lhs, (m, k), device)?;
let rhs = Tensor::from_slice(&rhs, (k, n), device)?;
let mm = lhs.matmul(&rhs)?;
let qtensor = quantized::QTensor::quantize(&rhs.t()?, dtype)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&lhs)?;
let error: f32 = ((&mm - &res)?.abs()? / &mm.abs()?)?
.sum_all()?
.to_scalar()?;
let error = error / (b * m * n) as f32;
assert!(
error <= 0.02,
"Error {error} is too big. \nExpected:\n {mm} \nFound:\n {res}\n for {dtype:?}"
);
Ok(())
}
fn quantized_matmul(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let (m, k, n) = (3, 64, 4); let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>(); let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?; let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
let mut dst = vec![42.; 3 * 4]; let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>(); let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!( assert_eq!(
@ -66,7 +33,6 @@ fn quantized_matmul(device: &Device) -> Result<()> {
341876.0, 994283.0, 1655709.0, 2301518.0 341876.0, 994283.0, 1655709.0, 2301518.0
] ]
); );
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
let mm = tensor_lhs.matmul(&tensor_rhs)?; let mm = tensor_lhs.matmul(&tensor_rhs)?;
assert_eq!( assert_eq!(
mm.to_vec2::<f32>()?, mm.to_vec2::<f32>()?,
@ -77,49 +43,35 @@ fn quantized_matmul(device: &Device) -> Result<()> {
] ]
); );
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?; let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?; let res = matmul.forward(&tensor_lhs)?;
match device { assert_eq!(
Device::Metal(_) => assert_eq!( to_vec2_round(&res, 0)?,
to_vec2_round(&res, 0)?, &[
&[ [85120.0, 214562.0, 345455.0, 474748.0],
[84946.0, 214126.0, 344757.0, 473798.0], [213475.0, 604465.0, 1000686.0, 1388317.0],
[213458.0, 604350.0, 1000469.0, 1387990.0], [341876.0, 994283.0, 1655709.0, 2301518.0]
[341970.0, 994574.0, 1656181.0, 2302182.0] ]
] );
),
_ => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[85120.0, 214562.0, 345455.0, 474748.0],
[213475.0, 604465.0, 1000686.0, 1388317.0],
[341876.0, 994283.0, 1655709.0, 2301518.0]
]
),
}
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
Ok(()) Ok(())
} }
fn quantized_matmul_neg(device: &Device) -> Result<()> { #[test]
// TODO Enable this later when we enable cuda. fn quantized_matmul_neg() -> Result<()> {
if device.is_cuda() { let cpu = &Device::Cpu;
return Ok(());
}
let (m, k, n) = (3, 64, 4); let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k)) let lhs = (0..(m * k))
.map(|v| v as f32 - (m * k) as f32 / 2.0) .map(|v| v as f32 - (m * k) as f32 / 2.0)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?; let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
let mut dst = vec![42.; 3 * 4]; let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..k * n) let rhs = (0..k * n)
.map(|v| v as f32 - (k * n) as f32 / 3.0) .map(|v| v as f32 - (k * n) as f32 / 3.0)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?; let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!( assert_eq!(
@ -139,52 +91,32 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
] ]
); );
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?; let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?; let res = matmul.forward(&tensor_lhs)?;
match device { assert_eq!(
Device::Metal(_) => assert_eq!( to_vec2_round(&res, 0)?,
to_vec2_round(&res, 0)?, &[
&[ [243524.0, -19596.0, -285051.0, -549815.0],
[243666.0, -19714.0, -285433.0, -550453.0], [23777.0, 21651.0, 19398.0, 18367.0],
[23782.0, 21654.0, 19400.0, 18369.0], [-196472.0, 63012.0, 324585.0, 587902.0]
[-196102.0, 63022.0, 324233.0, 587191.0] ]
] );
),
_ => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243524.0, -19596.0, -285051.0, -549815.0],
[23777.0, 21651.0, 19398.0, 18367.0],
[-196472.0, 63012.0, 324585.0, 587902.0]
]
),
}
Ok(()) Ok(())
} }
test_device!( #[test]
quantized_matmul, fn quantize_q4_0() -> Result<()> {
quantized_matmul_cpu, use k_quants::BlockQ4_0;
quantized_matmul_cuda,
quantized_matmul_metal
);
test_device!(
quantized_matmul_neg,
quantized_matmul_neg_cpu,
quantized_matmul_neg_cuda,
quantized_matmul_neg_metal
);
fn quantize_q4_0(device: &Device) -> Result<()> {
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let mut dst = vec![0f32; 32 * 4];
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let mut quant = vec![BlockQ4_0::zeros(); 4];
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?; BlockQ4_0::from_float(&src, &mut quant)?;
let dst = quant.dequantize(device)?; BlockQ4_0::to_float(&quant, dst.as_mut_slice())?;
assert_eq!( assert_eq!(
dst.to_vec1::<f32>()?, dst,
&[ &[
-0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625, -0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625,
11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25, 11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25,
@ -200,17 +132,21 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
127.0, 127.0 127.0, 127.0
] ]
); );
ggml_quantization_error_test(GgmlDType::Q4_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ4_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q4_1(device: &Device) -> Result<()> { #[test]
fn quantize_q4_1() -> Result<()> {
use k_quants::BlockQ4_1;
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let mut dst = vec![0f32; 32 * 4];
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?; let mut quant = vec![BlockQ4_1::zeros(); 4];
let dst = quant.dequantize(device)?; BlockQ4_1::from_float(&src, &mut quant)?;
BlockQ4_1::to_float(&quant, dst.as_mut_slice())?;
assert_eq!( assert_eq!(
round_vector(&dst.to_vec1::<f32>()?), round_vector(&dst),
&[ &[
0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332, 0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332,
12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73, 12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73,
@ -226,17 +162,21 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996 118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996
] ]
); );
ggml_quantization_error_test(GgmlDType::Q4_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ4_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q5_0(device: &Device) -> Result<()> { #[test]
fn quantize_q5_0() -> Result<()> {
use k_quants::BlockQ5_0;
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let mut dst = vec![0f32; 32 * 4];
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?; let mut quant = vec![BlockQ5_0::zeros(); 4];
let dst = quant.dequantize(device)?; BlockQ5_0::from_float(&src, &mut quant)?;
BlockQ5_0::to_float(&quant, dst.as_mut_slice())?;
assert_eq!( assert_eq!(
round_vector(&dst.to_vec1::<f32>()?), round_vector(&dst),
&[ &[
-0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625, -0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625,
11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313, 11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313,
@ -252,17 +192,21 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0 119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0
] ]
); );
ggml_quantization_error_test(GgmlDType::Q5_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ5_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q5_1(device: &Device) -> Result<()> { #[test]
fn quantize_q5_1() -> Result<()> {
use k_quants::BlockQ5_1;
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let mut dst = vec![0f32; 32 * 4];
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?; let mut quant = vec![BlockQ5_1::zeros(); 4];
let dst = quant.dequantize(device)?; BlockQ5_1::from_float(&src, &mut quant)?;
BlockQ5_1::to_float(&quant, dst.as_mut_slice())?;
assert_eq!( assert_eq!(
round_vector(&dst.to_vec1::<f32>()?), dst,
&[ &[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
@ -276,11 +220,13 @@ fn quantize_q5_1(device: &Device) -> Result<()> {
124.0, 125.0, 126.0, 127.0 124.0, 125.0, 126.0, 127.0
] ]
); );
ggml_quantization_error_test(GgmlDType::Q5_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
ggml_quantization_error_test::<BlockQ5_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result<Tensor> { /// Generates a small test vector ranging from -`bound` to `bound` with `size` steps
fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
assert!( assert!(
size % crate::quantized::k_quants::QK_K == 0, size % crate::quantized::k_quants::QK_K == 0,
"size must be a multiple of {}", "size must be a multiple of {}",
@ -290,8 +236,10 @@ fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result<Tensor>
let src = (0..size) let src = (0..size)
.map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.)) .map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let dst = vec![0f32; size];
assert_eq!([src[0], src[size / 2]], [-bound, 0.0]); assert_eq!([src[0], src[size / 2]], [-bound, 0.0]);
Tensor::from_vec(src, (size,), device) (src, dst)
} }
/// Round a vector /// Round a vector
@ -340,12 +288,11 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
/// Similar to the GGML quantization unit test: /// Similar to the GGML quantization unit test:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50 /// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f32) -> Result<()> { fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
let src = create_ggml_like_vector(0.0); let src = create_ggml_like_vector(0.0);
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?; let mut dst = vec![0.0; GGML_TEST_SIZE];
let quant = quantized::QTensor::quantize(&src, dtype)?; let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
let dst = quant.dequantize(device)?; let error = calculate_rmse(src.as_slice(), dst.as_slice());
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
if error > max_error { if error > max_error {
bail!( bail!(
"Quantization error {} exceeds max error {}", "Quantization error {} exceeds max error {}",
@ -356,15 +303,19 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
Ok(()) Ok(())
} }
fn quantize_q2k(device: &Device) -> Result<()> { fn quantize_roundtrip<T: GgmlType>(src: &[f32], dst: &mut [f32]) -> Result<Vec<T>> {
let dtype = GgmlDType::Q2K; let mut quant = vec![T::zeros(); src.len() / T::BLCK_SIZE];
T::from_float(src, &mut quant)?;
T::to_float(&quant, dst)?;
Ok(quant)
}
let src = get_test_vector2(0.5, 1024, device)?; #[test]
let quant = quantized::QTensor::quantize(&src, dtype)?; fn quantize_q2k() -> Result<()> {
let dst = quant.dequantize(device)?; use k_quants::BlockQ2K;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ2K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.1); compare_with_error(dst.as_slice(), src.as_slice(), 0.1);
// Test some specific values // Test some specific values
@ -378,26 +329,20 @@ fn quantize_q2k(device: &Device) -> Result<()> {
[-0.499, -0.366, -0.249, 0.0, 0.295, 0.492] [-0.499, -0.366, -0.249, 0.0, 0.295, 0.492]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ2K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?; ggml_quantization_error_test::<BlockQ2K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
Ok(()) Ok(())
} }
fn quantize_q3k(device: &Device) -> Result<()> { #[test]
let dtype = GgmlDType::Q3K; fn quantize_q3k() -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?; use k_quants::BlockQ3K;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ3K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.03); compare_with_error(dst.as_slice(), src.as_slice(), 0.03);
// Test some specific values // Test some specific values
@ -411,26 +356,20 @@ fn quantize_q3k(device: &Device) -> Result<()> {
[-0.493, -0.37, -0.243, -0.0, 0.292, 0.492] [-0.493, -0.37, -0.243, -0.0, 0.292, 0.492]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ3K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?; ggml_quantization_error_test::<BlockQ3K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
Ok(()) Ok(())
} }
fn quantize_q4k(device: &Device) -> Result<()> { #[test]
let dtype = GgmlDType::Q4K; fn quantize_q4k() -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?; use k_quants::BlockQ4K;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ4K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.017); compare_with_error(dst.as_slice(), src.as_slice(), 0.017);
// Test some specific values // Test some specific values
@ -444,27 +383,21 @@ fn quantize_q4k(device: &Device) -> Result<()> {
[-0.5, -0.373, -0.25, 0.0, 0.288, 0.498] [-0.5, -0.373, -0.25, 0.0, 0.288, 0.498]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ4K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ4K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q5k(device: &Device) -> Result<()> { #[test]
let dtype = GgmlDType::Q5K; fn quantize_q5k() -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?; use k_quants::BlockQ5K;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ5K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.009); compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
// Test some specific values // Test some specific values
assert_eq!( assert_eq!(
@ -474,29 +407,24 @@ fn quantize_q5k(device: &Device) -> Result<()> {
let dst = round_vector(&dst); let dst = round_vector(&dst);
assert_eq!( assert_eq!(
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
[-0.5, -0.373, -0.25, 0.0, 0.279, 0.499] [-0.499, -0.372, -0.249, 0.001, 0.279, 0.499]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ5K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ5K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q6k(device: &Device) -> Result<()> { #[test]
let dtype = GgmlDType::Q6K; fn quantize_q6k() -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?; use k_quants::BlockQ6K;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ6K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.008); compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
// Test some specific values // Test some specific values
@ -510,27 +438,22 @@ fn quantize_q6k(device: &Device) -> Result<()> {
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5] [-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ6K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ6K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q8k(device: &Device) -> Result<()> { #[test]
let dtype = GgmlDType::Q8K; fn quantize_q8k() -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?; use k_quants::BlockQ8K;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ8K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.008); compare_with_error(dst.as_slice(), src.as_slice(), 0.003);
// Test some specific values // Test some specific values
assert_eq!( assert_eq!(
@ -543,79 +466,15 @@ fn quantize_q8k(device: &Device) -> Result<()> {
[-0.5, -0.375, -0.25, -0.0, 0.281, 0.499] [-0.5, -0.375, -0.25, -0.0, 0.281, 0.499]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ8K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ8K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
test_device!(
quantize_q4_0,
quantize_q4_0_cpu,
quantize_q4_0_cuda,
quantize_q4_0_metal
);
test_device!(
quantize_q4_1,
quantize_q4_1_cpu,
quantize_q4_1_cuda,
quantize_q4_1_metal
);
test_device!(
quantize_q5_0,
quantize_q5_0_cpu,
quantize_q5_0_cuda,
quantize_q5_0_metal
);
test_device!(
quantize_q5_1,
quantize_q5_1_cpu,
quantize_q5_1_cuda,
quantize_q5_1_metal
);
test_device!(
quantize_q2k,
quantize_q2k_cpu,
quantize_q2k_cuda,
quantize_q2k_metal
);
test_device!(
quantize_q3k,
quantize_q3k_cpu,
quantize_q3k_cuda,
quantize_q3k_metal
);
test_device!(
quantize_q4k,
quantize_q4k_cpu,
quantize_q4k_cuda,
quantize_q4k_metal
);
test_device!(
quantize_q5k,
quantize_q5k_cpu,
quantize_q5k_cuda,
quantize_q5k_metal
);
test_device!(
quantize_q6k,
quantize_q6k_cpu,
quantize_q6k_cuda,
quantize_q6k_metal
);
test_device!(
quantize_q8k,
quantize_q8k_cpu,
quantize_q8k_cuda,
quantize_q8k_metal
);
/// Very simple dot product implementation /// Very simple dot product implementation
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 { fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(a, b)| a * b).sum() a.iter().zip(b).map(|(a, b)| a * b).sum()
@ -699,6 +558,26 @@ fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Res
Ok(()) Ok(())
} }
fn get_small_tensors(
m: usize,
k: usize,
n: usize,
device: &Device,
) -> Result<(Tensor, Tensor, Tensor)> {
let lhs = (0..m * k)
.map(|i| i as f32 / (m * k) as f32)
.collect::<Vec<_>>();
let rhs = (0..n * k)
.map(|i| i as f32 / (n * k) as f32)
.collect::<Vec<_>>();
let lhs = Tensor::from_vec(lhs, (m, k), device)?;
let rhs = Tensor::from_vec(rhs, (n, k), device)?;
let mm = lhs.matmul(&rhs.t()?)?;
Ok((lhs, rhs, mm))
}
#[test] #[test]
fn quantized_mm() -> Result<()> { fn quantized_mm() -> Result<()> {
ggml_matmul_error_test::<k_quants::BlockQ4_0>()?; ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
@ -732,108 +611,6 @@ fn get_random_tensors(
Ok((lhs, rhs, mm)) Ok((lhs, rhs, mm))
} }
#[macro_export]
macro_rules! quantized_matmul {
// TODO: Switch to generating the two last arguments automatically once concat_idents is
// stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
fn $fn_name(device: &Device) -> Result<()> {
test_matmul(device, (1, 3, 4, 256), $dtype)?;
Ok(())
}
test_device!($fn_name, $fn_name_cpu, $fn_name_cuda, $fn_name_metal);
};
}
quantized_matmul!(
quantized_matmul_q4_0_bis,
quantized_matmul_q4_0_cpu,
quantized_matmul_q4_0_cuda,
quantized_matmul_q4_0_metal,
GgmlDType::Q4_0
);
quantized_matmul!(
quantized_matmul_q4_1_bis,
quantized_matmul_q4_1_cpu,
quantized_matmul_q4_1_cuda,
quantized_matmul_q4_1_metal,
GgmlDType::Q4_1
);
quantized_matmul!(
quantized_matmul_q5_0_bis,
quantized_matmul_q5_0_cpu,
quantized_matmul_q5_0_cuda,
quantized_matmul_q5_0_metal,
GgmlDType::Q5_0
);
quantized_matmul!(
quantized_matmul_q5_1_bis,
quantized_matmul_q5_1_cpu,
quantized_matmul_q5_1_cuda,
quantized_matmul_q5_1_metal,
GgmlDType::Q5_1
);
quantized_matmul!(
quantized_matmul_q8_0_bis,
quantized_matmul_q8_0_cpu,
quantized_matmul_q8_0_cuda,
quantized_matmul_q8_0_metal,
GgmlDType::Q8_0
);
// Not implemented in Ggml
// quantized_matmul!(
// quantized_matmul_q8_1_bis,
// quantized_matmul_q8_1_cpu,
// quantized_matmul_q8_1_cuda,
// quantized_matmul_q8_1_metal,
// GgmlDType::Q8_1
// );
// TODO This is bugged (also bugged in GGML
quantized_matmul!(
quantized_matmul_q2k_bis,
quantized_matmul_q2k_cpu,
quantized_matmul_q2k_cuda,
quantized_matmul_q2k_metal,
GgmlDType::Q2K
);
quantized_matmul!(
quantized_matmul_q3k_bis,
quantized_matmul_q3k_cpu,
quantized_matmul_q3k_cuda,
quantized_matmul_q3k_metal,
GgmlDType::Q3K
);
quantized_matmul!(
quantized_matmul_q4k_bis,
quantized_matmul_q4k_cpu,
quantized_matmul_q4k_cuda,
quantized_matmul_q4k_metal,
GgmlDType::Q4K
);
quantized_matmul!(
quantized_matmul_q5k_bis,
quantized_matmul_q5k_cpu,
quantized_matmul_q5k_cuda,
quantized_matmul_q5k_metal,
GgmlDType::Q5K
);
quantized_matmul!(
quantized_matmul_q6k_bis,
quantized_matmul_q6k_cpu,
quantized_matmul_q6k_cuda,
quantized_matmul_q6k_metal,
GgmlDType::Q6K
);
// Not implemented on metal
// quantized_matmul!(
// quantized_matmul_q8k_bis,
// quantized_matmul_q8k_cpu,
// quantized_matmul_q8k_cuda,
// quantized_matmul_q8k_metal,
// GgmlDType::Q8K
// );
#[test] #[test]
fn quantized_matmul_q2k() -> Result<()> { fn quantized_matmul_q2k() -> Result<()> {
use k_quants::BlockQ2K; use k_quants::BlockQ2K;
@ -846,7 +623,7 @@ fn quantized_matmul_q2k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?; let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;
@ -866,20 +643,30 @@ fn quantized_matmul_q3k() -> Result<()> {
let cpu = &Device::Cpu; let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21); let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; let (lhs, rhs, mm) = get_small_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]); // assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?; // let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); // let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); // assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q3K)?; let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let qmm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]); let error: f32 = ((&mm - &qmm)?.abs()? / &mm.abs()?)?
let dst = mm.flatten_all()?.to_vec1::<f32>()?; .sum_all()?
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); .to_scalar()?;
assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]); let error = error / (m * n) as f32;
// assert_eq!(qmm.dims(), [m, n]);
// let dst = qmm.flatten_all()?.to_vec1::<f32>()?;
// let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
// assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]);
assert!(
error < 0.01,
"{error} is too big, shouldn't exceed a few percent. \nGot:{qmm}\nExpected:\n{mm} "
);
ggml_matmul_error_test::<BlockQ3K>()?; ggml_matmul_error_test::<BlockQ3K>()?;
@ -892,20 +679,30 @@ fn quantized_matmul_q4k() -> Result<()> {
let cpu = &Device::Cpu; let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21); let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; let (lhs, rhs, mm) = get_small_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]); // assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?; // let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); // let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); // assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q4K)?; let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let qmm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]); let error: f32 = ((&mm - &qmm)?.abs()? / &mm.abs()?)?
let dst = mm.flatten_all()?.to_vec1::<f32>()?; .sum_all()?
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); .to_scalar()?;
assert_eq!(dst, [1.125, 1.435, -0.201, 1.589]); let error = error / (m * n) as f32;
assert!(
error < 0.01,
"{error} is too big, shouldn't exceed a few percent. \nGot:{qmm}\nExpected:\n{mm} "
);
// assert_eq!(mm.dims(), [m, n]);
// let dst = mm.flatten_all()?.to_vec1::<f32>()?;
// let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
// assert_eq!(dst, [1.125, 1.435, -0.201, 1.589]);
ggml_matmul_error_test::<BlockQ4K>()?; ggml_matmul_error_test::<BlockQ4K>()?;
@ -924,7 +721,7 @@ fn quantized_matmul_q5k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q5K)?; let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;
@ -951,7 +748,7 @@ fn quantized_matmul_q6k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q6K)?; let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;
@ -976,7 +773,7 @@ fn quantized_matmul_q8k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q8K)?; let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;

View File

@ -120,13 +120,6 @@ fn unary_op(device: &Device) -> Result<()> {
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999] [0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
] ]
); );
assert_eq!(
test_utils::to_vec2_round(&tensor.silu()?, 4)?,
[
[-0.1423, 0.7311, 3.9281, -0.0475, 0.3112],
[2.53, -0.2553, -0.1205, 1.5447, 2.6395]
]
);
assert_eq!( assert_eq!(
test_utils::to_vec2_round(&tensor.ceil()?, 4)?, test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]] [[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
@ -672,31 +665,6 @@ fn cat(device: &Device) -> Result<()> {
[2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0] [2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
] ]
); );
// 3D
let t1 = Tensor::arange(0, 48i64, device)?.reshape((2, 6, 4))?;
let t2 = Tensor::arange(100, 124i64, device)?.reshape((2, 3, 4))?;
let t3 = Tensor::arange(10000, 10032i64, device)?.reshape((2, 4, 4))?;
let t_cat = Tensor::cat(&[&t1, &t2, &t3], 1)?;
let t1 = t1.t()?.contiguous()?.t()?;
let t2 = t2.t()?.contiguous()?.t()?;
let t3 = t3.t()?.contiguous()?.t()?;
let t_cat2 = Tensor::cat(&[&t1, &t2, &t3], 1)?;
let diff = t_cat.eq(&t_cat2)?.to_dtype(DType::F32)?.sum_all()?;
assert_eq!(diff.to_vec0::<f32>()?, 104.0);
assert_eq!(t_cat.i((0, 0, 0))?.to_vec0::<i64>()?, 0);
assert_eq!(t_cat.i((0, 4, 0))?.to_vec0::<i64>()?, 16);
assert_eq!(t_cat.i((0, 5, 0))?.to_vec0::<i64>()?, 20);
assert_eq!(t_cat.i((1, 5, 0))?.to_vec0::<i64>()?, 44);
assert_eq!(t_cat.i((0, 6, 0))?.to_vec0::<i64>()?, 100);
assert_eq!(t_cat.i((1, 6, 0))?.to_vec0::<i64>()?, 112);
assert_eq!(t_cat.i((0, 6, 1))?.to_vec0::<i64>()?, 101);
assert_eq!(t_cat.i((0, 7, 1))?.to_vec0::<i64>()?, 105);
assert_eq!(t_cat.i((0, 12, 1))?.to_vec0::<i64>()?, 10013);
assert_eq!(t_cat.i((1, 12, 3))?.to_vec0::<i64>()?, 10031);
Ok(()) Ok(())
} }
@ -1105,33 +1073,8 @@ fn broadcasting(device: &Device) -> Result<()> {
fn randn(device: &Device) -> Result<()> { fn randn(device: &Device) -> Result<()> {
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?; let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
assert_eq!(tensor.dims(), [5, 3]); assert_eq!(tensor.dims(), [5, 3]);
// Check that the seed gets updated by checking that
// a new series of numbers is generated each time
let tensor2 = Tensor::randn(0f32, 1f32, (5, 3), device)?;
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?; let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
assert_eq!(tensor.dims(), [5, 3]); assert_eq!(tensor.dims(), [5, 3]);
// Check that the seed gets updated by checking that
// a new series of numbers is generated each time
let tensor2 = Tensor::rand(0f32, 1f32, (5, 3), device)?;
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
// We do not expect deterministic elements at any index.
// There once was a bug that had a deterministic zero element in evenly sized tensors.
const N: usize = 2;
let v = (0..100)
.map(|_| Tensor::randn(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
.collect::<Result<Vec<_>>>()?;
assert!(
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
"There are deterministic values in the randn tensors"
);
let v = (0..100)
.map(|_| Tensor::rand(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
.collect::<Result<Vec<_>>>()?;
assert!(
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
"There are deterministic values in the rand tensors"
);
Ok(()) Ok(())
} }
@ -1302,23 +1245,11 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
} }
#[test] #[test]
fn log_sum_exp() -> Result<()> { fn logsumexp() -> Result<()> {
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let output = input.log_sum_exp(D::Minus1)?; let output = input.logsumexp(D::Minus1)?;
// The expectations obtained from pytorch. // The expectations obtained from pytorch.
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?; let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
assert_close(&output, &expected, 0.00001)?; assert_close(&output, &expected, 0.00001)?;
Ok(()) Ok(())
} }
#[test]
fn pow() -> Result<()> {
let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let rhs = (&lhs - 2.)?;
let res = lhs.pow(&rhs)?;
assert_eq!(
test_utils::to_vec2_round(&res, 4)?,
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
);
Ok(())
}

Binary file not shown.

Binary file not shown.

View File

@ -12,7 +12,7 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { workspace = true } candle = { workspace = true }
candle-datasets = { workspace = true, optional = true } candle-datasets = { workspace = true }
candle-nn = { workspace = true } candle-nn = { workspace = true }
candle-transformers = { workspace = true } candle-transformers = { workspace = true }
candle-flash-attn = { workspace = true, optional = true } candle-flash-attn = { workspace = true, optional = true }
@ -21,7 +21,7 @@ candle-onnx = { workspace = true, optional = true }
csv = "1.3.0" csv = "1.3.0"
cudarc = { workspace = true, optional = true } cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true } half = { workspace = true, optional = true }
hf-hub = { workspace = true, features = ["tokio"] } hf-hub = { workspace = true, features=["tokio"]}
image = { workspace = true } image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true } num-traits = { workspace = true }
@ -30,9 +30,7 @@ rayon = { workspace = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = { workspace = true, features = ["onig"] } tokenizers = { workspace = true, features = ["onig"] }
cpal= { version = "0.15.2", optional = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
@ -45,6 +43,7 @@ rusttype = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
tracing-chrome = { workspace = true } tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true } tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1 # Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.29.1" tokio = "1.29.1"
@ -62,7 +61,6 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
nccl = ["cuda", "cudarc/nccl", "dep:half"] nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"] onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"] metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"]
[[example]] [[example]]
name = "llama_multiprocess" name = "llama_multiprocess"
@ -79,25 +77,3 @@ required-features = ["onnx"]
[[example]] [[example]]
name = "onnx_basics" name = "onnx_basics"
required-features = ["onnx"] required-features = ["onnx"]
[[example]]
name = "whisper"
required-features = ["symphonia"]
[[example]]
name = "whisper-microphone"
required-features = ["microphone"]
[[example]]
name = "mnist-training"
required-features = ["candle-datasets"]
[[example]]
name = "llama2-c"
required-features = ["candle-datasets"]
[[example]]
name = "encodec"
required-features = ["symphonia"]

View File

@ -27,5 +27,11 @@ fn main() -> Result<()> {
bindings.write(kdir.rust_target).unwrap() bindings.write(kdir.rust_target).unwrap()
} }
} }
#[cfg(not(feature = "cuda"))]
{
for kdir in KERNEL_DIRS.iter() {
let _file = std::fs::File::create(kdir.rust_target)?;
}
}
Ok(()) Ok(())
} }

View File

@ -106,17 +106,17 @@ pub fn main() -> anyhow::Result<()> {
let config = blip::Config::image_captioning_large(); let config = blip::Config::image_captioning_large();
let device = candle_examples::device(args.cpu)?;
let (image_embeds, device, mut model) = if args.quantized { let (image_embeds, device, mut model) = if args.quantized {
let device = Device::Cpu; let device = Device::Cpu;
let image = load_image(args.image)?.to_device(&device)?; let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}"); println!("loaded image {image:?}");
let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?; let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?; let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?; let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
(image_embeds, device, Model::Q(model)) (image_embeds, device, Model::Q(model))
} else { } else {
let device = candle_examples::device(args.cpu)?;
let image = load_image(args.image)?.to_device(&device)?; let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}"); println!("loaded image {image:?}");

View File

@ -1,237 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::chatglm::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
if tokens.is_empty() {
anyhow::bail!("Empty prompts are not supported in the chatglm model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("</s>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the endoftext token"),
};
print!("{prompt}");
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
weight_file: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => "THUDM/chatglm3-6b".to_string(),
};
let revision = match args.revision {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("lmz/candle-chatglm".to_string())
.get("chatglm-tokenizer.json")?,
};
let filenames = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::glm3_6b();
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -28,7 +28,7 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}"); println!("loaded image {image:?}");
let model_file = match args.model { let model_file = match args.model {

View File

@ -1,23 +0,0 @@
# candle-convnext
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808).
This candle implementation uses a pre-trained ConvNeXt network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example convnext --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which tiny
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 84.09%
bicycle-built-for-two, tandem bicycle, tandem: 4.15%
maillot : 0.74%
crash helmet : 0.54%
unicycle, monocycle : 0.44%
```

View File

@ -1,126 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::convnext;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
Atto,
Femto,
Pico,
Nano,
Tiny,
Small,
Base,
Large,
AttoV2,
FemtoV2,
PicoV2,
NanoV2,
TinyV2,
BaseV2,
LargeV2,
XLarge,
Huge,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::Atto => "convnext_atto.d2_in1k",
Self::Femto => "convnext_femto.d1_in1k",
Self::Pico => "convnext_pico.d1_in1k",
Self::Nano => "convnext_nano.d1h_in1k",
Self::Tiny => "convnext_tiny.fb_in1k",
Self::Small => "convnext_small.fb_in1k",
Self::Base => "convnext_base.fb_in1k",
Self::Large => "convnext_large.fb_in1k",
Self::AttoV2 => "convnextv2_atto.fcmae_ft_in1k",
Self::FemtoV2 => "convnextv2_femto.fcmae_ft_in1k",
Self::PicoV2 => "convnextv2_pico.fcmae_ft_in1k",
Self::NanoV2 => "convnextv2_nano.fcmae_ft_in1k",
Self::TinyV2 => "convnextv2_tiny.fcmae_ft_in1k",
Self::BaseV2 => "convnextv2_base.fcmae_ft_in1k",
Self::LargeV2 => "convnextv2_large.fcmae_ft_in1k",
Self::XLarge => "convnext_xlarge.fb_in22k_ft_in1k",
Self::Huge => "convnextv2_huge.fcmae_ft_in1k",
};
format!("timm/{name}")
}
fn config(&self) -> convnext::Config {
match self {
Self::Atto | Self::AttoV2 => convnext::Config::atto(),
Self::Femto | Self::FemtoV2 => convnext::Config::femto(),
Self::Pico | Self::PicoV2 => convnext::Config::pico(),
Self::Nano | Self::NanoV2 => convnext::Config::nano(),
Self::Tiny | Self::TinyV2 => convnext::Config::tiny(),
Self::Small => convnext::Config::small(),
Self::Base | Self::BaseV2 => convnext::Config::base(),
Self::Large | Self::LargeV2 => convnext::Config::large(),
Self::XLarge => convnext::Config::xlarge(),
Self::Huge => convnext::Config::huge(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::Tiny)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = convnext::convnext(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -31,7 +31,7 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}"); println!("loaded image {image:?}");
let model_file = match args.model { let model_file = match args.model {

View File

@ -47,7 +47,7 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}"); println!("loaded image {image:?}");
let model_file = match args.model { let model_file = match args.model {

View File

@ -1,20 +0,0 @@
# candle-efficientvit
[EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention](https://arxiv.org/abs/2305.07027).
This candle implementation uses a pre-trained EfficientViT (from Microsoft Research Asia) network for inference.
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example efficientvit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which m1
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 69.80%
unicycle, monocycle : 13.03%
bicycle-built-for-two, tandem bicycle, tandem: 9.28%
crash helmet : 2.25%
alp : 0.46%
```

View File

@ -1,99 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::efficientvit;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
M0,
M1,
M2,
M3,
M4,
M5,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::M0 => "m0",
Self::M1 => "m1",
Self::M2 => "m2",
Self::M3 => "m3",
Self::M4 => "m4",
Self::M5 => "m5",
};
format!("timm/efficientvit_{}.r224_in1k", name)
}
fn config(&self) -> efficientvit::Config {
match self {
Self::M0 => efficientvit::Config::m0(),
Self::M1 => efficientvit::Config::m1(),
Self::M2 => efficientvit::Config::m2(),
Self::M3 => efficientvit::Config::m3(),
Self::M4 => efficientvit::Config::m4(),
Self::M5 => efficientvit::Config::m5(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::M0)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = efficientvit::efficientvit(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -1,20 +0,0 @@
# candle-endocec
[EnCodec](https://huggingface.co/facebook/encodec_24khz) is a high-quality audio
compression model using an encoder/decoder architecture with residual vector
quantization.
## Running one example
```bash
cargo run --example encodec --features symphonia --release -- code-to-audio \
candle-examples/examples/encodec/jfk-codes.safetensors \
jfk.wav
```
This decodes the EnCodec tokens stored in `jfk-codes.safetensors` and generates
an output wav file containing the audio data. Instead of `code-to-audio` one
can use:
- `audio-to-audio in.mp3 out.wav`: encodes the input audio file then decodes it to a wav file.
- `audio-to-code in.mp3 out.safetensors`: generates a safetensors file
containing EnCodec tokens for the input audio file.

View File

@ -1,143 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::encodec::{Config, Model};
use clap::{Parser, ValueEnum};
use hf_hub::api::sync::Api;
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Action {
AudioToAudio,
AudioToCode,
CodeToAudio,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The action to be performed, specifies the format for the input and output data.
action: Action,
/// The input file, either an audio file or some encodec tokens stored as safetensors.
in_file: String,
/// The output file, either a wave audio file or some encodec tokens stored as safetensors.
out_file: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The model weight file, in safetensor format.
#[arg(long)]
model: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.model("facebook/encodec_24khz".to_string())
.get("model.safetensors")?,
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let config = Config::default();
let model = Model::new(&config, vb)?;
let codes = match args.action {
Action::CodeToAudio => {
let codes = candle::safetensors::load(args.in_file, &device)?;
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
codes
}
Action::AudioToCode | Action::AudioToAudio => {
let (pcm, sample_rate) = pcm_decode(args.in_file)?;
if sample_rate != 24_000 {
println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}")
}
let pcm_len = pcm.len();
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
println!("input pcm shape: {:?}", pcm.shape());
model.encode(&pcm)?
}
};
println!("codes shape: {:?}", codes.shape());
match args.action {
Action::AudioToCode => {
codes.save_safetensors("codes", &args.out_file)?;
}
Action::AudioToAudio | Action::CodeToAudio => {
let pcm = model.decode(&codes)?;
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
}
}
Ok(())
}

View File

@ -1,27 +0,0 @@
# candle-gemma: 2b and 7b LLMs from Google DeepMind
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
models published by Google Deepmind with a 2b and a 7b variant.
In order to use the example below, you have to accept the license on the
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
your access token via the [HuggingFace cli login
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
## Running the example
```bash
$ cargo run --example gemma --release -- --prompt "fn count_primes(max_n: usize)"
fn count_primes(max_n: usize) -> usize {
let mut primes = vec![true; max_n];
for i in 2..=max_n {
if primes[i] {
for j in i * i..max_n {
primes[j] = false;
}
}
}
primes.len()
}
```

View File

@ -1,256 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::gemma::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => match model_id.as_str() {
"7b-it" => "google/gemma-7b-it".to_string(),
"7b" => "google/gemma-7b".to_string(),
"2b-it" => "google/gemma-2b-it".to_string(),
"2b" => "google/gemma-2b".to_string(),
_ => model_id.to_string(),
},
None => "google/gemma-2b".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -57,7 +57,7 @@ struct Args {
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 10000)] #[arg(long, default_value_t = 100)]
sample_len: usize, sample_len: usize,
/// Disable the key-value cache. /// Disable the key-value cache.
@ -120,7 +120,7 @@ fn main() -> Result<()> {
Some(dtype) => bail!("Unsupported dtype {dtype}"), Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16, None => DType::F16,
}; };
let (llama, tokenizer_filename, mut cache) = { let (llama, tokenizer_filename, cache) = {
let api = Api::new()?; let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| match args.which { let model_id = args.model_id.unwrap_or_else(|| match args.which {
Which::V1 => "Narsil/amall-7b".to_string(), Which::V1 => "Narsil/amall-7b".to_string(),
@ -143,10 +143,11 @@ fn main() -> Result<()> {
} }
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?], Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
}; };
println!("building the model");
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &config)?, tokenizer_filename, cache) (Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
}; };
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN); let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
@ -156,7 +157,6 @@ fn main() -> Result<()> {
.map_err(E::msg)? .map_err(E::msg)?
.get_ids() .get_ids()
.to_vec(); .to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
println!("starting the inference loop"); println!("starting the inference loop");
print!("{prompt}"); print!("{prompt}");
@ -172,7 +172,7 @@ fn main() -> Result<()> {
}; };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, context_index, &mut cache)?; let logits = llama.forward(&input, context_index)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. { let logits = if args.repeat_penalty == 1. {
logits logits
@ -190,16 +190,18 @@ fn main() -> Result<()> {
token_generated += 1; token_generated += 1;
tokens.push(next_token); tokens.push(next_token);
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
if Some(next_token) == eos_token_id { if Some(next_token) == eos_token_id {
break; break;
} }
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
} }
let dt = start_gen.elapsed(); let dt = start_gen.elapsed();
println!( println!(

View File

@ -19,7 +19,7 @@ use candle_transformers::generation::LogitsProcessor;
use std::io::Write; use std::io::Write;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use model::{Cache, Config, Llama}; use model::{Config, Llama};
use qmodel::QLlama; use qmodel::QLlama;
use weights::TransformerWeights; use weights::TransformerWeights;
@ -160,10 +160,10 @@ enum Model {
} }
impl Model { impl Model {
fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result<Tensor> { fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
match self { match self {
Self::Llama(l) => Ok(l.forward(xs, pos, cache)?), Self::Llama(l) => Ok(l.forward(xs, pos)?),
Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?), Self::QLlama(l) => Ok(l.forward(xs, pos)?),
} }
} }
} }
@ -188,8 +188,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
let config = Config::from_reader(&mut file)?; let config = Config::from_reader(&mut file)?;
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?; let vb = weights.var_builder(&config, &device)?;
let mut cache = Cache::new(false, &config, vb.pp("rot"))?; let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, config)?; let model = Llama::load(vb, &cache, config)?;
let tokens = match &args.pretokenized_dir { let tokens = match &args.pretokenized_dir {
None => { None => {
@ -235,7 +235,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter { for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?; let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0, &mut cache)?; let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
println!("{}", loss.to_vec0::<f32>()?); println!("{}", loss.to_vec0::<f32>()?);
} }
@ -261,8 +261,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let is_safetensors = config_path let is_safetensors = config_path
.extension() .extension()
.map_or(false, |v| v == "safetensors"); .map_or(false, |v| v == "safetensors");
let (model, config, mut cache) = if is_gguf { let (model, config) = if is_gguf {
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?; let vb = qmodel::VarBuilder::from_gguf(config_path)?;
let (_vocab_size, dim) = vb let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")? .get_no_shape("model.embed_tokens.weight")?
.shape() .shape()
@ -279,13 +279,13 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
(config.seq_len, config.head_size() / 2), (config.seq_len, config.head_size() / 2),
"rot.freq_cis_real", "rot.freq_cis_real",
)? )?
.dequantize(&device)?; .dequantize(&candle::Device::Cpu)?;
let freq_cis_imag = vb let freq_cis_imag = vb
.get( .get(
(config.seq_len, config.head_size() / 2), (config.seq_len, config.head_size() / 2),
"rot.freq_cis_imag", "rot.freq_cis_imag",
)? )?
.dequantize(&device)?; .dequantize(&candle::Device::Cpu)?;
let fake_vb = candle_nn::VarBuilder::from_tensors( let fake_vb = candle_nn::VarBuilder::from_tensors(
[ [
@ -295,18 +295,18 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.into_iter() .into_iter()
.collect(), .collect(),
candle::DType::F32, candle::DType::F32,
&device, &candle::Device::Cpu,
); );
let cache = model::Cache::new(true, &config, fake_vb)?; let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, config.clone())?); let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
(model, config, cache) (model, config)
} else if is_safetensors { } else if is_safetensors {
let config = Config::tiny_15m(); let config = Config::tiny_15m();
let tensors = candle::safetensors::load(config_path, &device)?; let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device); let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
let cache = model::Cache::new(true, &config, vb.pp("rot"))?; let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, config.clone())?); let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config, cache) (model, config)
} else { } else {
let mut file = std::fs::File::open(config_path)?; let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?; let config = Config::from_reader(&mut file)?;
@ -314,8 +314,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?; let vb = weights.var_builder(&config, &device)?;
let cache = model::Cache::new(true, &config, vb.pp("rot"))?; let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, config.clone())?); let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config, cache) (model, config)
}; };
println!("starting the inference loop"); println!("starting the inference loop");
@ -328,7 +328,6 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.map_err(E::msg)? .map_err(E::msg)?
.get_ids() .get_ids()
.to_vec(); .to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
for index in 0.. { for index in 0.. {
@ -338,7 +337,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let context_size = if index > 0 { 1 } else { tokens.len() }; let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, index_pos, &mut cache)?; let logits = model.forward(&input, index_pos)?;
let logits = logits.i((0, logits.dim(1)? - 1))?; let logits = logits.i((0, logits.dim(1)? - 1))?;
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() { let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
logits logits
@ -354,14 +353,16 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let next_token = logits_processor.sample(&logits)?; let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token); tokens.push(next_token);
if let Some(t) = tokenizer.next_token(next_token)? { // Extracting the last token as a string is complicated, here we just apply some simple
print!("{t}"); // heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?; std::io::stdout().flush()?;
} }
} }
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
let dt = start_gen.elapsed(); let dt = start_gen.elapsed();
println!( println!(
"\n{} tokens generated ({:.2} token/s)\n", "\n{} tokens generated ({:.2} token/s)\n",

View File

@ -8,7 +8,6 @@ fn valid_loss(
model: &Llama, model: &Llama,
args: &crate::TrainingCmd, args: &crate::TrainingCmd,
device: &Device, device: &Device,
cache: &mut Cache,
) -> Result<f64> { ) -> Result<f64> {
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone()); let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
@ -16,7 +15,7 @@ fn valid_loss(
let mut cnt = 0usize; let mut cnt = 0usize;
for inp_tgt in batch_iter.take(50) { for inp_tgt in batch_iter.take(50) {
let (inp, tgt) = inp_tgt?; let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0, cache)?; let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
sum_ce += loss.to_vec0::<f32>()? as f64; sum_ce += loss.to_vec0::<f32>()? as f64;
cnt += 1; cnt += 1;
@ -38,8 +37,8 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone()); let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
let mut cache = Cache::new(false, &config, vb.pp("rot"))?; let cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, config)?; let model = Llama::load(vb, &cache, config)?;
let params = candle_nn::ParamsAdamW { let params = candle_nn::ParamsAdamW {
lr: args.learning_rate, lr: args.learning_rate,
..Default::default() ..Default::default()
@ -47,14 +46,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?; let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
for (batch_index, batch) in batch_iter.enumerate() { for (batch_index, batch) in batch_iter.enumerate() {
let (inp, tgt) = batch?; let (inp, tgt) = batch?;
let logits = model.forward(&inp, 0, &mut cache)?; let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
opt.backward_step(&loss)?; opt.backward_step(&loss)?;
if batch_index > 0 && batch_index % 100 == 0 { if batch_index > 0 && batch_index % 100 == 0 {
// TODO: Add a way to deactivate the backprop graph tracking when computing the // TODO: Add a way to deactivate the backprop graph tracking when computing the
// validation loss. // validation loss.
let loss = valid_loss(&dataset, &model, args, &device, &mut cache)?; let loss = valid_loss(&dataset, &model, args, &device)?;
println!("{batch_index} {loss}"); println!("{batch_index} {loss}");
} }
if batch_index > 0 && batch_index % 1000 == 0 { if batch_index > 0 && batch_index % 1000 == 0 {

View File

@ -2,9 +2,6 @@
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal). This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
Compared to the mamba example, this version can handle training but is much
slower.
## Running the example ## Running the example
```bash ```bash

View File

@ -1,17 +0,0 @@
# candle-mamba: Mamba implementation
Candle implementation of *Mamba* [1] inference only. Mamba is an alternative to
the transformer architecture. It leverages State Space Models (SSMs) with the
goal of being computationally efficient on long sequences. The implementation is
based on [mamba.rs](https://github.com/LaurentMazare/mamba.rs).
- [1]. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752).
Compared to the mamba-minimal example, this version is far more efficient but
would only work for inference.
## Running the example
```bash
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
```

View File

@ -1,299 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::mamba::{Config, Model, State};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
config: Config,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
config: Config,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
config,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let mut state = State::new(1, &self.config, &self.device)?;
let mut next_logits = None;
for &t in tokens.iter() {
let input = Tensor::new(&[t], &self.device)?;
let logits = self.model.forward(&input, &mut state)?;
next_logits = Some(logits);
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
for _ in 0..sample_len {
let logits = match next_logits.as_ref() {
Some(logits) => logits,
None => anyhow::bail!("cannot work on an empty prompt"),
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
let input = Tensor::new(&[next_token], &self.device)?;
next_logits = Some(self.model.forward(&input, &mut state)?)
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
enum Which {
Mamba130m,
Mamba370m,
Mamba790m,
Mamba1_4b,
Mamba2_8b,
Mamba2_8bSlimPj,
}
impl std::fmt::Display for Which {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl Which {
fn model_id(&self) -> &'static str {
match self {
Self::Mamba130m => "state-spaces/mamba-130m",
Self::Mamba370m => "state-spaces/mamba-370m",
Self::Mamba790m => "state-spaces/mamba-790m",
Self::Mamba1_4b => "state-spaces/mamba-1.4b",
Self::Mamba2_8b => "state-spaces/mamba-2.8b",
Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
}
}
fn revision(&self) -> &'static str {
match self {
Self::Mamba130m
| Self::Mamba370m
| Self::Mamba790m
| Self::Mamba1_4b
| Self::Mamba2_8bSlimPj => "refs/pr/1",
Self::Mamba2_8b => "refs/pr/4",
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long, default_value = "mamba130m")]
which: Which,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id
.unwrap_or_else(|| args.which.model_id().to_string()),
RepoType::Model,
args.revision
.unwrap_or_else(|| args.which.revision().to_string()),
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("EleutherAI/gpt-neox-20b".to_string())
.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
vec![repo.get("model.safetensors")?]
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = Model::new(&config, vb.pp("backbone"))?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
config,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,18 +0,0 @@
# candle-metavoice
MetaVoice-1B is a text-to-speech model trained on 100K hours of speech, more
details on the [model
card](https://huggingface.co/metavoiceio/metavoice-1B-v0.1).
Note that the current candle implementation suffers from some limitations as of
2024-03-02:
- The speaker embeddings are hardcoded.
- The generated audio file quality is weaker than the Python implementation,
probably because of some implementation discrepancies.
## Run an example
```bash
cargo run --example metavoice --release -- \\
--prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
```

View File

@ -1,277 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use clap::Parser;
use std::io::Write;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::encodec;
use candle_transformers::models::metavoice::{adapters, gpt, tokenizers, transformer};
use candle_transformers::models::quantized_metavoice::transformer as qtransformer;
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use hf_hub::api::sync::Api;
use rand::{distributions::Distribution, SeedableRng};
pub const ENCODEC_NTOKENS: u32 = 1024;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum ArgDType {
F32,
F16,
Bf16,
}
enum Transformer {
Normal(transformer::Model),
Quantized(qtransformer::Model),
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// Use the quantized version of the model.
#[arg(long)]
quantized: bool,
/// The guidance scale.
#[arg(long, default_value_t = 3.0)]
guidance_scale: f64,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 1.0)]
temperature: f64,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The maximum number of tokens to generate for the first stage.
#[arg(long, default_value_t = 2000)]
max_tokens: u64,
/// The output file using the wav format.
#[arg(long, default_value = "out.wav")]
out_file: String,
#[arg(long)]
first_stage_meta: Option<String>,
#[arg(long)]
first_stage_weights: Option<String>,
#[arg(long)]
second_stage_weights: Option<String>,
#[arg(long)]
encodec_weights: Option<String>,
#[arg(long)]
spk_emb: Option<String>,
#[arg(long, default_value = "f32")]
dtype: ArgDType,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let device = candle_examples::device(args.cpu)?;
let api = Api::new()?;
let repo = api.model("lmz/candle-metavoice".to_string());
let first_stage_meta = match &args.first_stage_meta {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("first_stage.meta.json")?,
};
let first_stage_meta: serde_json::Value =
serde_json::from_reader(&std::fs::File::open(first_stage_meta)?)?;
let first_stage_tokenizer = match first_stage_meta.as_object() {
None => anyhow::bail!("not a json object"),
Some(j) => match j.get("tokenizer") {
None => anyhow::bail!("no tokenizer key"),
Some(j) => j,
},
};
let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?;
let second_stage_weights = match &args.second_stage_weights {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("second_stage.safetensors")?,
};
let encodec_weights = match args.encodec_weights {
Some(w) => std::path::PathBuf::from(w),
None => Api::new()?
.model("facebook/encodec_24khz".to_string())
.get("model.safetensors")?,
};
let dtype = match args.dtype {
ArgDType::F32 => DType::F32,
ArgDType::F16 => DType::F16,
ArgDType::Bf16 => DType::BF16,
};
let first_stage_config = transformer::Config::cfg1b_v0_1();
let mut first_stage_model = if args.quantized {
let filename = match &args.first_stage_weights {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("first_stage_q4k.gguf")?,
};
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let first_stage_model = qtransformer::Model::new(&first_stage_config, vb)?;
Transformer::Quantized(first_stage_model)
} else {
let first_stage_weights = match &args.first_stage_weights {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("first_stage.safetensors")?,
};
let first_stage_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
let first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
Transformer::Normal(first_stage_model)
};
let second_stage_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? };
let second_stage_config = gpt::Config::cfg1b_v0_1();
let second_stage_model = gpt::Model::new(second_stage_config.clone(), second_stage_vb)?;
let encodec_device = if device.is_metal() {
&candle::Device::Cpu
} else {
&device
};
let encodec_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[encodec_weights], dtype, encodec_device)? };
let encodec_config = encodec::Config::default();
let encodec_model = encodec::Model::new(&encodec_config, encodec_vb)?;
println!("prompt: '{}'", args.prompt);
let prompt_tokens = fs_tokenizer.encode(&args.prompt)?;
let mut tokens = prompt_tokens.clone();
println!("{tokens:?}");
let spk_emb_file = match &args.spk_emb {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("spk_emb.safetensors")?,
};
let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?;
let spk_emb = match spk_emb.get("spk_emb") {
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
Some(spk_emb) => spk_emb.to_dtype(dtype)?,
};
let spk_emb = spk_emb.to_device(&device)?;
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95));
// First stage generation.
for index in 0..args.max_tokens {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &device)?;
let input = Tensor::stack(&[&input, &input], 0)?;
let logits = match &mut first_stage_model {
Transformer::Normal(m) => m.forward(&input, &spk_emb, tokens.len() - context_size)?,
Transformer::Quantized(m) => {
m.forward(&input, &spk_emb, tokens.len() - context_size)?
}
};
let logits0 = logits.i((0, 0))?;
let logits1 = logits.i((1, 0))?;
let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?;
let logits = logits.to_dtype(DType::F32)?;
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
print!(".");
std::io::stdout().flush()?;
if next_token == 2048 {
break;
}
}
println!();
let fie2c = adapters::FlattenedInterleavedEncodec2Codebook::new(ENCODEC_NTOKENS);
let (text_ids, ids1, ids2) = fie2c.decode(&tokens);
println!("text ids len: {}", text_ids.len());
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed + 1337);
// TODO: Use the config rather than hardcoding the offset here.
let encoded_text: Vec<_> = prompt_tokens.iter().map(|v| v - 1024).collect();
let mut hierarchies_in1 =
[encoded_text.as_slice(), ids1.as_slice(), &[ENCODEC_NTOKENS]].concat();
let mut hierarchies_in2 = [
vec![ENCODEC_NTOKENS; encoded_text.len()].as_slice(),
ids2.as_slice(),
&[ENCODEC_NTOKENS],
]
.concat();
hierarchies_in1.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
hierarchies_in2.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
let in_x1 = Tensor::new(hierarchies_in1, &device)?;
let in_x2 = Tensor::new(hierarchies_in2, &device)?;
let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?;
let logits = second_stage_model.forward(&in_x)?;
println!("sampling from logits...");
let mut codes = vec![];
for logits in logits.iter() {
let logits = logits.squeeze(0)?;
let (seq_len, _) = logits.dims2()?;
let mut codes_ = Vec::with_capacity(seq_len);
for step in 0..seq_len {
let logits = logits.i(step)?.to_dtype(DType::F32)?;
let logits = &(&logits / 1.0)?;
let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::<f32>()?;
let distr = rand::distributions::WeightedIndex::new(prs.as_slice())?;
let sample = distr.sample(&mut rng) as u32;
codes_.push(sample)
}
codes.push(codes_)
}
let codes = Tensor::new(codes, &device)?.unsqueeze(0)?;
let codes = Tensor::cat(&[in_x, codes], 1)?;
println!("codes: {codes}");
let tilted_encodec = adapters::TiltedEncodec::new(ENCODEC_NTOKENS);
let codes = codes.i(0)?.to_vec2::<u32>()?;
let (text_ids, audio_ids) = tilted_encodec.decode(&codes);
println!("text_ids len: {:?}", text_ids.len());
let audio_ids = Tensor::new(audio_ids, encodec_device)?.unsqueeze(0)?;
println!("audio_ids shape: {:?}", audio_ids.shape());
let pcm = encodec_model.decode(&audio_ids)?;
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
Ok(())
}

View File

@ -152,7 +152,7 @@ struct Args {
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)] #[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize, sample_len: usize,
#[arg(long)] #[arg(long)]
@ -244,14 +244,13 @@ fn main() -> Result<()> {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config = Config::config_7b_v0_1(args.use_flash_attn); let config = Config::config_7b_v0_1(args.use_flash_attn);
let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized { let (model, device) = if args.quantized {
let filename = &filenames[0]; let filename = &filenames[0];
let vb = let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QMistral::new(&config, vb)?; let model = QMistral::new(&config, vb)?;
(Model::Quantized(model), device) (Model::Quantized(model), Device::Cpu)
} else { } else {
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() { let dtype = if device.is_cuda() {
DType::BF16 DType::BF16
} else { } else {

View File

@ -143,7 +143,7 @@ struct Args {
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)] #[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize, sample_len: usize,
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")] #[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]

View File

@ -1,22 +0,0 @@
# candle-mobileone
[MobileOne: An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040).
This candle implementation uses a pre-trained MobileOne network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example mobileone --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which s2
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 79.33%
bicycle-built-for-two, tandem bicycle, tandem: 15.32%
crash helmet : 2.58%
unicycle, monocycle : 1.70%
alp : 0.21%
```

View File

@ -1,96 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::mobileone;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
S0,
S1,
S2,
S3,
S4,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::S0 => "s0",
Self::S1 => "s1",
Self::S2 => "s2",
Self::S3 => "s3",
Self::S4 => "s4",
};
format!("timm/mobileone_{}.apple_in1k", name)
}
fn config(&self) -> mobileone::Config {
match self {
Self::S0 => mobileone::Config::s0(),
Self::S1 => mobileone::Config::s1(),
Self::S2 => mobileone::Config::s2(),
Self::S3 => mobileone::Config::s3(),
Self::S4 => mobileone::Config::s4(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::S0)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = mobileone::mobileone(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -0,0 +1,580 @@
use crate::nn::conv1d_weight_norm;
use candle::{DType, IndexOp, Module, Result, Tensor};
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
// Encodec Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
#[derive(Debug, Clone, PartialEq)]
enum NormType {
WeightNorm,
TimeGroupNorm,
None,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Config {
target_bandwidths: Vec<f64>,
sampling_rate: usize,
audio_channels: usize,
normalize: bool,
chunk_length_s: Option<usize>,
overlap: Option<usize>,
hidden_size: usize,
num_filters: usize,
num_residual_layers: usize,
upsampling_ratios: Vec<usize>,
norm_type: NormType,
kernel_size: usize,
last_kernel_size: usize,
residual_kernel_size: usize,
dilation_growth_rate: usize,
use_causal_conv: bool,
pad_mode: &'static str,
compress: usize,
num_lstm_layers: usize,
trim_right_ratio: f64,
codebook_size: usize,
codebook_dim: Option<usize>,
use_conv_shortcut: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
sampling_rate: 24_000,
audio_channels: 1,
normalize: false,
chunk_length_s: None,
overlap: None,
hidden_size: 128,
num_filters: 32,
num_residual_layers: 1,
upsampling_ratios: vec![8, 5, 4, 2],
norm_type: NormType::WeightNorm,
kernel_size: 7,
last_kernel_size: 7,
residual_kernel_size: 3,
dilation_growth_rate: 2,
use_causal_conv: true,
pad_mode: "reflect",
compress: 2,
num_lstm_layers: 2,
trim_right_ratio: 1.0,
codebook_size: 1024,
codebook_dim: None,
use_conv_shortcut: true,
}
}
}
impl Config {
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
pub fn musicgen_small() -> Self {
Self {
audio_channels: 1,
chunk_length_s: None,
codebook_dim: Some(128),
codebook_size: 2048,
compress: 2,
dilation_growth_rate: 2,
hidden_size: 128,
kernel_size: 7,
last_kernel_size: 7,
norm_type: NormType::WeightNorm,
normalize: false,
num_filters: 64,
num_lstm_layers: 2,
num_residual_layers: 1,
overlap: None,
pad_mode: "reflect",
residual_kernel_size: 3,
sampling_rate: 32_000,
target_bandwidths: vec![2.2],
trim_right_ratio: 1.0,
upsampling_ratios: vec![8, 5, 4, 4],
use_causal_conv: false,
use_conv_shortcut: false,
}
}
fn codebook_dim(&self) -> usize {
self.codebook_dim.unwrap_or(self.codebook_size)
}
fn frame_rate(&self) -> usize {
let hop_length: usize = self.upsampling_ratios.iter().product();
(self.sampling_rate + hop_length - 1) / hop_length
}
fn num_quantizers(&self) -> usize {
let num = 1000f64
* self
.target_bandwidths
.last()
.expect("empty target_bandwidths");
(num as usize) / (self.frame_rate() * 10)
}
}
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340
#[derive(Debug)]
struct EncodecEuclideanCodebook {
inited: Tensor,
cluster_size: Tensor,
embed: Tensor,
embed_avg: Tensor,
}
impl EncodecEuclideanCodebook {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let inited = vb.get(1, "inited")?;
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
let embed = vb.get(e_shape, "embed")?;
let embed_avg = vb.get(e_shape, "embed_avg")?;
Ok(Self {
inited,
cluster_size,
embed,
embed_avg,
})
}
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = self.embed.embedding(embed_ind)?;
Ok(quantize)
}
}
#[derive(Debug)]
struct EncodecVectorQuantization {
codebook: EncodecEuclideanCodebook,
}
impl EncodecVectorQuantization {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?;
Ok(Self { codebook })
}
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = self.codebook.decode(embed_ind)?;
let quantize = quantize.transpose(1, 2)?;
Ok(quantize)
}
}
#[derive(Debug)]
struct EncodecResidualVectorQuantizer {
layers: Vec<EncodecVectorQuantization>,
}
impl EncodecResidualVectorQuantizer {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let vb = &vb.pp("layers");
let layers = (0..cfg.num_quantizers())
.map(|i| EncodecVectorQuantization::load(vb.pp(&i.to_string()), cfg))
.collect::<Result<Vec<_>>>()?;
Ok(Self { layers })
}
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
if codes.dim(0)? != self.layers.len() {
candle::bail!(
"codes shape {:?} does not match the number of quantization layers {}",
codes.shape(),
self.layers.len()
)
}
for (i, layer) in self.layers.iter().enumerate() {
let quantized = layer.decode(&codes.i(i)?)?;
quantized_out = quantized.broadcast_add(&quantized_out)?;
}
Ok(quantized_out)
}
}
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
#[derive(Debug)]
struct EncodecLSTM {
layers: Vec<candle_nn::LSTM>,
}
impl EncodecLSTM {
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let vb = &vb.pp("lstm");
let mut layers = vec![];
for layer_idx in 0..cfg.num_lstm_layers {
let config = candle_nn::LSTMConfig {
layer_idx,
..Default::default()
};
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
layers.push(lstm)
}
Ok(Self { layers })
}
}
impl Module for EncodecLSTM {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
use candle_nn::RNN;
let mut xs = xs.clone();
for layer in self.layers.iter() {
let states = layer.seq(&xs)?;
xs = layer.states_to_tensor(&states)?;
}
Ok(xs)
}
}
#[derive(Debug)]
struct EncodecConvTranspose1d {
weight_g: Tensor,
weight_v: Tensor,
bias: Tensor,
}
impl EncodecConvTranspose1d {
fn load(
in_c: usize,
out_c: usize,
k: usize,
_stride: usize,
vb: VarBuilder,
_cfg: &Config,
) -> Result<Self> {
let vb = &vb.pp("conv");
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
let weight_v = vb.get((in_c, out_c, k), "weight_v")?;
let bias = vb.get(out_c, "bias")?;
Ok(Self {
weight_g,
weight_v,
bias,
})
}
}
impl Module for EncodecConvTranspose1d {
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
}
#[derive(Debug)]
struct EncodecConv1d {
causal: bool,
conv: Conv1d,
norm: Option<candle_nn::GroupNorm>,
}
impl EncodecConv1d {
fn load(
in_c: usize,
out_c: usize,
kernel_size: usize,
stride: usize,
vb: VarBuilder,
cfg: &Config,
) -> Result<Self> {
let conv = match cfg.norm_type {
NormType::WeightNorm => conv1d_weight_norm(
in_c,
out_c,
kernel_size,
Conv1dConfig {
padding: 0,
stride,
groups: 1,
dilation: 1,
},
vb.pp("conv"),
)?,
NormType::None | NormType::TimeGroupNorm => conv1d(
in_c,
out_c,
kernel_size,
Conv1dConfig {
padding: 0,
stride,
groups: 1,
dilation: 1,
},
vb.pp("conv"),
)?,
};
let norm = match cfg.norm_type {
NormType::None | NormType::WeightNorm => None,
NormType::TimeGroupNorm => {
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(gn)
}
};
Ok(Self {
causal: cfg.use_causal_conv,
conv,
norm,
})
}
}
impl Module for EncodecConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: padding, depending on causal.
let xs = self.conv.forward(xs)?;
match &self.norm {
None => Ok(xs),
Some(norm) => xs.apply(norm),
}
}
}
#[derive(Debug)]
struct EncodecResnetBlock {
block_conv1: EncodecConv1d,
block_conv2: EncodecConv1d,
shortcut: Option<EncodecConv1d>,
}
impl EncodecResnetBlock {
fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h = dim / cfg.compress;
let mut layer = Layer::new(vb.pp("block"));
if dilations.len() != 2 {
candle::bail!("expected dilations of size 2")
}
// TODO: Apply dilations!
layer.inc();
let block_conv1 =
EncodecConv1d::load(dim, h, cfg.residual_kernel_size, 1, layer.next(), cfg)?;
layer.inc();
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, layer.next(), cfg)?;
let shortcut = if cfg.use_conv_shortcut {
let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?;
Some(conv)
} else {
None
};
Ok(Self {
block_conv1,
block_conv2,
shortcut,
})
}
}
impl Module for EncodecResnetBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs.clone();
let xs = xs.elu(1.)?;
let xs = self.block_conv1.forward(&xs)?;
let xs = xs.elu(1.)?;
let xs = self.block_conv2.forward(&xs)?;
let xs = match &self.shortcut {
None => (xs + residual)?,
Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
};
Ok(xs)
}
}
struct Layer<'a> {
vb: VarBuilder<'a>,
cnt: usize,
}
impl<'a> Layer<'a> {
fn new(vb: VarBuilder<'a>) -> Self {
Self { vb, cnt: 0 }
}
fn inc(&mut self) {
self.cnt += 1;
}
fn next(&mut self) -> VarBuilder {
let vb = self.vb.pp(&self.cnt.to_string());
self.cnt += 1;
vb
}
}
#[derive(Debug)]
struct EncodecEncoder {
init_conv: EncodecConv1d,
sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
final_lstm: EncodecLSTM,
final_conv: EncodecConv1d,
}
impl EncodecEncoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let mut layer = Layer::new(vb.pp("layers"));
let init_conv = EncodecConv1d::load(
cfg.audio_channels,
cfg.num_filters,
cfg.kernel_size,
1,
layer.next(),
cfg,
)?;
let mut sampling_layers = vec![];
let mut scaling = 1;
for &ratio in cfg.upsampling_ratios.iter().rev() {
let current_scale = scaling * cfg.num_filters;
let mut resnets = vec![];
for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::load(
current_scale,
&[cfg.dilation_growth_rate.pow(j), 1],
layer.next(),
cfg,
)?;
resnets.push(resnet)
}
layer.inc(); // ELU
let conv1d = EncodecConv1d::load(
current_scale,
current_scale * 2,
ratio * 2,
ratio,
layer.next(),
cfg,
)?;
sampling_layers.push((resnets, conv1d));
scaling *= 2;
}
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
layer.inc(); // ELU
let final_conv = EncodecConv1d::load(
cfg.num_filters * scaling,
cfg.hidden_size,
cfg.last_kernel_size,
1,
layer.next(),
cfg,
)?;
Ok(Self {
init_conv,
sampling_layers,
final_conv,
final_lstm,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.init_conv)?;
for (resnets, conv) in self.sampling_layers.iter() {
for resnet in resnets.iter() {
xs = xs.apply(resnet)?;
}
xs = xs.elu(1.0)?.apply(conv)?;
}
xs.apply(&self.final_lstm)?
.elu(1.0)?
.apply(&self.final_conv)
}
}
#[derive(Debug)]
struct EncodecDecoder {
init_conv: EncodecConv1d,
init_lstm: EncodecLSTM,
sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
final_conv: EncodecConv1d,
}
impl EncodecDecoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let mut layer = Layer::new(vb.pp("layers"));
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
let init_conv = EncodecConv1d::load(
cfg.hidden_size,
cfg.num_filters * scaling,
cfg.last_kernel_size,
1,
layer.next(),
cfg,
)?;
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
let mut sampling_layers = vec![];
for &ratio in cfg.upsampling_ratios.iter() {
let current_scale = scaling * cfg.num_filters;
layer.inc(); // ELU
let conv1d = EncodecConvTranspose1d::load(
current_scale,
current_scale / 2,
ratio * 2,
ratio,
layer.next(),
cfg,
)?;
let mut resnets = vec![];
for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::load(
current_scale / 2,
&[cfg.dilation_growth_rate.pow(j), 1],
layer.next(),
cfg,
)?;
resnets.push(resnet)
}
sampling_layers.push((conv1d, resnets));
scaling /= 2;
}
layer.inc(); // ELU
let final_conv = EncodecConv1d::load(
cfg.num_filters,
cfg.audio_channels,
cfg.last_kernel_size,
1,
layer.next(),
cfg,
)?;
Ok(Self {
init_conv,
init_lstm,
sampling_layers,
final_conv,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
for (conv, resnets) in self.sampling_layers.iter() {
xs = xs.elu(1.)?.apply(conv)?;
for resnet in resnets.iter() {
xs = xs.apply(resnet)?
}
}
xs.elu(1.)?.apply(&self.final_conv)
}
}
#[derive(Debug)]
pub struct EncodecModel {
encoder: EncodecEncoder,
decoder: EncodecDecoder,
quantizer: EncodecResidualVectorQuantizer,
}
impl EncodecModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let encoder = EncodecEncoder::load(vb.pp("encoder"), cfg)?;
let decoder = EncodecDecoder::load(vb.pp("decoder"), cfg)?;
let quantizer = EncodecResidualVectorQuantizer::load(vb.pp("quantizer"), cfg)?;
Ok(Self {
encoder,
decoder,
quantizer,
})
}
pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
}

View File

@ -10,7 +10,9 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")] #[cfg(feature = "accelerate")]
extern crate accelerate_src; extern crate accelerate_src;
mod encodec_model;
mod musicgen_model; mod musicgen_model;
mod nn;
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration}; use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};

View File

@ -1,9 +1,10 @@
use crate::encodec_model;
use candle::{DType, Device, Result, Tensor, D}; use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{ use candle_nn::{
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module, embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
VarBuilder, VarBuilder,
}; };
use candle_transformers::models::{encodec, t5}; use candle_transformers::models::t5;
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83 // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@ -371,7 +372,7 @@ impl MusicgenForCausalLM {
#[derive(Debug)] #[derive(Debug)]
pub struct MusicgenForConditionalGeneration { pub struct MusicgenForConditionalGeneration {
pub text_encoder: t5::T5EncoderModel, pub text_encoder: t5::T5EncoderModel,
pub audio_encoder: encodec::Model, pub audio_encoder: crate::encodec_model::EncodecModel,
pub decoder: MusicgenForCausalLM, pub decoder: MusicgenForCausalLM,
cfg: GenConfig, cfg: GenConfig,
} }
@ -380,42 +381,15 @@ pub struct MusicgenForConditionalGeneration {
pub struct GenConfig { pub struct GenConfig {
musicgen: Config, musicgen: Config,
t5: t5::Config, t5: t5::Config,
encodec: encodec::Config, encodec: crate::encodec_model::Config,
} }
impl GenConfig { impl GenConfig {
pub fn small() -> Self { pub fn small() -> Self {
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
let encodec = encodec::Config {
audio_channels: 1,
chunk_length_s: None,
codebook_dim: Some(128),
codebook_size: 2048,
compress: 2,
dilation_growth_rate: 2,
hidden_size: 128,
kernel_size: 7,
last_kernel_size: 7,
norm_type: encodec::NormType::WeightNorm,
normalize: false,
num_filters: 64,
num_lstm_layers: 2,
num_residual_layers: 1,
overlap: None,
// This should be Reflect and not Replicate but Reflect does not work yet.
pad_mode: encodec::PadMode::Replicate,
residual_kernel_size: 3,
sampling_rate: 32_000,
target_bandwidths: vec![2.2],
trim_right_ratio: 1.0,
upsampling_ratios: vec![8, 5, 4, 4],
use_causal_conv: false,
use_conv_shortcut: false,
};
Self { Self {
musicgen: Config::musicgen_small(), musicgen: Config::musicgen_small(),
t5: t5::Config::musicgen_small(), t5: t5::Config::musicgen_small(),
encodec, encodec: encodec_model::Config::musicgen_small(),
} }
} }
} }
@ -427,7 +401,8 @@ impl MusicgenForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> { pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?; let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
let audio_encoder = encodec::Model::new(&cfg.encodec, vb.pp("audio_encoder"))?; let audio_encoder =
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?; let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
Ok(Self { Ok(Self {
text_encoder, text_encoder,

View File

@ -0,0 +1,20 @@
use candle::Result;
use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
// Applies weight norm for inference by recomputing the weight tensor. This
// does not apply to training.
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
pub fn conv1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = vb.get(out_c, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
}

View File

@ -1,39 +1,10 @@
## Using ONNX models in Candle ## Using ONNX models in Candle
This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based models in Candle. This example demonstrates how to run ONNX based models in Candle, the model
being used here is a small sequeezenet variant.
It contains small variants of two models, [SqueezeNet](https://arxiv.org/pdf/1602.07360.pdf) (default) and [EfficientNet](https://arxiv.org/pdf/1905.11946.pdf). You can run the example with the following command:
You can run the examples with following commands:
```bash ```bash
cargo run --example onnx --features=onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
```
Use the `--which` flag to specify explicitly which network to use, i.e.
```bash
$ cargo run --example onnx --features=onnx --release -- --which squeeze-net --image candle-examples/examples/yolo-v8/assets/bike.jpg
Finished release [optimized] target(s) in 0.21s
Running `target/release/examples/onnx --which squeeze-net --image candle-examples/examples/yolo-v8/assets/bike.jpg`
loaded image Tensor[dims 3, 224, 224; f32]
unicycle, monocycle : 83.23%
ballplayer, baseball player : 3.68%
bearskin, busby, shako : 1.54%
military uniform : 0.78%
cowboy hat, ten-gallon hat : 0.76%
```
```bash
$ cargo run --example onnx --features=onnx --release -- --which efficient-net --image candle-examples/examples/yolo-v8/assets/bike.jpg
Finished release [optimized] target(s) in 0.20s
Running `target/release/examples/onnx --which efficient-net --image candle-examples/examples/yolo-v8/assets/bike.jpg`
loaded image Tensor[dims 224, 224, 3; f32]
bicycle-built-for-two, tandem bicycle, tandem : 99.16%
mountain bike, all-terrain bike, off-roader : 0.60%
unicycle, monocycle : 0.17%
crash helmet : 0.02%
alp : 0.02%
``` ```

View File

@ -8,7 +8,6 @@ use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer}; use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer; use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
@ -19,7 +18,6 @@ use tokenizers::Tokenizer;
enum Model { enum Model {
MixFormer(MixFormer), MixFormer(MixFormer),
Phi(Phi),
Quantized(QMixFormer), Quantized(QMixFormer),
} }
@ -86,7 +84,6 @@ impl TextGeneration {
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = match &mut self.model { let logits = match &mut self.model {
Model::MixFormer(m) => m.forward(&input)?, Model::MixFormer(m) => m.forward(&input)?,
Model::Phi(m) => m.forward(&input)?,
Model::Quantized(m) => m.forward(&input)?, Model::Quantized(m) => m.forward(&input)?,
}; };
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
@ -120,7 +117,7 @@ impl TextGeneration {
} }
} }
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)] #[derive(Clone, Copy, Debug, ValueEnum)]
enum WhichModel { enum WhichModel {
#[value(name = "1")] #[value(name = "1")]
V1, V1,
@ -128,8 +125,6 @@ enum WhichModel {
V1_5, V1_5,
#[value(name = "2")] #[value(name = "2")]
V2, V2,
#[value(name = "2-old")]
V2Old,
PuffinPhiV2, PuffinPhiV2,
PhiHermes, PhiHermes,
} }
@ -174,7 +169,7 @@ struct Args {
#[arg(long)] #[arg(long)]
model_id: Option<String>, model_id: Option<String>,
#[arg(long, default_value = "2")] #[arg(long, default_value = "1.5")]
model: WhichModel, model: WhichModel,
#[arg(long)] #[arg(long)]
@ -235,7 +230,7 @@ fn main() -> Result<()> {
match args.model { match args.model {
WhichModel::V1 => "microsoft/phi-1".to_string(), WhichModel::V1 => "microsoft/phi-1".to_string(),
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(), WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(), WhichModel::V2 => "microsoft/phi-2".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string() "lmz/candle-quantized-phi".to_string()
} }
@ -250,9 +245,8 @@ fn main() -> Result<()> {
"main".to_string() "main".to_string()
} else { } else {
match args.model { match args.model {
WhichModel::V1 => "refs/pr/8".to_string(), WhichModel::V1 => "refs/pr/2".to_string(),
WhichModel::V1_5 => "refs/pr/73".to_string(), WhichModel::V1_5 => "refs/pr/18".to_string(),
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"main".to_string() "main".to_string()
} }
@ -264,9 +258,7 @@ fn main() -> Result<()> {
let tokenizer_filename = match args.tokenizer { let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file), Some(file) => std::path::PathBuf::from(file),
None => match args.model { None => match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old => { WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?,
repo.get("tokenizer.json")?
}
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")? repo.get("tokenizer-puffin-phi-v2.json")?
} }
@ -279,14 +271,14 @@ fn main() -> Result<()> {
match args.model { match args.model {
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?], WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?], WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?], WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?],
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?], WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?], WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
} }
} else { } else {
match args.model { match args.model {
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?], WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
WhichModel::V2 | WhichModel::V2Old => candle_examples::hub_load_safetensors( WhichModel::V2 => candle_examples::hub_load_safetensors(
&repo, &repo,
"model.safetensors.index.json", "model.safetensors.index.json",
)?, )?,
@ -300,44 +292,28 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config = || match args.model { let config = match args.model {
WhichModel::V1 => Config::v1(), WhichModel::V1 => Config::v1(),
WhichModel::V1_5 => Config::v1_5(), WhichModel::V1_5 => Config::v1_5(),
WhichModel::V2 | WhichModel::V2Old => Config::v2(), WhichModel::V2 => Config::v2(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
}; };
let device = candle_examples::device(args.cpu)?; let (model, device) = if args.quantized {
let model = if args.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
let config = config();
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&filenames[0],
&device,
)?;
let model = match args.model { let model = match args.model {
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?, WhichModel::V2 => QMixFormer::new_v2(&config, vb)?,
_ => QMixFormer::new(&config, vb)?, _ => QMixFormer::new(&config, vb)?,
}; };
Model::Quantized(model) (Model::Quantized(model), Device::Cpu)
} else { } else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
match args.model { let model = match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => { WhichModel::V2 => MixFormer::new_v2(&config, vb)?,
let config_filename = repo.get("config.json")?; _ => MixFormer::new(&config, vb)?,
let config = std::fs::read_to_string(config_filename)?; };
let config: PhiConfig = serde_json::from_str(&config)?; (Model::MixFormer(model), device)
let phi = Phi::new(&config, vb)?;
Model::Phi(phi)
}
WhichModel::V2Old => {
let config = config();
Model::MixFormer(MixFormer::new_v2(&config, vb)?)
}
WhichModel::PhiHermes | WhichModel::PuffinPhiV2 => {
let config = config();
Model::MixFormer(MixFormer::new(&config, vb)?)
}
}
}; };
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());
@ -417,10 +393,6 @@ fn mmlu<P: AsRef<std::path::Path>>(
m.clear_kv_cache(); m.clear_kv_cache();
m.forward(&input)? m.forward(&input)?
} }
Model::Phi(m) => {
m.clear_kv_cache();
m.forward(&input)?
}
Model::Quantized(m) => { Model::Quantized(m) => {
m.clear_kv_cache(); m.clear_kv_cache();
m.forward(&input)? m.forward(&input)?

View File

@ -132,8 +132,7 @@ impl T5ModelBuilder {
} }
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> { pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
let device = Device::Cpu; let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &device)?;
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
} }

View File

@ -9,7 +9,7 @@ use std::io::Write;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file}; use candle::quantized::{ggml_file, gguf_file};
use candle::Tensor; use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor; use candle_transformers::generation::LogitsProcessor;
use candle_examples::token_output_stream::TokenOutputStream; use candle_examples::token_output_stream::TokenOutputStream;
@ -212,14 +212,6 @@ struct Args {
#[arg(long)] #[arg(long)]
verbose_prompt: bool, verbose_prompt: bool,
/// Process prompt elements separately.
#[arg(long)]
split_prompt: bool,
/// Run on CPU rather than GPU even if a GPU is available.
#[arg(long)]
cpu: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty. /// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)] #[arg(long, default_value_t = 1.1)]
repeat_penalty: f32, repeat_penalty: f32,
@ -369,7 +361,6 @@ fn main() -> anyhow::Result<()> {
let model_path = args.model()?; let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?; let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let mut model = match model_path.extension().and_then(|v| v.to_str()) { let mut model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => { Some("gguf") => {
@ -378,7 +369,7 @@ fn main() -> anyhow::Result<()> {
for (_, tensor) in model.tensor_infos.iter() { for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count(); let elem_count = tensor.shape.elem_count();
total_size_in_bytes += total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
} }
println!( println!(
"loaded {:?} tensors ({}) in {:.2}s", "loaded {:?} tensors ({}) in {:.2}s",
@ -386,16 +377,15 @@ fn main() -> anyhow::Result<()> {
&format_size(total_size_in_bytes), &format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(), start.elapsed().as_secs_f32(),
); );
ModelWeights::from_gguf(model, &mut file, &device)? ModelWeights::from_gguf(model, &mut file)?
} }
Some("ggml" | "bin") | Some(_) | None => { Some("ggml" | "bin") | Some(_) | None => {
let model = ggml_file::Content::read(&mut file, &device) let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
.map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0; let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() { for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count(); let elem_count = tensor.shape().elem_count();
total_size_in_bytes += total_size_in_bytes +=
elem_count * tensor.dtype().type_size() / tensor.dtype().block_size(); elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
} }
println!( println!(
"loaded {:?} tensors ({}) in {:.2}s", "loaded {:?} tensors ({}) in {:.2}s",
@ -495,20 +485,11 @@ fn main() -> anyhow::Result<()> {
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p); let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
let start_prompt_processing = std::time::Instant::now(); let start_prompt_processing = std::time::Instant::now();
let mut next_token = if !args.split_prompt { let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?; let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)? logits_processor.sample(&logits)?
} else {
let mut next_token = 0;
for (pos, token) in prompt_tokens.iter().enumerate() {
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, pos)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?
}
next_token
}; };
let prompt_dt = start_prompt_processing.elapsed(); let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token); all_tokens.push(next_token);
@ -526,7 +507,7 @@ fn main() -> anyhow::Result<()> {
let start_post_prompt = std::time::Instant::now(); let start_post_prompt = std::time::Instant::now();
let mut sampled = 0; let mut sampled = 0;
for index in 0..to_sample { for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?; let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. { let logits = if args.repeat_penalty == 1. {

View File

@ -1,281 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::qwen2::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
enum WhichModel {
#[value(name = "0.5b")]
W0_5b,
#[value(name = "1.8b")]
W1_8b,
#[value(name = "4b")]
W4b,
#[value(name = "7b")]
W7b,
#[value(name = "14b")]
W14b,
#[value(name = "72b")]
W72b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long, default_value = "0.5b")]
model: WhichModel,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => {
let size = match args.model {
WhichModel::W0_5b => "0.5B",
WhichModel::W1_8b => "1.8B",
WhichModel::W4b => "4B",
WhichModel::W7b => "7B",
WhichModel::W14b => "14B",
WhichModel::W72b => "72B",
};
format!("Qwen/Qwen1.5-{size}")
}
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.model {
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
WhichModel::W4b | WhichModel::W7b | WhichModel::W14b | WhichModel::W72b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config_file = repo.get("config.json")?;
let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -411,7 +411,7 @@ impl DDPG<'_> {
pub fn actions(&mut self, state: &Tensor) -> Result<f32> { pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
let actions = self let actions = self
.actor .actor
.forward(&state.detach().unsqueeze(0)?)? .forward(&state.detach()?.unsqueeze(0)?)?
.squeeze(0)?; .squeeze(0)?;
let actions = if self.train { let actions = if self.train {
(actions + self.ou_noise.sample()?)? (actions + self.ou_noise.sample()?)?

View File

@ -74,7 +74,7 @@ pub fn run() -> Result<()> {
loop { loop {
let action = { let action = {
let action_probs: Vec<f32> = let action_probs: Vec<f32> =
softmax(&model.forward(&state.detach().unsqueeze(0)?)?, 1)? softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)?
.squeeze(0)? .squeeze(0)?
.to_vec1()?; .to_vec1()?;
weighted_sample(action_probs, &mut rng)? as i64 weighted_sample(action_probs, &mut rng)? as i64
@ -109,7 +109,7 @@ pub fn run() -> Result<()> {
let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)? let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
.to_dtype(DType::F32)? .to_dtype(DType::F32)?
.detach(); .detach()?;
let actions_mask = { let actions_mask = {
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect(); let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
@ -126,12 +126,12 @@ pub fn run() -> Result<()> {
.unwrap() .unwrap()
}) })
.collect(); .collect();
Tensor::stack(&actions_mask, 0)?.detach() Tensor::stack(&actions_mask, 0)?.detach()?
}; };
let states = { let states = {
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect(); let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
Tensor::stack(&states, 0)?.detach() Tensor::stack(&states, 0)?.detach()?
}; };
let log_probs = actions_mask let log_probs = actions_mask

View File

@ -236,15 +236,16 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let config = Config::replit_code_v1_5_3b(); let config = Config::replit_code_v1_5_3b();
let model = if args.quantized { let (model, device) = if args.quantized {
let vb = let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?; let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
Model::Q(Q::new(&config, vb.pp("transformer"))?) (model, Device::Cpu)
} else { } else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
Model::M(M::new(&config, vb.pp("transformer"))?) let model = Model::M(M::new(&config, vb.pp("transformer"))?);
(model, device)
}; };
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());

View File

@ -1,22 +0,0 @@
# candle-repvgg
[RepVGG: Making VGG-style ConvNets Great Again](https://arxiv.org/abs/2101.03697).
This candle implementation uses a pre-trained RepVGG network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example repvgg --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 61.70%
bicycle-built-for-two, tandem bicycle, tandem: 33.14%
unicycle, monocycle : 4.88%
crash helmet : 0.15%
moped : 0.04%
```

View File

@ -1,111 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::repvgg;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
A0,
A1,
A2,
B0,
B1,
B2,
B3,
B1G4,
B2G4,
B3G4,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::A0 => "a0",
Self::A1 => "a1",
Self::A2 => "a2",
Self::B0 => "b0",
Self::B1 => "b1",
Self::B2 => "b2",
Self::B3 => "b3",
Self::B1G4 => "b1g4",
Self::B2G4 => "b2g4",
Self::B3G4 => "b3g4",
};
format!("timm/repvgg_{}.rvgg_in1k", name)
}
fn config(&self) -> repvgg::Config {
match self {
Self::A0 => repvgg::Config::a0(),
Self::A1 => repvgg::Config::a1(),
Self::A2 => repvgg::Config::a2(),
Self::B0 => repvgg::Config::b0(),
Self::B1 => repvgg::Config::b1(),
Self::B2 => repvgg::Config::b2(),
Self::B3 => repvgg::Config::b3(),
Self::B1G4 => repvgg::Config::b1g4(),
Self::B2G4 => repvgg::Config::b2g4(),
Self::B3G4 => repvgg::Config::b3g4(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::A0)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = repvgg::repvgg(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -45,7 +45,7 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}"); println!("loaded image {image:?}");
let model_file = match args.model { let model_file = match args.model {

View File

@ -1,17 +0,0 @@
## candle-rwkv
The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
with performance on par with transformer architectures. Several variants are
available, candle implements the v5 and v6 versions and can be used with
Eagle 7B([blog post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
```bash
$ cargo run --example rwkv --release -- --prompt "The smallest prime is "
avx: true, neon: false, simd128: false, f16c: true
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
The smallest prime is ϕ(2) = 2.
The smallest composite is ϕ(3) = 3.
The smallest perfect number is ϕ(5) = 5.
The smallest perfect square is ϕ(4) = 4.
The smallest perfect cube is ϕ(6) = 6.
```

View File

@ -1,330 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use clap::{Parser, ValueEnum};
use candle_transformers::models::quantized_rwkv_v5::Model as Q5;
use candle_transformers::models::quantized_rwkv_v6::Model as Q6;
use candle_transformers::models::rwkv_v5::{Config, Model as M5, State, Tokenizer};
use candle_transformers::models::rwkv_v6::Model as M6;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
const EOS_TOKEN_ID: u32 = 261;
enum Model {
M5(M5),
Q5(Q5),
M6(M6),
Q6(Q6),
}
impl Model {
fn forward(&self, xs: &Tensor, state: &mut State) -> candle::Result<Tensor> {
match self {
Self::M5(m) => m.forward(xs, state),
Self::Q5(m) => m.forward(xs, state),
Self::M6(m) => m.forward(xs, state),
Self::Q6(m) => m.forward(xs, state),
}
}
}
struct TextGeneration {
model: Model,
config: Config,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
config: Config,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
config,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
let mut tokens = self.tokenizer.encode(prompt)?;
let mut generated_tokens = 0usize;
let mut state = State::new(1, &self.config, &self.device)?;
let mut next_logits = None;
for &t in tokens.iter() {
let input = Tensor::new(&[[t]], &self.device)?;
let logits = self.model.forward(&input, &mut state)?;
next_logits = Some(logits);
print!("{}", self.tokenizer.decode(&[t])?)
}
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
for _ in 0..sample_len {
let logits = match next_logits.as_ref() {
Some(logits) => logits,
None => anyhow::bail!("cannot work on an empty prompt"),
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == EOS_TOKEN_ID || next_token == 0 {
break;
}
print!("{}", self.tokenizer.decode(&[next_token])?);
std::io::stdout().flush()?;
let input = Tensor::new(&[[next_token]], &self.device)?;
next_logits = Some(self.model.forward(&input, &mut state)?)
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
enum Which {
Eagle7b,
World1b5,
World3b,
World6_1b6,
}
impl std::fmt::Display for Which {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl Which {
fn model_id(&self) -> &'static str {
match self {
Self::Eagle7b => "RWKV/v5-Eagle-7B-HF",
Self::World1b5 => "RWKV/rwkv-5-world-1b5",
Self::World3b => "RWKV/rwkv-5-world-3b",
Self::World6_1b6 => "paperfun/rwkv",
}
}
fn revision(&self) -> &'static str {
match self {
Self::Eagle7b => "refs/pr/1",
Self::World1b5 | Self::World3b => "refs/pr/2",
Self::World6_1b6 => "main",
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long, default_value = "world1b5")]
which: Which,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
quantized: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id
.unwrap_or_else(|| args.which.model_id().to_string()),
RepoType::Model,
args.revision
.unwrap_or_else(|| args.which.revision().to_string()),
));
let tokenizer = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("lmz/candle-rwkv".to_string())
.get("rwkv_vocab_v20230424.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
if args.quantized {
vec![match args.which {
Which::World1b5 => api
.model("lmz/candle-rwkv".to_string())
.get("world1b5-q4k.gguf")?,
Which::World3b => api
.model("lmz/candle-rwkv".to_string())
.get("world3b-q4k.gguf")?,
Which::Eagle7b => api
.model("lmz/candle-rwkv".to_string())
.get("eagle7b-q4k.gguf")?,
Which::World6_1b6 => repo.get("rwkv-6-world-1b6-q4k.gguf")?,
}]
} else {
vec![match args.which {
Which::World1b5 | Which::World3b | Which::Eagle7b => {
repo.get("model.safetensors")?
}
Which::World6_1b6 => repo.get("rwkv-6-world-1b6.safetensors")?,
}]
}
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::new(tokenizer)?;
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
let model = if args.quantized {
let filename = &filenames[0];
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
match args.which {
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::Q5(Q5::new(&config, vb)?),
Which::World6_1b6 => Model::Q6(Q6::new(&config, vb)?),
}
} else {
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
match args.which {
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::M5(M5::new(&config, vb)?),
Which::World6_1b6 => Model::M6(M6::new(&config, vb)?),
}
};
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
config,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,28 +0,0 @@
# candle-segformer
- [HuggingFace Segformer Model Card][segformer]
- [`mit-b0` - An encoder only pretrained model][encoder]
- [`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation][ade512]
## How to run the example
If you want you can use the example images from this [pull request][pr], download them and supply the path to the image as an argument to the example.
```bash
# run the image classification task
cargo run --example segformer classify <path-to-image>
# run the segmentation task
cargo run --example segformer segment <path-to-image>
```
Example output for classification:
```text
classification logits [3.275261e-5, 0.0008562019, 0.0008868563, 0.9977506, 0.0002465068, 0.0002241473, 2.846596e-6]
label: hamburger
```
[pr]: https://github.com/huggingface/candle/pull/1617
[segformer]: https://huggingface.co/docs/transformers/model_doc/segformer
[encoder]: https://huggingface.co/nvidia/mit-b0
[ade512]: https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512

View File

@ -1,752 +0,0 @@
[
{
"index": 1,
"color": "#787878",
"label": "wall"
},
{
"index": 2,
"color": "#B47878",
"label": "building;edifice"
},
{
"index": 3,
"color": "#06E6E6",
"label": "sky"
},
{
"index": 4,
"color": "#503232",
"label": "floor;flooring"
},
{
"index": 5,
"color": "#04C803",
"label": "tree"
},
{
"index": 6,
"color": "#787850",
"label": "ceiling"
},
{
"index": 7,
"color": "#8C8C8C",
"label": "road;route"
},
{
"index": 8,
"color": "#CC05FF",
"label": "bed"
},
{
"index": 9,
"color": "#E6E6E6",
"label": "windowpane;window"
},
{
"index": 10,
"color": "#04FA07",
"label": "grass"
},
{
"index": 11,
"color": "#E005FF",
"label": "cabinet"
},
{
"index": 12,
"color": "#EBFF07",
"label": "sidewalk;pavement"
},
{
"index": 13,
"color": "#96053D",
"label": "person;individual;someone;somebody;mortal;soul"
},
{
"index": 14,
"color": "#787846",
"label": "earth;ground"
},
{
"index": 15,
"color": "#08FF33",
"label": "door;double;door"
},
{
"index": 16,
"color": "#FF0652",
"label": "table"
},
{
"index": 17,
"color": "#8FFF8C",
"label": "mountain;mount"
},
{
"index": 18,
"color": "#CCFF04",
"label": "plant;flora;plant;life"
},
{
"index": 19,
"color": "#FF3307",
"label": "curtain;drape;drapery;mantle;pall"
},
{
"index": 20,
"color": "#CC4603",
"label": "chair"
},
{
"index": 21,
"color": "#0066C8",
"label": "car;auto;automobile;machine;motorcar"
},
{
"index": 22,
"color": "#3DE6FA",
"label": "water"
},
{
"index": 23,
"color": "#FF0633",
"label": "painting;picture"
},
{
"index": 24,
"color": "#0B66FF",
"label": "sofa;couch;lounge"
},
{
"index": 25,
"color": "#FF0747",
"label": "shelf"
},
{
"index": 26,
"color": "#FF09E0",
"label": "house"
},
{
"index": 27,
"color": "#0907E6",
"label": "sea"
},
{
"index": 28,
"color": "#DCDCDC",
"label": "mirror"
},
{
"index": 29,
"color": "#FF095C",
"label": "rug;carpet;carpeting"
},
{
"index": 30,
"color": "#7009FF",
"label": "field"
},
{
"index": 31,
"color": "#08FFD6",
"label": "armchair"
},
{
"index": 32,
"color": "#07FFE0",
"label": "seat"
},
{
"index": 33,
"color": "#FFB806",
"label": "fence;fencing"
},
{
"index": 34,
"color": "#0AFF47",
"label": "desk"
},
{
"index": 35,
"color": "#FF290A",
"label": "rock;stone"
},
{
"index": 36,
"color": "#07FFFF",
"label": "wardrobe;closet;press"
},
{
"index": 37,
"color": "#E0FF08",
"label": "lamp"
},
{
"index": 38,
"color": "#6608FF",
"label": "bathtub;bathing;tub;bath;tub"
},
{
"index": 39,
"color": "#FF3D06",
"label": "railing;rail"
},
{
"index": 40,
"color": "#FFC207",
"label": "cushion"
},
{
"index": 41,
"color": "#FF7A08",
"label": "base;pedestal;stand"
},
{
"index": 42,
"color": "#00FF14",
"label": "box"
},
{
"index": 43,
"color": "#FF0829",
"label": "column;pillar"
},
{
"index": 44,
"color": "#FF0599",
"label": "signboard;sign"
},
{
"index": 45,
"color": "#0633FF",
"label": "chest;of;drawers;chest;bureau;dresser"
},
{
"index": 46,
"color": "#EB0CFF",
"label": "counter"
},
{
"index": 47,
"color": "#A09614",
"label": "sand"
},
{
"index": 48,
"color": "#00A3FF",
"label": "sink"
},
{
"index": 49,
"color": "#8C8C8C",
"label": "skyscraper"
},
{
"index": 50,
"color": "#FA0A0F",
"label": "fireplace;hearth;open;fireplace"
},
{
"index": 51,
"color": "#14FF00",
"label": "refrigerator;icebox"
},
{
"index": 52,
"color": "#1FFF00",
"label": "grandstand;covered;stand"
},
{
"index": 53,
"color": "#FF1F00",
"label": "path"
},
{
"index": 54,
"color": "#FFE000",
"label": "stairs;steps"
},
{
"index": 55,
"color": "#99FF00",
"label": "runway"
},
{
"index": 56,
"color": "#0000FF",
"label": "case;display;case;showcase;vitrine"
},
{
"index": 57,
"color": "#FF4700",
"label": "pool;table;billiard;table;snooker;table"
},
{
"index": 58,
"color": "#00EBFF",
"label": "pillow"
},
{
"index": 59,
"color": "#00ADFF",
"label": "screen;door;screen"
},
{
"index": 60,
"color": "#1F00FF",
"label": "stairway;staircase"
},
{
"index": 61,
"color": "#0BC8C8",
"label": "river"
},
{
"index": 62,
"color": "#FF5200",
"label": "bridge;span"
},
{
"index": 63,
"color": "#00FFF5",
"label": "bookcase"
},
{
"index": 64,
"color": "#003DFF",
"label": "blind;screen"
},
{
"index": 65,
"color": "#00FF70",
"label": "coffee;table;cocktail;table"
},
{
"index": 66,
"color": "#00FF85",
"label": "toilet;can;commode;crapper;pot;potty;stool;throne"
},
{
"index": 67,
"color": "#FF0000",
"label": "flower"
},
{
"index": 68,
"color": "#FFA300",
"label": "book"
},
{
"index": 69,
"color": "#FF6600",
"label": "hill"
},
{
"index": 70,
"color": "#C2FF00",
"label": "bench"
},
{
"index": 71,
"color": "#008FFF",
"label": "countertop"
},
{
"index": 72,
"color": "#33FF00",
"label": "stove;kitchen;stove;range;kitchen;range;cooking;stove"
},
{
"index": 73,
"color": "#0052FF",
"label": "palm;palm;tree"
},
{
"index": 74,
"color": "#00FF29",
"label": "kitchen;island"
},
{
"index": 75,
"color": "#00FFAD",
"label": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system"
},
{
"index": 76,
"color": "#0A00FF",
"label": "swivel;chair"
},
{
"index": 77,
"color": "#ADFF00",
"label": "boat"
},
{
"index": 78,
"color": "#00FF99",
"label": "bar"
},
{
"index": 79,
"color": "#FF5C00",
"label": "arcade;machine"
},
{
"index": 80,
"color": "#FF00FF",
"label": "hovel;hut;hutch;shack;shanty"
},
{
"index": 81,
"color": "#FF00F5",
"label": "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle"
},
{
"index": 82,
"color": "#FF0066",
"label": "towel"
},
{
"index": 83,
"color": "#FFAD00",
"label": "light;light;source"
},
{
"index": 84,
"color": "#FF0014",
"label": "truck;motortruck"
},
{
"index": 85,
"color": "#FFB8B8",
"label": "tower"
},
{
"index": 86,
"color": "#001FFF",
"label": "chandelier;pendant;pendent"
},
{
"index": 87,
"color": "#00FF3D",
"label": "awning;sunshade;sunblind"
},
{
"index": 88,
"color": "#0047FF",
"label": "streetlight;street;lamp"
},
{
"index": 89,
"color": "#FF00CC",
"label": "booth;cubicle;stall;kiosk"
},
{
"index": 90,
"color": "#00FFC2",
"label": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box"
},
{
"index": 91,
"color": "#00FF52",
"label": "airplane;aeroplane;plane"
},
{
"index": 92,
"color": "#000AFF",
"label": "dirt;track"
},
{
"index": 93,
"color": "#0070FF",
"label": "apparel;wearing;apparel;dress;clothes"
},
{
"index": 94,
"color": "#3300FF",
"label": "pole"
},
{
"index": 95,
"color": "#00C2FF",
"label": "land;ground;soil"
},
{
"index": 96,
"color": "#007AFF",
"label": "bannister;banister;balustrade;balusters;handrail"
},
{
"index": 97,
"color": "#00FFA3",
"label": "escalator;moving;staircase;moving;stairway"
},
{
"index": 98,
"color": "#FF9900",
"label": "ottoman;pouf;pouffe;puff;hassock"
},
{
"index": 99,
"color": "#00FF0A",
"label": "bottle"
},
{
"index": 100,
"color": "#FF7000",
"label": "buffet;counter;sideboard"
},
{
"index": 101,
"color": "#8FFF00",
"label": "poster;posting;placard;notice;bill;card"
},
{
"index": 102,
"color": "#5200FF",
"label": "stage"
},
{
"index": 103,
"color": "#A3FF00",
"label": "van"
},
{
"index": 104,
"color": "#FFEB00",
"label": "ship"
},
{
"index": 105,
"color": "#08B8AA",
"label": "fountain"
},
{
"index": 106,
"color": "#8500FF",
"label": "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter"
},
{
"index": 107,
"color": "#00FF5C",
"label": "canopy"
},
{
"index": 108,
"color": "#B800FF",
"label": "washer;automatic;washer;washing;machine"
},
{
"index": 109,
"color": "#FF001F",
"label": "plaything;toy"
},
{
"index": 110,
"color": "#00B8FF",
"label": "swimming;pool;swimming;bath;natatorium"
},
{
"index": 111,
"color": "#00D6FF",
"label": "stool"
},
{
"index": 112,
"color": "#FF0070",
"label": "barrel;cask"
},
{
"index": 113,
"color": "#5CFF00",
"label": "basket;handbasket"
},
{
"index": 114,
"color": "#00E0FF",
"label": "waterfall;falls"
},
{
"index": 115,
"color": "#70E0FF",
"label": "tent;collapsible;shelter"
},
{
"index": 116,
"color": "#46B8A0",
"label": "bag"
},
{
"index": 117,
"color": "#A300FF",
"label": "minibike;motorbike"
},
{
"index": 118,
"color": "#9900FF",
"label": "cradle"
},
{
"index": 119,
"color": "#47FF00",
"label": "oven"
},
{
"index": 120,
"color": "#FF00A3",
"label": "ball"
},
{
"index": 121,
"color": "#FFCC00",
"label": "food;solid;food"
},
{
"index": 122,
"color": "#FF008F",
"label": "step;stair"
},
{
"index": 123,
"color": "#00FFEB",
"label": "tank;storage;tank"
},
{
"index": 124,
"color": "#85FF00",
"label": "trade;name;brand;name;brand;marque"
},
{
"index": 125,
"color": "#FF00EB",
"label": "microwave;microwave;oven"
},
{
"index": 126,
"color": "#F500FF",
"label": "pot;flowerpot"
},
{
"index": 127,
"color": "#FF007A",
"label": "animal;animate;being;beast;brute;creature;fauna"
},
{
"index": 128,
"color": "#FFF500",
"label": "bicycle;bike;wheel;cycle"
},
{
"index": 129,
"color": "#0ABED4",
"label": "lake"
},
{
"index": 130,
"color": "#D6FF00",
"label": "dishwasher;dish;washer;dishwashing;machine"
},
{
"index": 131,
"color": "#00CCFF",
"label": "screen;silver;screen;projection;screen"
},
{
"index": 132,
"color": "#1400FF",
"label": "blanket;cover"
},
{
"index": 133,
"color": "#FFFF00",
"label": "sculpture"
},
{
"index": 134,
"color": "#0099FF",
"label": "hood;exhaust;hood"
},
{
"index": 135,
"color": "#0029FF",
"label": "sconce"
},
{
"index": 136,
"color": "#00FFCC",
"label": "vase"
},
{
"index": 137,
"color": "#2900FF",
"label": "traffic;light;traffic;signal;stoplight"
},
{
"index": 138,
"color": "#29FF00",
"label": "tray"
},
{
"index": 139,
"color": "#AD00FF",
"label": "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin"
},
{
"index": 140,
"color": "#00F5FF",
"label": "fan"
},
{
"index": 141,
"color": "#4700FF",
"label": "pier;wharf;wharfage;dock"
},
{
"index": 142,
"color": "#7A00FF",
"label": "crt;screen"
},
{
"index": 143,
"color": "#00FFB8",
"label": "plate"
},
{
"index": 144,
"color": "#005CFF",
"label": "monitor;monitoring;device"
},
{
"index": 145,
"color": "#B8FF00",
"label": "bulletin;board;notice;board"
},
{
"index": 146,
"color": "#0085FF",
"label": "shower"
},
{
"index": 147,
"color": "#FFD600",
"label": "radiator"
},
{
"index": 148,
"color": "#19C2C2",
"label": "glass;drinking;glass"
},
{
"index": 149,
"color": "#66FF00",
"label": "clock"
},
{
"index": 150,
"color": "#5C00FF",
"label": "flag"
}
]

View File

@ -1,155 +0,0 @@
use candle::Device;
use candle::Module;
use candle_nn::VarBuilder;
use candle_transformers::models::segformer::{
Config, ImageClassificationModel, SemanticSegmentationModel,
};
use clap::{Args, Parser, Subcommand};
use image::Rgb;
use imageproc::integral_image::ArrayData;
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Parser)]
#[clap(about, version, long_about = None)]
struct CliArgs {
#[arg(long, help = "use cpu")]
cpu: bool,
#[command(subcommand)]
command: Commands,
}
#[derive(Args, Debug)]
struct SegmentationArgs {
#[arg(
long,
help = "name of the huggingface hub model",
default_value = "nvidia/segformer-b0-finetuned-ade-512-512"
)]
model_name: String,
#[arg(
long,
help = "path to the label file in json format",
default_value = "candle-examples/examples/segformer/assets/labels.json"
)]
label_path: PathBuf,
#[arg(long, help = "path to for the output mask image")]
output_path: PathBuf,
#[arg(help = "path to image as input")]
image: PathBuf,
}
#[derive(Args, Debug)]
struct ClassificationArgs {
#[arg(
long,
help = "name of the huggingface hub model",
default_value = "paolinox/segformer-finetuned-food101"
)]
model_name: String,
#[arg(help = "path to image as input")]
image: PathBuf,
}
#[derive(Subcommand, Debug)]
enum Commands {
Segment(SegmentationArgs),
Classify(ClassificationArgs),
}
fn get_vb_and_config(model_name: String, device: &Device) -> anyhow::Result<(VarBuilder, Config)> {
println!("loading model {} via huggingface hub", model_name);
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name.clone());
let model_file = api.get("model.safetensors")?;
println!("model {} downloaded and loaded", model_name);
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, device)? };
let config = std::fs::read_to_string(api.get("config.json")?)?;
let config: Config = serde_json::from_str(&config)?;
println!("{:?}", config);
Ok((vb, config))
}
#[derive(Debug, serde::Deserialize)]
struct LabelItem {
index: u32,
color: String,
}
fn segmentation_task(args: SegmentationArgs, device: &Device) -> anyhow::Result<()> {
let label_file = std::fs::read_to_string(&args.label_path)?;
let label_items: Vec<LabelItem> = serde_json::from_str(&label_file)?;
let label_colors: HashMap<u32, Rgb<u8>> = label_items
.iter()
.map(|x| {
(x.index - 1, {
let color = x.color.trim_start_matches('#');
let r = u8::from_str_radix(&color[0..2], 16).unwrap();
let g = u8::from_str_radix(&color[2..4], 16).unwrap();
let b = u8::from_str_radix(&color[4..6], 16).unwrap();
Rgb([r, g, b])
})
})
.collect();
let image = candle_examples::imagenet::load_image224(args.image)?
.unsqueeze(0)?
.to_device(device)?;
let (vb, config) = get_vb_and_config(args.model_name, device)?;
let num_labels = label_items.len();
let model = SemanticSegmentationModel::new(&config, num_labels, vb)?;
let segmentations = model.forward(&image)?;
// generate a mask image
let mask = &segmentations.squeeze(0)?.argmax(0)?;
let (h, w) = mask.dims2()?;
let mask = mask.flatten_all()?.to_vec1::<u32>()?;
let mask = mask
.iter()
.flat_map(|x| label_colors[x].data())
.collect::<Vec<u8>>();
let mask: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
image::ImageBuffer::from_raw(w as u32, h as u32, mask).unwrap();
// resize
let mask = image::DynamicImage::from(mask);
let mask = mask.resize_to_fill(
w as u32 * 4,
h as u32 * 4,
image::imageops::FilterType::CatmullRom,
);
mask.save(args.output_path.clone())?;
println!("mask image saved to {:?}", args.output_path);
Ok(())
}
fn classification_task(args: ClassificationArgs, device: &Device) -> anyhow::Result<()> {
let image = candle_examples::imagenet::load_image224(args.image)?
.unsqueeze(0)?
.to_device(device)?;
let (vb, config) = get_vb_and_config(args.model_name, device)?;
let num_labels = 7;
let model = ImageClassificationModel::new(&config, num_labels, vb)?;
let classification = model.forward(&image)?;
let classification = candle_nn::ops::softmax_last_dim(&classification)?;
let classification = classification.squeeze(0)?;
println!(
"classification logits {:?}",
classification.to_vec1::<f32>()?
);
let label_id = classification.argmax(0)?.to_scalar::<u32>()?;
let label_id = format!("{}", label_id);
println!("label: {}", config.id2label[&label_id]);
Ok(())
}
pub fn main() -> anyhow::Result<()> {
let args = CliArgs::parse();
let device = candle_examples::device(args.cpu)?;
if let Commands::Segment(args) = args.command {
segmentation_task(args, &device)?
} else if let Commands::Classify(args) = args.command {
classification_task(args, &device)?
}
Ok(())
}

View File

@ -57,7 +57,7 @@ The downside is some long compilation time. You can set the
`/home/user/.candle` to ensures that the compilation artifacts are properly `/home/user/.candle` to ensures that the compilation artifacts are properly
cached. cached.
Enabling flash-attention requires both a feature flag, `--features flash-attn` Enabling flash-attention requires both a feature flag, `--feature flash-attn`
and using the command line flag `--use-flash-attn`. and using the command line flag `--use-flash-attn`.
Note that flash-attention-v2 is only compatible with Ampere, Ada, or Hopper GPUs Note that flash-attention-v2 is only compatible with Ampere, Ada, or Hopper GPUs

Some files were not shown because too many files have changed in this diff Show More