mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Compare commits
54 Commits
ivarflakst
...
vocos
Author | SHA1 | Date | |
---|---|---|---|
3f3730b657 | |||
058a910d0e | |||
26fe162ab5 | |||
121a71e01f | |||
2d5f2a728d | |||
68f7655895 | |||
b60064780d | |||
14010a8498 | |||
0de0795220 | |||
c1b418586c | |||
ad73e93da2 | |||
13c67226e6 | |||
d0aa197b07 | |||
274bf11633 | |||
1e26d539d9 | |||
74497e6bf7 | |||
8ab384e63d | |||
27ffd644a9 | |||
bf20cc854c | |||
42ce593ec6 | |||
67589791d2 | |||
1c8d61f051 | |||
90447bc993 | |||
40ce16001b | |||
5657e596cd | |||
0dee8ea19b | |||
9cadd4e644 | |||
020a979de2 | |||
cdc3823d8f | |||
e5eb9602d0 | |||
b75e8945bc | |||
a90fc5ca5a | |||
adfae2460a | |||
678f64dd27 | |||
b545f54a19 | |||
1ba11f22d6 | |||
982722019b | |||
a83ca2ece0 | |||
153c940a9c | |||
50be8a98ba | |||
58cc896e69 | |||
5cdd84e0f6 | |||
a510ddec4e | |||
d32abbce53 | |||
dfab45e1c8 | |||
96bc704d17 | |||
a52d407ae6 | |||
9e824ec810 | |||
beadb1b434 | |||
6d83d42efb | |||
b6afb46601 | |||
73d79e6092 | |||
b1879f17f6 | |||
4f79f5df8a |
74
.github/workflows/ci_cuda.yaml
vendored
74
.github/workflows/ci_cuda.yaml
vendored
@ -5,49 +5,15 @@ on:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
start-runner:
|
||||
name: Start self-hosted EC2 runner
|
||||
runs-on: ubuntu-latest
|
||||
# Don't run on forks, they won't have access to secrets anyway.
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
||||
EC2_INSTANCE_TYPE: g5.xlarge
|
||||
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
|
||||
EC2_SECURITY_GROUP: sg-030175c435ac141d6
|
||||
outputs:
|
||||
label: ${{ steps.start-ec2-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
- name: Start EC2 runner
|
||||
id: start-ec2-runner
|
||||
uses: philschmid/philschmid-ec2-github-runner@main
|
||||
with:
|
||||
mode: start
|
||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||
ec2-image-id: ${{ env.EC2_AMI_ID }}
|
||||
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
|
||||
subnet-id: ${{ env.EC2_SUBNET_ID }}
|
||||
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
|
||||
aws-resource-tags: > # optional, requires additional permissions
|
||||
[
|
||||
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
|
||||
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
|
||||
]
|
||||
|
||||
test-cuda:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
needs: start-runner # required to start the main job when the runner is ready
|
||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||
runs-on: [single-gpu, nvidia-gpu, t4, ci]
|
||||
container:
|
||||
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
|
||||
options: --gpus 0
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
@ -58,32 +24,10 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Install dependencies
|
||||
run: apt-get update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
|
||||
- name: Install Rust Stable
|
||||
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||
uses: actions-rust-lang/setup-rust-toolchain@v1
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
|
||||
- name: Test (cuda)
|
||||
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||
stop-runner:
|
||||
name: Stop self-hosted EC2 runner
|
||||
needs:
|
||||
- start-runner
|
||||
- test-cuda
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
if: ${{ (success() || failure()) && github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }} # required to stop the runner even if the error happened in the previous jobs
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
- name: Stop EC2 runner
|
||||
uses: philschmid/philschmid-ec2-github-runner@main
|
||||
with:
|
||||
mode: stop
|
||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||
label: ${{ needs.start-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
||||
run: cargo test --features cuda
|
||||
|
18
Cargo.toml
18
Cargo.toml
@ -19,7 +19,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.3.3"
|
||||
version = "0.4.0"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -31,14 +31,14 @@ license = "MIT OR Apache-2.0"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core" }
|
||||
candle-datasets = { path = "./candle-datasets" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn" }
|
||||
candle-kernels = { path = "./candle-kernels" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels" }
|
||||
candle-nn = { path = "./candle-nn" }
|
||||
candle-onnx = { path = "./candle-onnx" }
|
||||
candle-transformers = { path = "./candle-transformers" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.4.0" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.4.0" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.0" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.4.0" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.0" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.4.0" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.4.0" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.4.0" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||
|
21
README.md
21
README.md
@ -65,8 +65,9 @@ We also provide a some command line based examples using state of the art models
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets.
|
||||
- [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal
|
||||
pre-trained on 1T tokens of English and code datasets. Also supports
|
||||
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
|
||||
- [Mamba](./candle-examples/examples/mamba/): an inference only
|
||||
implementation of the Mamba state space model.
|
||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||
better performance than all publicly available 13b models as of 2023-09-28.
|
||||
@ -74,6 +75,9 @@ We also provide a some command line based examples using state of the art models
|
||||
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
||||
much faster inference.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
|
||||
- [RWKV v5](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
|
||||
performance.
|
||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||
(English/Chinese) general LLMs with 6b and 34b parameters.
|
||||
@ -111,9 +115,10 @@ We also provide a some command line based examples using state of the art models
|
||||
evaluation, segmentation).
|
||||
- [VGG](./candle-examples/examples/vgg/),
|
||||
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
generate captions for an image.
|
||||
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
|
||||
dedicated submodels for hand-writing and printed recognition.
|
||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||
model, generates the translated text from the input text.
|
||||
|
||||
@ -184,13 +189,15 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Falcon.
|
||||
- StarCoder.
|
||||
- Phi 1, 1.5, and 2.
|
||||
- Minimal Mamba
|
||||
- Mamba, Minimal Mamba
|
||||
- Mistral 7b v0.1.
|
||||
- Mixtral 8x7b v0.1.
|
||||
- StableLM-3B-4E1T.
|
||||
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
|
||||
- Replit-code-v1.5-3B.
|
||||
- Bert.
|
||||
- Yi-6B and Yi-34B.
|
||||
- Qwen1.5.
|
||||
- RWKV.
|
||||
- Quantized LLMs.
|
||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||
- Mistral 7b, and 7b instruct.
|
||||
@ -206,8 +213,10 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Wurstchen v2.
|
||||
- Image to text.
|
||||
- BLIP.
|
||||
- TrOCR.
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
|
||||
ConvNeXTv2.
|
||||
- yolo-v3, yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||
|
@ -1,11 +1,9 @@
|
||||
mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
|
||||
criterion_main!(
|
||||
//benchmarks::affine::benches,
|
||||
//benchmarks::matmul::benches,
|
||||
//benchmarks::random::benches,
|
||||
benchmarks::reduce::benches,
|
||||
//benchmarks::where_cond::benches
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::where_cond::benches
|
||||
);
|
||||
|
@ -1,7 +1,6 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod reduce;
|
||||
pub(crate) mod where_cond;
|
||||
|
||||
use candle_core::{Device, Result};
|
||||
|
@ -1,239 +0,0 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Storage, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use half::{bf16, f16};
|
||||
use std::ops::Deref;
|
||||
use std::time::Instant;
|
||||
|
||||
fn run_sum(a: &Tensor) {
|
||||
a.sum(2).unwrap();
|
||||
}
|
||||
fn run_arg_min(a: &Tensor) {
|
||||
a.argmin(2).unwrap();
|
||||
}
|
||||
|
||||
// TODO: Remove before merging. Softmax impls live in candle-nn, so this is a temporary workaround.
|
||||
fn softmax(a: &Tensor) -> candle_core::Result<()> {
|
||||
use candle_core::{backend::BackendStorage, DType};
|
||||
let (storage, layout) = a.storage_and_layout();
|
||||
|
||||
let device = a.device();
|
||||
|
||||
if let (Device::Metal(device), Storage::Metal(storage)) = (device, storage.deref()) {
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let kernels = device.kernels();
|
||||
let name = match a.dtype() {
|
||||
DType::F32 => "softmax_f32",
|
||||
DType::F16 => "softmax_f16",
|
||||
DType::BF16 => "softmax_bf16",
|
||||
dtype => candle_core::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
||||
};
|
||||
|
||||
let n = layout.stride().len();
|
||||
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
|
||||
candle_core::bail!("Non contiguous softmax-last-dim is not implemented");
|
||||
}
|
||||
|
||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
|
||||
candle_metal_kernels::call_last_softmax(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
name,
|
||||
elem_count,
|
||||
last_dim,
|
||||
storage.buffer(),
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
let (lo, up) = (-1000.0f32, 1000.0f32);
|
||||
for device in handler.devices {
|
||||
run_softmax(c, &device, (lo, up));
|
||||
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||
|
||||
run_reduce(c, &device, (lo, up), false);
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up), false);
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||
|
||||
run_reduce(c, &device, (lo, up), true);
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up), true);
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
|
||||
if !device.is_metal() {
|
||||
return;
|
||||
}
|
||||
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => "softmax_f32",
|
||||
DType::F16 => "softmax_f16",
|
||||
DType::BF16 => "softmax_bf16",
|
||||
_ => "softmax",
|
||||
};
|
||||
softmax(&a).unwrap();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
softmax(black_box(&a)).unwrap();
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn run_reduce<T: candle_core::FloatDType>(
|
||||
c: &mut Criterion,
|
||||
device: &Device,
|
||||
(lo, up): (T, T),
|
||||
strided: bool,
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = if strided {
|
||||
Tensor::rand(lo, up, (b, m, k), &device)
|
||||
.unwrap()
|
||||
.transpose(0, 2)
|
||||
.unwrap()
|
||||
} else {
|
||||
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||
};
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => {
|
||||
if strided {
|
||||
"reduce_f32_strided"
|
||||
} else {
|
||||
"reduce_f32"
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
if strided {
|
||||
"reduce_f16_strided"
|
||||
} else {
|
||||
"reduce_f16"
|
||||
}
|
||||
}
|
||||
DType::BF16 => {
|
||||
if strided {
|
||||
"reduce_bf16_strided"
|
||||
} else {
|
||||
"reduce_bf16"
|
||||
}
|
||||
}
|
||||
_ => "reduce",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run_sum(black_box(&a));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn run_arg_reduce<T: candle_core::FloatDType>(
|
||||
c: &mut Criterion,
|
||||
device: &Device,
|
||||
(lo, up): (T, T),
|
||||
strided: bool,
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = if strided {
|
||||
Tensor::rand(lo, up, (b, m, k), &device)
|
||||
.unwrap()
|
||||
.transpose(0, 2)
|
||||
.unwrap()
|
||||
} else {
|
||||
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||
};
|
||||
|
||||
let flops = b * m * k * (DType::U32.size_in_bytes() + T::DTYPE.size_in_bytes());
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => {
|
||||
if strided {
|
||||
"arg_reduce_f32_strided"
|
||||
} else {
|
||||
"arg_reduce_f32"
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
if strided {
|
||||
"arg_reduce_f16_strided"
|
||||
} else {
|
||||
"arg_reduce_f16"
|
||||
}
|
||||
}
|
||||
DType::BF16 => {
|
||||
if strided {
|
||||
"arg_reduce_bf16_strided"
|
||||
} else {
|
||||
"arg_reduce_bf16"
|
||||
}
|
||||
}
|
||||
_ => "unknown",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run_arg_min(black_box(&a));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -196,7 +196,7 @@ fn run_ls(
|
||||
}
|
||||
}
|
||||
Format::Pth => {
|
||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?;
|
||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?;
|
||||
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
for tensor_info in tensors.iter() {
|
||||
println!(
|
||||
|
@ -380,6 +380,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_exp_inplace(y: &mut [f32]) {
|
||||
unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_exp_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
@ -402,6 +412,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vs_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vd_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
||||
#[inline]
|
||||
|
@ -175,7 +175,7 @@ impl Tensor {
|
||||
// the backprop graph of the backprop itself. This would be an issue for second order
|
||||
// derivatives but these are out of scope at the moment.
|
||||
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
||||
let grad = if do_not_detach { grad } else { grad.detach()? };
|
||||
let grad = if do_not_detach { grad } else { grad.detach() };
|
||||
if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||
@ -589,6 +589,13 @@ impl Tensor {
|
||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Silu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
||||
let sigmoid_arg = (*node / arg)?;
|
||||
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
||||
}
|
||||
Op::Elu(arg, alpha) => {
|
||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -1149,6 +1149,55 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
inp_l: &Layout,
|
||||
k: &CudaSlice<T>,
|
||||
k_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_in_k, c_out, l_k)
|
||||
// Input shape: (b_size, c_in, l_in)
|
||||
let p = &self.0;
|
||||
let l_out = p.l_out();
|
||||
let dst_el = p.c_out * l_out * p.b_size;
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(k_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), kernels::CONV)?;
|
||||
let ds = if dims.len() == 3 {
|
||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (
|
||||
el,
|
||||
l_out,
|
||||
p.stride,
|
||||
p.padding,
|
||||
p.output_padding,
|
||||
p.dilation,
|
||||
&ds,
|
||||
inp,
|
||||
k,
|
||||
&out,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
|
||||
impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
@ -1810,12 +1859,15 @@ impl BackendStorage for CudaStorage {
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConvTranspose1D,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
let device = self.device().clone();
|
||||
let slice =
|
||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cudnn"))]
|
||||
|
@ -489,7 +489,6 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
let device = self.device.clone();
|
||||
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
// Source dims and strides with the sum dims at the end.
|
||||
@ -503,69 +502,13 @@ impl BackendStorage for MetalStorage {
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
if layout.is_contiguous() {
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
|
||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
||||
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
|
||||
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
|
||||
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
|
||||
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
|
||||
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
|
||||
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
|
||||
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
|
||||
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
|
||||
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
|
||||
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
|
||||
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
|
||||
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
|
||||
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
|
||||
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
|
||||
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
|
||||
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
|
||||
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
|
||||
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
|
||||
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
|
||||
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
|
||||
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
|
||||
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
|
||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
|
||||
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
|
||||
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
|
||||
(k, dtype) => {
|
||||
crate::bail!("Metal contiguous reduce op {k:?} {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.shape().elem_count(),
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
return Ok(Self::new(buffer, device, self.dtype));
|
||||
}
|
||||
|
||||
for &dim_idx in sum_dims.iter() {
|
||||
dims.push(src_dims[dim_idx]);
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
|
||||
// The reduction loop requires the shared array to be properly initialized and for
|
||||
// this we want the number of threads to be a power of two.
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
|
||||
@ -597,7 +540,7 @@ impl BackendStorage for MetalStorage {
|
||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
|
||||
(k, dtype) => crate::bail!("Metal strided reduce op {k:?} {dtype:?} not implemented"),
|
||||
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
@ -736,6 +679,7 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||
("usilu", DType::F32) => contiguous::silu::FLOAT,
|
||||
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
@ -753,6 +697,7 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||
("usilu", DType::F16) => contiguous::silu::HALF,
|
||||
("uabs", DType::F16) => contiguous::abs::HALF,
|
||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||
@ -787,6 +732,7 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => strided::erf::FLOAT,
|
||||
("usilu", DType::F32) => strided::silu::FLOAT,
|
||||
("uabs", DType::F32) => strided::abs::FLOAT,
|
||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
@ -802,6 +748,7 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F16) => strided::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => strided::erf::HALF,
|
||||
("usilu", DType::F16) => strided::silu::HALF,
|
||||
("uabs", DType::F16) => strided::abs::HALF,
|
||||
("uceil", DType::F16) => strided::ceil::HALF,
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
|
@ -333,6 +333,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_exp_inplace(y: &mut [f32]) {
|
||||
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_exp_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
@ -355,6 +365,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vs_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vd_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
|
||||
#[inline]
|
||||
|
@ -61,6 +61,7 @@ pub enum UnaryOp {
|
||||
GeluErf,
|
||||
Erf,
|
||||
Relu,
|
||||
Silu,
|
||||
Tanh,
|
||||
Floor,
|
||||
Ceil,
|
||||
@ -390,6 +391,7 @@ pub(crate) struct Gelu;
|
||||
pub(crate) struct GeluErf;
|
||||
pub(crate) struct Erf;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Silu;
|
||||
pub(crate) struct Tanh;
|
||||
pub(crate) struct Floor;
|
||||
pub(crate) struct Ceil;
|
||||
@ -724,6 +726,77 @@ impl UnaryOpT for Erf {
|
||||
}
|
||||
}
|
||||
|
||||
/// Silu operation
|
||||
impl UnaryOpT for Silu {
|
||||
const NAME: &'static str = "silu";
|
||||
const V: Self = Silu;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v / (bf16::ONE + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v / (f16::ONE + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v / (1.0 + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v / (1.0 + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
const KERNEL: &'static str = "usilu";
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::mkl::vs_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F64_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::vd_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F32_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::accelerate::vs_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F64_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::accelerate::vd_silu(xs, ys)
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Abs {
|
||||
const NAME: &'static str = "abs";
|
||||
const KERNEL: &'static str = "uabs";
|
||||
|
@ -217,6 +217,13 @@ impl Object {
|
||||
let args = args.remove(1);
|
||||
(callable, args)
|
||||
}
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
|
||||
let mut args = args.tuple()?;
|
||||
args.remove(0).reduce()?
|
||||
}
|
||||
_ => (callable, args),
|
||||
};
|
||||
match callable {
|
||||
@ -227,13 +234,11 @@ impl Object {
|
||||
_ => return Ok(None),
|
||||
};
|
||||
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
||||
let mut path = dir_name.to_path_buf();
|
||||
path.push(file_path);
|
||||
Ok(Some(TensorInfo {
|
||||
name,
|
||||
dtype,
|
||||
layout,
|
||||
path: path.to_string_lossy().into_owned(),
|
||||
path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
|
||||
storage_size,
|
||||
}))
|
||||
}
|
||||
@ -345,8 +350,10 @@ impl Stack {
|
||||
module_name,
|
||||
class_name,
|
||||
} => {
|
||||
if module_name == "collections" && class_name == "OrderedDict" {
|
||||
// TODO: have a separate ordered dict.
|
||||
if module_name == "collections"
|
||||
&& (class_name == "OrderedDict" || class_name == "defaultdict")
|
||||
{
|
||||
// TODO: have a separate ordered dict and a separate default dict.
|
||||
Some(Object::Dict(vec![]))
|
||||
} else {
|
||||
None
|
||||
@ -627,9 +634,16 @@ pub struct TensorInfo {
|
||||
pub storage_size: usize,
|
||||
}
|
||||
|
||||
/// Read the tensor info from a .pth file.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `file` - The path to the .pth file.
|
||||
/// * `verbose` - Whether to print debug information.
|
||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file.
|
||||
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
file: P,
|
||||
verbose: bool,
|
||||
key: Option<&str>,
|
||||
) -> Result<Vec<TensorInfo>> {
|
||||
let file = std::fs::File::open(file)?;
|
||||
let zip_reader = std::io::BufReader::new(file);
|
||||
@ -651,8 +665,9 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
stack.read_loop(&mut reader)?;
|
||||
let obj = stack.finalize()?;
|
||||
if VERBOSE || verbose {
|
||||
println!("{obj:?}");
|
||||
println!("{obj:#?}");
|
||||
}
|
||||
|
||||
let obj = match obj {
|
||||
Object::Build { callable, args } => match *callable {
|
||||
Object::Reduce { callable, args: _ } => match *callable {
|
||||
@ -666,6 +681,24 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
},
|
||||
obj => obj,
|
||||
};
|
||||
|
||||
// If key is provided, then we need to extract the state_dict from the object.
|
||||
let obj = if let Some(key) = key {
|
||||
if let Object::Dict(key_values) = obj {
|
||||
key_values
|
||||
.into_iter()
|
||||
.find(|(k, _)| *k == Object::Unicode(key.to_owned()))
|
||||
.map(|(_, v)| v)
|
||||
.ok_or_else(|| E::Msg(format!("key {key} not found")))?
|
||||
} else {
|
||||
obj
|
||||
}
|
||||
} else {
|
||||
obj
|
||||
};
|
||||
|
||||
// If the object is a dict, then we can extract the tensor info from it.
|
||||
// NOTE: We are assuming that the `obj` is state_dict by this stage.
|
||||
if let Object::Dict(key_values) = obj {
|
||||
for (name, value) in key_values.into_iter() {
|
||||
match value.into_tensor_info(name, &dir_name) {
|
||||
@ -688,8 +721,8 @@ pub struct PthTensors {
|
||||
}
|
||||
|
||||
impl PthTensors {
|
||||
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
|
||||
let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
|
||||
pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
|
||||
let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
|
||||
let tensor_infos = tensor_infos
|
||||
.into_iter()
|
||||
.map(|ti| (ti.name.to_string(), ti))
|
||||
@ -712,10 +745,12 @@ impl PthTensors {
|
||||
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
|
||||
let rank = tensor_info.layout.shape().rank();
|
||||
|
||||
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
||||
// case.
|
||||
if !tensor_info.layout.is_contiguous() {
|
||||
// case and when the tensor is fortran contiguous.
|
||||
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
|
||||
crate::bail!(
|
||||
"cannot retrieve non-contiguous tensors {:?}",
|
||||
tensor_info.layout
|
||||
@ -733,13 +768,33 @@ impl PthTensors {
|
||||
tensor_info.dtype,
|
||||
&mut reader,
|
||||
)?;
|
||||
Ok(Some(tensor))
|
||||
|
||||
if rank > 1 && is_fortran_contiguous {
|
||||
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
|
||||
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
|
||||
let tensor = tensor.reshape(shape_reversed)?;
|
||||
|
||||
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
|
||||
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
|
||||
let tensor = tensor.permute(dim_indeces_reversed)?;
|
||||
Ok(Some(tensor))
|
||||
} else {
|
||||
Ok(Some(tensor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read all the tensors from a PyTorch pth file.
|
||||
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
||||
let pth = PthTensors::new(path)?;
|
||||
/// Read all the tensors from a PyTorch pth file with a given key.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the pth file.
|
||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
||||
/// contains multiple objects and the state_dict is the one we are interested in.
|
||||
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
||||
path: P,
|
||||
key: Option<&str>,
|
||||
) -> Result<Vec<(String, Tensor)>> {
|
||||
let pth = PthTensors::new(path, key)?;
|
||||
let tensor_names = pth.tensor_infos.keys();
|
||||
let mut tensors = Vec::with_capacity(tensor_names.len());
|
||||
for name in tensor_names {
|
||||
@ -749,3 +804,11 @@ pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tenso
|
||||
}
|
||||
Ok(tensors)
|
||||
}
|
||||
|
||||
/// Read all the tensors from a PyTorch pth file.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the pth file.
|
||||
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
||||
read_all_with_key(path, None)
|
||||
}
|
||||
|
43
candle-core/src/quantized/dummy_metal.rs
Normal file
43
candle-core/src/quantized/dummy_metal.rs
Normal file
@ -0,0 +1,43 @@
|
||||
#![allow(unused)]
|
||||
use super::GgmlDType;
|
||||
use crate::{Error, MetalDevice, MetalStorage, Result};
|
||||
|
||||
pub struct QMetalStorage {
|
||||
dtype: GgmlDType,
|
||||
device: MetalDevice,
|
||||
}
|
||||
|
||||
impl QMetalStorage {
|
||||
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &MetalDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
_self_shape: &crate::Shape,
|
||||
_storage: &MetalStorage,
|
||||
_layout: &crate::Layout,
|
||||
) -> Result<(MetalStorage, crate::Shape)> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
@ -233,6 +233,7 @@ pub struct Content {
|
||||
pub hparams: HParams,
|
||||
pub vocab: Vocab,
|
||||
pub tensors: HashMap<String, super::QTensor>,
|
||||
pub device: Device,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
@ -252,11 +253,13 @@ impl Content {
|
||||
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||
tensors.insert(name, tensor);
|
||||
}
|
||||
let device = device.clone();
|
||||
Ok(Self {
|
||||
magic,
|
||||
hparams,
|
||||
vocab,
|
||||
tensors,
|
||||
device,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
use super::{GgmlDType, QStorage};
|
||||
use crate::{DType, MetalDevice, MetalStorage, Result};
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
|
||||
use metal::Buffer;
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -10,20 +11,26 @@ pub struct QMetalStorage {
|
||||
}
|
||||
|
||||
impl QMetalStorage {
|
||||
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||
let size = elem_count * dtype.type_size() / dtype.block_size();
|
||||
let buffer = device.allocate_zeros(size)?;
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn buffer(&self) -> &Buffer {
|
||||
&self.buffer
|
||||
pub fn device(&self) -> &MetalDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: GgmlDType) -> Self {
|
||||
Self {
|
||||
device,
|
||||
buffer,
|
||||
dtype,
|
||||
}
|
||||
pub fn buffer(&self) -> &Buffer {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
||||
@ -130,6 +137,59 @@ impl QMetalStorage {
|
||||
self.buffer = buffer;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.buffer.length() as usize
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
self_shape: &Shape,
|
||||
storage: &MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
use crate::MetalError;
|
||||
|
||||
if !layout.is_contiguous() {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
// self is transposed so n is first then k.
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let (n, k) = self_shape.dims2()?;
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
|
||||
let (b, m) = match dst_shape.len() {
|
||||
3 => (dst_shape[0], dst_shape[1]),
|
||||
2 => (1, dst_shape[0]),
|
||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||
};
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
let device = storage.device().clone();
|
||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
candle_metal_kernels::call_quantized_matmul_t(
|
||||
device.device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
self.dtype.into(),
|
||||
(b, m, n, k),
|
||||
storage.buffer(),
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
&self.buffer,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
||||
Ok((dst_storage, dst_shape))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
||||
@ -151,3 +211,24 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
||||
slice.to_vec()
|
||||
}
|
||||
|
||||
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
||||
fn from(value: GgmlDType) -> Self {
|
||||
match value {
|
||||
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
||||
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
||||
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
||||
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
||||
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
||||
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
||||
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
||||
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
||||
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
||||
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
||||
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
||||
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
||||
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
||||
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,16 +1,19 @@
|
||||
#[cfg(feature = "metal")]
|
||||
use crate::{backend::BackendStorage, DType};
|
||||
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
||||
use k_quants::*;
|
||||
use std::borrow::Cow;
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
mod dummy_metal;
|
||||
pub mod ggml_file;
|
||||
pub mod gguf_file;
|
||||
pub mod k_quants;
|
||||
#[cfg(feature = "metal")]
|
||||
pub mod metal;
|
||||
#[cfg(not(feature = "metal"))]
|
||||
mod metal {
|
||||
pub use super::dummy_metal::*;
|
||||
}
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(target_feature = "simd128")]
|
||||
@ -32,19 +35,9 @@ impl Device {
|
||||
let storage = dtype.cpu_zeros(elem_count);
|
||||
Ok(QStorage::Cpu(storage))
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
Device::Metal(metal) => {
|
||||
let size = elem_count * dtype.type_size() / dtype.block_size();
|
||||
let buffer = metal.allocate_zeros(size)?;
|
||||
Ok(QStorage::Metal(metal::QMetalStorage::new(
|
||||
buffer,
|
||||
metal.clone(),
|
||||
dtype,
|
||||
)))
|
||||
}
|
||||
#[cfg(not(feature = "metal"))]
|
||||
Device::Metal(_metal) => {
|
||||
crate::bail!("Metal feature not activated");
|
||||
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
|
||||
Ok(QStorage::Metal(storage))
|
||||
}
|
||||
Device::Cuda(_cuda) => {
|
||||
crate::bail!("Cuda ggml quantization not supported");
|
||||
@ -55,7 +48,6 @@ impl Device {
|
||||
|
||||
pub enum QStorage {
|
||||
Cpu(Box<dyn QuantizedType>),
|
||||
#[cfg(feature = "metal")]
|
||||
Metal(metal::QMetalStorage),
|
||||
}
|
||||
|
||||
@ -63,7 +55,6 @@ impl QStorage {
|
||||
fn block_size(&self) -> usize {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.block_size(),
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(storage) => storage.dtype().block_size(),
|
||||
}
|
||||
}
|
||||
@ -71,16 +62,21 @@ impl QStorage {
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.dtype(),
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(storage) => storage.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
fn device(&self) -> Device {
|
||||
match self {
|
||||
QStorage::Cpu(_storage) => Device::Cpu,
|
||||
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(storage) => storage.buffer().length() as usize,
|
||||
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -89,7 +85,6 @@ impl QStorage {
|
||||
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
||||
storage.from_float(src.as_slice::<f32>()?)?;
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
||||
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
||||
}
|
||||
@ -99,7 +94,6 @@ impl QStorage {
|
||||
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
||||
}
|
||||
}
|
||||
@ -112,7 +106,6 @@ impl QStorage {
|
||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||
Ok(Cow::from(data))
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(_storage) => {
|
||||
crate::bail!("not implemented");
|
||||
}
|
||||
@ -336,6 +329,10 @@ impl QTensor {
|
||||
self.storage.dtype()
|
||||
}
|
||||
|
||||
pub fn device(&self) -> Device {
|
||||
self.storage.device()
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
self.shape.rank()
|
||||
}
|
||||
@ -427,8 +424,7 @@ impl crate::CustomOp1 for QTensor {
|
||||
#[allow(clippy::infallible_destructuring_match)]
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Cpu(storage) => storage,
|
||||
#[cfg(feature = "metal")]
|
||||
_ => crate::bail!("Invalid storage"),
|
||||
QStorage::Metal(_) => crate::bail!("Invalid storage"),
|
||||
};
|
||||
let slice = storage.as_slice::<f32>()?;
|
||||
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
@ -437,79 +433,16 @@ impl crate::CustomOp1 for QTensor {
|
||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &crate::MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::MetalStorage, Shape)> {
|
||||
use crate::MetalError;
|
||||
|
||||
if !layout.is_contiguous() {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
// self is transposed so n is first then k.
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let (n, k) = self.shape.dims2()?;
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
|
||||
let (b, m) = match dst_shape.len() {
|
||||
3 => (dst_shape[0], dst_shape[1]),
|
||||
2 => (1, dst_shape[0]),
|
||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||
};
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
let device = storage.device().clone();
|
||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||
let (buffer, dtype) = match &self.storage {
|
||||
QStorage::Metal(metal) => (metal.buffer(), metal.dtype()),
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Metal(metal) => metal,
|
||||
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
||||
};
|
||||
let command_buffer = device.command_buffer()?;
|
||||
candle_metal_kernels::call_quantized_matmul_t(
|
||||
device.device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
dtype.into(),
|
||||
(b, m, n, k),
|
||||
storage.buffer(),
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
buffer,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
||||
Ok((dst_storage, dst_shape))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
||||
fn from(value: GgmlDType) -> Self {
|
||||
match value {
|
||||
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
||||
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
||||
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
||||
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
||||
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
||||
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
||||
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
||||
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
||||
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
||||
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
||||
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
||||
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
||||
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
||||
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
||||
}
|
||||
self_storage.fwd(&self.shape, storage, layout)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -508,6 +508,7 @@ impl Tensor {
|
||||
unary_op!(gelu_erf, GeluErf);
|
||||
unary_op!(erf, Erf);
|
||||
unary_op!(relu, Relu);
|
||||
unary_op!(silu, Silu);
|
||||
unary_op!(ceil, Ceil);
|
||||
unary_op!(floor, Floor);
|
||||
unary_op!(round, Round);
|
||||
@ -804,6 +805,35 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Roll the tensor input along the given dimension.
|
||||
/// Elements that are shifted beyond the last position are re-introduced at the first position.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.roll(1, 0)?;
|
||||
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.roll(-1, 0)?;
|
||||
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
|
||||
where
|
||||
D: Dim + Clone,
|
||||
{
|
||||
let dim = dim.to_index(self.shape(), "roll")?;
|
||||
let dim_size = self.dim(dim)?;
|
||||
let shift = shift.rem_euclid(dim_size as i32) as usize;
|
||||
if shift == 0 {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let a = self.narrow(dim, 0, dim_size - shift)?;
|
||||
let b = self.narrow(dim, dim_size - shift, shift)?;
|
||||
Tensor::cat(&[&b, &a], dim)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||
/// input dimensions.
|
||||
///
|
||||
@ -1853,9 +1883,9 @@ impl Tensor {
|
||||
/// this new node. The storage of this tensor is shared with the initial tensor.
|
||||
///
|
||||
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
||||
pub fn detach(&self) -> Result<Tensor> {
|
||||
pub fn detach(&self) -> Tensor {
|
||||
if self.op.is_none() && !self.is_variable {
|
||||
Ok(self.clone())
|
||||
self.clone()
|
||||
} else {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
@ -1866,7 +1896,7 @@ impl Tensor {
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
Tensor(Arc::new(tensor_))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -107,6 +107,10 @@ impl Var {
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
pub fn as_detached_tensor(&self) -> Tensor {
|
||||
self.0.detach()
|
||||
}
|
||||
|
||||
pub fn as_tensor(&self) -> &Tensor {
|
||||
&self.0
|
||||
}
|
||||
|
@ -50,17 +50,15 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
if dev.is_cpu() {
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
}
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
BIN
candle-core/tests/fortran_tensor_3d.pth
Normal file
BIN
candle-core/tests/fortran_tensor_3d.pth
Normal file
Binary file not shown.
@ -270,6 +270,19 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||
);
|
||||
|
||||
// testing compared to pytorch nn.Silu()
|
||||
let y = x.silu()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.8577, 0.7311, 3.9281, 0.0806]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0881, 0.9277, 1.0527, 0.5747],
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
37
candle-core/tests/pth.py
Normal file
37
candle-core/tests/pth.py
Normal file
@ -0,0 +1,37 @@
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
|
||||
# Write a trivial tensor to a pt file
|
||||
a= torch.tensor([[1,2,3,4], [5,6,7,8]])
|
||||
o = OrderedDict()
|
||||
o["test"] = a
|
||||
|
||||
# Write a trivial tensor to a pt file
|
||||
torch.save(o, "test.pt")
|
||||
|
||||
############################################################################################################
|
||||
# Write a trivial tensor to a pt file with a key
|
||||
torch.save({"model_state_dict": o}, "test_with_key.pt")
|
||||
|
||||
############################################################################################################
|
||||
# Create a tensor with fortran contiguous memory layout
|
||||
import numpy as np
|
||||
|
||||
# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers
|
||||
# For example, creating a 2x3x4 array
|
||||
array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4))
|
||||
|
||||
# Verify the memory order
|
||||
print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True
|
||||
print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False
|
||||
|
||||
# Step 2: Convert the NumPy array to a PyTorch tensor
|
||||
tensor_fortran = torch.from_numpy(array_fortran)
|
||||
|
||||
# Verify the tensor layout
|
||||
print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout
|
||||
|
||||
# Step 3: Save the PyTorch tensor to a .pth file
|
||||
torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth')
|
||||
|
||||
print("3D Tensor saved with Fortran layout.")
|
31
candle-core/tests/pth_tests.rs
Normal file
31
candle-core/tests/pth_tests.rs
Normal file
@ -0,0 +1,31 @@
|
||||
/// Regression test for pth files not loading on Windows.
|
||||
#[test]
|
||||
fn test_pth() {
|
||||
let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
|
||||
tensors.get("test").unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pth_with_key() {
|
||||
let tensors =
|
||||
candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
|
||||
.unwrap();
|
||||
tensors.get("test").unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pth_fortran_congiguous() {
|
||||
let tensors =
|
||||
candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap();
|
||||
let tensor = tensors.get("tensor_fortran").unwrap().unwrap();
|
||||
|
||||
assert_eq!(tensor.dims3().unwrap(), (2, 3, 4));
|
||||
|
||||
assert_eq!(
|
||||
tensor.to_vec3::<i64>().unwrap(),
|
||||
[
|
||||
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
||||
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
|
||||
]
|
||||
);
|
||||
}
|
@ -120,6 +120,13 @@ fn unary_op(device: &Device) -> Result<()> {
|
||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.silu()?, 4)?,
|
||||
[
|
||||
[-0.1423, 0.7311, 3.9281, -0.0475, 0.3112],
|
||||
[2.53, -0.2553, -0.1205, 1.5447, 2.6395]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||
|
BIN
candle-core/tests/test.pt
Normal file
BIN
candle-core/tests/test.pt
Normal file
Binary file not shown.
BIN
candle-core/tests/test_with_key.pt
Normal file
BIN
candle-core/tests/test_with_key.pt
Normal file
Binary file not shown.
@ -21,7 +21,7 @@ candle-onnx = { workspace = true, optional = true }
|
||||
csv = "1.3.0"
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
hf-hub = { workspace = true, features = ["tokio"] }
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
@ -30,7 +30,9 @@ rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
symphonia = { version = "0.5.3", features = ["all"] }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
cpal= { version = "0.15.2", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -43,7 +45,6 @@ rusttype = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.29.1"
|
||||
|
||||
@ -61,6 +62,7 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
onnx = ["candle-onnx"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
microphone = ["cpal"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
@ -77,3 +79,7 @@ required-features = ["onnx"]
|
||||
[[example]]
|
||||
name = "onnx_basics"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "whisper-microphone"
|
||||
required-features = ["microphone"]
|
||||
|
237
candle-examples/examples/chatglm/main.rs
Normal file
237
candle-examples/examples/chatglm/main.rs
Normal file
@ -0,0 +1,237 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::chatglm::{Config, Model};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
verbose_prompt,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||
if tokens.is_empty() {
|
||||
anyhow::bail!("Empty prompts are not supported in the chatglm model.")
|
||||
}
|
||||
if self.verbose_prompt {
|
||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
println!("{id:7} -> '{token}'");
|
||||
}
|
||||
}
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("</s>") {
|
||||
Some(token) => *token,
|
||||
None => anyhow::bail!("cannot find the endoftext token"),
|
||||
};
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => "THUDM/chatglm3-6b".to_string(),
|
||||
};
|
||||
let revision = match args.revision {
|
||||
Some(rev) => rev.to_string(),
|
||||
None => "main".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => api
|
||||
.model("lmz/candle-chatglm".to_string())
|
||||
.get("chatglm-tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_file {
|
||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::glm3_6b();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
args.verbose_prompt,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
23
candle-examples/examples/convnext/README.md
Normal file
23
candle-examples/examples/convnext/README.md
Normal file
@ -0,0 +1,23 @@
|
||||
# candle-convnext
|
||||
|
||||
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and
|
||||
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808).
|
||||
|
||||
This candle implementation uses a pre-trained ConvNeXt network for inference. The
|
||||
classification head has been trained on the ImageNet dataset and returns the
|
||||
probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example convnext --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which tiny
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
mountain bike, all-terrain bike, off-roader: 84.09%
|
||||
bicycle-built-for-two, tandem bicycle, tandem: 4.15%
|
||||
maillot : 0.74%
|
||||
crash helmet : 0.54%
|
||||
unicycle, monocycle : 0.44%
|
||||
|
||||
```
|
126
candle-examples/examples/convnext/main.rs
Normal file
126
candle-examples/examples/convnext/main.rs
Normal file
@ -0,0 +1,126 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::convnext;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
Atto,
|
||||
Femto,
|
||||
Pico,
|
||||
Nano,
|
||||
Tiny,
|
||||
Small,
|
||||
Base,
|
||||
Large,
|
||||
AttoV2,
|
||||
FemtoV2,
|
||||
PicoV2,
|
||||
NanoV2,
|
||||
TinyV2,
|
||||
BaseV2,
|
||||
LargeV2,
|
||||
XLarge,
|
||||
Huge,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::Atto => "convnext_atto.d2_in1k",
|
||||
Self::Femto => "convnext_femto.d1_in1k",
|
||||
Self::Pico => "convnext_pico.d1_in1k",
|
||||
Self::Nano => "convnext_nano.d1h_in1k",
|
||||
Self::Tiny => "convnext_tiny.fb_in1k",
|
||||
Self::Small => "convnext_small.fb_in1k",
|
||||
Self::Base => "convnext_base.fb_in1k",
|
||||
Self::Large => "convnext_large.fb_in1k",
|
||||
Self::AttoV2 => "convnextv2_atto.fcmae_ft_in1k",
|
||||
Self::FemtoV2 => "convnextv2_femto.fcmae_ft_in1k",
|
||||
Self::PicoV2 => "convnextv2_pico.fcmae_ft_in1k",
|
||||
Self::NanoV2 => "convnextv2_nano.fcmae_ft_in1k",
|
||||
Self::TinyV2 => "convnextv2_tiny.fcmae_ft_in1k",
|
||||
Self::BaseV2 => "convnextv2_base.fcmae_ft_in1k",
|
||||
Self::LargeV2 => "convnextv2_large.fcmae_ft_in1k",
|
||||
Self::XLarge => "convnext_xlarge.fb_in22k_ft_in1k",
|
||||
Self::Huge => "convnextv2_huge.fcmae_ft_in1k",
|
||||
};
|
||||
|
||||
format!("timm/{name}")
|
||||
}
|
||||
|
||||
fn config(&self) -> convnext::Config {
|
||||
match self {
|
||||
Self::Atto | Self::AttoV2 => convnext::Config::atto(),
|
||||
Self::Femto | Self::FemtoV2 => convnext::Config::femto(),
|
||||
Self::Pico | Self::PicoV2 => convnext::Config::pico(),
|
||||
Self::Nano | Self::NanoV2 => convnext::Config::nano(),
|
||||
Self::Tiny | Self::TinyV2 => convnext::Config::tiny(),
|
||||
Self::Small => convnext::Config::small(),
|
||||
Self::Base | Self::BaseV2 => convnext::Config::base(),
|
||||
Self::Large | Self::LargeV2 => convnext::Config::large(),
|
||||
Self::XLarge => convnext::Config::xlarge(),
|
||||
Self::Huge => convnext::Config::huge(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::Tiny)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = convnext::convnext(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -2,6 +2,9 @@
|
||||
|
||||
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
|
||||
|
||||
Compared to the mamba example, this version can handle training but is much
|
||||
slower.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
|
17
candle-examples/examples/mamba/README.md
Normal file
17
candle-examples/examples/mamba/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
# candle-mamba: Mamba implementation
|
||||
|
||||
Candle implementation of *Mamba* [1] inference only. Mamba is an alternative to
|
||||
the transformer architecture. It leverages State Space Models (SSMs) with the
|
||||
goal of being computationally efficient on long sequences. The implementation is
|
||||
based on [mamba.rs](https://github.com/LaurentMazare/mamba.rs).
|
||||
|
||||
- [1]. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752).
|
||||
|
||||
Compared to the mamba-minimal example, this version is far more efficient but
|
||||
would only work for inference.
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
|
||||
```
|
||||
|
299
candle-examples/examples/mamba/main.rs
Normal file
299
candle-examples/examples/mamba/main.rs
Normal file
@ -0,0 +1,299 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::mamba::{Config, Model, State};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
config: Config,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
config: Config,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
config,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let mut state = State::new(1, &self.config, &self.device)?;
|
||||
let mut next_logits = None;
|
||||
for &t in tokens.iter() {
|
||||
let input = Tensor::new(&[t], &self.device)?;
|
||||
let logits = self.model.forward(&input, &mut state)?;
|
||||
next_logits = Some(logits);
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for _ in 0..sample_len {
|
||||
let logits = match next_logits.as_ref() {
|
||||
Some(logits) => logits,
|
||||
None => anyhow::bail!("cannot work on an empty prompt"),
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
let input = Tensor::new(&[next_token], &self.device)?;
|
||||
next_logits = Some(self.model.forward(&input, &mut state)?)
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
|
||||
enum Which {
|
||||
Mamba130m,
|
||||
Mamba370m,
|
||||
Mamba790m,
|
||||
Mamba1_4b,
|
||||
Mamba2_8b,
|
||||
Mamba2_8bSlimPj,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Which {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_id(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Mamba130m => "state-spaces/mamba-130m",
|
||||
Self::Mamba370m => "state-spaces/mamba-370m",
|
||||
Self::Mamba790m => "state-spaces/mamba-790m",
|
||||
Self::Mamba1_4b => "state-spaces/mamba-1.4b",
|
||||
Self::Mamba2_8b => "state-spaces/mamba-2.8b",
|
||||
Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
|
||||
}
|
||||
}
|
||||
|
||||
fn revision(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Mamba130m
|
||||
| Self::Mamba370m
|
||||
| Self::Mamba790m
|
||||
| Self::Mamba1_4b
|
||||
| Self::Mamba2_8bSlimPj => "refs/pr/1",
|
||||
Self::Mamba2_8b => "refs/pr/4",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "mamba130m")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id
|
||||
.unwrap_or_else(|| args.which.model_id().to_string()),
|
||||
RepoType::Model,
|
||||
args.revision
|
||||
.unwrap_or_else(|| args.which.revision().to_string()),
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => api
|
||||
.model("EleutherAI/gpt-neox-20b".to_string())
|
||||
.get("tokenizer.json")?,
|
||||
};
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = Model::new(&config, vb.pp("backbone"))?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
config,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,10 +1,39 @@
|
||||
## Using ONNX models in Candle
|
||||
|
||||
This example demonstrates how to run ONNX based models in Candle, the model
|
||||
being used here is a small sequeezenet variant.
|
||||
This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based models in Candle.
|
||||
|
||||
You can run the example with the following command:
|
||||
It contains small variants of two models, [SqueezeNet](https://arxiv.org/pdf/1602.07360.pdf) (default) and [EfficientNet](https://arxiv.org/pdf/1905.11946.pdf).
|
||||
|
||||
You can run the examples with following commands:
|
||||
|
||||
```bash
|
||||
cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
cargo run --example onnx --features=onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
```
|
||||
|
||||
Use the `--which` flag to specify explicitly which network to use, i.e.
|
||||
|
||||
```bash
|
||||
$ cargo run --example onnx --features=onnx --release -- --which squeeze-net --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
Finished release [optimized] target(s) in 0.21s
|
||||
Running `target/release/examples/onnx --which squeeze-net --image candle-examples/examples/yolo-v8/assets/bike.jpg`
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
unicycle, monocycle : 83.23%
|
||||
ballplayer, baseball player : 3.68%
|
||||
bearskin, busby, shako : 1.54%
|
||||
military uniform : 0.78%
|
||||
cowboy hat, ten-gallon hat : 0.76%
|
||||
```
|
||||
|
||||
```bash
|
||||
$ cargo run --example onnx --features=onnx --release -- --which efficient-net --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
Finished release [optimized] target(s) in 0.20s
|
||||
Running `target/release/examples/onnx --which efficient-net --image candle-examples/examples/yolo-v8/assets/bike.jpg`
|
||||
loaded image Tensor[dims 224, 224, 3; f32]
|
||||
bicycle-built-for-two, tandem bicycle, tandem : 99.16%
|
||||
mountain bike, all-terrain bike, off-roader : 0.60%
|
||||
unicycle, monocycle : 0.17%
|
||||
crash helmet : 0.02%
|
||||
alp : 0.02%
|
||||
```
|
||||
|
281
candle-examples/examples/qwen/main.rs
Normal file
281
candle-examples/examples/qwen/main.rs
Normal file
@ -0,0 +1,281 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::qwen2::{Config, Model};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
|
||||
enum WhichModel {
|
||||
#[value(name = "0.5b")]
|
||||
W0_5b,
|
||||
#[value(name = "1.8b")]
|
||||
W1_8b,
|
||||
#[value(name = "4b")]
|
||||
W4b,
|
||||
#[value(name = "7b")]
|
||||
W7b,
|
||||
#[value(name = "14b")]
|
||||
W14b,
|
||||
#[value(name = "72b")]
|
||||
W72b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
#[arg(long, default_value = "0.5b")]
|
||||
model: WhichModel,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => {
|
||||
let size = match args.model {
|
||||
WhichModel::W0_5b => "0.5B",
|
||||
WhichModel::W1_8b => "1.8B",
|
||||
WhichModel::W4b => "4B",
|
||||
WhichModel::W7b => "7B",
|
||||
WhichModel::W14b => "14B",
|
||||
WhichModel::W72b => "72B",
|
||||
};
|
||||
format!("Qwen/Qwen1.5-{size}")
|
||||
}
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.model {
|
||||
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
|
||||
WhichModel::W4b | WhichModel::W7b | WhichModel::W14b | WhichModel::W72b => {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
},
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config_file = repo.get("config.json")?;
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -411,7 +411,7 @@ impl DDPG<'_> {
|
||||
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
|
||||
let actions = self
|
||||
.actor
|
||||
.forward(&state.detach()?.unsqueeze(0)?)?
|
||||
.forward(&state.detach().unsqueeze(0)?)?
|
||||
.squeeze(0)?;
|
||||
let actions = if self.train {
|
||||
(actions + self.ou_noise.sample()?)?
|
||||
|
@ -74,7 +74,7 @@ pub fn run() -> Result<()> {
|
||||
loop {
|
||||
let action = {
|
||||
let action_probs: Vec<f32> =
|
||||
softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)?
|
||||
softmax(&model.forward(&state.detach().unsqueeze(0)?)?, 1)?
|
||||
.squeeze(0)?
|
||||
.to_vec1()?;
|
||||
weighted_sample(action_probs, &mut rng)? as i64
|
||||
@ -109,7 +109,7 @@ pub fn run() -> Result<()> {
|
||||
|
||||
let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
|
||||
.to_dtype(DType::F32)?
|
||||
.detach()?;
|
||||
.detach();
|
||||
|
||||
let actions_mask = {
|
||||
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
|
||||
@ -126,12 +126,12 @@ pub fn run() -> Result<()> {
|
||||
.unwrap()
|
||||
})
|
||||
.collect();
|
||||
Tensor::stack(&actions_mask, 0)?.detach()?
|
||||
Tensor::stack(&actions_mask, 0)?.detach()
|
||||
};
|
||||
|
||||
let states = {
|
||||
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
|
||||
Tensor::stack(&states, 0)?.detach()?
|
||||
Tensor::stack(&states, 0)?.detach()
|
||||
};
|
||||
|
||||
let log_probs = actions_mask
|
||||
|
17
candle-examples/examples/rwkv/README.md
Normal file
17
candle-examples/examples/rwkv/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
## candle-rwkv
|
||||
|
||||
The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
|
||||
with performance on par with transformer architectures. Several variants are
|
||||
available, candle implements the v5 version and can be used with Eagle 7B([blog
|
||||
post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
|
||||
|
||||
```bash
|
||||
$ cargo run --example rwkv --release -- --prompt "The smallest prime is "
|
||||
avx: true, neon: false, simd128: false, f16c: true
|
||||
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
The smallest prime is ϕ(2) = 2.
|
||||
The smallest composite is ϕ(3) = 3.
|
||||
The smallest perfect number is ϕ(5) = 5.
|
||||
The smallest perfect square is ϕ(4) = 4.
|
||||
The smallest perfect cube is ϕ(6) = 6.
|
||||
```
|
265
candle-examples/examples/rwkv/main.rs
Normal file
265
candle-examples/examples/rwkv/main.rs
Normal file
@ -0,0 +1,265 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
config: Config,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
config: Config,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
config,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
let mut tokens = self.tokenizer.encode(prompt)?;
|
||||
let mut generated_tokens = 0usize;
|
||||
let mut state = State::new(1, &self.config, &self.device)?;
|
||||
let mut next_logits = None;
|
||||
for &t in tokens.iter() {
|
||||
let input = Tensor::new(&[[t]], &self.device)?;
|
||||
let logits = self.model.forward(&input, &mut state)?;
|
||||
next_logits = Some(logits);
|
||||
print!("{}", self.tokenizer.decode(&[t])?)
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for _ in 0..sample_len {
|
||||
let logits = match next_logits.as_ref() {
|
||||
Some(logits) => logits,
|
||||
None => anyhow::bail!("cannot work on an empty prompt"),
|
||||
};
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
print!("{}", self.tokenizer.decode(&[next_token])?);
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let input = Tensor::new(&[[next_token]], &self.device)?;
|
||||
next_logits = Some(self.model.forward(&input, &mut state)?)
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
|
||||
enum Which {
|
||||
Eagle7b,
|
||||
World1b5,
|
||||
World3b,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Which {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_id(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
|
||||
Self::World1b5 => "RWKV/rwkv-5-world-1b5",
|
||||
Self::World3b => "RWKV/rwkv-5-world-3b",
|
||||
}
|
||||
}
|
||||
|
||||
fn revision(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Eagle7b => "refs/pr/1",
|
||||
Self::World1b5 | Self::World3b => "refs/pr/2",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "world1b5")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id
|
||||
.unwrap_or_else(|| args.which.model_id().to_string()),
|
||||
RepoType::Model,
|
||||
args.revision
|
||||
.unwrap_or_else(|| args.which.revision().to_string()),
|
||||
));
|
||||
let tokenizer = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => api
|
||||
.model("lmz/candle-rwkv".to_string())
|
||||
.get("rwkv_vocab_v20230424.json")?,
|
||||
};
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::new(tokenizer)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
config,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -8,6 +8,13 @@ Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t).
|
||||
Note that this model is gated so you will have to request access on the Hub in
|
||||
order to be able to use it.
|
||||
|
||||
Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants.
|
||||
|
||||
StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by
|
||||
Candle, so to run it you can download a somewhat compatible
|
||||
[tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true)
|
||||
and pass it via the --tokenizer-file argument.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
|
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
|
||||
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
|
||||
@ -122,6 +122,16 @@ impl TextGeneration {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
|
||||
enum Which {
|
||||
V1Orig,
|
||||
V1,
|
||||
V1Zephyr,
|
||||
V2,
|
||||
V2Zephyr,
|
||||
Code,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -152,15 +162,18 @@ struct Args {
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
#[arg(long, short = 'n', default_value_t = 1000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")]
|
||||
model_id: String,
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long, default_value = "v2")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
@ -207,33 +220,80 @@ fn main() -> Result<()> {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => match args.which {
|
||||
Which::V1Orig => "lmz/candle-stablelm-3b-4e1t".to_string(),
|
||||
Which::V1 => "stabilityai/stablelm-3b-4e1t".to_string(),
|
||||
Which::V1Zephyr => "stabilityai/stablelm-zephyr-3b".to_string(),
|
||||
Which::Code => "stabilityai/stable-code-3b".to_string(),
|
||||
Which::V2 => "stabilityai/stablelm-2-1_6b".to_string(),
|
||||
Which::V2Zephyr => "stabilityai/stablelm-2-zephyr-1_6b".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
None => match args.which {
|
||||
Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::Code => {
|
||||
repo.get("tokenizer.json")?
|
||||
}
|
||||
Which::V2 | Which::V2Zephyr => api
|
||||
.model("lmz/candle-stablelm".to_string())
|
||||
.get("tokenizer-gpt4.json")?,
|
||||
},
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
vec![repo.get("model-q4k.gguf")?]
|
||||
} else {
|
||||
None => match (args.which, args.quantized) {
|
||||
(Which::V1Orig | Which::V1, true) => vec![repo.get("model-q4k.gguf")?],
|
||||
(Which::V2, true) => {
|
||||
let gguf = api
|
||||
.model("lmz/candle-stablelm".to_string())
|
||||
.get("stablelm-2-1_6b-q4k.gguf")?;
|
||||
vec![gguf]
|
||||
}
|
||||
(Which::V2Zephyr, true) => {
|
||||
let gguf = api
|
||||
.model("lmz/candle-stablelm".to_string())
|
||||
.get("stablelm-2-zephyr-1_6b-q4k.gguf")?;
|
||||
vec![gguf]
|
||||
}
|
||||
(Which::V1Zephyr | Which::Code, true) => {
|
||||
anyhow::bail!("Quantized {:?} variant not supported.", args.which)
|
||||
}
|
||||
(Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
}
|
||||
(Which::Code, false) => {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
|
||||
let config = match args.which {
|
||||
Which::V1Orig => Config::stablelm_3b_4e1t(args.use_flash_attn),
|
||||
Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code => {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: Config = serde_json::from_str(&config)?;
|
||||
config.set_use_flash_attn(args.use_flash_attn);
|
||||
config
|
||||
}
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
|
BIN
candle-examples/examples/trocr/assets/noto.png
Normal file
BIN
candle-examples/examples/trocr/assets/noto.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.5 KiB |
@ -10,15 +10,36 @@ use clap::{Parser, ValueEnum};
|
||||
use candle::{DType, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::trocr;
|
||||
use candle_transformers::models::{trocr, vit};
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
mod image_processor;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
Base,
|
||||
Large,
|
||||
#[value(name = "base")]
|
||||
BaseHandwritten,
|
||||
#[value(name = "large")]
|
||||
LargeHandwritten,
|
||||
BasePrinted,
|
||||
LargePrinted,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn repo_and_branch_name(&self) -> (&str, &str) {
|
||||
match self {
|
||||
Self::BaseHandwritten => ("microsoft/trocr-base-handwritten", "refs/pr/3"),
|
||||
Self::LargeHandwritten => ("microsoft/trocr-large-handwritten", "refs/pr/6"),
|
||||
Self::BasePrinted => ("microsoft/trocr-base-printed", "refs/pr/7"),
|
||||
Self::LargePrinted => ("microsoft/trocr-large-printed", "main"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
struct Config {
|
||||
encoder: vit::Config,
|
||||
decoder: trocr::TrOCRConfig,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -34,63 +55,64 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Text to be translated
|
||||
/// The image file to be processed.
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Tokenization config.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use hf_hub::api::sync::Api;
|
||||
let args = Args::parse();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
|
||||
let tokenizer_dec = {
|
||||
let tokenizer = Api::new()?
|
||||
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
||||
.get("tokenizer.json")?;
|
||||
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
let mut tokenizer_dec = {
|
||||
let tokenizer_file = match args.tokenizer {
|
||||
None => api
|
||||
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
||||
.get("tokenizer.json")?,
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(&tokenizer_file).map_err(E::msg)?;
|
||||
TokenOutputStream::new(tokenizer)
|
||||
};
|
||||
|
||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let vb = {
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => match args.which {
|
||||
Which::Base => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"microsoft/trocr-base-handwritten".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/3".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
Which::Large => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"microsoft/trocr-large-handwritten".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/6".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
},
|
||||
None => {
|
||||
let (repo, branch) = args.which.repo_and_branch_name();
|
||||
api.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
branch.to_string(),
|
||||
))
|
||||
.get("model.safetensors")?
|
||||
}
|
||||
};
|
||||
println!("model: {:?}", model);
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
|
||||
};
|
||||
|
||||
let encoder_config = match args.which {
|
||||
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
|
||||
Which::Large => {
|
||||
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
|
||||
}
|
||||
let (encoder_config, decoder_config) = {
|
||||
let (repo, branch) = args.which.repo_and_branch_name();
|
||||
let config_filename = api
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
branch.to_string(),
|
||||
))
|
||||
.get("config.json")?;
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
(config.encoder, config.decoder)
|
||||
};
|
||||
|
||||
let decoder_config = trocr::TrOCRConfig::default();
|
||||
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
|
||||
|
||||
let config = image_processor::ProcessorConfig::default();
|
||||
let processor = image_processor::ViTImageProcessor::new(&config);
|
||||
let processor_config = image_processor::ProcessorConfig::default();
|
||||
let processor = image_processor::ViTImageProcessor::new(&processor_config);
|
||||
|
||||
let image = vec![args.image.as_str()];
|
||||
let image = processor.preprocess(image)?;
|
||||
|
@ -5,12 +5,27 @@ transcribe image text. See the associated [model
|
||||
card](https://huggingface.co/microsoft/trocr-base-printed) for details on
|
||||
the model itself.
|
||||
|
||||
Supported models include:
|
||||
|
||||
- `--which base`: small handwritten OCR model.
|
||||
- `--which large`: large handwritten OCR model.
|
||||
- `--which base-printed`: small printed OCR model.
|
||||
- `--which large-printed`: large printed OCR model.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
cargo run --example trocr --release -- --which base --cpu --image candle-examples/examples/trocr/assets/trocr.png
|
||||
cargo run --example trocr --release -- --image candle-examples/examples/trocr/assets/trocr.png
|
||||
cargo run --example trocr --release -- --which large --image candle-examples/examples/trocr/assets/trocr.png
|
||||
cargo run --example trocr --release -- --which base-printed --image candle-examples/examples/trocr/assets/noto.png
|
||||
cargo run --example trocr --release -- --which large-printed --image candle-examples/examples/trocr/assets/noto.png
|
||||
```
|
||||
|
||||
### Outputs
|
||||
|
||||
```
|
||||
<s> industry , Mr. Brown commented icily . " Let us have a</s>
|
||||
industry , Mr. Brown commented icily . " Let us have a
|
||||
industry , " Mr. Brown commented icily . " Let us have a
|
||||
THE QUICK BROWN FOR JUMPS OVER THE LAY DOG
|
||||
THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG
|
||||
```
|
||||
|
673
candle-examples/examples/whisper-microphone/main.rs
Normal file
673
candle-examples/examples/whisper-microphone/main.rs
Normal file
@ -0,0 +1,673 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, IndexOp, Tensor};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use std::iter;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod multilingual;
|
||||
|
||||
use candle_transformers::models::whisper::{self as m, audio, Config};
|
||||
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub enum Model {
|
||||
Normal(m::model::Whisper),
|
||||
Quantized(m::quantized_model::Whisper),
|
||||
}
|
||||
|
||||
// Maybe we should use some traits rather than doing the dispatch for all these.
|
||||
impl Model {
|
||||
pub fn config(&self) -> &Config {
|
||||
match self {
|
||||
Self::Normal(m) => &m.config,
|
||||
Self::Quantized(m) => &m.config,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.encoder.forward(x, flush),
|
||||
Self::Quantized(m) => m.encoder.forward(x, flush),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decoder_forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: &Tensor,
|
||||
flush: bool,
|
||||
) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.decoder.forward(x, xa, flush),
|
||||
Self::Quantized(m) => m.decoder.forward(x, xa, flush),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.decoder.final_linear(x),
|
||||
Self::Quantized(m) => m.decoder.final_linear(x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecodingResult {
|
||||
tokens: Vec<u32>,
|
||||
text: String,
|
||||
avg_logprob: f64,
|
||||
no_speech_prob: f64,
|
||||
temperature: f64,
|
||||
compression_ratio: f64,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct Segment {
|
||||
start: f64,
|
||||
duration: f64,
|
||||
dr: DecodingResult,
|
||||
}
|
||||
|
||||
struct Decoder {
|
||||
model: Model,
|
||||
rng: rand::rngs::StdRng,
|
||||
task: Option<Task>,
|
||||
timestamps: bool,
|
||||
verbose: bool,
|
||||
tokenizer: Tokenizer,
|
||||
suppress_tokens: Tensor,
|
||||
sot_token: u32,
|
||||
transcribe_token: u32,
|
||||
translate_token: u32,
|
||||
eot_token: u32,
|
||||
no_speech_token: u32,
|
||||
no_timestamps_token: u32,
|
||||
language_token: Option<u32>,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
device: &Device,
|
||||
language_token: Option<u32>,
|
||||
task: Option<Task>,
|
||||
timestamps: bool,
|
||||
verbose: bool,
|
||||
) -> Result<Self> {
|
||||
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
|
||||
// Suppress the notimestamps token when in timestamps mode.
|
||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
|
||||
let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
|
||||
.map(|i| {
|
||||
if model.config().suppress_tokens.contains(&i)
|
||||
|| timestamps && i == no_timestamps_token
|
||||
{
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0f32
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||
let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
|
||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||
let no_speech_token = m::NO_SPEECH_TOKENS
|
||||
.iter()
|
||||
.find_map(|token| token_id(&tokenizer, token).ok());
|
||||
let no_speech_token = match no_speech_token {
|
||||
None => anyhow::bail!("unable to find any non-speech token"),
|
||||
Some(n) => n,
|
||||
};
|
||||
Ok(Self {
|
||||
model,
|
||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||
tokenizer,
|
||||
task,
|
||||
timestamps,
|
||||
verbose,
|
||||
suppress_tokens,
|
||||
sot_token,
|
||||
transcribe_token,
|
||||
translate_token,
|
||||
eot_token,
|
||||
no_speech_token,
|
||||
language_token,
|
||||
no_timestamps_token,
|
||||
})
|
||||
}
|
||||
|
||||
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
||||
let model = &mut self.model;
|
||||
let audio_features = model.encoder_forward(mel, true)?;
|
||||
if self.verbose {
|
||||
println!("audio features: {:?}", audio_features.dims());
|
||||
}
|
||||
let sample_len = model.config().max_target_positions / 2;
|
||||
let mut sum_logprob = 0f64;
|
||||
let mut no_speech_prob = f64::NAN;
|
||||
let mut tokens = vec![self.sot_token];
|
||||
if let Some(language_token) = self.language_token {
|
||||
tokens.push(language_token);
|
||||
}
|
||||
match self.task {
|
||||
None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
|
||||
Some(Task::Translate) => tokens.push(self.translate_token),
|
||||
}
|
||||
if !self.timestamps {
|
||||
tokens.push(self.no_timestamps_token);
|
||||
}
|
||||
for i in 0..sample_len {
|
||||
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
||||
|
||||
// The model expects a batch dim but this inference loop does not handle
|
||||
// it so we add it at this point.
|
||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
|
||||
|
||||
// Extract the no speech probability on the first iteration by looking at the first
|
||||
// token logits and the probability for the according token.
|
||||
if i == 0 {
|
||||
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
no_speech_prob = softmax(&logits, 0)?
|
||||
.i(self.no_speech_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
}
|
||||
|
||||
let (_, seq_len, _) = ys.dims3()?;
|
||||
let logits = model
|
||||
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
|
||||
.i(0)?
|
||||
.i(0)?;
|
||||
// TODO: Besides suppress tokens, we should apply the heuristics from
|
||||
// ApplyTimestampRules, i.e.:
|
||||
// - Timestamps come in pairs, except before EOT.
|
||||
// - Timestamps should be non-decreasing.
|
||||
// - If the sum of the probabilities of timestamps is higher than any other tokens,
|
||||
// only consider timestamps when sampling.
|
||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439
|
||||
let logits = logits.broadcast_add(&self.suppress_tokens)?;
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = softmax(&(&logits / t)?, 0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
distr.sample(&mut self.rng) as u32
|
||||
} else {
|
||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||
logits_v
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, u), (_, v)| u.total_cmp(v))
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
};
|
||||
tokens.push(next_token);
|
||||
let prob = softmax(&logits, candle::D::Minus1)?
|
||||
.i(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
|
||||
break;
|
||||
}
|
||||
sum_logprob += prob.ln();
|
||||
}
|
||||
let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
|
||||
let avg_logprob = sum_logprob / tokens.len() as f64;
|
||||
|
||||
Ok(DecodingResult {
|
||||
tokens,
|
||||
text,
|
||||
avg_logprob,
|
||||
no_speech_prob,
|
||||
temperature: t,
|
||||
compression_ratio: f64::NAN,
|
||||
})
|
||||
}
|
||||
|
||||
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
|
||||
for (i, &t) in m::TEMPERATURES.iter().enumerate() {
|
||||
let dr: Result<DecodingResult> = self.decode(segment, t);
|
||||
if i == m::TEMPERATURES.len() - 1 {
|
||||
return dr;
|
||||
}
|
||||
// On errors, we try again with a different temperature.
|
||||
match dr {
|
||||
Ok(dr) => {
|
||||
let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|
||||
|| dr.avg_logprob < m::LOGPROB_THRESHOLD;
|
||||
if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
|
||||
return Ok(dr);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
println!("Error running at {t}: {err}")
|
||||
}
|
||||
}
|
||||
}
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn run(&mut self, mel: &Tensor, times: Option<(f64, f64)>) -> Result<Vec<Segment>> {
|
||||
let (_, _, content_frames) = mel.dims3()?;
|
||||
let mut seek = 0;
|
||||
let mut segments = vec![];
|
||||
while seek < content_frames {
|
||||
let start = std::time::Instant::now();
|
||||
let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
||||
let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
|
||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
||||
let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
||||
let dr = self.decode_with_fallback(&mel_segment)?;
|
||||
seek += segment_size;
|
||||
if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
|
||||
println!("no speech detected, skipping {seek} {dr:?}");
|
||||
continue;
|
||||
}
|
||||
let segment = Segment {
|
||||
start: time_offset,
|
||||
duration: segment_duration,
|
||||
dr,
|
||||
};
|
||||
if self.timestamps {
|
||||
println!(
|
||||
"{:.1}s -- {:.1}s",
|
||||
segment.start,
|
||||
segment.start + segment.duration,
|
||||
);
|
||||
let mut tokens_to_decode = vec![];
|
||||
let mut prev_timestamp_s = 0f32;
|
||||
for &token in segment.dr.tokens.iter() {
|
||||
if token == self.sot_token || token == self.eot_token {
|
||||
continue;
|
||||
}
|
||||
// The no_timestamp_token is the last before the timestamp ones.
|
||||
if token > self.no_timestamps_token {
|
||||
let timestamp_s = (token - self.no_timestamps_token + 1) as f32 / 50.;
|
||||
if !tokens_to_decode.is_empty() {
|
||||
let text = self
|
||||
.tokenizer
|
||||
.decode(&tokens_to_decode, true)
|
||||
.map_err(E::msg)?;
|
||||
println!(" {:.1}s-{:.1}s: {}", prev_timestamp_s, timestamp_s, text);
|
||||
tokens_to_decode.clear()
|
||||
}
|
||||
prev_timestamp_s = timestamp_s;
|
||||
} else {
|
||||
tokens_to_decode.push(token)
|
||||
}
|
||||
}
|
||||
if !tokens_to_decode.is_empty() {
|
||||
let text = self
|
||||
.tokenizer
|
||||
.decode(&tokens_to_decode, true)
|
||||
.map_err(E::msg)?;
|
||||
if !text.is_empty() {
|
||||
println!(" {:.1}s-...: {}", prev_timestamp_s, text);
|
||||
}
|
||||
tokens_to_decode.clear()
|
||||
}
|
||||
} else {
|
||||
match times {
|
||||
Some((start, end)) => {
|
||||
println!("{:.1}s -- {:.1}s: {}", start, end, segment.dr.text)
|
||||
}
|
||||
None => {
|
||||
println!(
|
||||
"{:.1}s -- {:.1}s: {}",
|
||||
segment.start,
|
||||
segment.start + segment.duration,
|
||||
segment.dr.text,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
if self.verbose {
|
||||
println!("{seek}: {segment:?}, in {:?}", start.elapsed());
|
||||
}
|
||||
segments.push(segment)
|
||||
}
|
||||
Ok(segments)
|
||||
}
|
||||
|
||||
fn set_language_token(&mut self, language_token: Option<u32>) {
|
||||
self.language_token = language_token;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn reset_kv_cache(&mut self) {
|
||||
match &mut self.model {
|
||||
Model::Normal(m) => m.reset_kv_cache(),
|
||||
Model::Quantized(m) => m.reset_kv_cache(),
|
||||
}
|
||||
}
|
||||
|
||||
fn model(&mut self) -> &mut Model {
|
||||
&mut self.model
|
||||
}
|
||||
}
|
||||
|
||||
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
|
||||
match tokenizer.token_to_id(token) {
|
||||
None => candle::bail!("no token-id for {token}"),
|
||||
Some(id) => Ok(id),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Task {
|
||||
Transcribe,
|
||||
Translate,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
|
||||
enum WhichModel {
|
||||
Tiny,
|
||||
#[value(name = "tiny.en")]
|
||||
TinyEn,
|
||||
Base,
|
||||
#[value(name = "base.en")]
|
||||
BaseEn,
|
||||
Small,
|
||||
#[value(name = "small.en")]
|
||||
SmallEn,
|
||||
Medium,
|
||||
#[value(name = "medium.en")]
|
||||
MediumEn,
|
||||
Large,
|
||||
LargeV2,
|
||||
LargeV3,
|
||||
#[value(name = "distil-medium.en")]
|
||||
DistilMediumEn,
|
||||
#[value(name = "distil-large-v2")]
|
||||
DistilLargeV2,
|
||||
}
|
||||
|
||||
impl WhichModel {
|
||||
fn is_multilingual(&self) -> bool {
|
||||
match self {
|
||||
Self::Tiny
|
||||
| Self::Base
|
||||
| Self::Small
|
||||
| Self::Medium
|
||||
| Self::Large
|
||||
| Self::LargeV2
|
||||
| Self::LargeV3
|
||||
| Self::DistilLargeV2 => true,
|
||||
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn model_and_revision(&self) -> (&'static str, &'static str) {
|
||||
match self {
|
||||
Self::Tiny => ("openai/whisper-tiny", "main"),
|
||||
Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"),
|
||||
Self::Base => ("openai/whisper-base", "refs/pr/22"),
|
||||
Self::BaseEn => ("openai/whisper-base.en", "refs/pr/13"),
|
||||
Self::Small => ("openai/whisper-small", "main"),
|
||||
Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"),
|
||||
Self::Medium => ("openai/whisper-medium", "main"),
|
||||
Self::MediumEn => ("openai/whisper-medium.en", "main"),
|
||||
Self::Large => ("openai/whisper-large", "refs/pr/36"),
|
||||
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
|
||||
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
|
||||
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
|
||||
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
/// The model to use, check out available models:
|
||||
/// https://huggingface.co/models?search=whisper
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// The model to be used, can be tiny, small, medium.
|
||||
#[arg(long, default_value = "tiny.en")]
|
||||
model: WhichModel,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Language.
|
||||
#[arg(long)]
|
||||
language: Option<String>,
|
||||
|
||||
/// Task, when no task is specified, the input tokens contain only the sot token which can
|
||||
/// improve things when in no-timestamp mode.
|
||||
#[arg(long)]
|
||||
task: Option<Task>,
|
||||
|
||||
/// Timestamps mode, this is not fully implemented yet.
|
||||
#[arg(long)]
|
||||
timestamps: bool,
|
||||
|
||||
/// Print the full DecodingResult structure rather than just the text.
|
||||
#[arg(long)]
|
||||
verbose: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (default_model, default_revision) = if args.quantized {
|
||||
("lmz/candle-whisper", "main")
|
||||
} else {
|
||||
args.model.model_and_revision()
|
||||
};
|
||||
let default_model = default_model.to_string();
|
||||
let default_revision = default_revision.to_string();
|
||||
let (model_id, revision) = match (args.model_id, args.revision) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let (config, tokenizer, model) = if args.quantized {
|
||||
let ext = match args.model {
|
||||
WhichModel::TinyEn => "tiny-en",
|
||||
WhichModel::Tiny => "tiny",
|
||||
_ => unimplemented!("no quantized support for {:?}", args.model),
|
||||
};
|
||||
(
|
||||
repo.get(&format!("config-{ext}.json"))?,
|
||||
repo.get(&format!("tokenizer-{ext}.json"))?,
|
||||
repo.get(&format!("model-{ext}-q80.gguf"))?,
|
||||
)
|
||||
} else {
|
||||
let config = repo.get("config.json")?;
|
||||
let tokenizer = repo.get("tokenizer.json")?;
|
||||
let model = repo.get("model.safetensors")?;
|
||||
(config, tokenizer, model)
|
||||
};
|
||||
(config, tokenizer, model)
|
||||
};
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let model = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&weights_filename,
|
||||
&device,
|
||||
)?;
|
||||
Model::Quantized(m::quantized_model::Whisper::load(&vb, config.clone())?)
|
||||
} else {
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
|
||||
Model::Normal(m::model::Whisper::load(&vb, config.clone())?)
|
||||
};
|
||||
let language_token = None;
|
||||
let mut dc = Decoder::new(
|
||||
model,
|
||||
tokenizer.clone(),
|
||||
args.seed,
|
||||
&device,
|
||||
language_token,
|
||||
args.task,
|
||||
args.timestamps,
|
||||
args.verbose,
|
||||
)?;
|
||||
|
||||
let mel_bytes = match config.num_mel_bins {
|
||||
80 => include_bytes!("../whisper/melfilters.bytes").as_slice(),
|
||||
128 => include_bytes!("../whisper/melfilters128.bytes").as_slice(),
|
||||
nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"),
|
||||
};
|
||||
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
||||
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
||||
|
||||
// Set up the input device and stream with the default input config.
|
||||
let host = cpal::default_host();
|
||||
let _device = "default";
|
||||
let _device = if _device == "default" {
|
||||
host.default_input_device()
|
||||
} else {
|
||||
host.input_devices()?
|
||||
.find(|x| x.name().map(|y| y == _device).unwrap_or(false))
|
||||
}
|
||||
.expect("failed to find input device");
|
||||
|
||||
let _config = _device
|
||||
.default_input_config()
|
||||
.expect("Failed to get default input config");
|
||||
|
||||
let channel_count = _config.channels() as usize;
|
||||
|
||||
let audio_ring_buffer = Arc::new(Mutex::new(Vec::new()));
|
||||
let audio_ring_buffer_2 = audio_ring_buffer.clone();
|
||||
|
||||
std::thread::spawn(move || loop {
|
||||
let data = record_audio(&_device, &_config, 300).unwrap();
|
||||
audio_ring_buffer.lock().unwrap().extend_from_slice(&data);
|
||||
let max_len = data.len() * 16;
|
||||
let data_len = data.len();
|
||||
let len = audio_ring_buffer.lock().unwrap().len();
|
||||
if len > max_len {
|
||||
let mut data = audio_ring_buffer.lock().unwrap();
|
||||
let new_data = data[data_len..].to_vec();
|
||||
*data = new_data;
|
||||
}
|
||||
});
|
||||
|
||||
// loop to process the audio data forever (until the user stops the program)
|
||||
println!("Transcribing audio...");
|
||||
for (i, _) in iter::repeat(()).enumerate() {
|
||||
std::thread::sleep(std::time::Duration::from_millis(1000));
|
||||
let data = audio_ring_buffer_2.lock().unwrap().clone();
|
||||
let pcm_data: Vec<_> = data[..data.len() / channel_count as usize]
|
||||
.iter()
|
||||
.map(|v| *v as f32 / 32768.)
|
||||
.collect();
|
||||
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
|
||||
let mel_len = mel.len();
|
||||
let mel = Tensor::from_vec(
|
||||
mel,
|
||||
(1, config.num_mel_bins, mel_len / config.num_mel_bins),
|
||||
&device,
|
||||
)?;
|
||||
|
||||
// on the first iteration, we detect the language and set the language token.
|
||||
if i == 0 {
|
||||
let language_token = match (args.model.is_multilingual(), args.language.clone()) {
|
||||
(true, None) => Some(multilingual::detect_language(dc.model(), &tokenizer, &mel)?),
|
||||
(false, None) => None,
|
||||
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
|
||||
Ok(token_id) => Some(token_id),
|
||||
Err(_) => anyhow::bail!("language {language} is not supported"),
|
||||
},
|
||||
(false, Some(_)) => {
|
||||
anyhow::bail!("a language cannot be set for non-multilingual models")
|
||||
}
|
||||
};
|
||||
println!("language_token: {:?}", language_token);
|
||||
dc.set_language_token(language_token);
|
||||
}
|
||||
dc.run(
|
||||
&mel,
|
||||
Some((
|
||||
i as f64,
|
||||
i as f64 + data.len() as f64 / m::SAMPLE_RATE as f64,
|
||||
)),
|
||||
)?;
|
||||
dc.reset_kv_cache();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn record_audio(
|
||||
device: &cpal::Device,
|
||||
config: &cpal::SupportedStreamConfig,
|
||||
milliseconds: u64,
|
||||
) -> Result<Vec<i16>> {
|
||||
let writer = Arc::new(Mutex::new(Vec::new()));
|
||||
let writer_2 = writer.clone();
|
||||
let stream = device.build_input_stream(
|
||||
&config.config(),
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
let processed = data
|
||||
.iter()
|
||||
.map(|v| (v * 32768.0) as i16)
|
||||
.collect::<Vec<i16>>();
|
||||
writer_2.lock().unwrap().extend_from_slice(&processed);
|
||||
},
|
||||
move |err| {
|
||||
eprintln!("an error occurred on stream: {}", err);
|
||||
},
|
||||
None,
|
||||
)?;
|
||||
stream.play()?;
|
||||
std::thread::sleep(std::time::Duration::from_millis(milliseconds));
|
||||
drop(stream);
|
||||
let data = writer.lock().unwrap().clone();
|
||||
let step = 3;
|
||||
let data: Vec<i16> = data.iter().step_by(step).copied().collect();
|
||||
Ok(data)
|
||||
}
|
137
candle-examples/examples/whisper-microphone/multilingual.rs
Normal file
137
candle-examples/examples/whisper-microphone/multilingual.rs
Normal file
@ -0,0 +1,137 @@
|
||||
use crate::{token_id, Model};
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_transformers::models::whisper::{self as m};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const LANGUAGES: [(&str, &str); 99] = [
|
||||
("en", "english"),
|
||||
("zh", "chinese"),
|
||||
("de", "german"),
|
||||
("es", "spanish"),
|
||||
("ru", "russian"),
|
||||
("ko", "korean"),
|
||||
("fr", "french"),
|
||||
("ja", "japanese"),
|
||||
("pt", "portuguese"),
|
||||
("tr", "turkish"),
|
||||
("pl", "polish"),
|
||||
("ca", "catalan"),
|
||||
("nl", "dutch"),
|
||||
("ar", "arabic"),
|
||||
("sv", "swedish"),
|
||||
("it", "italian"),
|
||||
("id", "indonesian"),
|
||||
("hi", "hindi"),
|
||||
("fi", "finnish"),
|
||||
("vi", "vietnamese"),
|
||||
("he", "hebrew"),
|
||||
("uk", "ukrainian"),
|
||||
("el", "greek"),
|
||||
("ms", "malay"),
|
||||
("cs", "czech"),
|
||||
("ro", "romanian"),
|
||||
("da", "danish"),
|
||||
("hu", "hungarian"),
|
||||
("ta", "tamil"),
|
||||
("no", "norwegian"),
|
||||
("th", "thai"),
|
||||
("ur", "urdu"),
|
||||
("hr", "croatian"),
|
||||
("bg", "bulgarian"),
|
||||
("lt", "lithuanian"),
|
||||
("la", "latin"),
|
||||
("mi", "maori"),
|
||||
("ml", "malayalam"),
|
||||
("cy", "welsh"),
|
||||
("sk", "slovak"),
|
||||
("te", "telugu"),
|
||||
("fa", "persian"),
|
||||
("lv", "latvian"),
|
||||
("bn", "bengali"),
|
||||
("sr", "serbian"),
|
||||
("az", "azerbaijani"),
|
||||
("sl", "slovenian"),
|
||||
("kn", "kannada"),
|
||||
("et", "estonian"),
|
||||
("mk", "macedonian"),
|
||||
("br", "breton"),
|
||||
("eu", "basque"),
|
||||
("is", "icelandic"),
|
||||
("hy", "armenian"),
|
||||
("ne", "nepali"),
|
||||
("mn", "mongolian"),
|
||||
("bs", "bosnian"),
|
||||
("kk", "kazakh"),
|
||||
("sq", "albanian"),
|
||||
("sw", "swahili"),
|
||||
("gl", "galician"),
|
||||
("mr", "marathi"),
|
||||
("pa", "punjabi"),
|
||||
("si", "sinhala"),
|
||||
("km", "khmer"),
|
||||
("sn", "shona"),
|
||||
("yo", "yoruba"),
|
||||
("so", "somali"),
|
||||
("af", "afrikaans"),
|
||||
("oc", "occitan"),
|
||||
("ka", "georgian"),
|
||||
("be", "belarusian"),
|
||||
("tg", "tajik"),
|
||||
("sd", "sindhi"),
|
||||
("gu", "gujarati"),
|
||||
("am", "amharic"),
|
||||
("yi", "yiddish"),
|
||||
("lo", "lao"),
|
||||
("uz", "uzbek"),
|
||||
("fo", "faroese"),
|
||||
("ht", "haitian creole"),
|
||||
("ps", "pashto"),
|
||||
("tk", "turkmen"),
|
||||
("nn", "nynorsk"),
|
||||
("mt", "maltese"),
|
||||
("sa", "sanskrit"),
|
||||
("lb", "luxembourgish"),
|
||||
("my", "myanmar"),
|
||||
("bo", "tibetan"),
|
||||
("tl", "tagalog"),
|
||||
("mg", "malagasy"),
|
||||
("as", "assamese"),
|
||||
("tt", "tatar"),
|
||||
("haw", "hawaiian"),
|
||||
("ln", "lingala"),
|
||||
("ha", "hausa"),
|
||||
("ba", "bashkir"),
|
||||
("jw", "javanese"),
|
||||
("su", "sundanese"),
|
||||
];
|
||||
|
||||
/// Returns the token id for the selected language.
|
||||
pub fn detect_language(model: &mut Model, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
|
||||
let (_bsize, _, seq_len) = mel.dims3()?;
|
||||
let mel = mel.narrow(
|
||||
2,
|
||||
0,
|
||||
usize::min(seq_len, model.config().max_source_positions),
|
||||
)?;
|
||||
let device = mel.device();
|
||||
let language_token_ids = LANGUAGES
|
||||
.iter()
|
||||
.map(|(t, _)| token_id(tokenizer, &format!("<|{t}|>")))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let sot_token = token_id(tokenizer, m::SOT_TOKEN)?;
|
||||
let audio_features = model.encoder_forward(&mel, true)?;
|
||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
|
||||
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
let logits = logits.index_select(&language_token_ids, 0)?;
|
||||
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
||||
let probs = probs.to_vec1::<f32>()?;
|
||||
let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::<Vec<_>>();
|
||||
probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for ((_, language), p) in probs.iter().take(5) {
|
||||
println!("{language}: {p}")
|
||||
}
|
||||
let language = token_id(tokenizer, &format!("<|{}|>", probs[0].0 .0))?;
|
||||
Ok(language)
|
||||
}
|
@ -18,6 +18,8 @@ use rand::{distributions::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod multilingual;
|
||||
mod pcm_decode;
|
||||
|
||||
use candle_transformers::models::whisper::{self as m, audio, Config};
|
||||
|
||||
pub enum Model {
|
||||
@ -535,17 +537,10 @@ fn main() -> Result<()> {
|
||||
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
||||
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
||||
|
||||
let mut input = std::fs::File::open(input)?;
|
||||
let (header, data) = wav::read(&mut input)?;
|
||||
println!("loaded wav data: {header:?}");
|
||||
if header.sampling_rate != m::SAMPLE_RATE as u32 {
|
||||
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
|
||||
let (pcm_data, sample_rate) = pcm_decode::pcm_decode(input)?;
|
||||
if sample_rate != m::SAMPLE_RATE as u32 {
|
||||
anyhow::bail!("input file must have a {} sampling rate", m::SAMPLE_RATE)
|
||||
}
|
||||
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
||||
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
||||
.iter()
|
||||
.map(|v| *v as f32 / 32768.)
|
||||
.collect();
|
||||
println!("pcm data loaded {}", pcm_data.len());
|
||||
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
|
||||
let mel_len = mel.len();
|
||||
|
74
candle-examples/examples/whisper/pcm_decode.rs
Normal file
74
candle-examples/examples/whisper/pcm_decode.rs
Normal file
@ -0,0 +1,74 @@
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
|
||||
use symphonia::core::conv::FromSample;
|
||||
|
||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||
where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
|
||||
// Open the media source.
|
||||
let src = std::fs::File::open(path)?;
|
||||
|
||||
// Create the media source stream.
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
|
||||
// Create a probe hint using the file's extension. [Optional]
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
|
||||
// Use the default options for metadata and format readers.
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
|
||||
// Probe the media source.
|
||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||
// Get the instantiated format reader.
|
||||
let mut format = probed.format;
|
||||
|
||||
// Find the first audio track with a known (decodeable) codec.
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
|
||||
.expect("no supported audio tracks");
|
||||
|
||||
// Use the default options for the decoder.
|
||||
let dec_opts: DecoderOptions = Default::default();
|
||||
|
||||
// Create a decoder for the track.
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &dec_opts)
|
||||
.expect("unsupported codec");
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
// The decode loop.
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
// Consume any new metadata that has been read since the last packet.
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
|
||||
// If the packet does not belong to the selected track, skip over it.
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
@ -104,6 +104,7 @@ impl TextGeneration {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
let t = t.replace("<|im_end|>", "\n");
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
@ -216,7 +216,7 @@ fn detect(
|
||||
xs: &Tensor,
|
||||
image_height: usize,
|
||||
classes: usize,
|
||||
anchors: &Vec<(usize, usize)>,
|
||||
anchors: &[(usize, usize)],
|
||||
) -> Result<Tensor> {
|
||||
let (bsize, _channels, height, _width) = xs.dims4()?;
|
||||
let stride = image_height / height;
|
||||
|
@ -40,7 +40,7 @@ impl TokenOutputStream {
|
||||
};
|
||||
self.tokens.push(token);
|
||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_ascii() {
|
||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphabetic() {
|
||||
let text = text.split_at(prev_text.len());
|
||||
self.prev_index = self.current_index;
|
||||
self.current_index = self.tokens.len();
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.3.3"
|
||||
version = "0.4.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.0" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.3.3"
|
||||
version = "0.4.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -71,7 +71,6 @@ __device__ void im2col1d(
|
||||
}
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 3;
|
||||
const size_t b_in = src_dims[0];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t l_in = src_dims[2];
|
||||
|
||||
@ -120,7 +119,6 @@ __device__ void im2col(
|
||||
}
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 4;
|
||||
const size_t b_in = src_dims[0];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t h_in = src_dims[2];
|
||||
const size_t w_in = src_dims[3];
|
||||
@ -225,6 +223,60 @@ __device__ void conv2d(
|
||||
dst[dst_i] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
// Naive implementation of conv_transpose1d.
|
||||
template <typename T, typename A>
|
||||
__device__ void conv_transpose1d(
|
||||
const size_t src_numel,
|
||||
const size_t l_out,
|
||||
const size_t stride,
|
||||
const size_t padding,
|
||||
const size_t out_padding,
|
||||
const size_t dilation,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
const T *kernel,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// src: (b_size, c_in, l_in)
|
||||
// k: (c_in, c_out, l_k)
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 3;
|
||||
const size_t *k_dims = info + 6;
|
||||
const size_t *k_s = info + 9;
|
||||
const size_t l_k = k_dims[2];
|
||||
const size_t c_out = k_dims[1];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t l_in = src_dims[2];
|
||||
if (dst_i >= src_dims[0] * c_out * l_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO
|
||||
const size_t b_idx = dst_i / (l_out * c_out);
|
||||
const size_t dst_c_idx = (dst_i / l_out) % c_out;
|
||||
// NCL layout.
|
||||
const size_t out_x = dst_i % l_out;
|
||||
|
||||
const size_t src_idx0 = b_idx * src_s[0];
|
||||
A d = 0;
|
||||
for (int k_x = 0; k_x < (int)l_k; ++k_x) {
|
||||
// let out_x = inp_x * p.stride + k_x * p.dilation - p.padding;
|
||||
int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
|
||||
if (inp_x_stride < 0 || inp_x_stride % stride) {
|
||||
continue;
|
||||
}
|
||||
int inp_x = inp_x_stride / stride;
|
||||
if (inp_x >= l_in) continue;
|
||||
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
||||
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_x * src_s[2];
|
||||
const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_x * k_s[2];
|
||||
d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
|
||||
}
|
||||
}
|
||||
dst[dst_i] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
// Naive implementation of conv_transpose2d.
|
||||
template <typename T, typename A>
|
||||
__device__ void conv_transpose2d(
|
||||
@ -507,6 +559,22 @@ extern "C" __global__ void FN_NAME( \
|
||||
im2col<TYPENAME>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, info, src, dst); \
|
||||
} \
|
||||
|
||||
#define CONVT1D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t src_numel, \
|
||||
const size_t l_out, \
|
||||
const size_t stride, \
|
||||
const size_t padding, \
|
||||
const size_t out_padding, \
|
||||
const size_t dilation, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *kernel, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
conv_transpose1d<TYPENAME, TYPEACC>(src_numel, l_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t src_numel, \
|
||||
@ -568,6 +636,7 @@ extern "C" __global__ void FN_NAME( \
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
|
||||
CONV2D_OP(__nv_bfloat16, float, conv2d_bf16)
|
||||
CONVT1D_OP(__nv_bfloat16, float, conv_transpose1d_bf16)
|
||||
CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16)
|
||||
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
|
||||
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
||||
@ -579,6 +648,7 @@ IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
CONV1D_OP(__half, float, conv1d_f16)
|
||||
CONV2D_OP(__half, float, conv2d_f16)
|
||||
CONVT1D_OP(__half, float, conv_transpose1d_f16)
|
||||
CONVT2D_OP(__half, float, conv_transpose2d_f16)
|
||||
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
|
||||
MAX_POOL2D_OP(__half, max_pool2d_f16)
|
||||
@ -597,6 +667,11 @@ CONV2D_OP(double, double, conv2d_f64)
|
||||
CONV2D_OP(uint8_t, uint8_t, conv2d_u8)
|
||||
CONV2D_OP(uint32_t, uint32_t, conv2d_u32)
|
||||
|
||||
CONVT1D_OP(float, float, conv_transpose1d_f32)
|
||||
CONVT1D_OP(double, double, conv_transpose1d_f64)
|
||||
CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8)
|
||||
CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32)
|
||||
|
||||
CONVT2D_OP(float, float, conv_transpose2d_f32)
|
||||
CONVT2D_OP(double, double, conv_transpose2d_f64)
|
||||
CONVT2D_OP(uint8_t, uint8_t, conv_transpose2d_u8)
|
||||
|
@ -55,6 +55,11 @@ __device__ __forceinline__ T relu_fwd(T x) {
|
||||
return maxg(x, zero);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T silu_fwd(T x) {
|
||||
return x / (static_cast<T>(1) + expg(-x));
|
||||
}
|
||||
|
||||
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
@ -103,6 +108,7 @@ UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
|
||||
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
|
||||
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
||||
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
||||
#endif
|
||||
|
||||
@ -127,6 +133,7 @@ UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
|
||||
UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x))
|
||||
UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
||||
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
||||
UNARY_OP(__half, usilu_f16, silu_fwd(x))
|
||||
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
||||
#endif
|
||||
|
||||
@ -173,5 +180,7 @@ UNARY_OP(float, urelu_f32, relu_fwd(x))
|
||||
UNARY_OP(double, urelu_f64, relu_fwd(x))
|
||||
UNARY_OP1(float, uelu_f32, elu_fwd(x, param))
|
||||
UNARY_OP1(double, uelu_f64, elu_fwd(x, param))
|
||||
UNARY_OP(float, usilu_f32, silu_fwd(x))
|
||||
UNARY_OP(double, usilu_f64, silu_fwd(x))
|
||||
UNARY_OP1(float, upowf_f32, powg(x, param))
|
||||
UNARY_OP1(double, upowf_f64, powg(x, param))
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.3.3"
|
||||
version = "0.4.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
|
@ -183,7 +183,7 @@ macro_rules! ops{
|
||||
pub mod unary {
|
||||
ops!(
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||
tanh, recip
|
||||
tanh, recip, silu
|
||||
);
|
||||
}
|
||||
pub mod binary {
|
||||
@ -623,8 +623,7 @@ pub fn call_reduce_strided(
|
||||
strides,
|
||||
elements_to_sum,
|
||||
(input, input_offset),
|
||||
output,
|
||||
out_length
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
@ -1365,13 +1364,12 @@ pub fn call_gemm(
|
||||
// TODO byte_stride_d
|
||||
let byte_stride_d = 0;
|
||||
|
||||
let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
|
||||
for i in 0..b {
|
||||
buffer.push((i * byte_stride_a) as u64);
|
||||
buffer.push((i * byte_stride_b) as u64);
|
||||
buffer.push((i * byte_stride_c) as u64);
|
||||
buffer.push((i * byte_stride_d) as u64);
|
||||
}
|
||||
let buffer: Vec<u64> = vec![
|
||||
byte_stride_a as _,
|
||||
byte_stride_b as _,
|
||||
byte_stride_c as _,
|
||||
byte_stride_d as _,
|
||||
];
|
||||
encoder.set_bytes(
|
||||
10,
|
||||
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
||||
|
Binary file not shown.
@ -1,18 +1,16 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_limits>
|
||||
using namespace metal;
|
||||
|
||||
// TODO: Load multiple values per thread to improve memory bandwidth utilization
|
||||
// static constant constexpr uint VALUES_PER_THREAD = 1;
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant const size_t &num_dims,
|
||||
constant const size_t *dims,
|
||||
constant const size_t *strides
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
@ -21,637 +19,288 @@ METAL_FUNC uint get_strided_index(
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
template <typename V>
|
||||
struct Indexed {
|
||||
uint i;
|
||||
V val;
|
||||
typedef V type;
|
||||
|
||||
constexpr Indexed<V>() thread = default;
|
||||
constexpr Indexed<V>() threadgroup = default;
|
||||
constexpr Indexed<V>() device = default;
|
||||
constexpr Indexed<V>() constant = default;
|
||||
|
||||
constexpr Indexed<V>(uint _i, V _val) : i(_i), val(_val) {}
|
||||
|
||||
template <typename U, typename = typename enable_if<is_convertible_v<U, V>>::type>
|
||||
constexpr Indexed<V>(uint _i, U _val) : i(_i), val(static_cast<U>(_val)) {}
|
||||
|
||||
template <typename U>
|
||||
constexpr Indexed<V>(const thread Indexed<U> &iv): Indexed<V>(iv.i, iv.val) {}
|
||||
|
||||
template <typename U>
|
||||
constexpr Indexed<V>(const threadgroup Indexed<V> &iv): Indexed<V>(iv.i, iv.val) {}
|
||||
|
||||
Indexed<V> operator=(const thread Indexed<V> &iv) thread {
|
||||
this->i = iv.i;
|
||||
this->val = iv.val;
|
||||
return *this;
|
||||
}
|
||||
Indexed<V> operator=(const thread Indexed<V> &iv) threadgroup {
|
||||
this->i = iv.i;
|
||||
this->val = iv.val;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename V>
|
||||
constexpr METAL_FUNC bool operator<(Indexed<V> lhs, Indexed<V> rhs) {
|
||||
return lhs.val < rhs.val || (lhs.val == rhs.val && lhs.i < rhs.i);
|
||||
}
|
||||
|
||||
template<typename V>
|
||||
constexpr METAL_FUNC bool operator>(Indexed<V> lhs, Indexed<V> rhs) {
|
||||
return lhs.val > rhs.val || (lhs.val == rhs.val && lhs.i > rhs.i);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
struct _numeric_limits_impl<Indexed<T>> {
|
||||
static constexpr Indexed<T> lowest() {
|
||||
return Indexed<T>(0, numeric_limits<T>::lowest());
|
||||
}
|
||||
|
||||
static constexpr Indexed<T> max() {
|
||||
return Indexed<T>(0, numeric_limits<T>::max());
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
// Metal does not have simd_shuffle_down for bfloat16
|
||||
// TODO: Check if volatile threadgroup memory reduction is faster than simd_shuffle_down for bfloat
|
||||
bfloat simd_shuffle_down(bfloat value, ushort delta) {
|
||||
return static_cast<bfloat>(__metal_simd_shuffle_down(static_cast<float>(value), delta));
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename V>
|
||||
Indexed<V> simd_shuffle_down(Indexed<V> iv, ushort delta) {
|
||||
return Indexed<V>(
|
||||
simd_shuffle_down(iv.i, delta),
|
||||
simd_shuffle_down(iv.val, delta)
|
||||
);
|
||||
}
|
||||
|
||||
#define impl_reduction_op_helper(name, op, init_val, __result_type__) \
|
||||
template<typename T, typename R = __result_type__> \
|
||||
struct name { \
|
||||
static constexpr T init() { \
|
||||
return init_val; \
|
||||
} \
|
||||
METAL_FUNC R operator()(T a, T b) { \
|
||||
return op; \
|
||||
} \
|
||||
METAL_FUNC R operator()(thread const T& a, thread const T& b) const { \
|
||||
return op; \
|
||||
} \
|
||||
METAL_FUNC R operator()(threadgroup const T& a, threadgroup const T& b) const { \
|
||||
return op; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define impl_reduction_op(name, op, init_val) \
|
||||
impl_reduction_op_helper(name, op, init_val, T);
|
||||
|
||||
#define impl_arg_reduction_op(name, op, init_val) \
|
||||
impl_reduction_op_helper(name, op, init_val, tuple<bool, Indexed<T>>);
|
||||
|
||||
impl_reduction_op(Sum, a + b, 0);
|
||||
impl_reduction_op(Mul, a * b, 1);
|
||||
impl_reduction_op(Min, a < b ? a : b, numeric_limits<T>::max());
|
||||
impl_reduction_op(Max, a > b ? a : b, numeric_limits<T>::lowest());
|
||||
#undef impl_reduction_op
|
||||
|
||||
// These are used when loading elements from global memory into shared memory.
|
||||
// They let us use the same code for both indexed and non-indexed types.
|
||||
template<typename Op, typename T, typename U>
|
||||
METAL_FUNC T apply_operator(Op op, size_t _idx, T a, U b) {
|
||||
return op(a, static_cast<T>(b));
|
||||
}
|
||||
|
||||
template<typename Op, typename T, typename U>
|
||||
METAL_FUNC Indexed<T> apply_operator(Op op, size_t idx, Indexed<T> a, U b) {
|
||||
return op(a, Indexed<T>(idx, b));
|
||||
}
|
||||
|
||||
// Load elements from global memory into shared memory.
|
||||
// Handles both indexed and non-indexed types by using apply_operator.
|
||||
template<
|
||||
typename T,
|
||||
typename R,
|
||||
typename ReductionOp,
|
||||
ushort BLOCKSIZE,
|
||||
bool STRIDED = false
|
||||
>
|
||||
METAL_FUNC R load_from_global(
|
||||
R value,
|
||||
constant size_t &num_elements,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
const device T *src,
|
||||
const ushort offset,
|
||||
threadgroup R shared[BLOCKSIZE],
|
||||
const ushort tid
|
||||
) {
|
||||
ReductionOp op;
|
||||
|
||||
size_t stop_idx = offset + el_to_sum_per_block;
|
||||
size_t idx = offset + tid;
|
||||
|
||||
while (idx < stop_idx) {
|
||||
if (STRIDED) {
|
||||
idx = get_strided_index(idx, num_dims, dims, strides);
|
||||
}
|
||||
value = apply_operator(op, idx, value, src[idx]);
|
||||
idx += BLOCKSIZE;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
constant int THREADGROUP_SIZE = 2048;
|
||||
|
||||
|
||||
// Convenience function for when we don't need to sum over multiple dimensions.
|
||||
template<
|
||||
typename T,
|
||||
typename R,
|
||||
typename ReductionOp,
|
||||
ushort BLOCKSIZE
|
||||
>
|
||||
METAL_FUNC R load_from_global(
|
||||
R value,
|
||||
constant size_t &num_elements,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
const device T *src,
|
||||
const size_t offset,
|
||||
threadgroup R shared[BLOCKSIZE],
|
||||
const ushort tid
|
||||
) {
|
||||
return load_from_global<T, R, ReductionOp, BLOCKSIZE, false>(
|
||||
value,
|
||||
num_elements,
|
||||
// Dummy values for num_dims, dims, and strides
|
||||
num_elements,
|
||||
nullptr,
|
||||
nullptr,
|
||||
// end dummy values
|
||||
el_to_sum_per_block,
|
||||
src,
|
||||
offset,
|
||||
shared,
|
||||
tid
|
||||
);
|
||||
}
|
||||
|
||||
// Since we are using simd_shuffle_down with a BLOCKSIZE guard we don't need any barriers.
|
||||
template<typename ReductionOp, ushort BLOCKSIZE, typename T>
|
||||
METAL_FUNC T simdgroup_reduce(T value) {
|
||||
ReductionOp op;
|
||||
if (BLOCKSIZE >= 32) value = op(value, simd_shuffle_down(value, 16));
|
||||
if (BLOCKSIZE >= 16) value = op(value, simd_shuffle_down(value, 8));
|
||||
if (BLOCKSIZE >= 8) value = op(value, simd_shuffle_down(value, 4));
|
||||
if (BLOCKSIZE >= 4) value = op(value, simd_shuffle_down(value, 2));
|
||||
if (BLOCKSIZE >= 2) value = op(value, simd_shuffle_down(value, 1));
|
||||
return value;
|
||||
}
|
||||
|
||||
template<
|
||||
typename ReductionOp,
|
||||
ushort BLOCKSIZE,
|
||||
typename T
|
||||
>
|
||||
METAL_FUNC T threadgroup_reduce(
|
||||
threadgroup T shared[BLOCKSIZE],
|
||||
ushort tid [[ thread_index_in_threadgroup ]]
|
||||
) {
|
||||
ReductionOp op;
|
||||
|
||||
// Fully unrolled reduction loop from BLOCKSIZE down to 64.
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint s = BLOCKSIZE / 2; s >= 64; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared[tid] = op(shared[tid], shared[tid + s]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
if (tid < 32) {
|
||||
// Last shared memory reduce can be done without tid < s check.
|
||||
if (BLOCKSIZE >= 64) {
|
||||
shared[tid] = op(shared[tid], shared[tid + 32]);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
// Remaining 32 threads can be reduced with simdgroup_reduce.
|
||||
shared[tid] = simdgroup_reduce<ReductionOp, BLOCKSIZE>(shared[tid]);
|
||||
}
|
||||
|
||||
return shared[tid];
|
||||
}
|
||||
|
||||
// Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris
|
||||
template<
|
||||
typename T,
|
||||
typename R,
|
||||
typename ReductionOp,
|
||||
ushort BLOCKSIZE,
|
||||
bool STRIDED = false
|
||||
>
|
||||
METAL_FUNC void reduce(
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const T *src,
|
||||
device R *dst,
|
||||
constant size_t &num_elements,
|
||||
threadgroup T shared[BLOCKSIZE],
|
||||
ushort tid [[ thread_index_in_threadgroup ]],
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]]
|
||||
) {
|
||||
// Initialize shared memory for current thread to correct value for reduction operation
|
||||
shared[tid] = ReductionOp::init();
|
||||
|
||||
// Calcluate offset for the threadgroup of current thread
|
||||
ushort offset = dst_id * el_to_sum_per_block;
|
||||
R initial = ReductionOp::init();
|
||||
// Load with reduction from global memory into shared memory
|
||||
shared[tid] = load_from_global<T, R, ReductionOp, BLOCKSIZE, STRIDED>(
|
||||
initial,
|
||||
num_elements,
|
||||
num_dims,
|
||||
dims,
|
||||
strides,
|
||||
el_to_sum_per_block,
|
||||
src,
|
||||
offset,
|
||||
shared,
|
||||
tid
|
||||
);
|
||||
// Threadgroup barrier is needed to ensure that all threads have written to shared memory
|
||||
// Memory space is not shared between threadgroups so we can use the mem_none flag for all threadgroup barriers.
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Complete reduction
|
||||
R value = threadgroup_reduce<ReductionOp, BLOCKSIZE>(shared, tid);
|
||||
|
||||
if (tid == 0) dst[dst_id] = value;
|
||||
}
|
||||
#define ARGMIN(NAME, T, MAXVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MAXVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
bool notset = true; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || src[strided_i] < shared_memory[tid]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
|
||||
|
||||
#define reduce_case(OP, T, R, N) \
|
||||
case N: { \
|
||||
threadgroup R shared[N]; \
|
||||
reduce<T, R, OP<R>, N, STRIDED>( \
|
||||
num_dims, \
|
||||
dims, \
|
||||
strides, \
|
||||
el_to_sum_per_block, \
|
||||
src, \
|
||||
dst, \
|
||||
num_elements, \
|
||||
shared, \
|
||||
tid, \
|
||||
dst_id); \
|
||||
break; \
|
||||
}
|
||||
#define ARGMAX(NAME, T, MINVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MINVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
bool notset = true; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || shared_memory[tid] < src[strided_i]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define impl_reduce(OP, NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
constant size_t &num_elements, \
|
||||
ushort tid [[ thread_index_in_threadgroup ]], \
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]], \
|
||||
ushort block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
constant size_t *dims = {}; \
|
||||
constant size_t *strides = {}; \
|
||||
const bool STRIDED = false; \
|
||||
switch (block_dim) { \
|
||||
reduce_case(OP, T, T, 2048); \
|
||||
reduce_case(OP, T, T, 1024); \
|
||||
reduce_case(OP, T, T, 512); \
|
||||
reduce_case(OP, T, T, 256); \
|
||||
reduce_case(OP, T, T, 128); \
|
||||
reduce_case(OP, T, T, 64); \
|
||||
reduce_case(OP, T, T, 32); \
|
||||
reduce_case(OP, T, T, 16); \
|
||||
reduce_case(OP, T, T, 8); \
|
||||
reduce_case(OP, T, T, 4); \
|
||||
reduce_case(OP, T, T, 2); \
|
||||
reduce_case(OP, T, T, 1); \
|
||||
} \
|
||||
} \
|
||||
kernel void NAME##_strided( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
constant size_t &num_elements, \
|
||||
ushort tid [[ thread_index_in_threadgroup ]], \
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]], \
|
||||
ushort block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
const bool STRIDED = true; \
|
||||
switch (block_dim) { \
|
||||
reduce_case(OP, T, T, 2048); \
|
||||
reduce_case(OP, T, T, 1024); \
|
||||
reduce_case(OP, T, T, 512); \
|
||||
reduce_case(OP, T, T, 256); \
|
||||
reduce_case(OP, T, T, 128); \
|
||||
reduce_case(OP, T, T, 64); \
|
||||
reduce_case(OP, T, T, 32); \
|
||||
reduce_case(OP, T, T, 16); \
|
||||
reduce_case(OP, T, T, 8); \
|
||||
reduce_case(OP, T, T, 4); \
|
||||
reduce_case(OP, T, T, 2); \
|
||||
reduce_case(OP, T, T, 1); \
|
||||
} \
|
||||
}
|
||||
#define REDUCE(FN, NAME, T, START) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = START; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = src[strided_i]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = shared_memory[tid + s]; \
|
||||
shared_memory[tid] = FN; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
} \
|
||||
|
||||
template<
|
||||
typename T,
|
||||
typename ReductionOp,
|
||||
ushort BLOCKSIZE,
|
||||
bool STRIDED
|
||||
>
|
||||
METAL_FUNC void reduce(
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const T *src,
|
||||
device uint *dst,
|
||||
constant size_t &num_elements,
|
||||
threadgroup Indexed<T> shared[BLOCKSIZE],
|
||||
ushort tid [[ thread_index_in_threadgroup ]],
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]]
|
||||
) {
|
||||
// Initialize shared memory for current thread to correct value for reduction operation
|
||||
shared[tid] = ReductionOp::init();
|
||||
|
||||
// Calcluate offset for the threadgroup of current thread
|
||||
ushort offset = dst_id * el_to_sum_per_block;
|
||||
Indexed<T> initial = ReductionOp::init();
|
||||
// Load with reduction from global memory into shared memory
|
||||
shared[tid] = load_from_global<T, Indexed<T>, ReductionOp, BLOCKSIZE, STRIDED>(
|
||||
initial,
|
||||
num_elements,
|
||||
num_dims,
|
||||
dims,
|
||||
strides,
|
||||
el_to_sum_per_block,
|
||||
src,
|
||||
offset,
|
||||
shared,
|
||||
tid
|
||||
);
|
||||
// Threadgroup barrier is needed to ensure that all threads have written to shared memory
|
||||
// Memory space is not shared between threadgroups so we can use the mem_none flag for all threadgroup barriers.
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
\
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = -INFINITY; \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
\
|
||||
float tmp = -INFINITY; \
|
||||
while (idx < stop_idx) { \
|
||||
tmp = MAX(tmp, float(src[idx])); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
shared_memory[tid] = tmp; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
/* wait for shared_memory[0] to be filled */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float _max = shared_memory[0]; \
|
||||
\
|
||||
/* prevent tid=0 from overwriting _max before other threads have written it */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
shared_memory[tid] = 0; \
|
||||
\
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
const float val = exp(float(src[idx]) - _max); \
|
||||
dst[idx] = T(val); \
|
||||
shared_memory[tid] += val; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] += shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
const T inv_acc = T(1.0/shared_memory[0]); \
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
dst[idx] *= inv_acc; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
} \
|
||||
|
||||
// Complete reduction
|
||||
Indexed<T> value = threadgroup_reduce<ReductionOp, BLOCKSIZE, Indexed<T>>(shared, tid);
|
||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
||||
REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
|
||||
REDUCE(x * y, fast_mul_f32_strided, float, 1)
|
||||
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
|
||||
REDUCE(x * y, fast_mul_f16_strided, half, 1)
|
||||
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
|
||||
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
|
||||
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
|
||||
REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
|
||||
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
|
||||
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
|
||||
REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
|
||||
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
||||
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
||||
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
||||
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
|
||||
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
||||
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
||||
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
||||
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
|
||||
|
||||
// Return index of reduce result
|
||||
if (tid == 0) dst[dst_id] = value.i;
|
||||
}
|
||||
|
||||
#define arg_reduce_case(OP, T, N) \
|
||||
case N: { \
|
||||
threadgroup Indexed<T> shared[N]; \
|
||||
reduce<T, OP<Indexed<T>>, N, STRIDED>( \
|
||||
num_dims, \
|
||||
dims, \
|
||||
strides, \
|
||||
el_to_sum_per_block, \
|
||||
src, \
|
||||
dst, \
|
||||
num_elements, \
|
||||
shared, \
|
||||
tid, \
|
||||
dst_id); \
|
||||
break; \
|
||||
}
|
||||
|
||||
#define impl_arg_reduce(OP, NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
constant size_t &num_elements, \
|
||||
ushort tid [[ thread_index_in_threadgroup ]], \
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]], \
|
||||
ushort block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
constant size_t *dims = {}; \
|
||||
constant size_t *strides = {}; \
|
||||
const bool STRIDED = false; \
|
||||
switch (block_dim) { \
|
||||
arg_reduce_case(OP, T, 2048); \
|
||||
arg_reduce_case(OP, T, 1024); \
|
||||
arg_reduce_case(OP, T, 512); \
|
||||
arg_reduce_case(OP, T, 256); \
|
||||
arg_reduce_case(OP, T, 128); \
|
||||
arg_reduce_case(OP, T, 64); \
|
||||
arg_reduce_case(OP, T, 32); \
|
||||
arg_reduce_case(OP, T, 16); \
|
||||
arg_reduce_case(OP, T, 8); \
|
||||
arg_reduce_case(OP, T, 4); \
|
||||
arg_reduce_case(OP, T, 2); \
|
||||
arg_reduce_case(OP, T, 1); \
|
||||
} \
|
||||
} \
|
||||
kernel void NAME##_strided( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
constant size_t &num_elements, \
|
||||
ushort tid [[ thread_index_in_threadgroup ]], \
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]], \
|
||||
ushort block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
const bool STRIDED = true; \
|
||||
switch (block_dim) { \
|
||||
arg_reduce_case(OP, T, 2048); \
|
||||
arg_reduce_case(OP, T, 1024); \
|
||||
arg_reduce_case(OP, T, 512); \
|
||||
arg_reduce_case(OP, T, 256); \
|
||||
arg_reduce_case(OP, T, 128); \
|
||||
arg_reduce_case(OP, T, 64); \
|
||||
arg_reduce_case(OP, T, 32); \
|
||||
arg_reduce_case(OP, T, 16); \
|
||||
arg_reduce_case(OP, T, 8); \
|
||||
arg_reduce_case(OP, T, 4); \
|
||||
arg_reduce_case(OP, T, 2); \
|
||||
arg_reduce_case(OP, T, 1); \
|
||||
} \
|
||||
}
|
||||
|
||||
template<
|
||||
typename T,
|
||||
typename ACC = float,
|
||||
ushort BLOCKSIZE
|
||||
>
|
||||
METAL_FUNC void softmax(
|
||||
constant size_t &src_numel,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
const device T *src,
|
||||
device T *dst,
|
||||
threadgroup ACC shared[BLOCKSIZE],
|
||||
|
||||
ushort tid [[ thread_index_in_threadgroup ]],
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]]
|
||||
) {
|
||||
// Initialize shared memory for current thread to lowest value
|
||||
shared[tid] = numeric_limits<ACC>::lowest();
|
||||
|
||||
// Calcluate offset for the threadgroup of current thread
|
||||
size_t offset = dst_id * el_to_sum_per_block;
|
||||
ACC initial = numeric_limits<ACC>::lowest();
|
||||
// Load with reduction from global memory into shared memory
|
||||
shared[tid] = load_from_global<T, ACC, Max<ACC>, BLOCKSIZE>(
|
||||
initial,
|
||||
src_numel,
|
||||
el_to_sum_per_block,
|
||||
src,
|
||||
offset,
|
||||
shared,
|
||||
tid
|
||||
);
|
||||
// Threadgroup barrier is needed to ensure that all threads have written to shared memory
|
||||
// Memory space is not shared between threadgroups so we can use the mem_none flag for all threadgroup barriers.
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Reduce shared memory to find max value
|
||||
threadgroup_reduce<Max<ACC>, BLOCKSIZE>(shared, tid);
|
||||
ACC max_result = shared[0];
|
||||
|
||||
// Ensure all threads have max_result = shared[0] before we set shared[0] = 0.
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
shared[tid] = 0;
|
||||
|
||||
// Calculate softmax values
|
||||
size_t stop_idx = min(offset + el_to_sum_per_block, src_numel);
|
||||
size_t idx = offset + tid;
|
||||
while (idx < stop_idx) {
|
||||
const ACC val = exp(ACC(src[idx]) - max_result);
|
||||
dst[idx] = T(val);
|
||||
shared[tid] += val;
|
||||
idx += BLOCKSIZE;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
threadgroup_reduce<Sum<ACC>, BLOCKSIZE>(shared, tid);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
const T inv_acc = T(1.0/shared[0]);
|
||||
idx = offset + tid;
|
||||
while (idx < stop_idx) {
|
||||
dst[idx] *= inv_acc;
|
||||
idx += BLOCKSIZE;
|
||||
}
|
||||
}
|
||||
|
||||
#define softmax_case(T, ACC, N) \
|
||||
case N: { \
|
||||
threadgroup ACC shared[N]; \
|
||||
softmax<T, ACC, N>( \
|
||||
src_numel, \
|
||||
el_to_sum_per_block, \
|
||||
src, \
|
||||
dst, \
|
||||
shared, \
|
||||
tid, \
|
||||
dst_id); \
|
||||
break; \
|
||||
}
|
||||
|
||||
#define impl_softmax(NAME, T, ACC) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
\
|
||||
ushort tid [[ thread_index_in_threadgroup ]], \
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]], \
|
||||
ushort block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
switch (block_dim) { \
|
||||
softmax_case(T, ACC, 2048); \
|
||||
softmax_case(T, ACC, 1024); \
|
||||
softmax_case(T, ACC, 512); \
|
||||
softmax_case(T, ACC, 256); \
|
||||
softmax_case(T, ACC, 128); \
|
||||
softmax_case(T, ACC, 64); \
|
||||
softmax_case(T, ACC, 32); \
|
||||
softmax_case(T, ACC, 16); \
|
||||
softmax_case(T, ACC, 8); \
|
||||
softmax_case(T, ACC, 4); \
|
||||
softmax_case(T, ACC, 2); \
|
||||
softmax_case(T, ACC, 1); \
|
||||
} \
|
||||
}
|
||||
|
||||
impl_reduce(Sum, fast_sum_f32, float)
|
||||
impl_reduce(Sum, fast_sum_u32, uint)
|
||||
impl_reduce(Sum, fast_sum_f16, half)
|
||||
impl_reduce(Sum, fast_sum_u8, uint8_t)
|
||||
|
||||
impl_reduce(Mul, fast_mul_f32, float)
|
||||
impl_reduce(Mul, fast_mul_u32, uint)
|
||||
impl_reduce(Mul, fast_mul_f16, half)
|
||||
impl_reduce(Mul, fast_mul_u8, uint8_t)
|
||||
|
||||
impl_reduce(Max, fast_max_f32, float)
|
||||
impl_reduce(Max, fast_max_u32, uint)
|
||||
impl_reduce(Max, fast_max_f16, half)
|
||||
impl_reduce(Max, fast_max_u8, uint8_t)
|
||||
|
||||
impl_reduce(Min, fast_min_f32, float)
|
||||
impl_reduce(Min, fast_min_u32, uint)
|
||||
impl_reduce(Min, fast_min_f16, half)
|
||||
impl_reduce(Min, fast_min_u8, uint8_t)
|
||||
|
||||
impl_arg_reduce(Min, fast_argmin_f32, float)
|
||||
impl_arg_reduce(Min, fast_argmin_f16, half)
|
||||
impl_arg_reduce(Min, fast_argmin_u32, uint)
|
||||
impl_arg_reduce(Min, fast_argmin_u8, uint8_t)
|
||||
|
||||
impl_arg_reduce(Max, fast_argmax_f32, float)
|
||||
impl_arg_reduce(Max, fast_argmax_f16, half)
|
||||
impl_arg_reduce(Max, fast_argmax_u32, uint)
|
||||
impl_arg_reduce(Max, fast_argmax_u8, uint8_t)
|
||||
|
||||
impl_softmax(softmax_f32, float, float)
|
||||
impl_softmax(softmax_f16, half, float)
|
||||
SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
impl_reduce(Sum, fast_sum_i64, int64_t)
|
||||
impl_reduce(Mul, fast_mul_i64, int64_t)
|
||||
impl_reduce(Min, fast_min_i64, int64_t)
|
||||
impl_reduce(Max, fast_max_i64, int64_t)
|
||||
|
||||
impl_arg_reduce(Min, fast_argmin_i64, int64_t)
|
||||
impl_arg_reduce(Max, fast_argmax_i64, int64_t)
|
||||
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
|
||||
REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
|
||||
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
|
||||
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
impl_reduce(Sum, fast_sum_bf16, bfloat)
|
||||
impl_reduce(Mul, fast_mul_bf16, bfloat)
|
||||
impl_reduce(Max, fast_max_bf16, bfloat)
|
||||
impl_reduce(Min, fast_min_bf16, bfloat)
|
||||
|
||||
impl_arg_reduce(Min, fast_argmin_bf16, bfloat)
|
||||
impl_arg_reduce(Max, fast_argmax_bf16, bfloat)
|
||||
|
||||
impl_softmax(softmax_bf16, bfloat, float)
|
||||
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
||||
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
#endif
|
||||
|
@ -1,346 +0,0 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
constant int THREADGROUP_SIZE = 2048;
|
||||
|
||||
|
||||
#define ARGMIN(NAME, T, MAXVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MAXVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
bool notset = true; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || src[strided_i] < shared_memory[tid]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
|
||||
|
||||
#define ARGMAX(NAME, T, MINVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MINVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
bool notset = true; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || shared_memory[tid] < src[strided_i]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define REDUCE(FN, NAME, T, START) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = START; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = src[strided_i]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = shared_memory[tid + s]; \
|
||||
shared_memory[tid] = FN; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
} \
|
||||
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
\
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = -INFINITY; \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
\
|
||||
float tmp = -INFINITY; \
|
||||
while (idx < stop_idx) { \
|
||||
tmp = MAX(tmp, float(src[idx])); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
shared_memory[tid] = tmp; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
/* wait for shared_memory[0] to be filled */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float _max = shared_memory[0]; \
|
||||
\
|
||||
/* prevent tid=0 from overwriting _max before other threads have written it */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
shared_memory[tid] = 0; \
|
||||
\
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
const float val = exp(float(src[idx]) - _max); \
|
||||
dst[idx] = T(val); \
|
||||
shared_memory[tid] += val; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] += shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
const T inv_acc = T(1.0/shared_memory[0]); \
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
dst[idx] *= inv_acc; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
} \
|
||||
|
||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
||||
REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
|
||||
REDUCE(x * y, fast_mul_f32_strided, float, 1)
|
||||
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
|
||||
REDUCE(x * y, fast_mul_f16_strided, half, 1)
|
||||
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
|
||||
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
|
||||
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
|
||||
REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
|
||||
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
|
||||
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
|
||||
REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
|
||||
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
||||
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
||||
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
||||
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
|
||||
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
||||
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
||||
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
||||
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_f32, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32, uint, 0)
|
||||
REDUCE(x + y, fast_sum_f16, half, 0)
|
||||
REDUCE(x + y, fast_sum_u8, uint8_t, 0)
|
||||
REDUCE(x * y, fast_mul_f32, float, 1)
|
||||
REDUCE(x * y, fast_mul_u32, uint, 1)
|
||||
REDUCE(x * y, fast_mul_f16, half, 1)
|
||||
REDUCE(MAX(x, y), fast_max_f32, float, -HUGE_VALF)
|
||||
REDUCE(MAX(x, y), fast_max_u32, uint, 0)
|
||||
REDUCE(MAX(x, y), fast_max_f16, half, -HUGE_VALH)
|
||||
REDUCE(MAX(x, y), fast_max_u8, uint8_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_f32, float, HUGE_VALF)
|
||||
REDUCE(MIN(x, y), fast_min_u32, uint, 0xFFFFFFFF)
|
||||
REDUCE(MIN(x, y), fast_min_f16, half, HUGE_VALH)
|
||||
REDUCE(MIN(x, y), fast_min_u8, uint8_t, 0xFF)
|
||||
ARGMIN(fast_argmin_f32, float, HUGE_VALF)
|
||||
ARGMIN(fast_argmin_f16, half, HUGE_VALH)
|
||||
ARGMIN(fast_argmin_u32, uint, 0xFFFFFFFF)
|
||||
ARGMIN(fast_argmin_u8, uint8_t, 0xFF)
|
||||
ARGMAX(fast_argmax_f32, float, -HUGE_VALF)
|
||||
ARGMAX(fast_argmax_f16, half, -HUGE_VALH)
|
||||
ARGMAX(fast_argmax_u32, uint, 0)
|
||||
ARGMAX(fast_argmax_u8, uint8_t, 0)
|
||||
|
||||
SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
|
||||
REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
|
||||
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
|
||||
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_i64, int64_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_i64, int64_t, INT_MAX)
|
||||
REDUCE(MAX(x, y), fast_max_i64, int64_t, INT_MIN)
|
||||
ARGMIN(fast_argmin_i64, int64_t, INT_MAX)
|
||||
ARGMAX(fast_argmax_i64, int64_t, INT_MIN)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
REDUCE(x + y, fast_sum_bf16_strided, bfloat, 0)
|
||||
REDUCE(x * y, fast_mul_bf16_strided, bfloat, 1)
|
||||
REDUCE(MAX(x, y), fast_max_bf16_strided, bfloat, -HUGE_VALBF)
|
||||
REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF)
|
||||
ARGMIN(fast_argmin_bf16_strided, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16_strided, bfloat, -HUGE_VALBF)
|
||||
|
||||
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
||||
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
#endif
|
@ -231,6 +231,25 @@ fn gelu_f32() {
|
||||
assert_eq!(approx(results, 3), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn silu_f16() {
|
||||
let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
let expected: Vec<f32> = vec![-0.0, -0.27, 0.0, 0.73, 1.76, 2.86, 10.0, 20.0];
|
||||
let results = run(&v, unary::contiguous::silu::HALF);
|
||||
assert_eq!(approx_f16(results, 2), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn silu_f32() {
|
||||
let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
|
||||
let expected: Vec<f32> = vec![-0.0, -0.269, 0.0, 0.731, 1.762, 2.858, 10.0, 20.0];
|
||||
let results = run(&v, unary::contiguous::silu::FLOAT);
|
||||
assert_eq!(approx(results, 3), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_add_f32() {
|
||||
let left = vec![1.0f32, 2.0, 3.0];
|
||||
@ -622,7 +641,7 @@ fn cos_f16() {
|
||||
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
|
||||
}
|
||||
|
||||
fn run_reduce<T, U: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<U> {
|
||||
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
@ -630,10 +649,10 @@ fn run_reduce<T, U: Clone>(v: &[T], out_length: usize, name: &'static str) -> Ve
|
||||
let input = new_buffer(&device, v);
|
||||
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let output = device.new_buffer((out_length * core::mem::size_of::<U>()) as u64, options);
|
||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||
let dims = vec![v.len()];
|
||||
let strides = vec![1];
|
||||
match call_reduce_strided(
|
||||
call_reduce_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
@ -644,13 +663,8 @@ fn run_reduce<T, U: Clone>(v: &[T], out_length: usize, name: &'static str) -> Ve
|
||||
&input,
|
||||
0,
|
||||
&output,
|
||||
) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
println!("Error: {}", e);
|
||||
panic!();
|
||||
}
|
||||
}
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
@ -682,114 +696,22 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
|
||||
read_to_vec(&output, v.len())
|
||||
}
|
||||
|
||||
const fn create_array<const N: usize>() -> [f32; N] {
|
||||
let mut array: [f32; N] = [0.0; N];
|
||||
let mut i = 1;
|
||||
while i <= N {
|
||||
array[i - 1] = i as f32;
|
||||
i += 1;
|
||||
}
|
||||
array
|
||||
}
|
||||
|
||||
const fn correct_sum<const N: usize, const D: usize>() -> [f32; D] {
|
||||
let mut sum = 0;
|
||||
let mut results: [f32; D] = [0.0; D];
|
||||
let mut i = 1;
|
||||
let mut j = 1;
|
||||
while i <= N {
|
||||
sum += i;
|
||||
i += 1;
|
||||
if i > j * N / D {
|
||||
results[j - 1] = sum as f32;
|
||||
j += 1;
|
||||
sum = 0;
|
||||
}
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
fn correct_argmax<const N: usize, const D: usize>(arr: [f32; N]) -> [u32; D] {
|
||||
let mut max = 0.0;
|
||||
let mut max_index: u32 = 0;
|
||||
let mut results: [u32; D] = [0; D];
|
||||
let mut i = 0;
|
||||
let mut j = 1;
|
||||
while i <= N {
|
||||
if i >= (j * N / D) {
|
||||
results[j - 1] = max_index;
|
||||
max = 0.0;
|
||||
max_index = 0;
|
||||
j += 1;
|
||||
}
|
||||
if i == N {
|
||||
break;
|
||||
}
|
||||
if arr[i] > max {
|
||||
max = arr[i];
|
||||
max_index = i as u32;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
fn reduce_sum_case<const N: usize, const D: usize>() {
|
||||
let v = create_array::<N>();
|
||||
let results = run_reduce(&v, D, "fast_sum_f32_strided");
|
||||
assert_eq!(approx(results, 4), correct_sum::<N, D>());
|
||||
}
|
||||
|
||||
fn reduce_argmax_case<const N: usize, const D: usize>() {
|
||||
let v = create_array::<N>();
|
||||
let results: Vec<u32> = run_reduce(&v, D, "fast_argmax_f32_strided");
|
||||
assert_eq!(results, correct_argmax::<N, D>(v));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_sum() {
|
||||
reduce_sum_case::<6, 1>();
|
||||
reduce_sum_case::<10, 1>();
|
||||
reduce_sum_case::<64, 1>();
|
||||
reduce_sum_case::<128, 1>();
|
||||
reduce_sum_case::<256, 1>();
|
||||
reduce_sum_case::<512, 1>();
|
||||
reduce_sum_case::<1024, 1>();
|
||||
reduce_sum_case::<2048, 1>();
|
||||
reduce_sum_case::<4096, 1>();
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let out_length = 1;
|
||||
|
||||
reduce_sum_case::<6, 2>();
|
||||
reduce_sum_case::<10, 2>();
|
||||
reduce_sum_case::<64, 2>();
|
||||
reduce_sum_case::<128, 2>();
|
||||
reduce_sum_case::<256, 2>();
|
||||
reduce_sum_case::<512, 2>();
|
||||
reduce_sum_case::<1024, 2>();
|
||||
reduce_sum_case::<2048, 2>();
|
||||
reduce_sum_case::<4096, 2>();
|
||||
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
||||
assert_eq!(approx(results, 4), vec![21.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_argmax() {
|
||||
reduce_argmax_case::<6, 1>();
|
||||
reduce_argmax_case::<10, 1>();
|
||||
reduce_argmax_case::<64, 1>();
|
||||
reduce_argmax_case::<128, 1>();
|
||||
reduce_argmax_case::<256, 1>();
|
||||
reduce_argmax_case::<512, 1>();
|
||||
reduce_argmax_case::<1024, 1>();
|
||||
reduce_argmax_case::<2048, 1>();
|
||||
reduce_argmax_case::<4096, 1>();
|
||||
fn reduce_sum2() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let out_length = 2;
|
||||
|
||||
reduce_argmax_case::<6, 2>();
|
||||
reduce_argmax_case::<10, 2>();
|
||||
reduce_argmax_case::<64, 2>();
|
||||
reduce_argmax_case::<128, 2>();
|
||||
reduce_argmax_case::<256, 2>();
|
||||
reduce_argmax_case::<512, 2>();
|
||||
reduce_argmax_case::<1024, 2>();
|
||||
reduce_argmax_case::<2048, 2>();
|
||||
reduce_argmax_case::<4096, 2>();
|
||||
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
||||
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -64,6 +64,9 @@ template <typename T> METAL_FUNC T relu(T in){
|
||||
}
|
||||
return in;
|
||||
}
|
||||
template <typename T> METAL_FUNC T silu(T in){
|
||||
return in / (static_cast<T>(1) + exp(-in));
|
||||
}
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
@ -108,6 +111,7 @@ UNARY_OP(neg)
|
||||
UNARY_OP(exp)
|
||||
UNARY_OP(log)
|
||||
UNARY_OP(gelu)
|
||||
UNARY_OP(silu)
|
||||
UNARY_OP(abs)
|
||||
UNARY_OP(ceil)
|
||||
UNARY_OP(floor)
|
||||
@ -135,6 +139,7 @@ BFLOAT_UNARY_OP(neg)
|
||||
BFLOAT_UNARY_OP(exp)
|
||||
BFLOAT_UNARY_OP(log)
|
||||
BFLOAT_UNARY_OP(gelu)
|
||||
BFLOAT_UNARY_OP(silu)
|
||||
BFLOAT_UNARY_OP(abs)
|
||||
BFLOAT_UNARY_OP(ceil)
|
||||
BFLOAT_UNARY_OP(floor)
|
||||
|
@ -30,7 +30,7 @@ impl super::Module for Activation {
|
||||
Self::Relu => xs.relu(),
|
||||
Self::Relu2 => xs.relu()?.sqr(),
|
||||
Self::Relu6 => xs.clamp(0f32, 6f32),
|
||||
Self::Silu => crate::ops::silu(xs),
|
||||
Self::Silu => xs.silu(),
|
||||
Self::Sigmoid => crate::ops::sigmoid(xs),
|
||||
Self::HardSigmoid => crate::ops::hard_sigmoid(xs),
|
||||
Self::Swiglu => crate::ops::swiglu(xs),
|
||||
|
@ -262,9 +262,19 @@ impl BatchNorm {
|
||||
let target_shape = target_shape.as_slice();
|
||||
|
||||
let x = x
|
||||
.broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)?
|
||||
.broadcast_sub(
|
||||
&self
|
||||
.running_mean
|
||||
.as_detached_tensor()
|
||||
.reshape(target_shape)?,
|
||||
)?
|
||||
.broadcast_div(
|
||||
&(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?,
|
||||
&(self
|
||||
.running_var
|
||||
.as_detached_tensor()
|
||||
.reshape(target_shape)?
|
||||
+ self.eps)?
|
||||
.sqrt()?,
|
||||
)?;
|
||||
|
||||
match &self.weight_and_bias {
|
||||
|
@ -124,7 +124,7 @@ fn set_at_index<D: WithDType, I: Into<i64>>(
|
||||
value: I,
|
||||
offset: usize,
|
||||
depth: usize,
|
||||
v: &mut Vec<D>,
|
||||
v: &mut [D],
|
||||
on_value: D,
|
||||
) -> Result<()> {
|
||||
let value = value.into();
|
||||
|
@ -35,13 +35,12 @@ pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
|
||||
}
|
||||
|
||||
pub fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: Should we have a specialized op for this?
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
xs.silu()
|
||||
}
|
||||
|
||||
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.chunk(2, candle::D::Minus1)?;
|
||||
crate::ops::silu(&xs[0])? * &xs[1]
|
||||
&xs[0].silu()? * &xs[1]
|
||||
}
|
||||
|
||||
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||
|
@ -412,7 +412,16 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
|
||||
}
|
||||
|
||||
impl<'a> VarBuilder<'a> {
|
||||
fn new(backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device) -> Self {
|
||||
/// Initializes a `VarBuilder` using a custom backend.
|
||||
///
|
||||
/// It is preferred to use one of the more specific constructors. This
|
||||
/// constructor is provided to allow downstream users to define their own
|
||||
/// backends.
|
||||
pub fn from_backend(
|
||||
backend: Box<dyn SimpleBackend + 'a>,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
) -> Self {
|
||||
let data = TensorData {
|
||||
backend,
|
||||
dtype,
|
||||
@ -427,13 +436,13 @@ impl<'a> VarBuilder<'a> {
|
||||
|
||||
/// Initializes a `VarBuilder` that uses zeros for any tensor.
|
||||
pub fn zeros(dtype: DType, dev: &Device) -> Self {
|
||||
Self::new(Box::new(Zeros), dtype, dev.clone())
|
||||
Self::from_backend(Box::new(Zeros), dtype, dev.clone())
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a hashtable. An error is
|
||||
/// returned if no tensor is available under the requested path or on shape mismatches.
|
||||
pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, dev: &Device) -> Self {
|
||||
Self::new(Box::new(ts), dtype, dev.clone())
|
||||
Self::from_backend(Box::new(ts), dtype, dev.clone())
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` using a `VarMap`. The requested tensors are created and
|
||||
@ -443,7 +452,7 @@ impl<'a> VarBuilder<'a> {
|
||||
/// Note that it is possible to load the tensor values after model creation using the `load`
|
||||
/// method on `varmap`, this can be used to start model training from an existing checkpoint.
|
||||
pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self {
|
||||
Self::new(Box::new(varmap.clone()), dtype, dev.clone())
|
||||
Self::from_backend(Box::new(varmap.clone()), dtype, dev.clone())
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
|
||||
@ -458,25 +467,25 @@ impl<'a> VarBuilder<'a> {
|
||||
dev: &Device,
|
||||
) -> Result<Self> {
|
||||
let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
|
||||
Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
|
||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` from a binary builder in the safetensor format.
|
||||
pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
|
||||
Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
|
||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
|
||||
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let npz = candle::npy::NpzTensors::new(p)?;
|
||||
Ok(Self::new(Box::new(npz), dtype, dev.clone()))
|
||||
Ok(Self::from_backend(Box::new(npz), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
|
||||
pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let pth = candle::pickle::PthTensors::new(p)?;
|
||||
Ok(Self::new(Box::new(pth), dtype, dev.clone()))
|
||||
let pth = candle::pickle::PthTensors::new(p, None)?;
|
||||
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.3.3"
|
||||
version = "0.4.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.4.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.4.0" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -766,6 +766,16 @@ pub fn simple_eval(
|
||||
let output = input.cumsum(axis as usize)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#flatten
|
||||
"Flatten" => {
|
||||
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(1) as usize;
|
||||
let input = get(&node.input[0])?;
|
||||
let first_part: usize = input.shape().dims().iter().take(axis).product();
|
||||
let end_index = input.shape().dims().iter().product::<usize>();
|
||||
let new_shape = (first_part, end_index / first_part);
|
||||
let output = input.reshape(new_shape)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{Device, Result, Tensor};
|
||||
use candle_onnx::onnx::{GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||
use std::collections::HashMap;
|
||||
|
||||
const INPUT_X: &str = "x";
|
||||
@ -677,6 +677,134 @@ fn test_dropout_operation() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Flatten"
|
||||
#[test]
|
||||
fn test_flatten_operation() -> Result<()> {
|
||||
let mut att_axis = AttributeProto {
|
||||
name: "axis".to_string(),
|
||||
ref_attr_name: "axis".to_string(),
|
||||
i: 0,
|
||||
doc_string: "axis".to_string(),
|
||||
r#type: 2,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Flatten".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![att_axis.clone()],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x = Tensor::from_vec(
|
||||
vec![
|
||||
1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32,
|
||||
],
|
||||
&[2, 2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(results, vec![vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]]);
|
||||
|
||||
att_axis.i = 1;
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Flatten".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![att_axis.clone()],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Below are ops that are implemented but not tested yet
|
||||
|
||||
// "MaxPool"
|
||||
|
@ -88,23 +88,27 @@ class QTensor:
|
||||
Dequantizes the tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def ggml_dtype(self) -> str:
|
||||
"""
|
||||
Gets the tensors quantized dtype.
|
||||
"""
|
||||
pass
|
||||
|
||||
def matmul_t(self, lhs: Tensor) -> Tensor:
|
||||
"""
|
||||
Performs a quantized matrix multiplication, with the quantized tensor as the right hand side.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
"""
|
||||
Gets the rank of the tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int]:
|
||||
"""
|
||||
@ -119,178 +123,213 @@ class Tensor:
|
||||
|
||||
def __init__(self, data: _ArrayLike):
|
||||
pass
|
||||
|
||||
def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Add a scalar to a tensor or two tensors together.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
|
||||
"""
|
||||
Return a slice of a tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Multiply a tensor by a scalar or one tensor by another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Add a scalar to a tensor or two tensors together.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Multiply a tensor by a scalar or one tensor by another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Subtract a scalar from a tensor or one tensor from another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Divide a tensor by a scalar or one tensor by another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def abs(self) -> Tensor:
|
||||
"""
|
||||
Performs the `abs` operation on the tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def argmax_keepdim(self, dim: int) -> Tensor:
|
||||
"""
|
||||
Returns the indices of the maximum value(s) across the selected dimension.
|
||||
"""
|
||||
pass
|
||||
|
||||
def argmin_keepdim(self, dim: int) -> Tensor:
|
||||
"""
|
||||
Returns the indices of the minimum value(s) across the selected dimension.
|
||||
"""
|
||||
pass
|
||||
|
||||
def broadcast_add(self, rhs: Tensor) -> Tensor:
|
||||
"""
|
||||
Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def broadcast_as(self, *shape: Shape) -> Tensor:
|
||||
"""
|
||||
Broadcasts the tensor to the given shape.
|
||||
"""
|
||||
pass
|
||||
|
||||
def broadcast_div(self, rhs: Tensor) -> Tensor:
|
||||
"""
|
||||
Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def broadcast_left(self, *shape: Shape) -> Tensor:
|
||||
"""
|
||||
Broadcasts the tensor to the given shape, adding new dimensions on the left.
|
||||
"""
|
||||
pass
|
||||
|
||||
def broadcast_mul(self, rhs: Tensor) -> Tensor:
|
||||
"""
|
||||
Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def broadcast_sub(self, rhs: Tensor) -> Tensor:
|
||||
"""
|
||||
Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def contiguous(self) -> Tensor:
|
||||
"""
|
||||
Makes the tensor contiguous in memory.
|
||||
"""
|
||||
pass
|
||||
|
||||
def copy(self) -> Tensor:
|
||||
"""
|
||||
Returns a copy of the tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def cos(self) -> Tensor:
|
||||
"""
|
||||
Performs the `cos` operation on the tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def detach(self) -> Tensor:
|
||||
"""
|
||||
Detach the tensor from the computation graph.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def device(self) -> Device:
|
||||
"""
|
||||
Gets the tensor's device.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType:
|
||||
"""
|
||||
Gets the tensor's dtype.
|
||||
"""
|
||||
pass
|
||||
|
||||
def exp(self) -> Tensor:
|
||||
"""
|
||||
Performs the `exp` operation on the tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def flatten_all(self) -> Tensor:
|
||||
"""
|
||||
Flattens the tensor into a 1D tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def flatten_from(self, dim: int) -> Tensor:
|
||||
"""
|
||||
Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
|
||||
"""
|
||||
pass
|
||||
|
||||
def flatten_to(self, dim: int) -> Tensor:
|
||||
"""
|
||||
Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
|
||||
"""
|
||||
pass
|
||||
|
||||
def get(self, index: int) -> Tensor:
|
||||
"""
|
||||
Gets the value at the specified index.
|
||||
"""
|
||||
pass
|
||||
|
||||
def index_select(self, rhs: Tensor, dim: int) -> Tensor:
|
||||
"""
|
||||
Select values for the input tensor at the target indexes across the specified dimension.
|
||||
@ -302,161 +341,192 @@ class Tensor:
|
||||
tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_contiguous(self) -> bool:
|
||||
"""
|
||||
Returns true if the tensor is contiguous in C order.
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_fortran_contiguous(self) -> bool:
|
||||
"""
|
||||
Returns true if the tensor is contiguous in Fortran order.
|
||||
"""
|
||||
pass
|
||||
|
||||
def log(self) -> Tensor:
|
||||
"""
|
||||
Performs the `log` operation on the tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def matmul(self, rhs: Tensor) -> Tensor:
|
||||
"""
|
||||
Performs a matrix multiplication between the two tensors.
|
||||
"""
|
||||
pass
|
||||
|
||||
def max_keepdim(self, dim: int) -> Tensor:
|
||||
"""
|
||||
Gathers the maximum value across the selected dimension.
|
||||
"""
|
||||
pass
|
||||
|
||||
def mean_all(self) -> Tensor:
|
||||
"""
|
||||
Returns the mean of the tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def min_keepdim(self, dim: int) -> Tensor:
|
||||
"""
|
||||
Gathers the minimum value across the selected dimension.
|
||||
"""
|
||||
pass
|
||||
|
||||
def narrow(self, dim: int, start: int, len: int) -> Tensor:
|
||||
"""
|
||||
Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||
ranges from `start` to `start + len`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def nelement(self) -> int:
|
||||
"""
|
||||
Gets the tensor's element count.
|
||||
"""
|
||||
pass
|
||||
|
||||
def powf(self, p: float) -> Tensor:
|
||||
"""
|
||||
Performs the `pow` operation on the tensor with the given exponent.
|
||||
"""
|
||||
pass
|
||||
|
||||
def quantize(self, quantized_dtype: str) -> QTensor:
|
||||
"""
|
||||
Quantize the tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
"""
|
||||
Gets the tensor's rank.
|
||||
"""
|
||||
pass
|
||||
|
||||
def recip(self) -> Tensor:
|
||||
"""
|
||||
Get the `recip` of the tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def reshape(self, *shape: Shape) -> Tensor:
|
||||
"""
|
||||
Reshapes the tensor to the given shape.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int]:
|
||||
"""
|
||||
Gets the tensor's shape.
|
||||
"""
|
||||
pass
|
||||
|
||||
def sin(self) -> Tensor:
|
||||
"""
|
||||
Performs the `sin` operation on the tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def sqr(self) -> Tensor:
|
||||
"""
|
||||
Squares the tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def sqrt(self) -> Tensor:
|
||||
"""
|
||||
Calculates the square root of the tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def squeeze(self, dim: int) -> Tensor:
|
||||
"""
|
||||
Creates a new tensor with the specified dimension removed if its size was one.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def stride(self) -> Tuple[int]:
|
||||
"""
|
||||
Gets the tensor's strides.
|
||||
"""
|
||||
pass
|
||||
|
||||
def sum_all(self) -> Tensor:
|
||||
"""
|
||||
Returns the sum of the tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def sum_keepdim(self, dim: Union[int, List[int]]) -> Tensor:
|
||||
"""
|
||||
Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions.
|
||||
"""
|
||||
pass
|
||||
|
||||
def t(self) -> Tensor:
|
||||
"""
|
||||
Transposes the tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def to(self, *args, **kwargs) -> Tensor:
|
||||
"""
|
||||
Performs Tensor dtype and/or device conversion.
|
||||
"""
|
||||
pass
|
||||
|
||||
def to_device(self, device: Union[str, Device]) -> Tensor:
|
||||
"""
|
||||
Move the tensor to a new device.
|
||||
"""
|
||||
pass
|
||||
|
||||
def to_dtype(self, dtype: Union[str, DType]) -> Tensor:
|
||||
"""
|
||||
Convert the tensor to a new dtype.
|
||||
"""
|
||||
pass
|
||||
|
||||
def to_torch(self) -> torch.Tensor:
|
||||
"""
|
||||
Converts candle's tensor to pytorch's tensor
|
||||
"""
|
||||
pass
|
||||
|
||||
def transpose(self, dim1: int, dim2: int) -> Tensor:
|
||||
"""
|
||||
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
|
||||
"""
|
||||
pass
|
||||
|
||||
def unsqueeze(self, dim: int) -> Tensor:
|
||||
"""
|
||||
Creates a new tensor with a dimension of size one inserted at the specified position.
|
||||
"""
|
||||
pass
|
||||
|
||||
def values(self) -> _ArrayLike:
|
||||
"""
|
||||
Gets the tensor's data as a Python scalar or array-like object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def where_cond(self, on_true: Tensor, on_false: Tensor) -> Tensor:
|
||||
"""
|
||||
Returns a tensor with the same shape as the input tensor, the values are taken from
|
||||
|
@ -57,12 +57,10 @@ class Sequential(Module):
|
||||
_modules: Dict[str, Module] # type: ignore[assignment]
|
||||
|
||||
@overload
|
||||
def __init__(self, *args: Module) -> None:
|
||||
...
|
||||
def __init__(self, *args: Module) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__(self, arg: "OrderedDict[str, Module]") -> None:
|
||||
...
|
||||
def __init__(self, arg: "OrderedDict[str, Module]") -> None: ...
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__()
|
||||
|
@ -204,12 +204,10 @@ class Module:
|
||||
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
|
||||
|
||||
@overload
|
||||
def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination:
|
||||
...
|
||||
def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ...
|
||||
|
||||
@overload
|
||||
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]:
|
||||
...
|
||||
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: ...
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
r"""Returns a dictionary containing references to the whole state of the module.
|
||||
@ -586,12 +584,10 @@ class Module:
|
||||
self: T,
|
||||
device: str = ...,
|
||||
dtype: Optional[Union[DType, str]] = ...,
|
||||
) -> T:
|
||||
...
|
||||
) -> T: ...
|
||||
|
||||
@overload
|
||||
def to(self: T, dtype: Union[DType, str]) -> T:
|
||||
...
|
||||
def to(self: T, dtype: Union[DType, str]) -> T: ...
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
r"""Moves and/or casts the parameters and buffers.
|
||||
|
@ -14,6 +14,7 @@ class LayerNorm(Module):
|
||||
math::
|
||||
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
||||
"""
|
||||
|
||||
__constants__ = ["normalized_shape", "eps"]
|
||||
normalized_shape: Tuple[int, ...]
|
||||
eps: float
|
||||
|
@ -11,59 +11,69 @@ class ONNXModel:
|
||||
|
||||
def __init__(self, path: str):
|
||||
pass
|
||||
|
||||
@property
|
||||
def doc_string(self) -> str:
|
||||
"""
|
||||
The doc string of the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def domain(self) -> str:
|
||||
"""
|
||||
The domain of the operator set of the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
def initializers(self) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Get the weights of the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
|
||||
"""
|
||||
The inputs of the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def ir_version(self) -> int:
|
||||
"""
|
||||
The version of the IR this model targets.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def model_version(self) -> int:
|
||||
"""
|
||||
The version of the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
|
||||
"""
|
||||
The outputs of the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def producer_name(self) -> str:
|
||||
"""
|
||||
The producer of the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def producer_version(self) -> str:
|
||||
"""
|
||||
The version of the producer of the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Run the model on the given inputs.
|
||||
@ -81,6 +91,7 @@ class ONNXTensorDescription:
|
||||
The data type of the tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[Union[int, str, Any]]:
|
||||
"""
|
||||
|
@ -938,8 +938,8 @@ impl PyTensor {
|
||||
|
||||
/// Detach the tensor from the computation graph.
|
||||
/// &RETURNS&: Tensor
|
||||
fn detach(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.detach().map_err(wrap_err)?))
|
||||
fn detach(&self) -> Self {
|
||||
PyTensor(self.0.detach())
|
||||
}
|
||||
|
||||
/// Returns a copy of the tensor.
|
||||
|
@ -189,7 +189,6 @@ def do_black(content, is_pyi):
|
||||
line_length=119,
|
||||
is_pyi=is_pyi,
|
||||
string_normalization=True,
|
||||
experimental_string_processing=False,
|
||||
)
|
||||
try:
|
||||
return black.format_file_contents(content, fast=True, mode=mode)
|
||||
|
@ -23,7 +23,6 @@ serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde_plain = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
593
candle-transformers/src/models/chatglm.rs
Normal file
593
candle-transformers/src/models/chatglm.rs
Normal file
@ -0,0 +1,593 @@
|
||||
use crate::models::with_tracing::Linear;
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub num_layers: usize,
|
||||
pub padded_vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub ffn_hidden_size: usize,
|
||||
pub kv_channels: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub seq_length: usize,
|
||||
pub layernorm_epsilon: f64,
|
||||
pub rmsnorm: bool,
|
||||
pub apply_residual_connection_post_layernorm: bool,
|
||||
pub post_layer_norm: bool,
|
||||
pub add_bias_linear: bool,
|
||||
pub add_qkv_bias: bool,
|
||||
pub bias_dropout_fusion: bool,
|
||||
pub multi_query_attention: bool,
|
||||
pub multi_query_group_num: usize,
|
||||
pub apply_query_key_layer_scaling: bool,
|
||||
pub attention_softmax_in_fp32: bool,
|
||||
pub fp32_residual_connection: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn glm3_6b() -> Self {
|
||||
Self {
|
||||
num_layers: 28,
|
||||
padded_vocab_size: 65024,
|
||||
hidden_size: 4096,
|
||||
ffn_hidden_size: 13696,
|
||||
kv_channels: 128,
|
||||
num_attention_heads: 32,
|
||||
seq_length: 8192,
|
||||
layernorm_epsilon: 1e-5,
|
||||
rmsnorm: true,
|
||||
apply_residual_connection_post_layernorm: false,
|
||||
post_layer_norm: true,
|
||||
add_bias_linear: false,
|
||||
add_qkv_bias: true,
|
||||
bias_dropout_fusion: true,
|
||||
multi_query_attention: true,
|
||||
multi_query_group_num: 2,
|
||||
apply_query_key_layer_scaling: true,
|
||||
attention_softmax_in_fp32: true,
|
||||
fp32_residual_connection: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn linear(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
if bias {
|
||||
crate::models::with_tracing::linear(in_dim, out_dim, vb)
|
||||
} else {
|
||||
crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
cache: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let rotary_dim = cfg.kv_channels;
|
||||
let n_elem = rotary_dim / 2;
|
||||
let inv_freq: Vec<_> = (0..n_elem)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10_000f64.powf(i as f64 / n_elem as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((cfg.seq_length, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?;
|
||||
Ok(Self { cache })
|
||||
}
|
||||
|
||||
fn apply(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (seqlen, _b, np, _hn) = xs.dims4()?;
|
||||
let cache = self.cache.narrow(0, seqlen_offset, seqlen)?;
|
||||
let rot_dim = cache.dim(D::Minus2)? * 2;
|
||||
let (xs, xs_pass) = (
|
||||
xs.narrow(D::Minus1, 0, rot_dim)?,
|
||||
xs.narrow(D::Minus1, rot_dim, rot_dim)?,
|
||||
);
|
||||
let xshaped = xs.reshape((seqlen, (), np, rot_dim / 2, 2))?;
|
||||
let cache = cache.reshape((seqlen, (), 1, rot_dim / 2, 2))?;
|
||||
let (xshaped0, xshaped1) = (
|
||||
xshaped.i((.., .., .., .., 0))?,
|
||||
xshaped.i((.., .., .., .., 1))?,
|
||||
);
|
||||
let (cache0, cache1) = (cache.i((.., .., .., .., 0))?, cache.i((.., .., .., .., 1))?);
|
||||
let xs_out = Tensor::stack(
|
||||
&[
|
||||
(xshaped0.broadcast_mul(&cache0)? - xshaped1.broadcast_mul(&cache1)?)?,
|
||||
(xshaped1.broadcast_mul(&cache0)? + xshaped0.broadcast_mul(&cache1)?)?,
|
||||
],
|
||||
D::Minus1,
|
||||
)?;
|
||||
let xs_out = xs_out.flatten_from(3)?;
|
||||
Tensor::cat(&[xs_out, xs_pass], D::Minus1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CoreAttention {
|
||||
coeff: Option<f64>,
|
||||
norm_factor: f64,
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
impl CoreAttention {
|
||||
fn new(layer_number: usize, cfg: &Config) -> Result<Self> {
|
||||
let norm_factor = (cfg.kv_channels as f64).sqrt();
|
||||
let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling {
|
||||
let coeff = f64::max(1.0, layer_number as f64);
|
||||
(norm_factor * coeff, Some(coeff))
|
||||
} else {
|
||||
(norm_factor, None)
|
||||
};
|
||||
Ok(Self { coeff, norm_factor })
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
query_layer: &Tensor,
|
||||
key_layer: &Tensor,
|
||||
value_layer: &Tensor,
|
||||
attention_mask: &Option<Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let output_size = (
|
||||
query_layer.dim(1)?, // b
|
||||
query_layer.dim(2)?, // np
|
||||
query_layer.dim(0)?, // sq
|
||||
key_layer.dim(0)?, // sk
|
||||
);
|
||||
let query_layer =
|
||||
query_layer.reshape((output_size.2, output_size.0 * output_size.1, ()))?;
|
||||
let key_layer = key_layer.reshape((output_size.3, output_size.0 * output_size.1, ()))?;
|
||||
let matmul_result = Tensor::matmul(
|
||||
&query_layer.transpose(0, 1)?,
|
||||
&key_layer.transpose(0, 1)?.transpose(1, 2)?,
|
||||
)?;
|
||||
let matmul_result = (matmul_result / self.norm_factor)?.reshape(output_size)?;
|
||||
let matmul_result = match self.coeff {
|
||||
None => matmul_result,
|
||||
Some(coeff) => (matmul_result * coeff)?,
|
||||
};
|
||||
let attention_scores = match attention_mask {
|
||||
Some(mask) => masked_fill(
|
||||
&matmul_result,
|
||||
&mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
|
||||
f32::NEG_INFINITY,
|
||||
)?,
|
||||
None => matmul_result,
|
||||
};
|
||||
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
|
||||
|
||||
let output_size = (
|
||||
value_layer.dim(1)?,
|
||||
value_layer.dim(2)?,
|
||||
query_layer.dim(0)?,
|
||||
value_layer.dim(3)?,
|
||||
);
|
||||
let value_layer =
|
||||
value_layer.reshape((value_layer.dim(0)?, output_size.0 * output_size.1, ()))?;
|
||||
let attention_probs =
|
||||
attention_probs.reshape((output_size.0 * output_size.1, output_size.2, ()))?;
|
||||
let context_layer = Tensor::matmul(&attention_probs, &value_layer.transpose(0, 1)?)?;
|
||||
let context_layer = context_layer.reshape(output_size)?;
|
||||
let context_layer = context_layer.permute((2, 0, 1, 3))?.contiguous()?;
|
||||
context_layer.flatten_from(D::Minus2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SelfAttention {
|
||||
query_key_value: Linear,
|
||||
core_attention: CoreAttention,
|
||||
dense: Linear,
|
||||
multi_query_attention: bool,
|
||||
num_attention_heads_per_partition: usize,
|
||||
num_multi_query_groups_per_partition: usize,
|
||||
hidden_size_per_attention_head: usize,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl SelfAttention {
|
||||
fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let projection_size = cfg.kv_channels * cfg.num_attention_heads;
|
||||
let hidden_size_per_attention_head = projection_size / cfg.num_attention_heads;
|
||||
let qkv_hidden_size = if cfg.multi_query_attention {
|
||||
projection_size + 2 * hidden_size_per_attention_head * cfg.multi_query_group_num
|
||||
} else {
|
||||
3 * projection_size
|
||||
};
|
||||
let query_key_value = linear(
|
||||
cfg.hidden_size,
|
||||
qkv_hidden_size,
|
||||
cfg.add_bias_linear || cfg.add_qkv_bias,
|
||||
vb.pp("query_key_value"),
|
||||
)?;
|
||||
let core_attention = CoreAttention::new(layer_number, cfg)?;
|
||||
let dense = linear(
|
||||
cfg.hidden_size,
|
||||
cfg.hidden_size,
|
||||
cfg.add_bias_linear,
|
||||
vb.pp("dense"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
query_key_value,
|
||||
core_attention,
|
||||
dense,
|
||||
multi_query_attention: cfg.multi_query_attention,
|
||||
num_attention_heads_per_partition: cfg.num_attention_heads,
|
||||
num_multi_query_groups_per_partition: cfg.multi_query_group_num,
|
||||
hidden_size_per_attention_head: cfg.kv_channels,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: &Option<Tensor>,
|
||||
rotary_emb: &RotaryEmbedding,
|
||||
) -> Result<Tensor> {
|
||||
let mixed_x_layer = xs.apply(&self.query_key_value)?;
|
||||
if !self.multi_query_attention {
|
||||
candle::bail!("only multi_query_attention=true is supported")
|
||||
}
|
||||
let hpa = self.hidden_size_per_attention_head;
|
||||
let query_layer =
|
||||
mixed_x_layer.narrow(D::Minus1, 0, self.num_attention_heads_per_partition * hpa)?;
|
||||
let key_layer = mixed_x_layer.narrow(
|
||||
D::Minus1,
|
||||
self.num_attention_heads_per_partition * hpa,
|
||||
self.num_multi_query_groups_per_partition * hpa,
|
||||
)?;
|
||||
let value_layer = mixed_x_layer.narrow(
|
||||
D::Minus1,
|
||||
self.num_attention_heads_per_partition * hpa
|
||||
+ self.num_multi_query_groups_per_partition * hpa,
|
||||
self.num_multi_query_groups_per_partition * hpa,
|
||||
)?;
|
||||
let query_layer = query_layer.reshape((
|
||||
query_layer.dim(0)?,
|
||||
query_layer.dim(1)?,
|
||||
self.num_attention_heads_per_partition,
|
||||
hpa,
|
||||
))?;
|
||||
let key_layer = key_layer.reshape((
|
||||
key_layer.dim(0)?,
|
||||
key_layer.dim(1)?,
|
||||
self.num_multi_query_groups_per_partition,
|
||||
hpa,
|
||||
))?;
|
||||
let value_layer = value_layer.reshape((
|
||||
value_layer.dim(0)?,
|
||||
value_layer.dim(1)?,
|
||||
self.num_multi_query_groups_per_partition,
|
||||
hpa,
|
||||
))?;
|
||||
|
||||
// Rotary embeddings.
|
||||
let seqlen_offset = match &self.kv_cache {
|
||||
None => 0,
|
||||
Some((prev_k, _)) => prev_k.dim(0)?,
|
||||
};
|
||||
let query_layer = rotary_emb.apply(&query_layer, seqlen_offset)?;
|
||||
let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?;
|
||||
|
||||
// KV cache.
|
||||
let (key_layer, value_layer) = match &self.kv_cache {
|
||||
None => (key_layer, value_layer),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let k = Tensor::cat(&[prev_k, &key_layer], 0)?;
|
||||
let v = Tensor::cat(&[prev_v, &value_layer], 0)?;
|
||||
(k, v)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_layer.clone(), value_layer.clone()));
|
||||
|
||||
// Repeat KV.
|
||||
let ratio =
|
||||
self.num_attention_heads_per_partition / self.num_multi_query_groups_per_partition;
|
||||
let key_layer = {
|
||||
let (d0, d1, d2, d3) = key_layer.dims4()?;
|
||||
key_layer
|
||||
.unsqueeze(D::Minus2)?
|
||||
.expand((d0, d1, d2, ratio, d3))?
|
||||
.reshape((
|
||||
d0,
|
||||
d1,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
))?
|
||||
};
|
||||
let value_layer = {
|
||||
let (d0, d1, d2, d3) = value_layer.dims4()?;
|
||||
value_layer
|
||||
.unsqueeze(D::Minus2)?
|
||||
.expand((d0, d1, d2, ratio, d3))?
|
||||
.reshape((
|
||||
d0,
|
||||
d1,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
))?
|
||||
};
|
||||
|
||||
let context_layer =
|
||||
self.core_attention
|
||||
.forward(&query_layer, &key_layer, &value_layer, attention_mask)?;
|
||||
let output = context_layer.apply(&self.dense)?;
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct MLP {
|
||||
dense_h_to_4h: Linear,
|
||||
dense_4h_to_h: Linear,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense_h_to_4h = linear(
|
||||
cfg.hidden_size,
|
||||
cfg.ffn_hidden_size * 2,
|
||||
cfg.add_bias_linear,
|
||||
vb.pp("dense_h_to_4h"),
|
||||
)?;
|
||||
let dense_4h_to_h = linear(
|
||||
cfg.ffn_hidden_size,
|
||||
cfg.hidden_size,
|
||||
cfg.add_bias_linear,
|
||||
vb.pp("dense_4h_to_h"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
dense_4h_to_h,
|
||||
dense_h_to_4h,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.dense_h_to_4h)?
|
||||
.apply(&candle_nn::Activation::Swiglu)?
|
||||
.apply(&self.dense_4h_to_h)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Block {
|
||||
input_layernorm: candle_nn::LayerNorm,
|
||||
self_attention: SelfAttention,
|
||||
post_attention_layernorm: candle_nn::LayerNorm,
|
||||
mlp: MLP,
|
||||
apply_residual_connection_post_layernorm: bool,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let input_layernorm = if cfg.rmsnorm {
|
||||
candle_nn::rms_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layernorm_epsilon,
|
||||
vb.pp("input_layernorm"),
|
||||
)?
|
||||
.into_inner()
|
||||
} else {
|
||||
candle_nn::layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layernorm_epsilon,
|
||||
vb.pp("input_layernorm"),
|
||||
)?
|
||||
};
|
||||
let post_attention_layernorm = if cfg.rmsnorm {
|
||||
candle_nn::rms_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layernorm_epsilon,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?
|
||||
.into_inner()
|
||||
} else {
|
||||
candle_nn::layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layernorm_epsilon,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?
|
||||
};
|
||||
let self_attention = SelfAttention::new(layer_number, cfg, vb.pp("self_attention"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
Ok(Self {
|
||||
input_layernorm,
|
||||
self_attention,
|
||||
post_attention_layernorm,
|
||||
mlp,
|
||||
apply_residual_connection_post_layernorm: cfg.apply_residual_connection_post_layernorm,
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.self_attention.reset_kv_cache()
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: &Option<Tensor>,
|
||||
rotary_emb: &RotaryEmbedding,
|
||||
) -> Result<Tensor> {
|
||||
let layernorm_output = xs.apply(&self.input_layernorm)?;
|
||||
let attention_output =
|
||||
self.self_attention
|
||||
.forward(&layernorm_output, attention_mask, rotary_emb)?;
|
||||
let residual = if self.apply_residual_connection_post_layernorm {
|
||||
&layernorm_output
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let layernorm_input = (residual + attention_output)?;
|
||||
let layernorm_output = layernorm_input.apply(&self.post_attention_layernorm)?;
|
||||
let mlp_output = layernorm_output.apply(&self.mlp)?;
|
||||
let residual = if self.apply_residual_connection_post_layernorm {
|
||||
&layernorm_output
|
||||
} else {
|
||||
&layernorm_input
|
||||
};
|
||||
mlp_output + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Transformer {
|
||||
layers: Vec<Block>,
|
||||
final_layernorm: Option<candle_nn::LayerNorm>,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
}
|
||||
|
||||
impl Transformer {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_l = vb.pp("layers");
|
||||
let mut layers = Vec::with_capacity(cfg.num_layers);
|
||||
for layer_index in 0..cfg.num_layers {
|
||||
let block = Block::new(layer_index + 1, cfg, vb_l.pp(layer_index))?;
|
||||
layers.push(block)
|
||||
}
|
||||
let final_layernorm = if cfg.post_layer_norm {
|
||||
let ln = if cfg.rmsnorm {
|
||||
candle_nn::rms_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layernorm_epsilon,
|
||||
vb.pp("final_layernorm"),
|
||||
)?
|
||||
.into_inner()
|
||||
} else {
|
||||
candle_nn::layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layernorm_epsilon,
|
||||
vb.pp("final_layernorm"),
|
||||
)?
|
||||
};
|
||||
Some(ln)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let rotary_emb = RotaryEmbedding::new(cfg, vb.dtype(), vb.device())?;
|
||||
Ok(Self {
|
||||
layers,
|
||||
final_layernorm,
|
||||
rotary_emb,
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
for block in self.layers.iter_mut() {
|
||||
block.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, attention_mask: &Option<Tensor>) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for block in self.layers.iter_mut() {
|
||||
xs = block.forward(&xs, attention_mask, &self.rotary_emb)?
|
||||
}
|
||||
match self.final_layernorm.as_ref() {
|
||||
None => Ok(xs),
|
||||
Some(ln) => xs.apply(ln),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Embedding {
|
||||
word_embeddings: candle_nn::Embedding,
|
||||
fp32_residual_connection: bool,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let word_embeddings = candle_nn::embedding(
|
||||
cfg.padded_vocab_size,
|
||||
cfg.hidden_size,
|
||||
vb.pp("word_embeddings"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
fp32_residual_connection: cfg.fp32_residual_connection,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Embedding {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.word_embeddings.forward(xs)?.transpose(0, 1)?; // b,s,h -> s,b,h
|
||||
if self.fp32_residual_connection {
|
||||
xs.to_dtype(candle::DType::F32)
|
||||
} else {
|
||||
xs.contiguous()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embedding: Embedding,
|
||||
encoder: Transformer,
|
||||
output_layer: Linear,
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size, size), device)
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = vb.pp("transformer");
|
||||
let embedding = Embedding::new(cfg, vb.pp("embedding"))?;
|
||||
let encoder = Transformer::new(cfg, vb.pp("encoder"))?;
|
||||
let output_layer = linear(
|
||||
cfg.hidden_size,
|
||||
cfg.padded_vocab_size,
|
||||
false,
|
||||
vb.pp("output_layer"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
embedding,
|
||||
encoder,
|
||||
output_layer,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.encoder.reset_kv_cache()
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b_size, seq_len) = xs.dims2()?;
|
||||
let input_embeds = xs.apply(&self.embedding)?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(get_mask(seq_len, xs.device())?)
|
||||
};
|
||||
let xs = self.encoder.forward(&input_embeds, &attention_mask)?;
|
||||
let lm_logits = xs.i(seq_len - 1)?.apply(&self.output_layer)?;
|
||||
Ok(lm_logits)
|
||||
}
|
||||
}
|
339
candle-transformers/src/models/convnext.rs
Normal file
339
candle-transformers/src/models/convnext.rs
Normal file
@ -0,0 +1,339 @@
|
||||
//! ConvNeXt implementation.
|
||||
//!
|
||||
//! See "A ConvNet for the 2020s" Liu et al. 2022
|
||||
//! <https://arxiv.org/abs/2201.03545>
|
||||
//! and
|
||||
//! "ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders" Woo et al. 2023
|
||||
//! <https://arxiv.org/abs/2301.00808>
|
||||
|
||||
//! Original code:
|
||||
//! https://github.com/facebookresearch/ConvNeXt/
|
||||
//! https://github.com/facebookresearch/ConvNeXt-V2/
|
||||
//! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py
|
||||
|
||||
use candle::shape::ShapeWithOneHole;
|
||||
use candle::{Result, D};
|
||||
use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
blocks: [usize; 4],
|
||||
channels: [usize; 4],
|
||||
use_conv_mlp: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn atto() -> Self {
|
||||
Self {
|
||||
blocks: [2, 2, 6, 2],
|
||||
channels: [40, 80, 160, 320],
|
||||
use_conv_mlp: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn femto() -> Self {
|
||||
Self {
|
||||
blocks: [2, 2, 6, 2],
|
||||
channels: [48, 96, 192, 384],
|
||||
use_conv_mlp: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn pico() -> Self {
|
||||
Self {
|
||||
blocks: [2, 2, 6, 2],
|
||||
channels: [64, 128, 256, 512],
|
||||
use_conv_mlp: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn nano() -> Self {
|
||||
Self {
|
||||
blocks: [2, 2, 8, 2],
|
||||
channels: [80, 160, 320, 640],
|
||||
use_conv_mlp: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny() -> Self {
|
||||
Self {
|
||||
blocks: [3, 3, 9, 3],
|
||||
channels: [96, 192, 384, 768],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn small() -> Self {
|
||||
Self {
|
||||
blocks: [3, 3, 27, 3],
|
||||
channels: [96, 192, 384, 768],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn base() -> Self {
|
||||
Self {
|
||||
blocks: [3, 3, 27, 3],
|
||||
channels: [128, 256, 512, 1024],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn large() -> Self {
|
||||
Self {
|
||||
blocks: [3, 3, 27, 3],
|
||||
channels: [192, 384, 768, 1536],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn xlarge() -> Self {
|
||||
Self {
|
||||
blocks: [3, 3, 27, 3],
|
||||
channels: [256, 512, 1024, 2048],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn huge() -> Self {
|
||||
Self {
|
||||
blocks: [3, 3, 27, 3],
|
||||
channels: [352, 704, 1408, 2816],
|
||||
use_conv_mlp: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Layer norm for data in channels-last format.
|
||||
fn layer_norm_cl(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let norm = layer_norm(dim, 1e-6, vb)?;
|
||||
|
||||
Ok(Func::new(move |xs| xs.apply(&norm)))
|
||||
}
|
||||
|
||||
// Layer norm for data in channels-first format.
|
||||
fn layer_norm_cf(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let norm = layer_norm(dim, 1e-6, vb)?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.permute((0, 2, 3, 1))?
|
||||
.apply(&norm)?
|
||||
.permute((0, 3, 1, 2))?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Global response normalization layer
|
||||
// Based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/grn.py
|
||||
fn convnext2_grn(dim: usize, channels_last: bool, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let (shape, spatial_dim, channel_dim) = if channels_last {
|
||||
((1, 1, 1, ()).into_shape(dim)?, [1, 2], 3)
|
||||
} else {
|
||||
((1, (), 1, 1).into_shape(dim)?, [2, 3], 1)
|
||||
};
|
||||
|
||||
let gamma = vb.get(dim, "weight")?.reshape(&shape)?;
|
||||
let beta = vb.get(dim, "bias")?.reshape(&shape)?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs;
|
||||
let gx = xs
|
||||
.sqr()?
|
||||
.sum_keepdim(spatial_dim)?
|
||||
.mean_keepdim(spatial_dim)?
|
||||
.sqrt()?;
|
||||
|
||||
let gxmean = gx.mean_keepdim(channel_dim)?;
|
||||
let nx = gx.broadcast_div(&(gxmean + 1e-6)?)?;
|
||||
let xs = xs
|
||||
.broadcast_mul(&nx)?
|
||||
.broadcast_mul(&gamma)?
|
||||
.broadcast_add(&beta)?;
|
||||
|
||||
xs + residual
|
||||
}))
|
||||
}
|
||||
|
||||
// Initial downsampling via a patchify layer.
|
||||
fn convnext_stem(out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride: 4,
|
||||
..Default::default()
|
||||
};
|
||||
let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?;
|
||||
let norm = layer_norm_cf(out_channels, vb.pp(1))?;
|
||||
|
||||
Ok(Func::new(move |xs| xs.apply(&patchify)?.apply(&norm)))
|
||||
}
|
||||
|
||||
// Downsampling applied after the stages.
|
||||
fn convnext_downsample(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let norm = layer_norm_cf(dim / 2, vb.pp(0))?;
|
||||
let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?;
|
||||
|
||||
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&conv)))
|
||||
}
|
||||
|
||||
// MLP block from the original paper with optional GRN layer (v2 models).
|
||||
fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?;
|
||||
let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?;
|
||||
let grn = convnext2_grn(4 * dim, true, vb.pp("grn"));
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.apply(&fc1)?.gelu_erf()?;
|
||||
if let Ok(g) = &grn {
|
||||
xs = xs.apply(g)?;
|
||||
}
|
||||
xs = xs.apply(&fc2)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// MLP block using pointwise convolutions, with optional GRN layer (v2 models).
|
||||
fn convnext_conv_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
..Default::default()
|
||||
};
|
||||
let fc1 = conv2d(dim, 4 * dim, 1, conv2d_cfg, vb.pp("fc1"))?;
|
||||
let fc2 = conv2d(4 * dim, dim, 1, conv2d_cfg, vb.pp("fc2"))?;
|
||||
|
||||
let grn = convnext2_grn(4 * dim, false, vb.pp("grn"));
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.apply(&fc1)?.gelu_erf()?;
|
||||
if let Ok(g) = &grn {
|
||||
xs = xs.apply(g)?;
|
||||
}
|
||||
xs = xs.apply(&fc2)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// A block consisting of a depthwise convolution, a MLP and layer scaling (v1 models only).
|
||||
fn convnext_block(dim: usize, use_conv_mlp: bool, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
groups: dim,
|
||||
padding: 3,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?;
|
||||
let gamma = vb.get(dim, "gamma");
|
||||
|
||||
let (mlp, norm) = if use_conv_mlp {
|
||||
(
|
||||
convnext_conv_mlp(dim, vb.pp("mlp"))?,
|
||||
layer_norm_cf(dim, vb.pp("norm"))?,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
convnext_mlp(dim, vb.pp("mlp"))?,
|
||||
layer_norm_cl(dim, vb.pp("norm"))?,
|
||||
)
|
||||
};
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs;
|
||||
let mut xs = xs.apply(&conv_dw)?;
|
||||
|
||||
xs = if use_conv_mlp {
|
||||
xs.apply(&norm)?.apply(&mlp)?
|
||||
} else {
|
||||
xs.permute((0, 2, 3, 1))?
|
||||
.apply(&norm)?
|
||||
.apply(&mlp)?
|
||||
.permute((0, 3, 1, 2))?
|
||||
};
|
||||
|
||||
if let Ok(g) = &gamma {
|
||||
xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
|
||||
};
|
||||
|
||||
xs + residual
|
||||
}))
|
||||
}
|
||||
|
||||
// Each stage contains blocks and a downsampling layer for the previous stage.
|
||||
fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let nblocks = cfg.blocks[stage_idx];
|
||||
let mut blocks = Vec::with_capacity(nblocks);
|
||||
|
||||
let dim = cfg.channels[stage_idx];
|
||||
|
||||
if stage_idx > 0 {
|
||||
blocks.push(convnext_downsample(dim, vb.pp("downsample"))?);
|
||||
}
|
||||
|
||||
for block_idx in 0..nblocks {
|
||||
blocks.push(convnext_block(
|
||||
dim,
|
||||
cfg.use_conv_mlp,
|
||||
vb.pp(format!("blocks.{block_idx}")),
|
||||
)?);
|
||||
}
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.clone();
|
||||
for block in blocks.iter() {
|
||||
xs = xs.apply(block)?
|
||||
}
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Classification head.
|
||||
fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let norm = layer_norm_cl(outputs, vb.pp("norm"))?;
|
||||
let linear = linear(outputs, nclasses, vb.pp("fc"))?;
|
||||
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear)))
|
||||
}
|
||||
|
||||
// Build a convnext model for a given configuration.
|
||||
fn convnext_model(
|
||||
config: &Config,
|
||||
nclasses: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let head = match nclasses {
|
||||
None => None,
|
||||
Some(nclasses) => {
|
||||
let head = convnext_head(config.channels[3], nclasses, vb.pp("head"))?;
|
||||
Some(head)
|
||||
}
|
||||
};
|
||||
|
||||
let stem = convnext_stem(config.channels[0], vb.pp("stem"))?;
|
||||
let vb = vb.pp("stages");
|
||||
let stage1 = convnext_stage(config, 0, vb.pp(0))?;
|
||||
let stage2 = convnext_stage(config, 1, vb.pp(1))?;
|
||||
let stage3 = convnext_stage(config, 2, vb.pp(2))?;
|
||||
let stage4 = convnext_stage(config, 3, vb.pp(3))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&stem)?
|
||||
.apply(&stage1)?
|
||||
.apply(&stage2)?
|
||||
.apply(&stage3)?
|
||||
.apply(&stage4)?
|
||||
.mean(D::Minus2)?
|
||||
.mean(D::Minus1)?;
|
||||
match &head {
|
||||
None => Ok(xs),
|
||||
Some(head) => xs.apply(head),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn convnext(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
convnext_model(cfg, Some(nclasses), vb)
|
||||
}
|
||||
|
||||
pub fn convnext_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
convnext_model(cfg, None, vb)
|
||||
}
|
@ -1,13 +1,12 @@
|
||||
use super::with_tracing::{linear_no_bias as linear, Linear};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct LlamaConfig {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
@ -40,6 +39,7 @@ impl LlamaConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
@ -82,7 +82,7 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
pub use_kv_cache: bool,
|
||||
@ -136,6 +136,7 @@ impl Cache {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RmsNorm {
|
||||
inner: candle_nn::RmsNorm,
|
||||
span: tracing::Span,
|
||||
@ -154,6 +155,7 @@ impl RmsNorm {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CausalSelfAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
@ -314,6 +316,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
c_fc1: Linear,
|
||||
c_fc2: Linear,
|
||||
@ -344,6 +347,7 @@ impl Mlp {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Block {
|
||||
rms_1: RmsNorm,
|
||||
attn: CausalSelfAttention,
|
||||
@ -383,6 +387,7 @@ impl Block {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Llama {
|
||||
wte: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
|
211
candle-transformers/src/models/mamba.rs
Normal file
211
candle-transformers/src/models/mamba.rs
Normal file
@ -0,0 +1,211 @@
|
||||
#![allow(unused)]
|
||||
/// A fast implementation of mamba for inference only.
|
||||
/// This is based on: https://github.com/LaurentMazare/mamba.rs
|
||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{RmsNorm, VarBuilder};
|
||||
|
||||
const D_CONV: usize = 4;
|
||||
const D_STATE: usize = 16;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
d_model: usize,
|
||||
n_layer: usize,
|
||||
vocab_size: usize,
|
||||
pad_vocab_size_multiple: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn vocab_size(&self) -> usize {
|
||||
let pad = self.pad_vocab_size_multiple;
|
||||
(self.vocab_size + pad - 1) / pad * pad
|
||||
}
|
||||
|
||||
fn dt_rank(&self) -> usize {
|
||||
(self.d_model + 15) / 16
|
||||
}
|
||||
|
||||
fn d_inner(&self) -> usize {
|
||||
self.d_model * 2
|
||||
}
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
hs: Vec<Tensor>,
|
||||
prev_xs: Vec<[Tensor; D_CONV]>,
|
||||
pos: usize,
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn new(batch_size: usize, cfg: &Config, device: &Device) -> Result<Self> {
|
||||
let mut hs = Vec::with_capacity(cfg.n_layer);
|
||||
let mut prev_xs = Vec::with_capacity(cfg.n_layer);
|
||||
for _i in 0..cfg.n_layer {
|
||||
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), DType::F32, device)?;
|
||||
let x = Tensor::zeros((batch_size, cfg.d_inner()), DType::F32, device)?;
|
||||
hs.push(h);
|
||||
prev_xs.push([x.clone(), x.clone(), x.clone(), x.clone()]);
|
||||
}
|
||||
Ok(Self {
|
||||
hs,
|
||||
prev_xs,
|
||||
pos: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MambaBlock {
|
||||
in_proj: Linear,
|
||||
conv1d_bias: Tensor,
|
||||
conv1d_weights: [Tensor; D_CONV],
|
||||
x_proj: Linear,
|
||||
dt_proj: Linear,
|
||||
a_log: Tensor,
|
||||
d: Tensor,
|
||||
out_proj: Linear,
|
||||
dt_rank: usize,
|
||||
layer_index: usize,
|
||||
d_inner: usize,
|
||||
}
|
||||
|
||||
impl MambaBlock {
|
||||
pub fn new(layer_index: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let d_inner = cfg.d_inner();
|
||||
let dt_rank = cfg.dt_rank();
|
||||
let in_proj = linear_no_bias(cfg.d_model, d_inner * 2, vb.pp("in_proj"))?;
|
||||
let x_proj = linear_no_bias(d_inner, dt_rank + D_STATE * 2, vb.pp("x_proj"))?;
|
||||
let dt_proj = linear(dt_rank, d_inner, vb.pp("dt_proj"))?;
|
||||
let a_log = vb.get((d_inner, D_STATE), "A_log")?;
|
||||
let d = vb.get(d_inner, "D")?;
|
||||
let out_proj = linear_no_bias(d_inner, cfg.d_model, vb.pp("out_proj"))?;
|
||||
let conv1d_bias = vb.get(d_inner, "conv1d.bias")?;
|
||||
let conv1d_weight = vb.get((d_inner, 1, D_CONV), "conv1d.weight")?;
|
||||
let conv1d_weights = [
|
||||
conv1d_weight.i((.., 0, 0))?,
|
||||
conv1d_weight.i((.., 0, 1))?,
|
||||
conv1d_weight.i((.., 0, 2))?,
|
||||
conv1d_weight.i((.., 0, 3))?,
|
||||
];
|
||||
Ok(Self {
|
||||
in_proj,
|
||||
conv1d_bias,
|
||||
conv1d_weights,
|
||||
x_proj,
|
||||
dt_proj,
|
||||
a_log,
|
||||
d,
|
||||
out_proj,
|
||||
dt_rank,
|
||||
layer_index,
|
||||
d_inner,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let (b_sz, _dim) = xs.dims2()?;
|
||||
let li = self.layer_index;
|
||||
let mut xs = xs.apply(&self.in_proj)?.chunk(2, D::Minus1)?;
|
||||
let proj_for_silu = xs.remove(1);
|
||||
state.prev_xs[li][state.pos % D_CONV] = xs.remove(0);
|
||||
let mut proj_for_conv = self.conv1d_bias.broadcast_as((b_sz, self.d_inner))?;
|
||||
for d_c in 0..D_CONV {
|
||||
proj_for_conv = (proj_for_conv
|
||||
+ self.conv1d_weights[d_c]
|
||||
.broadcast_mul(&state.prev_xs[li][(d_c + 1 + state.pos) % D_CONV])?)?;
|
||||
}
|
||||
let proj_for_conv = candle_nn::ops::silu(&proj_for_conv)?;
|
||||
// SSM + Selection, we're doing inference here so only need the last step of
|
||||
// the sequence.
|
||||
// Algorithm 3.2 on page 6, https://arxiv.org/pdf/2312.00752.pdf
|
||||
|
||||
let x_proj = self.x_proj.forward(&proj_for_conv)?;
|
||||
let delta = x_proj.narrow(D::Minus1, 0, self.dt_rank)?;
|
||||
let b = x_proj.narrow(D::Minus1, self.dt_rank, D_STATE)?;
|
||||
let c = x_proj.narrow(D::Minus1, self.dt_rank + D_STATE, D_STATE)?;
|
||||
|
||||
let delta = delta.apply(&self.dt_proj)?;
|
||||
// softplus
|
||||
let delta = (delta.exp()? + 1.)?.log()?;
|
||||
let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
|
||||
let d = self.d.to_dtype(candle::DType::F32)?;
|
||||
|
||||
// Selective scan part
|
||||
// Eqn (2a), page 3, h_t = Ab h_{t-1} + Bb x_t
|
||||
let delta = delta
|
||||
.unsqueeze(D::Minus1)?
|
||||
.broadcast_as((b_sz, self.d_inner, D_STATE))?;
|
||||
let a = a.broadcast_as((b_sz, self.d_inner, D_STATE))?;
|
||||
let b = b.broadcast_as((b_sz, self.d_inner, D_STATE))?;
|
||||
let proj_for_conv_b =
|
||||
proj_for_conv
|
||||
.unsqueeze(D::Minus1)?
|
||||
.broadcast_as((b_sz, self.d_inner, D_STATE))?;
|
||||
state.hs[li] = ((&state.hs[li] * (&delta * &a)?.exp()?)? + &delta * &b * &proj_for_conv_b)?;
|
||||
let ss = (state.hs[li]
|
||||
.matmul(&c.unsqueeze(D::Minus1)?)?
|
||||
.squeeze(D::Minus1)?
|
||||
+ proj_for_conv.broadcast_mul(&d)?)?;
|
||||
|
||||
let ys = (ss * candle_nn::ops::silu(&proj_for_silu))?;
|
||||
ys.apply(&self.out_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ResidualBlock {
|
||||
mixer: MambaBlock,
|
||||
norm: RmsNorm,
|
||||
}
|
||||
|
||||
impl ResidualBlock {
|
||||
pub fn new(layer_index: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?;
|
||||
let mixer = MambaBlock::new(layer_index, cfg, vb.pp("mixer"))?;
|
||||
Ok(Self { mixer, norm })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
self.mixer.forward(&xs.apply(&self.norm)?, state)? + xs
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L56
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Model {
|
||||
embedding: candle_nn::Embedding,
|
||||
layers: Vec<ResidualBlock>,
|
||||
norm_f: RmsNorm,
|
||||
lm_head: Linear,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let embedding = candle_nn::embedding(cfg.vocab_size(), cfg.d_model, vb.pp("embedding"))?;
|
||||
let mut layers = Vec::with_capacity(cfg.n_layer);
|
||||
let vb_l = vb.pp("layers");
|
||||
for layer_idx in 0..cfg.n_layer {
|
||||
let layer = ResidualBlock::new(layer_idx, cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
|
||||
let lm_head = Linear::from_weights(embedding.embeddings().clone(), None);
|
||||
Ok(Self {
|
||||
embedding,
|
||||
layers,
|
||||
norm_f,
|
||||
lm_head,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let _b_size = input_ids.dims1()?;
|
||||
let mut xs = self.embedding.forward(input_ids)?;
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, state)?
|
||||
}
|
||||
state.pos += 1;
|
||||
xs.apply(&self.norm_f)?.apply(&self.lm_head)
|
||||
}
|
||||
}
|
@ -8,7 +8,7 @@ use serde::Deserialize;
|
||||
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
// https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py
|
||||
// https://huggingface.co/microsoft/phi-1_5/blob/d38e6f954ec29b96fe2cf033937dad64e279b5d9/configuration_mixformer_sequential.py
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub(crate) vocab_size: usize,
|
||||
|
@ -2,7 +2,9 @@ pub mod bert;
|
||||
pub mod bigcode;
|
||||
pub mod blip;
|
||||
pub mod blip_text;
|
||||
pub mod chatglm;
|
||||
pub mod convmixer;
|
||||
pub mod convnext;
|
||||
pub mod dinov2;
|
||||
pub mod distilbert;
|
||||
pub mod efficientnet;
|
||||
@ -11,6 +13,7 @@ pub mod jina_bert;
|
||||
pub mod llama;
|
||||
pub mod llama2_c;
|
||||
pub mod llama2_c_weights;
|
||||
pub mod mamba;
|
||||
pub mod marian;
|
||||
pub mod mistral;
|
||||
pub mod mixformer;
|
||||
@ -28,8 +31,10 @@ pub mod quantized_mixformer;
|
||||
pub mod quantized_mpt;
|
||||
pub mod quantized_stable_lm;
|
||||
pub mod quantized_t5;
|
||||
pub mod qwen2;
|
||||
pub mod repvgg;
|
||||
pub mod resnet;
|
||||
pub mod rwkv_v5;
|
||||
pub mod segment_anything;
|
||||
pub mod stable_diffusion;
|
||||
pub mod stable_lm;
|
||||
@ -37,6 +42,7 @@ pub mod t5;
|
||||
pub mod trocr;
|
||||
pub mod vgg;
|
||||
pub mod vit;
|
||||
pub mod vocos;
|
||||
pub mod whisper;
|
||||
pub mod with_tracing;
|
||||
pub mod wuerstchen;
|
||||
|
@ -16,7 +16,7 @@ struct RmsNorm {
|
||||
impl RmsNorm {
|
||||
fn new(scale: QTensor, eps: f32) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let scale = scale.dequantize(&Device::Cpu)?;
|
||||
let scale = scale.dequantize(&scale.device())?;
|
||||
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
@ -275,13 +275,17 @@ pub struct ModelWeights {
|
||||
span_output: tracing::Span,
|
||||
}
|
||||
|
||||
fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
|
||||
fn precomput_freqs_cis(
|
||||
head_dim: usize,
|
||||
freq_base: f32,
|
||||
device: &Device,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
@ -292,11 +296,10 @@ fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tenso
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
||||
let cpu = &Device::Cpu;
|
||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||
let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
|
||||
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?;
|
||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
|
||||
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
|
||||
let output = ct.remove("output.weight")?;
|
||||
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
||||
@ -358,7 +361,6 @@ impl ModelWeights {
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let cpu = &Device::Cpu;
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
@ -382,10 +384,10 @@ impl ModelWeights {
|
||||
let rope_freq_base = md_get("llama.rope.freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(10000f32);
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let norm = RmsNorm::new(
|
||||
ct.tensor(reader, "output_norm.weight", device)?,
|
||||
rms_norm_eps,
|
||||
@ -472,14 +474,14 @@ impl ModelWeights {
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&mut self, t: usize) -> Result<Tensor> {
|
||||
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
|
||||
if let Some(mask) = self.masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
|
||||
let mask = Tensor::from_slice(&mask, (t, t), device)?;
|
||||
self.masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
@ -487,7 +489,7 @@ impl ModelWeights {
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
let mask = self.mask(seq_len)?;
|
||||
let mask = self.mask(seq_len, x.device())?;
|
||||
let _enter = self.span.enter();
|
||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::quantized_nn::{layer_norm, linear_no_bias, Embedding, Linear};
|
||||
use crate::quantized_nn::{layer_norm, linear, linear_no_bias, Embedding, Linear};
|
||||
pub use crate::quantized_var_builder::VarBuilder;
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, LayerNorm};
|
||||
@ -67,9 +67,14 @@ impl Attention {
|
||||
let head_dim = cfg.head_dim();
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let linear_layer = if cfg.use_qkv_bias {
|
||||
linear
|
||||
} else {
|
||||
linear_no_bias
|
||||
};
|
||||
let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
|
377
candle-transformers/src/models/qwen2.rs
Normal file
377
candle-transformers/src/models/qwen2.rs
Normal file
@ -0,0 +1,377 @@
|
||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
pub sliding_window: usize,
|
||||
pub max_window_layers: usize,
|
||||
pub tie_word_embeddings: bool,
|
||||
pub rope_theta: f64,
|
||||
pub rms_norm_eps: f64,
|
||||
pub use_sliding_window: bool,
|
||||
pub hidden_act: Activation,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RmsNorm {
|
||||
inner: candle_nn::RmsNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RmsNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
let last_dim = xs.dim(D::Minus1)?;
|
||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: Activation,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let intermediate_sz = cfg.intermediate_size;
|
||||
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
|
||||
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
let rhs = xs.apply(&self.up_proj)?;
|
||||
(lhs * rhs)?.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
hidden_size: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
let head_dim = hidden_sz / num_heads;
|
||||
let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
hidden_size: hidden_sz,
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
let key_states = self.k_proj.forward(xs)?;
|
||||
let value_states = self.v_proj.forward(xs)?;
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let key_states = key_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
attn_weights.matmul(&value_states)?
|
||||
};
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, self.hidden_size))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
mlp: MLP,
|
||||
input_layernorm: RmsNorm,
|
||||
post_attention_layernorm: RmsNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let input_layernorm =
|
||||
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
input_layernorm,
|
||||
post_attention_layernorm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.input_layernorm.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
||||
residual + xs
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
sliding_window: usize,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
sliding_window: cfg.sliding_window,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
// Sliding window mask?
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + self.sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
409
candle-transformers/src/models/rwkv_v5.rs
Normal file
409
candle-transformers/src/models/rwkv_v5.rs
Normal file
@ -0,0 +1,409 @@
|
||||
use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
fn default_num_attention_heads() -> usize {
|
||||
64
|
||||
}
|
||||
|
||||
// https://huggingface.co/RWKV/HF_v5-Eagle-7B/blob/main/configuration_rwkv5.py
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub attention_hidden_size: usize,
|
||||
#[serde(default = "default_num_attention_heads")]
|
||||
pub num_attention_heads: usize,
|
||||
pub head_size: usize,
|
||||
pub intermediate_size: Option<usize>,
|
||||
pub layer_norm_epsilon: f64,
|
||||
pub rescale_every: usize,
|
||||
}
|
||||
|
||||
struct StatePerLayer {
|
||||
extract_key_value: Tensor,
|
||||
linear_attention: Tensor,
|
||||
feed_forward: Tensor,
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
per_layer: Vec<StatePerLayer>,
|
||||
pos: usize,
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn new(batch_size: usize, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let mut per_layer = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
// Certainly a weird convention but taken from modeling_rwkv5.py
|
||||
let num_attention_heads = cfg.hidden_size / cfg.num_attention_heads;
|
||||
for _layer_idx in 0..cfg.num_hidden_layers {
|
||||
let extract_key_value = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?;
|
||||
let linear_attention = Tensor::zeros(
|
||||
(
|
||||
batch_size,
|
||||
num_attention_heads,
|
||||
cfg.hidden_size / num_attention_heads,
|
||||
cfg.hidden_size / num_attention_heads,
|
||||
),
|
||||
DType::F32,
|
||||
dev,
|
||||
)?;
|
||||
let feed_forward = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?;
|
||||
per_layer.push(StatePerLayer {
|
||||
extract_key_value,
|
||||
linear_attention,
|
||||
feed_forward,
|
||||
});
|
||||
}
|
||||
Ok(Self { per_layer, pos: 0 })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SelfAttention {
|
||||
key: Linear,
|
||||
receptance: Linear,
|
||||
value: Linear,
|
||||
gate: Linear,
|
||||
output: Linear,
|
||||
ln_x: candle_nn::GroupNorm,
|
||||
time_mix_key: Tensor,
|
||||
time_mix_value: Tensor,
|
||||
time_mix_receptance: Tensor,
|
||||
time_decay: Tensor,
|
||||
time_faaaa: Tensor,
|
||||
time_mix_gate: Tensor,
|
||||
layer_id: usize,
|
||||
n_attn_heads: usize,
|
||||
}
|
||||
|
||||
impl SelfAttention {
|
||||
pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let attn_hidden_size = cfg.attention_hidden_size;
|
||||
let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
|
||||
let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
|
||||
let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
|
||||
let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
|
||||
let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
|
||||
let ln_x = candle_nn::group_norm(
|
||||
hidden_size / cfg.head_size,
|
||||
hidden_size,
|
||||
1e-5,
|
||||
vb.pp("ln_x"),
|
||||
)?;
|
||||
let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
|
||||
let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?;
|
||||
let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
|
||||
let n_attn_heads = cfg.hidden_size / cfg.head_size;
|
||||
let time_decay = vb.get((n_attn_heads, cfg.head_size), "time_decay")?;
|
||||
let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?;
|
||||
let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?;
|
||||
Ok(Self {
|
||||
key,
|
||||
value,
|
||||
receptance,
|
||||
gate,
|
||||
output,
|
||||
ln_x,
|
||||
time_mix_key,
|
||||
time_mix_value,
|
||||
time_mix_receptance,
|
||||
time_decay,
|
||||
time_faaaa,
|
||||
time_mix_gate,
|
||||
layer_id,
|
||||
n_attn_heads,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let h = self.time_decay.dim(0)?;
|
||||
let (b, t, s) = xs.dims3()?;
|
||||
let s = s / h;
|
||||
let (receptance, key, value, gate) = {
|
||||
// exctract key-value
|
||||
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
|
||||
let shifted = if shifted.rank() == 2 {
|
||||
shifted.unsqueeze(1)?
|
||||
} else {
|
||||
shifted
|
||||
};
|
||||
let key = ((xs * &self.time_mix_key)? + &shifted * (1.0 - &self.time_mix_key)?)?;
|
||||
let value = ((xs * &self.time_mix_value)? + &shifted * (1.0 - &self.time_mix_value)?)?;
|
||||
let receptance = ((xs * &self.time_mix_receptance)?
|
||||
+ &shifted * (1.0 - &self.time_mix_receptance)?)?;
|
||||
let gate = ((xs * &self.time_mix_gate)? + &shifted * (1.0 - &self.time_mix_gate)?)?;
|
||||
|
||||
let key = self.key.forward(&key)?;
|
||||
let value = self.value.forward(&value)?;
|
||||
let receptance = self.receptance.forward(&receptance)?;
|
||||
let gate = candle_nn::ops::silu(&self.gate.forward(&gate)?)?;
|
||||
state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
|
||||
(receptance, key, value, gate)
|
||||
};
|
||||
// linear attention
|
||||
let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
|
||||
let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
|
||||
let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
|
||||
let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
|
||||
|
||||
let time_decay = self
|
||||
.time_decay
|
||||
.exp()?
|
||||
.neg()?
|
||||
.exp()?
|
||||
.reshape(((), 1, 1))?
|
||||
.reshape((self.n_attn_heads, (), 1))?;
|
||||
let time_faaaa =
|
||||
self.time_faaaa
|
||||
.reshape(((), 1, 1))?
|
||||
.reshape((self.n_attn_heads, (), 1))?;
|
||||
|
||||
let mut out: Vec<Tensor> = Vec::with_capacity(t);
|
||||
for t_ in 0..t {
|
||||
//
|
||||
let rt = receptance.i((.., .., t_..t_ + 1))?;
|
||||
let kt = key.i((.., .., .., t_..t_ + 1))?;
|
||||
let vt = value.i((.., .., t_..t_ + 1))?;
|
||||
let at = kt.matmul(&vt)?;
|
||||
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
|
||||
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
|
||||
state_ = (&at + time_decay.broadcast_mul(&state_))?;
|
||||
out.push(out_)
|
||||
}
|
||||
let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
|
||||
let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
|
||||
let out = (out * gate)?.apply(&self.output)?;
|
||||
state.per_layer[self.layer_id].linear_attention = state_;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct FeedForward {
|
||||
time_mix_key: Tensor,
|
||||
time_mix_receptance: Tensor,
|
||||
key: Linear,
|
||||
receptance: Linear,
|
||||
value: Linear,
|
||||
layer_id: usize,
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let int_size = cfg
|
||||
.intermediate_size
|
||||
.unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
|
||||
let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
|
||||
let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
|
||||
let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
|
||||
let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
|
||||
let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
|
||||
Ok(Self {
|
||||
key,
|
||||
receptance,
|
||||
value,
|
||||
time_mix_key,
|
||||
time_mix_receptance,
|
||||
layer_id,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let shifted = &state.per_layer[self.layer_id].feed_forward;
|
||||
let key = (xs.broadcast_mul(&self.time_mix_key)?
|
||||
+ shifted.broadcast_mul(&(1.0 - &self.time_mix_key)?)?)?;
|
||||
let receptance = (xs.broadcast_mul(&self.time_mix_receptance)?
|
||||
+ shifted.broadcast_mul(&(1.0 - &self.time_mix_receptance)?)?)?;
|
||||
let key = key.apply(&self.key)?.relu()?.sqr()?;
|
||||
let value = key.apply(&self.value)?;
|
||||
let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
|
||||
state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
|
||||
let xs = (receptance * value)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Block {
|
||||
pre_ln: Option<LayerNorm>,
|
||||
ln1: LayerNorm,
|
||||
ln2: LayerNorm,
|
||||
attention: SelfAttention,
|
||||
feed_forward: FeedForward,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
|
||||
let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
|
||||
let pre_ln = if layer_id == 0 {
|
||||
let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
|
||||
Some(ln)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
|
||||
let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
|
||||
Ok(Self {
|
||||
pre_ln,
|
||||
ln1,
|
||||
ln2,
|
||||
attention,
|
||||
feed_forward,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let xs = match self.pre_ln.as_ref() {
|
||||
None => xs.clone(),
|
||||
Some(pre_ln) => xs.apply(pre_ln)?,
|
||||
};
|
||||
let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
|
||||
let xs = (xs + attention)?;
|
||||
let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
|
||||
let xs = (xs + feed_forward)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embeddings: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
ln_out: LayerNorm,
|
||||
head: Linear,
|
||||
rescale_every: usize,
|
||||
layers_are_rescaled: bool,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("rwkv");
|
||||
let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
|
||||
let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_b = vb_m.pp("blocks");
|
||||
for block_index in 0..cfg.num_hidden_layers {
|
||||
let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
|
||||
let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
blocks,
|
||||
ln_out,
|
||||
head,
|
||||
rescale_every: cfg.rescale_every,
|
||||
layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let (_b_size, _seq_len) = xs.dims2()?;
|
||||
let mut xs = xs.apply(&self.embeddings)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
xs = block.forward(&xs, state)?;
|
||||
if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
|
||||
xs = (xs / 2.)?
|
||||
}
|
||||
}
|
||||
let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
|
||||
state.pos += 1;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
type Bytes = Vec<u8>;
|
||||
|
||||
// https://github.com/BlinkDL/ChatRWKV/blob/095e812aef15a1f74107f6c39d13578a2412dc46/RWKV_v5_demo.py#L14
|
||||
pub struct Tokenizer {
|
||||
table: Vec<Vec<Vec<Bytes>>>,
|
||||
good: Vec<HashSet<u8>>,
|
||||
idx2token: HashMap<u32, Vec<u8>>,
|
||||
token2idx: HashMap<Vec<u8>, u32>,
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
let file = std::fs::File::open(p)?;
|
||||
let token2idx: HashMap<String, u32> =
|
||||
serde_json::from_reader(file).map_err(candle::Error::wrap)?;
|
||||
let token2idx = token2idx
|
||||
.into_iter()
|
||||
.map(|(key, value)| (key.into_bytes(), value))
|
||||
.collect::<HashMap<_, _>>();
|
||||
let idx2token = token2idx
|
||||
.iter()
|
||||
.map(|(key, value)| (*value, key.to_vec()))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let max_idx = token2idx.values().copied().max().unwrap_or(0);
|
||||
|
||||
let mut table = vec![vec![vec![]; 256]; 256];
|
||||
let mut good = vec![HashSet::new(); 256];
|
||||
for idx in (0..(1 + max_idx)).rev() {
|
||||
let s = match idx2token.get(&idx) {
|
||||
None => continue,
|
||||
Some(s) => s,
|
||||
};
|
||||
if s.len() >= 2 {
|
||||
let (s0, s1) = (s[0], s[1]);
|
||||
table[s0 as usize][s1 as usize].push(s.to_vec());
|
||||
good[s0 as usize].insert(s1);
|
||||
}
|
||||
}
|
||||
Ok(Self {
|
||||
table,
|
||||
good,
|
||||
idx2token,
|
||||
token2idx,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn decode_bytes(&self, tokens: &[u32]) -> Vec<u8> {
|
||||
let mut v = Vec::new();
|
||||
for token_id in tokens.iter() {
|
||||
if let Some(token) = self.idx2token.get(token_id) {
|
||||
v.extend_from_slice(token.as_slice())
|
||||
}
|
||||
}
|
||||
v
|
||||
}
|
||||
|
||||
pub fn decode(&self, tokens: &[u32]) -> Result<String> {
|
||||
let bytes = self.decode_bytes(tokens);
|
||||
String::from_utf8(bytes).map_err(candle::Error::wrap)
|
||||
}
|
||||
|
||||
pub fn encode_bytes(&self, bytes: &[u8]) -> Result<Vec<u32>> {
|
||||
let mut tokens = Vec::new();
|
||||
let mut i = 0;
|
||||
while i < bytes.len() {
|
||||
let mut s = vec![bytes[i]];
|
||||
if i + 1 < bytes.len() && self.good[bytes[i] as usize].contains(&bytes[i + 1]) {
|
||||
let table = &self.table[bytes[i] as usize][bytes[i + 1] as usize];
|
||||
for table_elem in table.iter() {
|
||||
if bytes[i..].starts_with(table_elem) {
|
||||
s = table_elem.to_vec();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
i += s.len();
|
||||
let token = match self.token2idx.get(&s) {
|
||||
None => candle::bail!("unexpected token '{}' {s:?}", String::from_utf8_lossy(&s)),
|
||||
Some(token) => *token,
|
||||
};
|
||||
tokens.push(token)
|
||||
}
|
||||
Ok(tokens)
|
||||
}
|
||||
|
||||
pub fn encode(&self, str: &str) -> Result<Vec<u32>> {
|
||||
self.encode_bytes(str.as_bytes())
|
||||
}
|
||||
}
|
@ -1,10 +1,11 @@
|
||||
use crate::models::with_tracing::{linear_no_bias, Linear};
|
||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, LayerNorm, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub(crate) vocab_size: usize,
|
||||
pub(crate) intermediate_size: usize,
|
||||
@ -18,7 +19,10 @@ pub struct Config {
|
||||
pub(crate) max_position_embeddings: usize,
|
||||
pub(crate) norm_eps: f64,
|
||||
pub(crate) use_cache: bool,
|
||||
pub(crate) use_flash_attn: bool,
|
||||
#[serde(default)]
|
||||
pub(crate) use_qkv_bias: bool, // Used in StableLM-2
|
||||
#[serde(default)]
|
||||
pub(crate) use_flash_attn: bool, // Not in config.json
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -35,6 +39,7 @@ impl Config {
|
||||
rope_theta: 10_000.,
|
||||
max_position_embeddings: 4096,
|
||||
norm_eps: 1e-5,
|
||||
use_qkv_bias: false,
|
||||
use_cache: true,
|
||||
use_flash_attn,
|
||||
}
|
||||
@ -51,6 +56,10 @@ impl Config {
|
||||
pub fn num_kv_groups(&self) -> usize {
|
||||
self.num_attention_heads / self.num_key_value_heads
|
||||
}
|
||||
|
||||
pub fn set_use_flash_attn(&mut self, use_flash_attn: bool) {
|
||||
self.use_flash_attn = use_flash_attn
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -179,9 +188,15 @@ impl Attention {
|
||||
let head_dim = cfg.head_dim();
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let linear_layer = if cfg.use_qkv_bias {
|
||||
linear
|
||||
} else {
|
||||
linear_no_bias
|
||||
};
|
||||
|
||||
let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
|
@ -1,15 +1,21 @@
|
||||
use crate::models::vit::{Config, Embeddings, Encoder};
|
||||
use candle::{Result, Tensor};
|
||||
use candle::{DType, Result, Tensor};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
fn default_tie_word_embeddings() -> bool {
|
||||
true
|
||||
}
|
||||
fn default_use_learned_position_embeddings() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||
pub struct TrOCRConfig {
|
||||
pub vocab_size: usize,
|
||||
pub d_model: usize,
|
||||
pub hidden_size: usize,
|
||||
pub cross_attention_hidden_size: usize,
|
||||
pub decoder_layers: usize,
|
||||
pub decoder_attention_heads: usize,
|
||||
pub decoder_ffn_dim: usize,
|
||||
@ -23,13 +29,14 @@ pub struct TrOCRConfig {
|
||||
pub decoder_layerdrop: f64,
|
||||
pub use_cache: bool,
|
||||
pub scale_embedding: bool,
|
||||
pub use_learned_position_embeddings: bool,
|
||||
pub layernorm_embedding: bool,
|
||||
pub pad_token_id: usize,
|
||||
pub bos_token_id: usize,
|
||||
pub eos_token_id: u32,
|
||||
pub num_attention_heads: usize,
|
||||
pub decoder_vocab_size: Option<usize>,
|
||||
#[serde(default = "default_use_learned_position_embeddings")]
|
||||
pub use_learned_position_embeddings: bool,
|
||||
#[serde(default = "default_tie_word_embeddings")]
|
||||
pub tie_word_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Default for TrOCRConfig {
|
||||
@ -37,7 +44,7 @@ impl Default for TrOCRConfig {
|
||||
Self {
|
||||
vocab_size: 50265,
|
||||
d_model: 1024,
|
||||
hidden_size: 768,
|
||||
cross_attention_hidden_size: 768,
|
||||
decoder_layers: 12,
|
||||
decoder_attention_heads: 16,
|
||||
decoder_ffn_dim: 4096,
|
||||
@ -51,13 +58,12 @@ impl Default for TrOCRConfig {
|
||||
decoder_layerdrop: 0.0,
|
||||
use_cache: true,
|
||||
scale_embedding: false,
|
||||
use_learned_position_embeddings: true,
|
||||
layernorm_embedding: true,
|
||||
pad_token_id: 1,
|
||||
bos_token_id: 0,
|
||||
eos_token_id: 2,
|
||||
num_attention_heads: 12,
|
||||
decoder_vocab_size: Some(50265),
|
||||
use_learned_position_embeddings: true,
|
||||
tie_word_embeddings: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -78,17 +84,49 @@ impl TrOCRLearnedPositionalEmbedding {
|
||||
Ok(Self { offset, weights })
|
||||
}
|
||||
|
||||
fn new_sinusoidal(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
|
||||
// https://github.com/huggingface/transformers/blob/58e3d23e97078f361a533b9ec4a6a2de674ea52a/src/transformers/models/trocr/modeling_trocr.py#L81
|
||||
let embedding_dim = cfg.d_model;
|
||||
let half_dim = embedding_dim / 2;
|
||||
let num_positions = cfg.max_position_embeddings + cfg.pad_token_id + 1;
|
||||
let dev = vb.device();
|
||||
let inv_freq: Vec<_> = (0..half_dim)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / (half_dim - 1) as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
||||
let t = Tensor::arange(0u32, num_positions as u32, dev)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((num_positions, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let emb = Tensor::cat(&[freqs.sin()?, freqs.cos()?], 1)?;
|
||||
let emb = Tensor::cat(
|
||||
&[
|
||||
emb.narrow(0, 0, cfg.pad_token_id)?,
|
||||
Tensor::zeros((1, embedding_dim), DType::F32, dev)?,
|
||||
emb.narrow(0, cfg.pad_token_id + 1, cfg.max_position_embeddings)?,
|
||||
],
|
||||
0,
|
||||
)?
|
||||
.contiguous()?;
|
||||
let emb = Embedding::new(emb, embedding_dim);
|
||||
Ok(Self {
|
||||
offset: cfg.pad_token_id + 1,
|
||||
weights: emb,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||
|
||||
let mut positions = Tensor::arange(
|
||||
let positions = Tensor::arange(
|
||||
past_key_values_length,
|
||||
seq_len as u32 + past_key_values_length,
|
||||
input_ids.device(),
|
||||
)?
|
||||
.expand((b_sz, seq_len))?;
|
||||
|
||||
positions =
|
||||
let positions =
|
||||
positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?;
|
||||
self.weights.forward(&positions)
|
||||
}
|
||||
@ -221,19 +259,17 @@ impl TrOCRDecoderLayer {
|
||||
let encoder_attn = TrOCRAttention::load(
|
||||
vb.pp("encoder_attn"),
|
||||
cfg,
|
||||
Some(cfg.hidden_size),
|
||||
Some(cfg.hidden_size),
|
||||
Some(cfg.cross_attention_hidden_size),
|
||||
Some(cfg.cross_attention_hidden_size),
|
||||
)?;
|
||||
let encoder_attn_layer_norm =
|
||||
layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||
let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
|
||||
let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?;
|
||||
let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?;
|
||||
let activation_fn = candle_nn::Activation::Gelu;
|
||||
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
activation_fn,
|
||||
activation_fn: cfg.activation_function,
|
||||
self_attn_layer_norm,
|
||||
encoder_attn,
|
||||
encoder_attn_layer_norm,
|
||||
@ -294,7 +330,11 @@ impl TrOCRDecoder {
|
||||
let vb = vb.pp("decoder.model.decoder");
|
||||
|
||||
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
|
||||
let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
|
||||
let embed_positions = if cfg.use_learned_position_embeddings {
|
||||
TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?
|
||||
} else {
|
||||
TrOCRLearnedPositionalEmbedding::new_sinusoidal(vb.pp("embed_positions"), cfg)?
|
||||
};
|
||||
let mut layers = Vec::with_capacity(cfg.decoder_layers);
|
||||
let vb_l = vb.pp("layers");
|
||||
for idx in 0..cfg.decoder_layers {
|
||||
@ -383,8 +423,15 @@ pub struct TrOCRForCausalLM {
|
||||
impl TrOCRForCausalLM {
|
||||
pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?;
|
||||
let output_projection =
|
||||
candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None);
|
||||
let output_projection = if decoder_cfg.tie_word_embeddings {
|
||||
candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None)
|
||||
} else {
|
||||
candle_nn::linear_no_bias(
|
||||
decoder_cfg.d_model,
|
||||
decoder_cfg.vocab_size,
|
||||
vb.pp("decoder.output_projection"),
|
||||
)?
|
||||
};
|
||||
Ok(Self {
|
||||
decoder,
|
||||
output_projection,
|
||||
|
@ -1,10 +1,9 @@
|
||||
#![allow(unused)]
|
||||
use crate::models::with_tracing::{conv2d, linear, linear_no_bias, Conv2d, Linear};
|
||||
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
@ -82,7 +81,7 @@ impl PatchEmbeddings {
|
||||
|
||||
impl Module for PatchEmbeddings {
|
||||
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||
let (b_size, num_channels, height, width) = pixel_values.dims4()?;
|
||||
let (_b_size, _num_channels, _height, _width) = pixel_values.dims4()?;
|
||||
self.projection
|
||||
.forward(pixel_values)?
|
||||
.flatten_from(2)?
|
||||
@ -123,9 +122,9 @@ impl Embeddings {
|
||||
|
||||
fn interpolate_pos_encoding(
|
||||
&self,
|
||||
embeddings: &Tensor,
|
||||
height: usize,
|
||||
width: usize,
|
||||
_embeddings: &Tensor,
|
||||
_height: usize,
|
||||
_width: usize,
|
||||
) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
@ -136,7 +135,7 @@ impl Embeddings {
|
||||
bool_masked_pos: Option<&Tensor>,
|
||||
interpolate_pos_encoding: bool,
|
||||
) -> Result<Tensor> {
|
||||
let (b_size, num_channels, height, width) = pixel_values.dims4()?;
|
||||
let (b_size, _num_channels, height, width) = pixel_values.dims4()?;
|
||||
let embeddings = self.patch_embeddings.forward(pixel_values)?;
|
||||
let embeddings = match (bool_masked_pos, &self.mask_token) {
|
||||
(None, _) => embeddings,
|
||||
@ -392,6 +391,9 @@ impl Model {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let embedding_output = self.embeddings.forward(xs, None, false)?;
|
||||
let encoder_outputs = self.encoder.forward(&embedding_output)?;
|
||||
encoder_outputs.i((.., 0, ..))?.apply(&self.classifier)
|
||||
encoder_outputs
|
||||
.i((.., 0, ..))?
|
||||
.apply(&self.layernorm)?
|
||||
.apply(&self.classifier)
|
||||
}
|
||||
}
|
||||
|
156
candle-transformers/src/models/vocos.rs
Normal file
156
candle-transformers/src/models/vocos.rs
Normal file
@ -0,0 +1,156 @@
|
||||
#![allow(unused)]
|
||||
use candle::{DType, Module, Result, Tensor, D};
|
||||
use candle_nn::{conv1d, embedding, linear, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder};
|
||||
|
||||
pub struct AdaLayerNorm {
|
||||
eps: f64,
|
||||
dim: usize,
|
||||
scale: Embedding,
|
||||
shift: Embedding,
|
||||
}
|
||||
|
||||
fn layer_norm(x: &Tensor, eps: f64) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let hidden_size = x.dim(D::Minus1)?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let x = {
|
||||
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
x.broadcast_sub(&mean_x)?
|
||||
};
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + eps)?.sqrt()?)?;
|
||||
x_normed.to_dtype(x_dtype)
|
||||
}
|
||||
|
||||
impl AdaLayerNorm {
|
||||
pub fn new(
|
||||
num_embeddings: usize,
|
||||
embedding_dim: usize,
|
||||
eps: f64,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let scale = embedding(num_embeddings, embedding_dim, vb.pp("scale"))?;
|
||||
let shift = embedding(num_embeddings, embedding_dim, vb.pp("shift"))?;
|
||||
Ok(Self {
|
||||
eps,
|
||||
dim: embedding_dim,
|
||||
scale,
|
||||
shift,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, cond_embedding_id: &Tensor) -> Result<Tensor> {
|
||||
let scale = self.scale.forward(cond_embedding_id)?;
|
||||
let shift = self.shift.forward(cond_embedding_id)?;
|
||||
let xs = layer_norm(xs, self.eps)?;
|
||||
xs * scale + shift
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConvNeXtBlock {
|
||||
dwconv: Conv1d,
|
||||
pwconv1: Linear,
|
||||
pwconv2: Linear,
|
||||
gamma: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl ConvNeXtBlock {
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
intermediate_dim: usize,
|
||||
layer_scale_init_value: f64,
|
||||
adanorm_num_embeddings: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let dwconv = {
|
||||
let cfg = Conv1dConfig {
|
||||
padding: 3,
|
||||
groups: dim,
|
||||
..Default::default()
|
||||
};
|
||||
conv1d(dim, dim, 7, cfg, vb.pp("dwconv"))?
|
||||
};
|
||||
let pwconv1 = linear(dim, intermediate_dim, vb.pp("pwconv1"))?;
|
||||
let pwconv2 = linear(intermediate_dim, dim, vb.pp("pwconv2"))?;
|
||||
let gamma = if layer_scale_init_value > 0. {
|
||||
Some(vb.get(dim, "gamma")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
dwconv,
|
||||
pwconv1,
|
||||
pwconv2,
|
||||
gamma,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.dwconv)?.transpose(1, 2)?;
|
||||
// TODO: norm
|
||||
let xs = xs.apply(&self.pwconv1)?.gelu()?.apply(&self.pwconv2)?;
|
||||
let xs = match self.gamma.as_ref() {
|
||||
Some(gamma) => (gamma * xs)?,
|
||||
None => xs,
|
||||
};
|
||||
xs.transpose(1, 2)? + residual
|
||||
}
|
||||
}
|
||||
|
||||
struct VocosBackbone {
|
||||
embed: Conv1d,
|
||||
convnext: Vec<ConvNeXtBlock>,
|
||||
final_layer_norm: candle_nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl VocosBackbone {
|
||||
pub fn new(
|
||||
input_channels: usize,
|
||||
dim: usize,
|
||||
intermediate_dim: usize,
|
||||
num_layers: dim,
|
||||
layer_scale_init_value: f64,
|
||||
adanorm_num_embeddings: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let embed = {
|
||||
let cfg = Conv1dConfig {
|
||||
padding: 3,
|
||||
..Default::default()
|
||||
};
|
||||
conv1d(input_channels, dim, 7, cfg, vb.pp("embed"))?
|
||||
};
|
||||
let mut convnext = Vec::with_capacity(num_layers);
|
||||
let vb_c = vb.pp("convnext");
|
||||
for i in 0..num_layers {
|
||||
let block = ConvNeXtBlock::new(
|
||||
dim,
|
||||
intermediate_dim,
|
||||
layer_scale_init_value,
|
||||
adanorm_num_embeddings,
|
||||
vb_c.pp(i),
|
||||
)?;
|
||||
}
|
||||
let final_layer_norm = candle_nn::layer_norm(dim, 1e-6, vb.pp("final_layer_norm"))?;
|
||||
Ok(Self {
|
||||
embed,
|
||||
convnext,
|
||||
final_layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.embed)?;
|
||||
// TODO: norm
|
||||
let mut xs = xs.transpose(1, 2)?;
|
||||
for conv_block in self.convnext.iter() {
|
||||
xs = conv_block.forward(&xs)?
|
||||
}
|
||||
xs.apply(&self.final_layer_norm)
|
||||
}
|
||||
}
|
@ -1,7 +1,14 @@
|
||||
// Audio processing code, adapted from whisper.cpp
|
||||
// https://github.com/ggerganov/whisper.cpp
|
||||
|
||||
pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
|
||||
use candle::utils::get_num_threads;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
pub trait Float:
|
||||
num_traits::Float + num_traits::FloatConst + num_traits::NumAssign + Send + Sync
|
||||
{
|
||||
}
|
||||
|
||||
impl Float for f32 {}
|
||||
impl Float for f64 {}
|
||||
@ -102,22 +109,26 @@ fn log_mel_spectrogram_w<T: Float>(
|
||||
let half = T::from(0.5).unwrap();
|
||||
let mut fft_in = vec![zero; fft_size];
|
||||
let mut mel = vec![zero; n_len * n_mel];
|
||||
let n_samples = samples.len();
|
||||
let end = std::cmp::min(n_samples / fft_step + 1, n_len);
|
||||
|
||||
for i in (ith..n_len).step_by(n_threads) {
|
||||
for i in (ith..end).step_by(n_threads) {
|
||||
let offset = i * fft_step;
|
||||
|
||||
// apply Hanning window
|
||||
for j in 0..fft_size {
|
||||
fft_in[j] = if offset + j < samples.len() {
|
||||
hann[j] * samples[offset + j]
|
||||
} else {
|
||||
zero
|
||||
}
|
||||
for j in 0..std::cmp::min(fft_size, n_samples - offset) {
|
||||
fft_in[j] = hann[j] * samples[offset + j];
|
||||
}
|
||||
|
||||
// FFT -> mag^2
|
||||
// fill the rest with zeros
|
||||
if n_samples - offset < fft_size {
|
||||
fft_in[n_samples - offset..].fill(zero);
|
||||
}
|
||||
|
||||
// FFT
|
||||
let mut fft_out: Vec<T> = fft(&fft_in);
|
||||
|
||||
// Calculate modulus^2 of complex numbers
|
||||
for j in 0..fft_size {
|
||||
fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
|
||||
}
|
||||
@ -136,8 +147,19 @@ fn log_mel_spectrogram_w<T: Float>(
|
||||
// mel spectrogram
|
||||
for j in 0..n_mel {
|
||||
let mut sum = zero;
|
||||
for k in 0..n_fft {
|
||||
let mut k = 0;
|
||||
// Unroll loop
|
||||
while k < n_fft.saturating_sub(3) {
|
||||
sum += fft_out[k] * filters[j * n_fft + k]
|
||||
+ fft_out[k + 1] * filters[j * n_fft + k + 1]
|
||||
+ fft_out[k + 2] * filters[j * n_fft + k + 2]
|
||||
+ fft_out[k + 3] * filters[j * n_fft + k + 3];
|
||||
k += 4;
|
||||
}
|
||||
// Handle remainder
|
||||
while k < n_fft {
|
||||
sum += fft_out[k] * filters[j * n_fft + k];
|
||||
k += 1;
|
||||
}
|
||||
mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
|
||||
}
|
||||
@ -145,7 +167,7 @@ fn log_mel_spectrogram_w<T: Float>(
|
||||
mel
|
||||
}
|
||||
|
||||
fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
||||
fn log_mel_spectrogram_<T: Float>(
|
||||
samples: &[T],
|
||||
filters: &[T],
|
||||
fft_size: usize,
|
||||
@ -180,10 +202,55 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
||||
samples_padded
|
||||
};
|
||||
|
||||
// Use a single thread for now.
|
||||
let mut mel = log_mel_spectrogram_w(
|
||||
0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1,
|
||||
);
|
||||
// ensure that the number of threads is even and less than 12
|
||||
let n_threads = std::cmp::min(get_num_threads() - get_num_threads() % 2, 12);
|
||||
|
||||
let hann = Arc::new(hann);
|
||||
let samples = Arc::new(samples);
|
||||
let filters = Arc::new(filters);
|
||||
|
||||
// use scope to allow for non static references to be passed to the threads
|
||||
// and directly collect the results into a single vector
|
||||
let all_outputs = thread::scope(|s| {
|
||||
(0..n_threads)
|
||||
// create threads and return their handles
|
||||
.map(|thread_id| {
|
||||
let hann = Arc::clone(&hann);
|
||||
let samples = Arc::clone(&samples);
|
||||
let filters = Arc::clone(&filters);
|
||||
// spawn new thread and start work
|
||||
s.spawn(move || {
|
||||
log_mel_spectrogram_w(
|
||||
thread_id, &hann, &samples, &filters, fft_size, fft_step, speed_up, n_len,
|
||||
n_mel, n_threads,
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
// wait for each thread to finish and collect their results
|
||||
.map(|handle| handle.join().expect("Thread failed"))
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
|
||||
let l = all_outputs[0].len();
|
||||
let mut mel = vec![zero; l];
|
||||
|
||||
// iterate over mel spectrogram segments, dividing work by threads.
|
||||
for segment_start in (0..l).step_by(n_threads) {
|
||||
// go through each thread's output.
|
||||
for thread_output in all_outputs.iter() {
|
||||
// add each thread's piece to our mel spectrogram.
|
||||
for offset in 0..n_threads {
|
||||
let mel_index = segment_start + offset; // find location in mel.
|
||||
if mel_index < mel.len() {
|
||||
// Make sure we don't go out of bounds.
|
||||
mel[mel_index] += thread_output[mel_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mmax = mel
|
||||
.iter()
|
||||
.max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater))
|
||||
@ -197,11 +264,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
||||
mel
|
||||
}
|
||||
|
||||
pub fn pcm_to_mel<T: Float + std::fmt::Display>(
|
||||
cfg: &super::Config,
|
||||
samples: &[T],
|
||||
filters: &[T],
|
||||
) -> Vec<T> {
|
||||
pub fn pcm_to_mel<T: Float>(cfg: &super::Config, samples: &[T], filters: &[T]) -> Vec<T> {
|
||||
log_mel_spectrogram_(
|
||||
samples,
|
||||
filters,
|
||||
@ -211,3 +274,62 @@ pub fn pcm_to_mel<T: Float + std::fmt::Display>(
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_fft() {
|
||||
let input = vec![0.0, 1.0, 0.0, 0.0];
|
||||
let output = fft(&input);
|
||||
assert_eq!(
|
||||
output,
|
||||
vec![
|
||||
1.0,
|
||||
0.0,
|
||||
6.123233995736766e-17,
|
||||
-1.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
-6.123233995736766e-17,
|
||||
1.0
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dft() {
|
||||
let input = vec![0.0, 1.0, 0.0, 0.0];
|
||||
let output = dft(&input);
|
||||
assert_eq!(
|
||||
output,
|
||||
vec![
|
||||
1.0,
|
||||
0.0,
|
||||
6.123233995736766e-17,
|
||||
-1.0,
|
||||
-1.0,
|
||||
-1.2246467991473532e-16,
|
||||
-1.8369701987210297e-16,
|
||||
1.0
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_log_mel_spectrogram() {
|
||||
let samples = vec![0.0; 1000];
|
||||
let filters = vec![0.0; 1000];
|
||||
let output = log_mel_spectrogram_(&samples, &filters, 100, 10, 10, false);
|
||||
assert_eq!(output.len(), 30_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tiny_log_mel_spectrogram() {
|
||||
let samples = vec![0.0; 100];
|
||||
let filters = vec![0.0; 100];
|
||||
let output = log_mel_spectrogram_(&samples, &filters, 20, 2, 2, false);
|
||||
assert_eq!(output.len(), 6_000);
|
||||
}
|
||||
}
|
||||
|
@ -129,6 +129,10 @@ impl MultiHeadAttention {
|
||||
.flatten_from(2)?;
|
||||
Ok(wv)
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.kv_cache = None;
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
|
||||
@ -193,16 +197,23 @@ impl ResidualAttentionBlock {
|
||||
)?;
|
||||
x + mlp
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.attn.reset_kv_cache();
|
||||
if let Some((attn, _)) = &mut self.cross_attn {
|
||||
attn.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
|
||||
fn sinusoids(length: usize, channels: usize, device: &Device) -> Result<Tensor> {
|
||||
let max_timescale = 10000f32;
|
||||
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
|
||||
let inv_timescales: Vec<_> = (0..channels / 2)
|
||||
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
|
||||
.collect();
|
||||
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
|
||||
let inv_timescales = Tensor::new(inv_timescales.as_slice(), device)?.unsqueeze(0)?;
|
||||
let arange = Tensor::arange(0, length as u32, device)?
|
||||
.to_dtype(candle::DType::F32)?
|
||||
.unsqueeze(1)?;
|
||||
let sh = (length, channels / 2);
|
||||
@ -246,7 +257,7 @@ impl AudioEncoder {
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
|
||||
let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?;
|
||||
let blocks = (0..cfg.encoder_layers)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
|
||||
@ -350,6 +361,12 @@ impl TextDecoder {
|
||||
};
|
||||
Ok(logits)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
for block in self.blocks.iter_mut() {
|
||||
block.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
|
||||
@ -370,4 +387,12 @@ impl Whisper {
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.encoder
|
||||
.blocks
|
||||
.iter_mut()
|
||||
.for_each(|b| b.reset_kv_cache());
|
||||
self.decoder.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
@ -126,6 +126,10 @@ impl MultiHeadAttention {
|
||||
.flatten_from(2)?;
|
||||
Ok(wv)
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.kv_cache = None;
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
|
||||
@ -189,16 +193,23 @@ impl ResidualAttentionBlock {
|
||||
.apply(&self.mlp_linear2)?;
|
||||
x + mlp
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.attn.reset_kv_cache();
|
||||
if let Some((attn, _)) = &mut self.cross_attn {
|
||||
attn.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
|
||||
fn sinusoids(length: usize, channels: usize, device: &Device) -> Result<Tensor> {
|
||||
let max_timescale = 10000f32;
|
||||
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
|
||||
let inv_timescales: Vec<_> = (0..channels / 2)
|
||||
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
|
||||
.collect();
|
||||
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
|
||||
let inv_timescales = Tensor::new(inv_timescales.as_slice(), device)?.unsqueeze(0)?;
|
||||
let arange = Tensor::arange(0, length as u32, device)?
|
||||
.to_dtype(candle::DType::F32)?
|
||||
.unsqueeze(1)?;
|
||||
let sh = (length, channels / 2);
|
||||
@ -242,7 +253,7 @@ impl AudioEncoder {
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
|
||||
let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?;
|
||||
let blocks = (0..cfg.encoder_layers)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))
|
||||
@ -281,6 +292,12 @@ impl AudioEncoder {
|
||||
let x = self.ln_post.forward(&x)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
for block in self.blocks.iter_mut() {
|
||||
block.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
|
||||
@ -348,6 +365,12 @@ impl TextDecoder {
|
||||
};
|
||||
Ok(logits)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
for block in self.blocks.iter_mut() {
|
||||
block.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
|
||||
@ -368,4 +391,9 @@ impl Whisper {
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.encoder.reset_kv_cache();
|
||||
self.decoder.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
## Running Segment Anything Example
|
||||
|
||||
Here, we provide two examples of how to run Whisper using a Candle-compiled WASM binary and runtimes.
|
||||
Here, we provide an example showing how to run the Segment Anything model in the
|
||||
browser.
|
||||
|
||||
### Vanilla JS and WebWorkers
|
||||
|
||||
|
Reference in New Issue
Block a user