mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 12:28:06 +00:00
Compare commits
6 Commits
spkemb
...
ivarflakst
Author | SHA1 | Date | |
---|---|---|---|
8babfe0411 | |||
077e781f53 | |||
086b6ef6b6 | |||
2056866c25 | |||
1f4c54493e | |||
d5902840e0 |
74
.github/workflows/ci_cuda.yaml
vendored
74
.github/workflows/ci_cuda.yaml
vendored
@ -5,15 +5,49 @@ on:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
start-runner:
|
||||
name: Start self-hosted EC2 runner
|
||||
runs-on: ubuntu-latest
|
||||
# Don't run on forks, they won't have access to secrets anyway.
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
||||
EC2_INSTANCE_TYPE: g5.xlarge
|
||||
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
|
||||
EC2_SECURITY_GROUP: sg-030175c435ac141d6
|
||||
outputs:
|
||||
label: ${{ steps.start-ec2-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
- name: Start EC2 runner
|
||||
id: start-ec2-runner
|
||||
uses: philschmid/philschmid-ec2-github-runner@main
|
||||
with:
|
||||
mode: start
|
||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||
ec2-image-id: ${{ env.EC2_AMI_ID }}
|
||||
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
|
||||
subnet-id: ${{ env.EC2_SUBNET_ID }}
|
||||
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
|
||||
aws-resource-tags: > # optional, requires additional permissions
|
||||
[
|
||||
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
|
||||
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
|
||||
]
|
||||
|
||||
test-cuda:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
runs-on: [single-gpu, nvidia-gpu, t4, ci]
|
||||
container:
|
||||
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
|
||||
options: --gpus 0
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
||||
needs: start-runner # required to start the main job when the runner is ready
|
||||
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
@ -24,10 +58,32 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
- name: Install dependencies
|
||||
run: apt-get update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
|
||||
- name: Install Rust Stable
|
||||
uses: actions-rust-lang/setup-rust-toolchain@v1
|
||||
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
|
||||
- name: Test (cuda)
|
||||
run: cargo test --features cuda
|
||||
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||
stop-runner:
|
||||
name: Stop self-hosted EC2 runner
|
||||
needs:
|
||||
- start-runner
|
||||
- test-cuda
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AWS_REGION: us-east-1
|
||||
if: ${{ (success() || failure()) && github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }} # required to stop the runner even if the error happened in the previous jobs
|
||||
steps:
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
- name: Stop EC2 runner
|
||||
uses: philschmid/philschmid-ec2-github-runner@main
|
||||
with:
|
||||
mode: stop
|
||||
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||
label: ${{ needs.start-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
||||
|
19
Cargo.toml
19
Cargo.toml
@ -19,7 +19,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.4.1"
|
||||
version = "0.3.3"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -31,18 +31,17 @@ license = "MIT OR Apache-2.0"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.4.1" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.4.1" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.1" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.4.1" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.1" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.4.1" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.4.1" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.4.1" }
|
||||
candle = { path = "./candle-core", package = "candle-core" }
|
||||
candle-datasets = { path = "./candle-datasets" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn" }
|
||||
candle-kernels = { path = "./candle-kernels" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels" }
|
||||
candle-nn = { path = "./candle-nn" }
|
||||
candle-onnx = { path = "./candle-onnx" }
|
||||
candle-transformers = { path = "./candle-transformers" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||
fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.3.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
|
40
README.md
40
README.md
@ -63,24 +63,17 @@ We also provide a some command line based examples using state of the art models
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
|
||||
the SOLAR-10.7B variant.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google
|
||||
Deepmind.
|
||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets. Also supports
|
||||
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
|
||||
- [Mamba](./candle-examples/examples/mamba/): an inference only
|
||||
pre-trained on 1T tokens of English and code datasets.
|
||||
- [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal
|
||||
implementation of the Mamba state space model.
|
||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||
better performance than all publicly available 13b models as of 2023-09-28.
|
||||
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
|
||||
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
||||
much faster inference.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/) and
|
||||
[StarCoder2](./candle-examples/examples/starcoder2/): LLM specialized to code generation.
|
||||
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
|
||||
- [RWKV v5 and v6](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
|
||||
performance.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||
(English/Chinese) general LLMs with 6b and 34b parameters.
|
||||
@ -110,12 +103,7 @@ We also provide a some command line based examples using state of the art models
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||
|
||||
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmantation model.
|
||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
|
||||
model using residual vector quantization.
|
||||
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
|
||||
text-to-speech.
|
||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
||||
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
@ -123,10 +111,9 @@ We also provide a some command line based examples using state of the art models
|
||||
evaluation, segmentation).
|
||||
- [VGG](./candle-examples/examples/vgg/),
|
||||
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
generate captions for an image.
|
||||
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
|
||||
dedicated submodels for hand-writing and printed recognition.
|
||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||
model, generates the translated text from the input text.
|
||||
|
||||
@ -195,18 +182,15 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Language Models.
|
||||
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
|
||||
- Falcon.
|
||||
- StarCoder, StarCoder2.
|
||||
- StarCoder.
|
||||
- Phi 1, 1.5, and 2.
|
||||
- Mamba, Minimal Mamba
|
||||
- Gemma 2b and 7b.
|
||||
- Minimal Mamba
|
||||
- Mistral 7b v0.1.
|
||||
- Mixtral 8x7b v0.1.
|
||||
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
|
||||
- StableLM-3B-4E1T.
|
||||
- Replit-code-v1.5-3B.
|
||||
- Bert.
|
||||
- Yi-6B and Yi-34B.
|
||||
- Qwen1.5.
|
||||
- RWKV v5 and v6.
|
||||
- Quantized LLMs.
|
||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||
- Mistral 7b, and 7b instruct.
|
||||
@ -216,22 +200,16 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Text to text.
|
||||
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||
- Marian MT (Machine Translation).
|
||||
- Whisper (multi-lingual support).
|
||||
- Text to image.
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Wurstchen v2.
|
||||
- Image to text.
|
||||
- BLIP.
|
||||
- TrOCR.
|
||||
- Audio.
|
||||
- Whisper, multi-lingual speech-to-text.
|
||||
- EnCodec, audio compression model.
|
||||
- MetaVoice-1B, text-to-speech model.
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
|
||||
ConvNeXTv2, MobileOne, EfficientVit (MSRA).
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG.
|
||||
- yolo-v3, yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- SegFormer.
|
||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||
- Serverless (on CPU), small and fast deployments.
|
||||
- Quantization support using the llama.cpp quantized types.
|
||||
|
@ -1,9 +1,11 @@
|
||||
mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
|
||||
criterion_main!(
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::where_cond::benches
|
||||
//benchmarks::affine::benches,
|
||||
//benchmarks::matmul::benches,
|
||||
//benchmarks::random::benches,
|
||||
benchmarks::reduce::benches,
|
||||
//benchmarks::where_cond::benches
|
||||
);
|
||||
|
@ -1,6 +1,7 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod reduce;
|
||||
pub(crate) mod where_cond;
|
||||
|
||||
use candle_core::{Device, Result};
|
||||
|
239
candle-core/benches/benchmarks/reduce.rs
Normal file
239
candle-core/benches/benchmarks/reduce.rs
Normal file
@ -0,0 +1,239 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Storage, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use half::{bf16, f16};
|
||||
use std::ops::Deref;
|
||||
use std::time::Instant;
|
||||
|
||||
fn run_sum(a: &Tensor) {
|
||||
a.sum(2).unwrap();
|
||||
}
|
||||
fn run_arg_min(a: &Tensor) {
|
||||
a.argmin(2).unwrap();
|
||||
}
|
||||
|
||||
// TODO: Remove before merging. Softmax impls live in candle-nn, so this is a temporary workaround.
|
||||
fn softmax(a: &Tensor) -> candle_core::Result<()> {
|
||||
use candle_core::{backend::BackendStorage, DType};
|
||||
let (storage, layout) = a.storage_and_layout();
|
||||
|
||||
let device = a.device();
|
||||
|
||||
if let (Device::Metal(device), Storage::Metal(storage)) = (device, storage.deref()) {
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let kernels = device.kernels();
|
||||
let name = match a.dtype() {
|
||||
DType::F32 => "softmax_f32",
|
||||
DType::F16 => "softmax_f16",
|
||||
DType::BF16 => "softmax_bf16",
|
||||
dtype => candle_core::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
||||
};
|
||||
|
||||
let n = layout.stride().len();
|
||||
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
|
||||
candle_core::bail!("Non contiguous softmax-last-dim is not implemented");
|
||||
}
|
||||
|
||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
|
||||
candle_metal_kernels::call_last_softmax(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
name,
|
||||
elem_count,
|
||||
last_dim,
|
||||
storage.buffer(),
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
let (lo, up) = (-1000.0f32, 1000.0f32);
|
||||
for device in handler.devices {
|
||||
run_softmax(c, &device, (lo, up));
|
||||
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||
|
||||
run_reduce(c, &device, (lo, up), false);
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up), false);
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||
|
||||
run_reduce(c, &device, (lo, up), true);
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up), true);
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
|
||||
if !device.is_metal() {
|
||||
return;
|
||||
}
|
||||
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => "softmax_f32",
|
||||
DType::F16 => "softmax_f16",
|
||||
DType::BF16 => "softmax_bf16",
|
||||
_ => "softmax",
|
||||
};
|
||||
softmax(&a).unwrap();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
softmax(black_box(&a)).unwrap();
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn run_reduce<T: candle_core::FloatDType>(
|
||||
c: &mut Criterion,
|
||||
device: &Device,
|
||||
(lo, up): (T, T),
|
||||
strided: bool,
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = if strided {
|
||||
Tensor::rand(lo, up, (b, m, k), &device)
|
||||
.unwrap()
|
||||
.transpose(0, 2)
|
||||
.unwrap()
|
||||
} else {
|
||||
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||
};
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => {
|
||||
if strided {
|
||||
"reduce_f32_strided"
|
||||
} else {
|
||||
"reduce_f32"
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
if strided {
|
||||
"reduce_f16_strided"
|
||||
} else {
|
||||
"reduce_f16"
|
||||
}
|
||||
}
|
||||
DType::BF16 => {
|
||||
if strided {
|
||||
"reduce_bf16_strided"
|
||||
} else {
|
||||
"reduce_bf16"
|
||||
}
|
||||
}
|
||||
_ => "reduce",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run_sum(black_box(&a));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn run_arg_reduce<T: candle_core::FloatDType>(
|
||||
c: &mut Criterion,
|
||||
device: &Device,
|
||||
(lo, up): (T, T),
|
||||
strided: bool,
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = if strided {
|
||||
Tensor::rand(lo, up, (b, m, k), &device)
|
||||
.unwrap()
|
||||
.transpose(0, 2)
|
||||
.unwrap()
|
||||
} else {
|
||||
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||
};
|
||||
|
||||
let flops = b * m * k * (DType::U32.size_in_bytes() + T::DTYPE.size_in_bytes());
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => {
|
||||
if strided {
|
||||
"arg_reduce_f32_strided"
|
||||
} else {
|
||||
"arg_reduce_f32"
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
if strided {
|
||||
"arg_reduce_f16_strided"
|
||||
} else {
|
||||
"arg_reduce_f16"
|
||||
}
|
||||
}
|
||||
DType::BF16 => {
|
||||
if strided {
|
||||
"arg_reduce_bf16_strided"
|
||||
} else {
|
||||
"arg_reduce_bf16"
|
||||
}
|
||||
}
|
||||
_ => "unknown",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run_arg_min(black_box(&a));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -5,32 +5,25 @@ extern crate accelerate_src;
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Module, Tensor};
|
||||
|
||||
use candle_core::quantized::{QMatMul, QTensor};
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
|
||||
let q_cpu = q.to_device(&Device::Cpu)?;
|
||||
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
|
||||
let q = QMatMul::from_qtensor(q)?;
|
||||
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
|
||||
let res_q_cuda = q.forward(&x)?;
|
||||
println!("{res_q_cuda}");
|
||||
|
||||
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
|
||||
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
|
||||
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
|
||||
let x_cpu = x.to_device(&Device::Cpu)?;
|
||||
let res_q_cpu = q_cpu.forward(&x_cpu)?;
|
||||
println!("{res_q_cpu}");
|
||||
|
||||
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
|
||||
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
|
||||
.abs()?
|
||||
.flatten_all()?
|
||||
.max(0)?;
|
||||
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
|
||||
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
|
||||
let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||
println!("{out_t}");
|
||||
let in_t = in_t.to_device(&Device::Cpu)?;
|
||||
let k_t = k_t.to_device(&Device::Cpu)?;
|
||||
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
|
||||
.sqr()?
|
||||
.sum_all()?;
|
||||
println!("{diff}");
|
||||
|
||||
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
|
||||
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
|
||||
let res = t.conv2d(&w, 1, 1, 1, 1)?;
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -196,7 +196,7 @@ fn run_ls(
|
||||
}
|
||||
}
|
||||
Format::Pth => {
|
||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?;
|
||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?;
|
||||
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
for tensor_info in tensors.iter() {
|
||||
println!(
|
||||
|
@ -380,16 +380,6 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_exp_inplace(y: &mut [f32]) {
|
||||
unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_exp_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
@ -412,28 +402,6 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vs_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vd_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
||||
#[inline]
|
||||
|
@ -113,7 +113,7 @@ impl Tensor {
|
||||
| Op::Unary(_node, UnaryOp::Floor)
|
||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D { arg: node, .. }
|
||||
| Op::UpsampleNearest1D(node)
|
||||
| Op::UpsampleNearest2D { arg: node, .. }
|
||||
| Op::AvgPool2D { arg: node, .. }
|
||||
| Op::MaxPool2D { arg: node, .. }
|
||||
@ -175,7 +175,7 @@ impl Tensor {
|
||||
// the backprop graph of the backprop itself. This would be an issue for second order
|
||||
// derivatives but these are out of scope at the moment.
|
||||
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
||||
let grad = if do_not_detach { grad } else { grad.detach() };
|
||||
let grad = if do_not_detach { grad } else { grad.detach()? };
|
||||
if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||
@ -250,7 +250,6 @@ impl Tensor {
|
||||
out_padding,
|
||||
*stride,
|
||||
*dilation,
|
||||
/* groups */ 1,
|
||||
)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
@ -348,18 +347,9 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
}
|
||||
Op::UpsampleNearest1D { arg, target_size } => {
|
||||
let (_n, c, size) = arg.dims3()?;
|
||||
if target_size % size != 0 {
|
||||
crate::bail!("backward not supported for non integer upscaling factors")
|
||||
}
|
||||
let scale = target_size / size;
|
||||
|
||||
let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
|
||||
let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = conv_sum;
|
||||
}
|
||||
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest1d",
|
||||
})?,
|
||||
Op::UpsampleNearest2D {
|
||||
arg,
|
||||
target_h,
|
||||
@ -599,13 +589,6 @@ impl Tensor {
|
||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Silu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
||||
let sigmoid_arg = (*node / arg)?;
|
||||
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
||||
}
|
||||
Op::Elu(arg, alpha) => {
|
||||
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -187,16 +187,36 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn conv_transpose1d_single_group(
|
||||
/// Applies a 1D transposed convolution over the input tensor.
|
||||
pub fn conv_transpose1d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
params: &ParamsConvTranspose1D,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
let params = ParamsConvTranspose1D {
|
||||
b_size,
|
||||
l_in,
|
||||
k_size,
|
||||
c_out,
|
||||
c_in,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
let storage = self.storage().conv_transpose1d(
|
||||
self.layout(),
|
||||
&kernel.storage(),
|
||||
kernel.layout(),
|
||||
params,
|
||||
¶ms,
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
||||
arg,
|
||||
@ -210,49 +230,6 @@ impl Tensor {
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
/// Applies a 1D transposed convolution over the input tensor.
|
||||
pub fn conv_transpose1d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
if c_in % groups != 0 {
|
||||
crate::bail!("in_channel {c_in} is not divisible by the number of groups")
|
||||
}
|
||||
let params = ParamsConvTranspose1D {
|
||||
b_size,
|
||||
l_in,
|
||||
k_size,
|
||||
c_out,
|
||||
c_in: c_in / groups,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv_transpose1d_single_group(kernel, ¶ms)
|
||||
} else {
|
||||
let blocks = self.chunk(groups, 1)?;
|
||||
let kernel = kernel.chunk(groups, 0)?;
|
||||
let blocks = blocks
|
||||
.iter()
|
||||
.zip(&kernel)
|
||||
.map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, ¶ms))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Tensor::cat(&blocks, 1)
|
||||
}
|
||||
}
|
||||
|
||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||
let storage =
|
||||
self.storage()
|
||||
|
@ -1263,7 +1263,6 @@ impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
let inp = &inp[inp_l.start_offset()..];
|
||||
let k = &k[k_l.start_offset()..];
|
||||
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
||||
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||
let l_out = p.l_out();
|
||||
@ -2575,7 +2574,7 @@ impl BackendStorage for CpuStorage {
|
||||
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
|
||||
}
|
||||
}
|
||||
|
||||
@ -2584,7 +2583,7 @@ impl BackendStorage for CpuStorage {
|
||||
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
|
||||
}
|
||||
}
|
||||
|
||||
@ -2601,7 +2600,7 @@ impl BackendStorage for CpuStorage {
|
||||
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1149,55 +1149,6 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
inp_l: &Layout,
|
||||
k: &CudaSlice<T>,
|
||||
k_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_in_k, c_out, l_k)
|
||||
// Input shape: (b_size, c_in, l_in)
|
||||
let p = &self.0;
|
||||
let l_out = p.l_out();
|
||||
let dst_el = p.c_out * l_out * p.b_size;
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(k_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), kernels::CONV)?;
|
||||
let ds = if dims.len() == 3 {
|
||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (
|
||||
el,
|
||||
l_out,
|
||||
p.stride,
|
||||
p.padding,
|
||||
p.output_padding,
|
||||
p.dilation,
|
||||
&ds,
|
||||
inp,
|
||||
k,
|
||||
&out,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
|
||||
impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
@ -1859,15 +1810,12 @@ impl BackendStorage for CudaStorage {
|
||||
|
||||
fn conv_transpose1d(
|
||||
&self,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice =
|
||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "cudnn"))]
|
||||
|
@ -129,15 +129,6 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: Module> Module for Option<&M> {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
None => Ok(xs.clone()),
|
||||
Some(m) => m.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A trait defining a module with forward method using a single tensor argument and a flag to
|
||||
// separate the training and evaluation behaviors.
|
||||
pub trait ModuleT {
|
||||
|
@ -489,6 +489,7 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
let device = self.device.clone();
|
||||
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
// Source dims and strides with the sum dims at the end.
|
||||
@ -502,13 +503,69 @@ impl BackendStorage for MetalStorage {
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
if layout.is_contiguous() {
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
|
||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
||||
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
|
||||
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
|
||||
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
|
||||
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
|
||||
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
|
||||
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
|
||||
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
|
||||
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
|
||||
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
|
||||
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
|
||||
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
|
||||
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
|
||||
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
|
||||
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
|
||||
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
|
||||
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
|
||||
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
|
||||
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
|
||||
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
|
||||
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
|
||||
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
|
||||
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
|
||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
|
||||
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
|
||||
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
|
||||
(k, dtype) => {
|
||||
crate::bail!("Metal contiguous reduce op {k:?} {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.shape().elem_count(),
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
return Ok(Self::new(buffer, device, self.dtype));
|
||||
}
|
||||
|
||||
for &dim_idx in sum_dims.iter() {
|
||||
dims.push(src_dims[dim_idx]);
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
|
||||
// The reduction loop requires the shared array to be properly initialized and for
|
||||
// this we want the number of threads to be a power of two.
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
|
||||
@ -540,7 +597,7 @@ impl BackendStorage for MetalStorage {
|
||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
|
||||
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
||||
(k, dtype) => crate::bail!("Metal strided reduce op {k:?} {dtype:?} not implemented"),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
@ -679,7 +736,6 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||
("usilu", DType::F32) => contiguous::silu::FLOAT,
|
||||
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
@ -697,7 +753,6 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||
("usilu", DType::F16) => contiguous::silu::HALF,
|
||||
("uabs", DType::F16) => contiguous::abs::HALF,
|
||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||
@ -732,13 +787,11 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => strided::erf::FLOAT,
|
||||
("usilu", DType::F32) => strided::silu::FLOAT,
|
||||
("uabs", DType::F32) => strided::abs::FLOAT,
|
||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
("urelu", DType::F32) => strided::relu::FLOAT,
|
||||
("uround", DType::F32) => strided::round::FLOAT,
|
||||
("utanh", DType::F32) => strided::tanh::FLOAT,
|
||||
("ucos", DType::F16) => strided::cos::HALF,
|
||||
("usin", DType::F16) => strided::sin::HALF,
|
||||
("usqr", DType::F16) => strided::sqr::HALF,
|
||||
@ -749,13 +802,11 @@ impl BackendStorage for MetalStorage {
|
||||
("ugelu", DType::F16) => strided::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => strided::erf::HALF,
|
||||
("usilu", DType::F16) => strided::silu::HALF,
|
||||
("uabs", DType::F16) => strided::abs::HALF,
|
||||
("uceil", DType::F16) => strided::ceil::HALF,
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
("urelu", DType::F16) => strided::relu::HALF,
|
||||
("uround", DType::F16) => strided::round::HALF,
|
||||
("utanh", DType::F16) => strided::tanh::HALF,
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
@ -829,9 +880,9 @@ impl BackendStorage for MetalStorage {
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
),
|
||||
&t.buffer,
|
||||
(t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
&f.buffer,
|
||||
(f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1266,7 +1317,7 @@ impl BackendStorage for MetalStorage {
|
||||
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length);
|
||||
blit.end_encoding();
|
||||
} else {
|
||||
let src_shape = src_l.shape();
|
||||
@ -1638,7 +1689,7 @@ impl BackendDevice for MetalDevice {
|
||||
min as f32,
|
||||
max as f32,
|
||||
shape.elem_count(),
|
||||
&self.seed.lock().unwrap(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1669,7 +1720,7 @@ impl BackendDevice for MetalDevice {
|
||||
mean as f32,
|
||||
stddev as f32,
|
||||
shape.elem_count(),
|
||||
&self.seed.lock().unwrap(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
@ -333,16 +333,6 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_exp_inplace(y: &mut [f32]) {
|
||||
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_exp_inplace(y: &mut [f64]) {
|
||||
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
@ -365,28 +355,6 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vs_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = -v
|
||||
}
|
||||
vd_exp_inplace(ys);
|
||||
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
|
||||
*y = v / (1.0 + *y)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
|
||||
#[inline]
|
||||
|
@ -61,7 +61,6 @@ pub enum UnaryOp {
|
||||
GeluErf,
|
||||
Erf,
|
||||
Relu,
|
||||
Silu,
|
||||
Tanh,
|
||||
Floor,
|
||||
Ceil,
|
||||
@ -132,10 +131,7 @@ pub enum Op {
|
||||
stride: (usize, usize),
|
||||
},
|
||||
|
||||
UpsampleNearest1D {
|
||||
arg: Tensor,
|
||||
target_size: usize,
|
||||
},
|
||||
UpsampleNearest1D(Tensor),
|
||||
UpsampleNearest2D {
|
||||
arg: Tensor,
|
||||
target_h: usize,
|
||||
@ -394,7 +390,6 @@ pub(crate) struct Gelu;
|
||||
pub(crate) struct GeluErf;
|
||||
pub(crate) struct Erf;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Silu;
|
||||
pub(crate) struct Tanh;
|
||||
pub(crate) struct Floor;
|
||||
pub(crate) struct Ceil;
|
||||
@ -729,77 +724,6 @@ impl UnaryOpT for Erf {
|
||||
}
|
||||
}
|
||||
|
||||
/// Silu operation
|
||||
impl UnaryOpT for Silu {
|
||||
const NAME: &'static str = "silu";
|
||||
const V: Self = Silu;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v / (bf16::ONE + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v / (f16::ONE + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v / (1.0 + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v / (1.0 + (-v).exp())
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(_: i64) -> i64 {
|
||||
0
|
||||
}
|
||||
const KERNEL: &'static str = "usilu";
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F32_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::mkl::vs_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
const F64_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::mkl::vd_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F32_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||
crate::accelerate::vs_silu(xs, ys)
|
||||
}
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
const F64_VEC: bool = true;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
#[inline(always)]
|
||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||
crate::accelerate::vd_silu(xs, ys)
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Abs {
|
||||
const NAME: &'static str = "abs";
|
||||
const KERNEL: &'static str = "uabs";
|
||||
|
@ -42,7 +42,7 @@ pub enum OpCode {
|
||||
Stop = b'.',
|
||||
NewObj = 0x81,
|
||||
EmptyList = b']',
|
||||
BinFloat = b'G',
|
||||
BinFloat = b'g',
|
||||
Append = b'a',
|
||||
Appends = b'e',
|
||||
}
|
||||
@ -217,13 +217,6 @@ impl Object {
|
||||
let args = args.remove(1);
|
||||
(callable, args)
|
||||
}
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
|
||||
let mut args = args.tuple()?;
|
||||
args.remove(0).reduce()?
|
||||
}
|
||||
_ => (callable, args),
|
||||
};
|
||||
match callable {
|
||||
@ -234,11 +227,13 @@ impl Object {
|
||||
_ => return Ok(None),
|
||||
};
|
||||
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
||||
let mut path = dir_name.to_path_buf();
|
||||
path.push(file_path);
|
||||
Ok(Some(TensorInfo {
|
||||
name,
|
||||
dtype,
|
||||
layout,
|
||||
path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
|
||||
path: path.to_string_lossy().into_owned(),
|
||||
storage_size,
|
||||
}))
|
||||
}
|
||||
@ -350,10 +345,8 @@ impl Stack {
|
||||
module_name,
|
||||
class_name,
|
||||
} => {
|
||||
if module_name == "collections"
|
||||
&& (class_name == "OrderedDict" || class_name == "defaultdict")
|
||||
{
|
||||
// TODO: have a separate ordered dict and a separate default dict.
|
||||
if module_name == "collections" && class_name == "OrderedDict" {
|
||||
// TODO: have a separate ordered dict.
|
||||
Some(Object::Dict(vec![]))
|
||||
} else {
|
||||
None
|
||||
@ -462,10 +455,7 @@ impl Stack {
|
||||
self.push(Object::Int(arg))
|
||||
}
|
||||
OpCode::BinFloat => {
|
||||
// Somehow floats are encoded using BigEndian whereas int types use LittleEndian.
|
||||
// https://github.com/python/cpython/blob/0c80da4c14d904a367968955544dd6ae58c8101c/Lib/pickletools.py#L855
|
||||
// https://github.com/pytorch/pytorch/blob/372d078f361e726bb4ac0884ac334b04c58179ef/torch/_weights_only_unpickler.py#L243
|
||||
let arg = r.read_f64::<byteorder::BigEndian>()?;
|
||||
let arg = r.read_f64::<LittleEndian>()?;
|
||||
self.push(Object::Float(arg))
|
||||
}
|
||||
OpCode::BinUnicode => {
|
||||
@ -637,16 +627,9 @@ pub struct TensorInfo {
|
||||
pub storage_size: usize,
|
||||
}
|
||||
|
||||
/// Read the tensor info from a .pth file.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `file` - The path to the .pth file.
|
||||
/// * `verbose` - Whether to print debug information.
|
||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file.
|
||||
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
file: P,
|
||||
verbose: bool,
|
||||
key: Option<&str>,
|
||||
) -> Result<Vec<TensorInfo>> {
|
||||
let file = std::fs::File::open(file)?;
|
||||
let zip_reader = std::io::BufReader::new(file);
|
||||
@ -668,9 +651,8 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
stack.read_loop(&mut reader)?;
|
||||
let obj = stack.finalize()?;
|
||||
if VERBOSE || verbose {
|
||||
println!("{obj:#?}");
|
||||
println!("{obj:?}");
|
||||
}
|
||||
|
||||
let obj = match obj {
|
||||
Object::Build { callable, args } => match *callable {
|
||||
Object::Reduce { callable, args: _ } => match *callable {
|
||||
@ -684,24 +666,6 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||
},
|
||||
obj => obj,
|
||||
};
|
||||
|
||||
// If key is provided, then we need to extract the state_dict from the object.
|
||||
let obj = if let Some(key) = key {
|
||||
if let Object::Dict(key_values) = obj {
|
||||
key_values
|
||||
.into_iter()
|
||||
.find(|(k, _)| *k == Object::Unicode(key.to_owned()))
|
||||
.map(|(_, v)| v)
|
||||
.ok_or_else(|| E::Msg(format!("key {key} not found")))?
|
||||
} else {
|
||||
obj
|
||||
}
|
||||
} else {
|
||||
obj
|
||||
};
|
||||
|
||||
// If the object is a dict, then we can extract the tensor info from it.
|
||||
// NOTE: We are assuming that the `obj` is state_dict by this stage.
|
||||
if let Object::Dict(key_values) = obj {
|
||||
for (name, value) in key_values.into_iter() {
|
||||
match value.into_tensor_info(name, &dir_name) {
|
||||
@ -724,8 +688,8 @@ pub struct PthTensors {
|
||||
}
|
||||
|
||||
impl PthTensors {
|
||||
pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
|
||||
let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
|
||||
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
|
||||
let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
|
||||
let tensor_infos = tensor_infos
|
||||
.into_iter()
|
||||
.map(|ti| (ti.name.to_string(), ti))
|
||||
@ -748,12 +712,10 @@ impl PthTensors {
|
||||
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
|
||||
let rank = tensor_info.layout.shape().rank();
|
||||
|
||||
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
||||
// case and when the tensor is fortran contiguous.
|
||||
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
|
||||
// case.
|
||||
if !tensor_info.layout.is_contiguous() {
|
||||
crate::bail!(
|
||||
"cannot retrieve non-contiguous tensors {:?}",
|
||||
tensor_info.layout
|
||||
@ -771,33 +733,13 @@ impl PthTensors {
|
||||
tensor_info.dtype,
|
||||
&mut reader,
|
||||
)?;
|
||||
|
||||
if rank > 1 && is_fortran_contiguous {
|
||||
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
|
||||
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
|
||||
let tensor = tensor.reshape(shape_reversed)?;
|
||||
|
||||
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
|
||||
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
|
||||
let tensor = tensor.permute(dim_indeces_reversed)?;
|
||||
Ok(Some(tensor))
|
||||
} else {
|
||||
Ok(Some(tensor))
|
||||
}
|
||||
Ok(Some(tensor))
|
||||
}
|
||||
}
|
||||
|
||||
/// Read all the tensors from a PyTorch pth file with a given key.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the pth file.
|
||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
||||
/// contains multiple objects and the state_dict is the one we are interested in.
|
||||
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
||||
path: P,
|
||||
key: Option<&str>,
|
||||
) -> Result<Vec<(String, Tensor)>> {
|
||||
let pth = PthTensors::new(path, key)?;
|
||||
/// Read all the tensors from a PyTorch pth file.
|
||||
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
||||
let pth = PthTensors::new(path)?;
|
||||
let tensor_names = pth.tensor_infos.keys();
|
||||
let mut tensors = Vec::with_capacity(tensor_names.len());
|
||||
for name in tensor_names {
|
||||
@ -807,11 +749,3 @@ pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
||||
}
|
||||
Ok(tensors)
|
||||
}
|
||||
|
||||
/// Read all the tensors from a PyTorch pth file.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the pth file.
|
||||
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
||||
read_all_with_key(path, None)
|
||||
}
|
||||
|
@ -1,343 +0,0 @@
|
||||
use super::{GgmlDType, QStorage};
|
||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||
use crate::{CudaDevice, CudaStorage, Result};
|
||||
|
||||
use cudarc::driver::{CudaSlice, DeviceSlice};
|
||||
|
||||
pub struct QCudaStorage {
|
||||
data: CudaSlice<u8>,
|
||||
dtype: GgmlDType,
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
pub const WARP_SIZE: usize = 32;
|
||||
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
|
||||
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
|
||||
pub const NWARPS_Q4_0_AMPERE: usize = 4;
|
||||
pub const GGML_CUDA_MMV_X: usize = 32;
|
||||
pub const GGML_CUDA_MMV_Y: usize = 1;
|
||||
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
|
||||
|
||||
fn dequantize(
|
||||
data: &CudaSlice<u8>,
|
||||
dtype: GgmlDType,
|
||||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
|
||||
GgmlDType::Q5_0 => {
|
||||
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
|
||||
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
(
|
||||
"dequantize_block_q5_0",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
nb,
|
||||
)
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
|
||||
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
(
|
||||
"dequantize_block_q5_1",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
nb,
|
||||
)
|
||||
}
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
|
||||
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
|
||||
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
|
||||
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (num_blocks as u32, 1, 1),
|
||||
block_dim: (block_dim as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
if is_k {
|
||||
let params = (data, &dst);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
} else {
|
||||
let nb32 = match dtype {
|
||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||
_ => elem_count / 32,
|
||||
};
|
||||
let params = (data, &dst, nb32 as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
}
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
fn dequantize_mut_mal_vec(
|
||||
data: &CudaSlice<u8>,
|
||||
y: &cudarc::driver::CudaView<f32>,
|
||||
dtype: GgmlDType,
|
||||
ncols: usize,
|
||||
nrows: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let kernel_name = match dtype {
|
||||
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
|
||||
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
|
||||
GgmlDType::Q5_0 => "dequantize_mul_mat_vec_q5_0_cuda",
|
||||
GgmlDType::Q5_1 => "dequantize_mul_mat_vec_q5_1_cuda",
|
||||
GgmlDType::Q8_0 => "dequantize_mul_mat_vec_q8_0_cuda",
|
||||
GgmlDType::Q2K => "dequantize_mul_mat_vec_q2_k",
|
||||
GgmlDType::Q3K => "dequantize_mul_mat_vec_q3_k",
|
||||
GgmlDType::Q4K => "dequantize_mul_mat_vec_q4_k",
|
||||
GgmlDType::Q5K => "dequantize_mul_mat_vec_q5_k",
|
||||
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = dev.alloc_zeros::<f32>(nrows).w()?;
|
||||
let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (block_num_y as u32, 1, 1),
|
||||
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
let params = (data, y, &dst, ncols as i32, nrows as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
impl QCudaStorage {
|
||||
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||
let size_in_bytes = el_count * dtype.type_size() / dtype.block_size();
|
||||
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
|
||||
Ok(QCudaStorage {
|
||||
data,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &CudaDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||
let fast_kernel = matches!(
|
||||
self.dtype,
|
||||
GgmlDType::Q4_0
|
||||
| GgmlDType::Q4_1
|
||||
| GgmlDType::Q5_0
|
||||
| GgmlDType::Q5_1
|
||||
| GgmlDType::Q8_0
|
||||
| GgmlDType::Q2K
|
||||
| GgmlDType::Q3K
|
||||
| GgmlDType::Q4K
|
||||
| GgmlDType::Q5K
|
||||
| GgmlDType::Q6K
|
||||
| GgmlDType::Q8K
|
||||
);
|
||||
if fast_kernel {
|
||||
return dequantize(&self.data, self.dtype, elem_count, self.device());
|
||||
}
|
||||
// Run the dequantization on cpu.
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
|
||||
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
let block_len = elem_count / self.dtype.block_size();
|
||||
match self.dtype {
|
||||
GgmlDType::F32 => {
|
||||
let slice =
|
||||
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const f32, block_len) };
|
||||
out.copy_from_slice(slice)
|
||||
}
|
||||
GgmlDType::F16 => {
|
||||
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
|
||||
half::f16::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8K => {
|
||||
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
}
|
||||
|
||||
self.device
|
||||
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
||||
// Run the quantization on cpu.
|
||||
let src = match &src.slice {
|
||||
crate::cuda_backend::CudaStorageSlice::F32(data) => {
|
||||
self.device.dtoh_sync_copy(data).w()?
|
||||
}
|
||||
_ => crate::bail!("only f32 can be quantized"),
|
||||
};
|
||||
let src_len = src.len();
|
||||
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
||||
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
|
||||
qcpu_storage.quantize(&src)?;
|
||||
let data = qcpu_storage.data()?;
|
||||
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
|
||||
self.data = data;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
self_shape: &crate::Shape,
|
||||
storage: &CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
|
||||
self.dequantize_matmul_vec(self_shape, storage, layout)
|
||||
} else {
|
||||
self.dequantize_matmul(self_shape, storage, layout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QCudaStorage {
|
||||
fn dequantize_matmul_vec(
|
||||
&self,
|
||||
self_shape: &crate::Shape,
|
||||
rhs: &CudaStorage,
|
||||
rhs_l: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
let (nrows, ncols) = self_shape.dims2()?;
|
||||
let rhs = rhs.as_cuda_slice::<f32>()?;
|
||||
let rhs = match rhs_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => rhs.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
|
||||
};
|
||||
let (with_batch, k) = match rhs_l.shape().dims() {
|
||||
[1, 1, k] => (true, k),
|
||||
[1, k] => (false, k),
|
||||
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
|
||||
};
|
||||
if ncols != *k {
|
||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
|
||||
}
|
||||
|
||||
let out =
|
||||
dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?;
|
||||
let out_shape = if with_batch {
|
||||
vec![1, 1, nrows]
|
||||
} else {
|
||||
vec![1, nrows]
|
||||
};
|
||||
Ok((out, out_shape.into()))
|
||||
}
|
||||
|
||||
fn dequantize_matmul(
|
||||
&self,
|
||||
self_shape: &crate::Shape,
|
||||
storage: &CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
use crate::backend::BackendStorage;
|
||||
let (n, k) = self_shape.dims2()?;
|
||||
let (b, m, k2) = match layout.shape().dims() {
|
||||
&[b, m, k2] => (b, m, k2),
|
||||
&[m, k2] => (1, m, k2),
|
||||
s => crate::bail!("unexpected shape for input {s:?}"),
|
||||
};
|
||||
if k2 != k {
|
||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
|
||||
}
|
||||
|
||||
let data_f32 = self.dequantize(n * k)?;
|
||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0);
|
||||
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
|
||||
let mut out_shape = layout.shape().dims().to_vec();
|
||||
out_shape.pop();
|
||||
out_shape.push(n);
|
||||
Ok((out, out_shape.into()))
|
||||
}
|
||||
}
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &[u8], n: usize) -> Vec<T> {
|
||||
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
|
||||
slice.to_vec()
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
device: &CudaDevice,
|
||||
data: &[T],
|
||||
) -> Result<super::QStorage> {
|
||||
let data = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
|
||||
};
|
||||
let data = device.htod_sync_copy(data).w()?;
|
||||
Ok(QStorage::Cuda(QCudaStorage {
|
||||
data,
|
||||
device: device.clone(),
|
||||
dtype: T::DTYPE,
|
||||
}))
|
||||
}
|
@ -1,50 +0,0 @@
|
||||
#![allow(unused)]
|
||||
use super::GgmlDType;
|
||||
use crate::{CudaDevice, CudaStorage, Error, Result};
|
||||
|
||||
pub struct QCudaStorage {
|
||||
dtype: GgmlDType,
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
impl QCudaStorage {
|
||||
pub fn zeros(_: &CudaDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &CudaDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, _elem_count: usize) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
_self_shape: &crate::Shape,
|
||||
_storage: &CudaStorage,
|
||||
_layout: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
_device: &CudaDevice,
|
||||
_data: &[T],
|
||||
) -> Result<super::QStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
@ -1,50 +0,0 @@
|
||||
#![allow(unused)]
|
||||
use super::GgmlDType;
|
||||
use crate::{Error, MetalDevice, MetalStorage, Result};
|
||||
|
||||
pub struct QMetalStorage {
|
||||
dtype: GgmlDType,
|
||||
device: MetalDevice,
|
||||
}
|
||||
|
||||
impl QMetalStorage {
|
||||
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &MetalDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
_self_shape: &crate::Shape,
|
||||
_storage: &MetalStorage,
|
||||
_layout: &crate::Layout,
|
||||
) -> Result<(MetalStorage, crate::Shape)> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
_device: &MetalDevice,
|
||||
_data: &[T],
|
||||
) -> Result<super::QStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
@ -1,5 +1,7 @@
|
||||
//! Support for the GGML file format.
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use super::metal::load_quantized_metal;
|
||||
use super::{k_quants, GgmlDType, QStorage};
|
||||
use crate::{Device, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
@ -128,8 +130,13 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
let data: QStorage = match device {
|
||||
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
|
||||
Device::Metal(metal) => super::metal::load_quantized(metal, data)?,
|
||||
Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?,
|
||||
#[cfg(feature = "metal")]
|
||||
Device::Metal(metal) => load_quantized_metal(metal, data)?,
|
||||
#[cfg(not(feature = "metal"))]
|
||||
Device::Metal(_metal) => {
|
||||
crate::bail!("Metal backend requires `metal` feature")
|
||||
}
|
||||
device => unimplemented!("Implement quantized tensor for device {device:?}"),
|
||||
};
|
||||
super::QTensor::new(data, dims)
|
||||
}
|
||||
@ -226,7 +233,6 @@ pub struct Content {
|
||||
pub hparams: HParams,
|
||||
pub vocab: Vocab,
|
||||
pub tensors: HashMap<String, super::QTensor>,
|
||||
pub device: Device,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
@ -246,13 +252,11 @@ impl Content {
|
||||
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||
tensors.insert(name, tensor);
|
||||
}
|
||||
let device = device.clone();
|
||||
Ok(Self {
|
||||
magic,
|
||||
hparams,
|
||||
vocab,
|
||||
tensors,
|
||||
device,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
use super::{GgmlDType, QStorage};
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
|
||||
use crate::{DType, MetalDevice, MetalStorage, Result};
|
||||
use metal::Buffer;
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -11,31 +10,23 @@ pub struct QMetalStorage {
|
||||
}
|
||||
|
||||
impl QMetalStorage {
|
||||
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||
let size = elem_count * dtype.type_size() / dtype.block_size();
|
||||
let buffer = device.allocate_zeros(size)?;
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &MetalDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn buffer(&self) -> &Buffer {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: GgmlDType) -> Self {
|
||||
Self {
|
||||
device,
|
||||
buffer,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("to_cpu");
|
||||
@ -45,62 +36,81 @@ impl QMetalStorage {
|
||||
blit.end_encoding();
|
||||
self.device.wait_until_completed()?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
let block_len = elem_count / self.dtype.block_size();
|
||||
match self.dtype {
|
||||
GgmlDType::F32 => {
|
||||
let vec: Vec<f32> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
f32::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::F16 => {
|
||||
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
half::f16::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ2K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ3K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ4K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ5K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ6K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8K => {
|
||||
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ8K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
}
|
||||
@ -120,62 +130,9 @@ impl QMetalStorage {
|
||||
self.buffer = buffer;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.buffer.length() as usize
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
self_shape: &Shape,
|
||||
storage: &MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
use crate::MetalError;
|
||||
|
||||
if !layout.is_contiguous() {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
// self is transposed so n is first then k.
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let (n, k) = self_shape.dims2()?;
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
|
||||
let (b, m) = match dst_shape.len() {
|
||||
3 => (dst_shape[0], dst_shape[1]),
|
||||
2 => (1, dst_shape[0]),
|
||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||
};
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
let device = storage.device().clone();
|
||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
candle_metal_kernels::call_quantized_matmul_t(
|
||||
device.device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
self.dtype.into(),
|
||||
(b, m, n, k),
|
||||
storage.buffer(),
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
&self.buffer,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
||||
Ok((dst_storage, dst_shape))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
||||
device: &MetalDevice,
|
||||
data: &[T],
|
||||
) -> Result<QStorage> {
|
||||
@ -194,24 +151,3 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
||||
slice.to_vec()
|
||||
}
|
||||
|
||||
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
||||
fn from(value: GgmlDType) -> Self {
|
||||
match value {
|
||||
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
||||
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
||||
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
||||
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
||||
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
||||
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
||||
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
||||
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
||||
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
||||
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
||||
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
||||
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
||||
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
||||
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,27 +1,16 @@
|
||||
#[cfg(feature = "metal")]
|
||||
use crate::{backend::BackendStorage, DType};
|
||||
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
||||
use k_quants::*;
|
||||
use std::borrow::Cow;
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
mod dummy_cuda;
|
||||
mod dummy_metal;
|
||||
pub mod ggml_file;
|
||||
pub mod gguf_file;
|
||||
pub mod k_quants;
|
||||
#[cfg(feature = "metal")]
|
||||
pub mod metal;
|
||||
#[cfg(not(feature = "metal"))]
|
||||
mod metal {
|
||||
pub use super::dummy_metal::*;
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
pub mod cuda;
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
mod cuda {
|
||||
pub use super::dummy_cuda::*;
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(target_feature = "simd128")]
|
||||
@ -43,13 +32,22 @@ impl Device {
|
||||
let storage = dtype.cpu_zeros(elem_count);
|
||||
Ok(QStorage::Cpu(storage))
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
Device::Metal(metal) => {
|
||||
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
|
||||
Ok(QStorage::Metal(storage))
|
||||
let size = elem_count * dtype.type_size() / dtype.block_size();
|
||||
let buffer = metal.allocate_zeros(size)?;
|
||||
Ok(QStorage::Metal(metal::QMetalStorage::new(
|
||||
buffer,
|
||||
metal.clone(),
|
||||
dtype,
|
||||
)))
|
||||
}
|
||||
Device::Cuda(cuda) => {
|
||||
let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
|
||||
Ok(QStorage::Cuda(storage))
|
||||
#[cfg(not(feature = "metal"))]
|
||||
Device::Metal(_metal) => {
|
||||
crate::bail!("Metal feature not activated");
|
||||
}
|
||||
Device::Cuda(_cuda) => {
|
||||
crate::bail!("Cuda ggml quantization not supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -57,40 +55,32 @@ impl Device {
|
||||
|
||||
pub enum QStorage {
|
||||
Cpu(Box<dyn QuantizedType>),
|
||||
#[cfg(feature = "metal")]
|
||||
Metal(metal::QMetalStorage),
|
||||
Cuda(cuda::QCudaStorage),
|
||||
}
|
||||
|
||||
impl QStorage {
|
||||
fn block_size(&self) -> usize {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.block_size(),
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(storage) => storage.dtype().block_size(),
|
||||
QStorage::Cuda(storage) => storage.dtype().block_size(),
|
||||
}
|
||||
}
|
||||
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.dtype(),
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(storage) => storage.dtype(),
|
||||
QStorage::Cuda(storage) => storage.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
fn device(&self) -> Device {
|
||||
match self {
|
||||
QStorage::Cpu(_storage) => Device::Cpu,
|
||||
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||
QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
||||
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
|
||||
QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(storage) => storage.buffer().length() as usize,
|
||||
}
|
||||
}
|
||||
|
||||
@ -99,8 +89,8 @@ impl QStorage {
|
||||
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
||||
storage.from_float(src.as_slice::<f32>()?)?;
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
||||
(QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
|
||||
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
||||
}
|
||||
Ok(())
|
||||
@ -109,8 +99,8 @@ impl QStorage {
|
||||
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
||||
QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -122,7 +112,8 @@ impl QStorage {
|
||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||
Ok(Cow::from(data))
|
||||
}
|
||||
QStorage::Metal(_) | QStorage::Cuda(_) => {
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(_storage) => {
|
||||
crate::bail!("not implemented");
|
||||
}
|
||||
}
|
||||
@ -345,10 +336,6 @@ impl QTensor {
|
||||
self.storage.dtype()
|
||||
}
|
||||
|
||||
pub fn device(&self) -> Device {
|
||||
self.storage.device()
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
self.shape.rank()
|
||||
}
|
||||
@ -440,7 +427,8 @@ impl crate::CustomOp1 for QTensor {
|
||||
#[allow(clippy::infallible_destructuring_match)]
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Cpu(storage) => storage,
|
||||
QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
|
||||
#[cfg(feature = "metal")]
|
||||
_ => crate::bail!("Invalid storage"),
|
||||
};
|
||||
let slice = storage.as_slice::<f32>()?;
|
||||
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
@ -449,28 +437,79 @@ impl crate::CustomOp1 for QTensor {
|
||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &crate::MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::MetalStorage, Shape)> {
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Metal(metal) => metal,
|
||||
use crate::MetalError;
|
||||
|
||||
if !layout.is_contiguous() {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
// self is transposed so n is first then k.
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let (n, k) = self.shape.dims2()?;
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
|
||||
let (b, m) = match dst_shape.len() {
|
||||
3 => (dst_shape[0], dst_shape[1]),
|
||||
2 => (1, dst_shape[0]),
|
||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||
};
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
let device = storage.device().clone();
|
||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||
let (buffer, dtype) = match &self.storage {
|
||||
QStorage::Metal(metal) => (metal.buffer(), metal.dtype()),
|
||||
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
||||
};
|
||||
self_storage.fwd(&self.shape, storage, layout)
|
||||
let command_buffer = device.command_buffer()?;
|
||||
candle_metal_kernels::call_quantized_matmul_t(
|
||||
device.device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
dtype.into(),
|
||||
(b, m, n, k),
|
||||
storage.buffer(),
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
buffer,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
||||
Ok((dst_storage, dst_shape))
|
||||
}
|
||||
}
|
||||
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
storage: &crate::CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::CudaStorage, Shape)> {
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Cuda(cuda) => cuda,
|
||||
_ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
|
||||
};
|
||||
self_storage.fwd(&self.shape, storage, layout)
|
||||
#[cfg(feature = "metal")]
|
||||
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
||||
fn from(value: GgmlDType) -> Self {
|
||||
match value {
|
||||
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
||||
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
||||
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
||||
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
||||
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
||||
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
||||
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
||||
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
||||
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
||||
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
||||
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
||||
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
||||
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
||||
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -352,10 +352,6 @@ impl Storage {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
|
@ -508,7 +508,6 @@ impl Tensor {
|
||||
unary_op!(gelu_erf, GeluErf);
|
||||
unary_op!(erf, Erf);
|
||||
unary_op!(relu, Relu);
|
||||
unary_op!(silu, Silu);
|
||||
unary_op!(ceil, Ceil);
|
||||
unary_op!(floor, Floor);
|
||||
unary_op!(round, Round);
|
||||
@ -805,35 +804,6 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Roll the tensor input along the given dimension.
|
||||
/// Elements that are shifted beyond the last position are re-introduced at the first position.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.roll(1, 0)?;
|
||||
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.roll(-1, 0)?;
|
||||
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
|
||||
where
|
||||
D: Dim + Clone,
|
||||
{
|
||||
let dim = dim.to_index(self.shape(), "roll")?;
|
||||
let dim_size = self.dim(dim)?;
|
||||
let shift = shift.rem_euclid(dim_size as i32) as usize;
|
||||
if shift == 0 {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let a = self.narrow(dim, 0, dim_size - shift)?;
|
||||
let b = self.narrow(dim, dim_size - shift, shift)?;
|
||||
Tensor::cat(&[&b, &a], dim)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||
/// input dimensions.
|
||||
///
|
||||
@ -1015,7 +985,7 @@ impl Tensor {
|
||||
/// tensor also has three dimensions, `(batch, channels, target_size)`.
|
||||
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
|
||||
let (n, c, _l) = self.dims3()?;
|
||||
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
|
||||
let storage = self
|
||||
.storage()
|
||||
.upsample_nearest1d(self.layout(), target_size)?;
|
||||
@ -1883,9 +1853,9 @@ impl Tensor {
|
||||
/// this new node. The storage of this tensor is shared with the initial tensor.
|
||||
///
|
||||
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
||||
pub fn detach(&self) -> Tensor {
|
||||
pub fn detach(&self) -> Result<Tensor> {
|
||||
if self.op.is_none() && !self.is_variable {
|
||||
self.clone()
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
@ -1896,7 +1866,7 @@ impl Tensor {
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Tensor(Arc::new(tensor_))
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -107,10 +107,6 @@ impl Var {
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
pub fn as_detached_tensor(&self) -> Tensor {
|
||||
self.0.detach()
|
||||
}
|
||||
|
||||
pub fn as_tensor(&self) -> &Tensor {
|
||||
&self.0
|
||||
}
|
||||
|
@ -18,9 +18,6 @@ w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
res = torch.nn.functional.conv_transpose1d(t, w_t, groups=2)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
*/
|
||||
fn conv1d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
@ -53,26 +50,17 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 2)?;
|
||||
assert_eq!(res.dims(), [1, 4, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
|
||||
[
|
||||
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
|
||||
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
|
||||
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
|
||||
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
|
||||
]
|
||||
);
|
||||
if dev.is_cpu() {
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
Binary file not shown.
@ -270,51 +270,19 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||
);
|
||||
|
||||
// testing compared to pytorch nn.Silu()
|
||||
let y = x.silu()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.8577, 0.7311, 3.9281, 0.0806]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0881, 0.9277, 1.0527, 0.5747],
|
||||
);
|
||||
|
||||
if device.is_cpu() {
|
||||
let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?;
|
||||
let y = x.interpolate1d(12)?.reshape(36)?;
|
||||
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16.,
|
||||
17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.,
|
||||
33., 34., 35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
let grads = loss.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(grad_x, 4)?,
|
||||
[[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]]
|
||||
);
|
||||
}
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
|
||||
35., 36.,
|
||||
1_f32, 02., 03., 04., 05., 06.,
|
||||
07., 08., 09., 10., 11., 12.,
|
||||
13., 14., 15., 16., 17., 18.,
|
||||
19., 20., 21., 22., 23., 24.,
|
||||
25., 26., 27., 28., 29., 30.,
|
||||
31., 32., 33., 34., 35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
@ -345,11 +313,15 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
|
||||
35., 36.,
|
||||
1_f32, 02., 03., 04., 05., 06.,
|
||||
07., 08., 09., 10., 11., 12.,
|
||||
13., 14., 15., 16., 17., 18.,
|
||||
19., 20., 21., 22., 23., 24.,
|
||||
25., 26., 27., 28., 29., 30.,
|
||||
31., 32., 33., 34., 35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
|
@ -1,37 +0,0 @@
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
|
||||
# Write a trivial tensor to a pt file
|
||||
a= torch.tensor([[1,2,3,4], [5,6,7,8]])
|
||||
o = OrderedDict()
|
||||
o["test"] = a
|
||||
|
||||
# Write a trivial tensor to a pt file
|
||||
torch.save(o, "test.pt")
|
||||
|
||||
############################################################################################################
|
||||
# Write a trivial tensor to a pt file with a key
|
||||
torch.save({"model_state_dict": o}, "test_with_key.pt")
|
||||
|
||||
############################################################################################################
|
||||
# Create a tensor with fortran contiguous memory layout
|
||||
import numpy as np
|
||||
|
||||
# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers
|
||||
# For example, creating a 2x3x4 array
|
||||
array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4))
|
||||
|
||||
# Verify the memory order
|
||||
print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True
|
||||
print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False
|
||||
|
||||
# Step 2: Convert the NumPy array to a PyTorch tensor
|
||||
tensor_fortran = torch.from_numpy(array_fortran)
|
||||
|
||||
# Verify the tensor layout
|
||||
print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout
|
||||
|
||||
# Step 3: Save the PyTorch tensor to a .pth file
|
||||
torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth')
|
||||
|
||||
print("3D Tensor saved with Fortran layout.")
|
@ -1,31 +0,0 @@
|
||||
/// Regression test for pth files not loading on Windows.
|
||||
#[test]
|
||||
fn test_pth() {
|
||||
let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
|
||||
tensors.get("test").unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pth_with_key() {
|
||||
let tensors =
|
||||
candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
|
||||
.unwrap();
|
||||
tensors.get("test").unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pth_fortran_congiguous() {
|
||||
let tensors =
|
||||
candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap();
|
||||
let tensor = tensors.get("tensor_fortran").unwrap().unwrap();
|
||||
|
||||
assert_eq!(tensor.dims3().unwrap(), (2, 3, 4));
|
||||
|
||||
assert_eq!(
|
||||
tensor.to_vec3::<i64>().unwrap(),
|
||||
[
|
||||
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
||||
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
|
||||
]
|
||||
);
|
||||
}
|
@ -178,6 +178,10 @@ test_device!(
|
||||
);
|
||||
|
||||
fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
@ -205,6 +209,10 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
||||
@ -231,6 +239,10 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
||||
@ -257,6 +269,10 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q5_1(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
||||
@ -357,6 +373,10 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
|
||||
}
|
||||
|
||||
fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q2K;
|
||||
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
@ -391,6 +411,10 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q3K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -424,6 +448,10 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q4K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -457,6 +485,10 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q5K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -490,6 +522,10 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q6K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -523,6 +559,10 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q8k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q8K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -738,6 +778,10 @@ macro_rules! quantized_matmul {
|
||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
|
||||
fn $fn_name(device: &Device) -> Result<()> {
|
||||
if device.is_cuda() {
|
||||
// TODO Enable Cuda GGML sometime maybe.
|
||||
return Ok(());
|
||||
}
|
||||
test_matmul(device, (1, 3, 4, 256), $dtype)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -120,13 +120,6 @@ fn unary_op(device: &Device) -> Result<()> {
|
||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.silu()?, 4)?,
|
||||
[
|
||||
[-0.1423, 0.7311, 3.9281, -0.0475, 0.3112],
|
||||
[2.53, -0.2553, -0.1205, 1.5447, 2.6395]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||
|
Binary file not shown.
Binary file not shown.
@ -12,7 +12,7 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { workspace = true }
|
||||
candle-datasets = { workspace = true, optional = true }
|
||||
candle-datasets = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
@ -21,7 +21,7 @@ candle-onnx = { workspace = true, optional = true }
|
||||
csv = "1.3.0"
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
hf-hub = { workspace = true, features = ["tokio"] }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
@ -30,9 +30,7 @@ rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
cpal= { version = "0.15.2", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -45,6 +43,7 @@ rusttype = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.29.1"
|
||||
|
||||
@ -62,7 +61,6 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
onnx = ["candle-onnx"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
microphone = ["cpal"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
@ -79,27 +77,3 @@ required-features = ["onnx"]
|
||||
[[example]]
|
||||
name = "onnx_basics"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "whisper"
|
||||
required-features = ["symphonia"]
|
||||
|
||||
[[example]]
|
||||
name = "whisper-microphone"
|
||||
required-features = ["microphone"]
|
||||
|
||||
[[example]]
|
||||
name = "mnist-training"
|
||||
required-features = ["candle-datasets"]
|
||||
|
||||
[[example]]
|
||||
name = "llama2-c"
|
||||
required-features = ["candle-datasets"]
|
||||
|
||||
[[example]]
|
||||
name = "encodec"
|
||||
required-features = ["symphonia"]
|
||||
|
||||
[[example]]
|
||||
name = "metavoice"
|
||||
required-features = ["symphonia"]
|
||||
|
@ -1,237 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::chatglm::{Config, Model};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
verbose_prompt,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||
if tokens.is_empty() {
|
||||
anyhow::bail!("Empty prompts are not supported in the chatglm model.")
|
||||
}
|
||||
if self.verbose_prompt {
|
||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
println!("{id:7} -> '{token}'");
|
||||
}
|
||||
}
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("</s>") {
|
||||
Some(token) => *token,
|
||||
None => anyhow::bail!("cannot find the endoftext token"),
|
||||
};
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => "THUDM/chatglm3-6b".to_string(),
|
||||
};
|
||||
let revision = match args.revision {
|
||||
Some(rev) => rev.to_string(),
|
||||
None => "main".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => api
|
||||
.model("lmz/candle-chatglm".to_string())
|
||||
.get("chatglm-tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_file {
|
||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::glm3_6b();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
args.verbose_prompt,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,23 +0,0 @@
|
||||
# candle-convnext
|
||||
|
||||
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and
|
||||
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808).
|
||||
|
||||
This candle implementation uses a pre-trained ConvNeXt network for inference. The
|
||||
classification head has been trained on the ImageNet dataset and returns the
|
||||
probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example convnext --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which tiny
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
mountain bike, all-terrain bike, off-roader: 84.09%
|
||||
bicycle-built-for-two, tandem bicycle, tandem: 4.15%
|
||||
maillot : 0.74%
|
||||
crash helmet : 0.54%
|
||||
unicycle, monocycle : 0.44%
|
||||
|
||||
```
|
@ -1,126 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::convnext;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
Atto,
|
||||
Femto,
|
||||
Pico,
|
||||
Nano,
|
||||
Tiny,
|
||||
Small,
|
||||
Base,
|
||||
Large,
|
||||
AttoV2,
|
||||
FemtoV2,
|
||||
PicoV2,
|
||||
NanoV2,
|
||||
TinyV2,
|
||||
BaseV2,
|
||||
LargeV2,
|
||||
XLarge,
|
||||
Huge,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::Atto => "convnext_atto.d2_in1k",
|
||||
Self::Femto => "convnext_femto.d1_in1k",
|
||||
Self::Pico => "convnext_pico.d1_in1k",
|
||||
Self::Nano => "convnext_nano.d1h_in1k",
|
||||
Self::Tiny => "convnext_tiny.fb_in1k",
|
||||
Self::Small => "convnext_small.fb_in1k",
|
||||
Self::Base => "convnext_base.fb_in1k",
|
||||
Self::Large => "convnext_large.fb_in1k",
|
||||
Self::AttoV2 => "convnextv2_atto.fcmae_ft_in1k",
|
||||
Self::FemtoV2 => "convnextv2_femto.fcmae_ft_in1k",
|
||||
Self::PicoV2 => "convnextv2_pico.fcmae_ft_in1k",
|
||||
Self::NanoV2 => "convnextv2_nano.fcmae_ft_in1k",
|
||||
Self::TinyV2 => "convnextv2_tiny.fcmae_ft_in1k",
|
||||
Self::BaseV2 => "convnextv2_base.fcmae_ft_in1k",
|
||||
Self::LargeV2 => "convnextv2_large.fcmae_ft_in1k",
|
||||
Self::XLarge => "convnext_xlarge.fb_in22k_ft_in1k",
|
||||
Self::Huge => "convnextv2_huge.fcmae_ft_in1k",
|
||||
};
|
||||
|
||||
format!("timm/{name}")
|
||||
}
|
||||
|
||||
fn config(&self) -> convnext::Config {
|
||||
match self {
|
||||
Self::Atto | Self::AttoV2 => convnext::Config::atto(),
|
||||
Self::Femto | Self::FemtoV2 => convnext::Config::femto(),
|
||||
Self::Pico | Self::PicoV2 => convnext::Config::pico(),
|
||||
Self::Nano | Self::NanoV2 => convnext::Config::nano(),
|
||||
Self::Tiny | Self::TinyV2 => convnext::Config::tiny(),
|
||||
Self::Small => convnext::Config::small(),
|
||||
Self::Base | Self::BaseV2 => convnext::Config::base(),
|
||||
Self::Large | Self::LargeV2 => convnext::Config::large(),
|
||||
Self::XLarge => convnext::Config::xlarge(),
|
||||
Self::Huge => convnext::Config::huge(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::Tiny)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = convnext::convnext(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1 +0,0 @@
|
||||
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));
|
||||
|
@ -1,20 +0,0 @@
|
||||
# candle-efficientvit
|
||||
|
||||
[EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention](https://arxiv.org/abs/2305.07027).
|
||||
|
||||
This candle implementation uses a pre-trained EfficientViT (from Microsoft Research Asia) network for inference.
|
||||
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example efficientvit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which m1
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
mountain bike, all-terrain bike, off-roader: 69.80%
|
||||
unicycle, monocycle : 13.03%
|
||||
bicycle-built-for-two, tandem bicycle, tandem: 9.28%
|
||||
crash helmet : 2.25%
|
||||
alp : 0.46%
|
||||
```
|
@ -1,99 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::efficientvit;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
M0,
|
||||
M1,
|
||||
M2,
|
||||
M3,
|
||||
M4,
|
||||
M5,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::M0 => "m0",
|
||||
Self::M1 => "m1",
|
||||
Self::M2 => "m2",
|
||||
Self::M3 => "m3",
|
||||
Self::M4 => "m4",
|
||||
Self::M5 => "m5",
|
||||
};
|
||||
format!("timm/efficientvit_{}.r224_in1k", name)
|
||||
}
|
||||
|
||||
fn config(&self) -> efficientvit::Config {
|
||||
match self {
|
||||
Self::M0 => efficientvit::Config::m0(),
|
||||
Self::M1 => efficientvit::Config::m1(),
|
||||
Self::M2 => efficientvit::Config::m2(),
|
||||
Self::M3 => efficientvit::Config::m3(),
|
||||
Self::M4 => efficientvit::Config::m4(),
|
||||
Self::M5 => efficientvit::Config::m5(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::M0)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = efficientvit::efficientvit(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,20 +0,0 @@
|
||||
# candle-endocec
|
||||
|
||||
[EnCodec](https://huggingface.co/facebook/encodec_24khz) is a high-quality audio
|
||||
compression model using an encoder/decoder architecture with residual vector
|
||||
quantization.
|
||||
|
||||
## Running one example
|
||||
|
||||
```bash
|
||||
cargo run --example encodec --features symphonia --release -- code-to-audio \
|
||||
candle-examples/examples/encodec/jfk-codes.safetensors \
|
||||
jfk.wav
|
||||
```
|
||||
|
||||
This decodes the EnCodec tokens stored in `jfk-codes.safetensors` and generates
|
||||
an output wav file containing the audio data. Instead of `code-to-audio` one
|
||||
can use:
|
||||
- `audio-to-audio in.mp3 out.wav`: encodes the input audio file then decodes it to a wav file.
|
||||
- `audio-to-code in.mp3 out.safetensors`: generates a safetensors file
|
||||
containing EnCodec tokens for the input audio file.
|
Binary file not shown.
@ -1,143 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::encodec::{Config, Model};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::api::sync::Api;
|
||||
|
||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||
where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
use symphonia::core::audio::Signal;
|
||||
use symphonia::core::conv::FromSample;
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
|
||||
let src = std::fs::File::open(path)?;
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||
let mut format = probed.format;
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||
.expect("no supported audio tracks");
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &Default::default())
|
||||
.expect("unsupported codec");
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Action {
|
||||
AudioToAudio,
|
||||
AudioToCode,
|
||||
CodeToAudio,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The action to be performed, specifies the format for the input and output data.
|
||||
action: Action,
|
||||
|
||||
/// The input file, either an audio file or some encodec tokens stored as safetensors.
|
||||
in_file: String,
|
||||
|
||||
/// The output file, either a wave audio file or some encodec tokens stored as safetensors.
|
||||
out_file: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The model weight file, in safetensor format.
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => Api::new()?
|
||||
.model("facebook/encodec_24khz".to_string())
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let config = Config::default();
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
let codes = match args.action {
|
||||
Action::CodeToAudio => {
|
||||
let codes = candle::safetensors::load(args.in_file, &device)?;
|
||||
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
|
||||
codes
|
||||
}
|
||||
Action::AudioToCode | Action::AudioToAudio => {
|
||||
let (pcm, sample_rate) = pcm_decode(args.in_file)?;
|
||||
if sample_rate != 24_000 {
|
||||
println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}")
|
||||
}
|
||||
let pcm_len = pcm.len();
|
||||
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||
println!("input pcm shape: {:?}", pcm.shape());
|
||||
model.encode(&pcm)?
|
||||
}
|
||||
};
|
||||
println!("codes shape: {:?}", codes.shape());
|
||||
|
||||
match args.action {
|
||||
Action::AudioToCode => {
|
||||
codes.save_safetensors("codes", &args.out_file)?;
|
||||
}
|
||||
Action::AudioToAudio | Action::CodeToAudio => {
|
||||
let pcm = model.decode(&codes)?;
|
||||
println!("output pcm shape: {:?}", pcm.shape());
|
||||
let pcm = pcm.i(0)?.i(0)?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,27 +0,0 @@
|
||||
# candle-mistral: 2b and 7b LLMs from Google DeepMind
|
||||
|
||||
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
|
||||
models published by Google Deepmind with a 2b and a 7b variant.
|
||||
|
||||
In order to use the example below, you have to accept the license on the
|
||||
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
|
||||
your access token via the [HuggingFace cli login
|
||||
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example gemma --release -- --prompt "fn count_primes(max_n: usize)"
|
||||
fn count_primes(max_n: usize) -> usize {
|
||||
let mut primes = vec![true; max_n];
|
||||
for i in 2..=max_n {
|
||||
if primes[i] {
|
||||
for j in i * i..max_n {
|
||||
primes[j] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
primes.len()
|
||||
}
|
||||
```
|
||||
|
@ -1,256 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::gemma::{Config, Model};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match &args.model_id {
|
||||
Some(model_id) => match model_id.as_str() {
|
||||
"7b-it" => "google/gemma-7b-it".to_string(),
|
||||
"7b" => "google/gemma-7b".to_string(),
|
||||
"2b-it" => "google/gemma-2b-it".to_string(),
|
||||
"2b" => "google/gemma-2b".to_string(),
|
||||
_ => model_id.to_string(),
|
||||
},
|
||||
None => "google/gemma-2b".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -57,7 +57,7 @@ struct Args {
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, default_value_t = 10000)]
|
||||
#[arg(long, default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
/// Disable the key-value cache.
|
||||
@ -120,7 +120,7 @@ fn main() -> Result<()> {
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
let (llama, tokenizer_filename, mut cache) = {
|
||||
let (llama, tokenizer_filename, cache) = {
|
||||
let api = Api::new()?;
|
||||
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||
@ -143,10 +143,11 @@ fn main() -> Result<()> {
|
||||
}
|
||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||
};
|
||||
println!("building the model");
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
(Llama::load(vb, &config)?, tokenizer_filename, cache)
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||
@ -156,7 +157,6 @@ fn main() -> Result<()> {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
@ -172,7 +172,7 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = llama.forward(&input, context_index, &mut cache)?;
|
||||
let logits = llama.forward(&input, context_index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
@ -190,16 +190,18 @@ fn main() -> Result<()> {
|
||||
token_generated += 1;
|
||||
tokens.push(next_token);
|
||||
|
||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||
// heuristics as it seems to work well enough for this example. See the following for more
|
||||
// details:
|
||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
if Some(next_token) == eos_token_id {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
|
@ -19,7 +19,7 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use model::{Cache, Config, Llama};
|
||||
use model::{Config, Llama};
|
||||
use qmodel::QLlama;
|
||||
use weights::TransformerWeights;
|
||||
|
||||
@ -160,10 +160,10 @@ enum Model {
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result<Tensor> {
|
||||
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
|
||||
match self {
|
||||
Self::Llama(l) => Ok(l.forward(xs, pos, cache)?),
|
||||
Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?),
|
||||
Self::Llama(l) => Ok(l.forward(xs, pos)?),
|
||||
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -188,8 +188,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, config)?;
|
||||
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
let tokens = match &args.pretokenized_dir {
|
||||
None => {
|
||||
@ -235,7 +235,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
for inp_tgt in batch_iter {
|
||||
let (inp, tgt) = inp_tgt?;
|
||||
let logits = model.forward(&inp, 0, &mut cache)?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
println!("{}", loss.to_vec0::<f32>()?);
|
||||
}
|
||||
@ -261,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let is_safetensors = config_path
|
||||
.extension()
|
||||
.map_or(false, |v| v == "safetensors");
|
||||
let (model, config, mut cache) = if is_gguf {
|
||||
let (model, config) = if is_gguf {
|
||||
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
|
||||
let (_vocab_size, dim) = vb
|
||||
.get_no_shape("model.embed_tokens.weight")?
|
||||
@ -298,15 +298,15 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
&device,
|
||||
);
|
||||
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||
let model = Model::QLlama(QLlama::load(vb, config.clone())?);
|
||||
(model, config, cache)
|
||||
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
} else if is_safetensors {
|
||||
let config = Config::tiny_15m();
|
||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, config.clone())?);
|
||||
(model, config, cache)
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
} else {
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
@ -314,8 +314,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, config.clone())?);
|
||||
(model, config, cache)
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
};
|
||||
|
||||
println!("starting the inference loop");
|
||||
@ -328,7 +328,6 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0.. {
|
||||
@ -338,7 +337,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, index_pos, &mut cache)?;
|
||||
let logits = model.forward(&input, index_pos)?;
|
||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
||||
logits
|
||||
@ -354,14 +353,16 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||
// heuristics as it seems to work well enough for this example. See the following for more
|
||||
// details:
|
||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{} tokens generated ({:.2} token/s)\n",
|
||||
|
@ -8,7 +8,6 @@ fn valid_loss(
|
||||
model: &Llama,
|
||||
args: &crate::TrainingCmd,
|
||||
device: &Device,
|
||||
cache: &mut Cache,
|
||||
) -> Result<f64> {
|
||||
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
@ -16,7 +15,7 @@ fn valid_loss(
|
||||
let mut cnt = 0usize;
|
||||
for inp_tgt in batch_iter.take(50) {
|
||||
let (inp, tgt) = inp_tgt?;
|
||||
let logits = model.forward(&inp, 0, cache)?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
sum_ce += loss.to_vec0::<f32>()? as f64;
|
||||
cnt += 1;
|
||||
@ -38,8 +37,8 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
|
||||
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, config)?;
|
||||
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
let params = candle_nn::ParamsAdamW {
|
||||
lr: args.learning_rate,
|
||||
..Default::default()
|
||||
@ -47,14 +46,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
|
||||
for (batch_index, batch) in batch_iter.enumerate() {
|
||||
let (inp, tgt) = batch?;
|
||||
let logits = model.forward(&inp, 0, &mut cache)?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
opt.backward_step(&loss)?;
|
||||
|
||||
if batch_index > 0 && batch_index % 100 == 0 {
|
||||
// TODO: Add a way to deactivate the backprop graph tracking when computing the
|
||||
// validation loss.
|
||||
let loss = valid_loss(&dataset, &model, args, &device, &mut cache)?;
|
||||
let loss = valid_loss(&dataset, &model, args, &device)?;
|
||||
println!("{batch_index} {loss}");
|
||||
}
|
||||
if batch_index > 0 && batch_index % 1000 == 0 {
|
||||
|
@ -2,9 +2,6 @@
|
||||
|
||||
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
|
||||
|
||||
Compared to the mamba example, this version can handle training but is much
|
||||
slower.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
|
@ -1,17 +0,0 @@
|
||||
# candle-mamba: Mamba implementation
|
||||
|
||||
Candle implementation of *Mamba* [1] inference only. Mamba is an alternative to
|
||||
the transformer architecture. It leverages State Space Models (SSMs) with the
|
||||
goal of being computationally efficient on long sequences. The implementation is
|
||||
based on [mamba.rs](https://github.com/LaurentMazare/mamba.rs).
|
||||
|
||||
- [1]. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752).
|
||||
|
||||
Compared to the mamba-minimal example, this version is far more efficient but
|
||||
would only work for inference.
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
|
||||
```
|
||||
|
@ -1,299 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::mamba::{Config, Model, State};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
config: Config,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
config: Config,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
config,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let mut state = State::new(1, &self.config, &self.device)?;
|
||||
let mut next_logits = None;
|
||||
for &t in tokens.iter() {
|
||||
let input = Tensor::new(&[t], &self.device)?;
|
||||
let logits = self.model.forward(&input, &mut state)?;
|
||||
next_logits = Some(logits);
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for _ in 0..sample_len {
|
||||
let logits = match next_logits.as_ref() {
|
||||
Some(logits) => logits,
|
||||
None => anyhow::bail!("cannot work on an empty prompt"),
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
let input = Tensor::new(&[next_token], &self.device)?;
|
||||
next_logits = Some(self.model.forward(&input, &mut state)?)
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
|
||||
enum Which {
|
||||
Mamba130m,
|
||||
Mamba370m,
|
||||
Mamba790m,
|
||||
Mamba1_4b,
|
||||
Mamba2_8b,
|
||||
Mamba2_8bSlimPj,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Which {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_id(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Mamba130m => "state-spaces/mamba-130m",
|
||||
Self::Mamba370m => "state-spaces/mamba-370m",
|
||||
Self::Mamba790m => "state-spaces/mamba-790m",
|
||||
Self::Mamba1_4b => "state-spaces/mamba-1.4b",
|
||||
Self::Mamba2_8b => "state-spaces/mamba-2.8b",
|
||||
Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
|
||||
}
|
||||
}
|
||||
|
||||
fn revision(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Mamba130m
|
||||
| Self::Mamba370m
|
||||
| Self::Mamba790m
|
||||
| Self::Mamba1_4b
|
||||
| Self::Mamba2_8bSlimPj => "refs/pr/1",
|
||||
Self::Mamba2_8b => "refs/pr/4",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "mamba130m")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id
|
||||
.unwrap_or_else(|| args.which.model_id().to_string()),
|
||||
RepoType::Model,
|
||||
args.revision
|
||||
.unwrap_or_else(|| args.which.revision().to_string()),
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => api
|
||||
.model("EleutherAI/gpt-neox-20b".to_string())
|
||||
.get("tokenizer.json")?,
|
||||
};
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = Model::new(&config, vb.pp("backbone"))?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
config,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,18 +0,0 @@
|
||||
# candle-metavoice
|
||||
|
||||
MetaVoice-1B is a text-to-speech model trained on 100K hours of speech, more
|
||||
details on the [model
|
||||
card](https://huggingface.co/metavoiceio/metavoice-1B-v0.1).
|
||||
|
||||
Note that the current candle implementation suffers from some limitations as of
|
||||
2024-03-02:
|
||||
- The speaker embeddings are hardcoded.
|
||||
- The generated audio file quality is weaker than the Python implementation,
|
||||
probably because of some implementation discrepancies.
|
||||
|
||||
## Run an example
|
||||
|
||||
```bash
|
||||
cargo run --example metavoice --release -- \\
|
||||
--prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
|
||||
```
|
@ -1,342 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use std::io::Write;
|
||||
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::models::encodec;
|
||||
use candle_transformers::models::metavoice::{
|
||||
adapters, gpt, speaker_encoder, tokenizers, transformer,
|
||||
};
|
||||
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use hf_hub::api::sync::Api;
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
|
||||
pub const ENCODEC_NTOKENS: u32 = 1024;
|
||||
|
||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||
where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
use symphonia::core::audio::Signal;
|
||||
use symphonia::core::conv::FromSample;
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
|
||||
let src = std::fs::File::open(path)?;
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||
let mut format = probed.format;
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||
.expect("no supported audio tracks");
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &Default::default())
|
||||
.expect("unsupported codec");
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum ArgDType {
|
||||
F32,
|
||||
F16,
|
||||
Bf16,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The guidance scale.
|
||||
#[arg(long, default_value_t = 3.0)]
|
||||
guidance_scale: f64,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
temperature: f64,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The maximum number of tokens to generate for the first stage.
|
||||
#[arg(long, default_value_t = 2000)]
|
||||
max_tokens: u64,
|
||||
|
||||
/// The output file using the wav format.
|
||||
#[arg(long, default_value = "out.wav")]
|
||||
out_file: String,
|
||||
|
||||
#[arg(long)]
|
||||
first_stage_meta: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
first_stage_weights: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
second_stage_weights: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
speaker_encoder_weights: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
encodec_weights: Option<String>,
|
||||
|
||||
/// The speaker embeddings, either an audio files in which case they are extracted, or a
|
||||
/// safetensors file with the embeddings already extracted.
|
||||
#[arg(long)]
|
||||
spk_emb: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "f32")]
|
||||
dtype: ArgDType,
|
||||
}
|
||||
|
||||
fn mel_filters() -> Result<Vec<f32>> {
|
||||
let mel_bytes = include_bytes!("melfilters40.bytes").as_slice();
|
||||
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
||||
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
||||
Ok(mel_filters)
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let api = Api::new()?;
|
||||
let repo = api.model("lmz/candle-metavoice".to_string());
|
||||
let first_stage_meta = match &args.first_stage_meta {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("first_stage.meta.json")?,
|
||||
};
|
||||
let first_stage_meta: serde_json::Value =
|
||||
serde_json::from_reader(&std::fs::File::open(first_stage_meta)?)?;
|
||||
let first_stage_tokenizer = match first_stage_meta.as_object() {
|
||||
None => anyhow::bail!("not a json object"),
|
||||
Some(j) => match j.get("tokenizer") {
|
||||
None => anyhow::bail!("no tokenizer key"),
|
||||
Some(j) => j,
|
||||
},
|
||||
};
|
||||
let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?;
|
||||
|
||||
let first_stage_weights = match &args.first_stage_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("first_stage.safetensors")?,
|
||||
};
|
||||
let second_stage_weights = match &args.second_stage_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("second_stage.safetensors")?,
|
||||
};
|
||||
let encodec_weights = match args.encodec_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => Api::new()?
|
||||
.model("facebook/encodec_24khz".to_string())
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let dtype = match args.dtype {
|
||||
ArgDType::F32 => DType::F32,
|
||||
ArgDType::F16 => DType::F16,
|
||||
ArgDType::Bf16 => DType::BF16,
|
||||
};
|
||||
let first_stage_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
|
||||
let first_stage_config = transformer::Config::cfg1b_v0_1();
|
||||
let mut first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
|
||||
|
||||
let second_stage_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? };
|
||||
let second_stage_config = gpt::Config::cfg1b_v0_1();
|
||||
let second_stage_model = gpt::Model::new(second_stage_config.clone(), second_stage_vb)?;
|
||||
|
||||
let encodec_device = if device.is_metal() {
|
||||
&candle::Device::Cpu
|
||||
} else {
|
||||
&device
|
||||
};
|
||||
let encodec_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[encodec_weights], dtype, encodec_device)? };
|
||||
let encodec_config = encodec::Config::default();
|
||||
let encodec_model = encodec::Model::new(&encodec_config, encodec_vb)?;
|
||||
|
||||
println!("prompt: '{}'", args.prompt);
|
||||
let prompt_tokens = fs_tokenizer.encode(&args.prompt)?;
|
||||
let mut tokens = prompt_tokens.clone();
|
||||
println!("{tokens:?}");
|
||||
let safetensors_embeddings = args
|
||||
.spk_emb
|
||||
.as_ref()
|
||||
.map_or(true, |v| v.ends_with("safetensors"));
|
||||
let spk_emb = if safetensors_embeddings {
|
||||
let spk_emb_file = match &args.spk_emb {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("spk_emb.safetensors")?,
|
||||
};
|
||||
let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?;
|
||||
match spk_emb.get("spk_emb") {
|
||||
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
|
||||
Some(spk_emb) => spk_emb.to_dtype(dtype)?.to_device(&device)?,
|
||||
}
|
||||
} else {
|
||||
let weights = match &args.speaker_encoder_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("speaker_encoder.safetensors")?,
|
||||
};
|
||||
let mel_filters = mel_filters()?;
|
||||
let config = speaker_encoder::Config::cfg();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights], dtype, &device)? };
|
||||
let model = speaker_encoder::Model::new(config, vb)?;
|
||||
let (pcm, sample_rate) = pcm_decode(&args.spk_emb.unwrap())?;
|
||||
if sample_rate != 16_000 {
|
||||
eprintln!("WARNING: speaker embedding input should use a 16kHz sample rate!")
|
||||
}
|
||||
model.embed_utterance(
|
||||
&pcm,
|
||||
&mel_filters,
|
||||
/* rate */ 1.3,
|
||||
/* min_c */ 0.75,
|
||||
&device,
|
||||
)?
|
||||
};
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95));
|
||||
|
||||
// First stage generation.
|
||||
for index in 0..args.max_tokens {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &device)?;
|
||||
let input = Tensor::stack(&[&input, &input], 0)?;
|
||||
let logits = first_stage_model.forward(&input, &spk_emb, tokens.len() - context_size)?;
|
||||
let logits0 = logits.i((0, 0))?;
|
||||
let logits1 = logits.i((1, 0))?;
|
||||
let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?;
|
||||
let logits = logits.to_dtype(DType::F32)?;
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
print!(".");
|
||||
std::io::stdout().flush()?;
|
||||
if next_token == 2048 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
println!();
|
||||
let fie2c = adapters::FlattenedInterleavedEncodec2Codebook::new(ENCODEC_NTOKENS);
|
||||
let (text_ids, ids1, ids2) = fie2c.decode(&tokens);
|
||||
println!("text ids len: {}", text_ids.len());
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed + 1337);
|
||||
// TODO: Use the config rather than hardcoding the offset here.
|
||||
let encoded_text: Vec<_> = prompt_tokens.iter().map(|v| v - 1024).collect();
|
||||
let mut hierarchies_in1 =
|
||||
[encoded_text.as_slice(), ids1.as_slice(), &[ENCODEC_NTOKENS]].concat();
|
||||
let mut hierarchies_in2 = [
|
||||
vec![ENCODEC_NTOKENS; encoded_text.len()].as_slice(),
|
||||
ids2.as_slice(),
|
||||
&[ENCODEC_NTOKENS],
|
||||
]
|
||||
.concat();
|
||||
hierarchies_in1.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
|
||||
hierarchies_in2.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
|
||||
let in_x1 = Tensor::new(hierarchies_in1, &device)?;
|
||||
let in_x2 = Tensor::new(hierarchies_in2, &device)?;
|
||||
let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?;
|
||||
let logits = second_stage_model.forward(&in_x)?;
|
||||
println!("sampling from logits...");
|
||||
let mut codes = vec![];
|
||||
for logits in logits.iter() {
|
||||
let logits = logits.squeeze(0)?;
|
||||
let (seq_len, _) = logits.dims2()?;
|
||||
let mut codes_ = Vec::with_capacity(seq_len);
|
||||
for step in 0..seq_len {
|
||||
let logits = logits.i(step)?.to_dtype(DType::F32)?;
|
||||
let logits = &(&logits / 1.0)?;
|
||||
let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::<f32>()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(prs.as_slice())?;
|
||||
let sample = distr.sample(&mut rng) as u32;
|
||||
codes_.push(sample)
|
||||
}
|
||||
codes.push(codes_)
|
||||
}
|
||||
|
||||
let codes = Tensor::new(codes, &device)?.unsqueeze(0)?;
|
||||
let codes = Tensor::cat(&[in_x, codes], 1)?;
|
||||
println!("codes: {codes}");
|
||||
let tilted_encodec = adapters::TiltedEncodec::new(ENCODEC_NTOKENS);
|
||||
let codes = codes.i(0)?.to_vec2::<u32>()?;
|
||||
let (text_ids, audio_ids) = tilted_encodec.decode(&codes);
|
||||
println!("text_ids len: {:?}", text_ids.len());
|
||||
let audio_ids = Tensor::new(audio_ids, encodec_device)?.unsqueeze(0)?;
|
||||
println!("audio_ids shape: {:?}", audio_ids.shape());
|
||||
let pcm = encodec_model.decode(&audio_ids)?;
|
||||
println!("output pcm shape: {:?}", pcm.shape());
|
||||
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||
Ok(())
|
||||
}
|
Binary file not shown.
@ -152,7 +152,7 @@ struct Args {
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
|
@ -143,7 +143,7 @@ struct Args {
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]
|
||||
|
580
candle-examples/examples/musicgen/encodec_model.rs
Normal file
580
candle-examples/examples/musicgen/encodec_model.rs
Normal file
@ -0,0 +1,580 @@
|
||||
use crate::nn::conv1d_weight_norm;
|
||||
use candle::{DType, IndexOp, Module, Result, Tensor};
|
||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
||||
|
||||
// Encodec Model
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum NormType {
|
||||
WeightNorm,
|
||||
TimeGroupNorm,
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Config {
|
||||
target_bandwidths: Vec<f64>,
|
||||
sampling_rate: usize,
|
||||
audio_channels: usize,
|
||||
normalize: bool,
|
||||
chunk_length_s: Option<usize>,
|
||||
overlap: Option<usize>,
|
||||
hidden_size: usize,
|
||||
num_filters: usize,
|
||||
num_residual_layers: usize,
|
||||
upsampling_ratios: Vec<usize>,
|
||||
norm_type: NormType,
|
||||
kernel_size: usize,
|
||||
last_kernel_size: usize,
|
||||
residual_kernel_size: usize,
|
||||
dilation_growth_rate: usize,
|
||||
use_causal_conv: bool,
|
||||
pad_mode: &'static str,
|
||||
compress: usize,
|
||||
num_lstm_layers: usize,
|
||||
trim_right_ratio: f64,
|
||||
codebook_size: usize,
|
||||
codebook_dim: Option<usize>,
|
||||
use_conv_shortcut: bool,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
|
||||
sampling_rate: 24_000,
|
||||
audio_channels: 1,
|
||||
normalize: false,
|
||||
chunk_length_s: None,
|
||||
overlap: None,
|
||||
hidden_size: 128,
|
||||
num_filters: 32,
|
||||
num_residual_layers: 1,
|
||||
upsampling_ratios: vec![8, 5, 4, 2],
|
||||
norm_type: NormType::WeightNorm,
|
||||
kernel_size: 7,
|
||||
last_kernel_size: 7,
|
||||
residual_kernel_size: 3,
|
||||
dilation_growth_rate: 2,
|
||||
use_causal_conv: true,
|
||||
pad_mode: "reflect",
|
||||
compress: 2,
|
||||
num_lstm_layers: 2,
|
||||
trim_right_ratio: 1.0,
|
||||
codebook_size: 1024,
|
||||
codebook_dim: None,
|
||||
use_conv_shortcut: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
|
||||
pub fn musicgen_small() -> Self {
|
||||
Self {
|
||||
audio_channels: 1,
|
||||
chunk_length_s: None,
|
||||
codebook_dim: Some(128),
|
||||
codebook_size: 2048,
|
||||
compress: 2,
|
||||
dilation_growth_rate: 2,
|
||||
hidden_size: 128,
|
||||
kernel_size: 7,
|
||||
last_kernel_size: 7,
|
||||
norm_type: NormType::WeightNorm,
|
||||
normalize: false,
|
||||
num_filters: 64,
|
||||
num_lstm_layers: 2,
|
||||
num_residual_layers: 1,
|
||||
overlap: None,
|
||||
pad_mode: "reflect",
|
||||
residual_kernel_size: 3,
|
||||
sampling_rate: 32_000,
|
||||
target_bandwidths: vec![2.2],
|
||||
trim_right_ratio: 1.0,
|
||||
upsampling_ratios: vec![8, 5, 4, 4],
|
||||
use_causal_conv: false,
|
||||
use_conv_shortcut: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn codebook_dim(&self) -> usize {
|
||||
self.codebook_dim.unwrap_or(self.codebook_size)
|
||||
}
|
||||
|
||||
fn frame_rate(&self) -> usize {
|
||||
let hop_length: usize = self.upsampling_ratios.iter().product();
|
||||
(self.sampling_rate + hop_length - 1) / hop_length
|
||||
}
|
||||
|
||||
fn num_quantizers(&self) -> usize {
|
||||
let num = 1000f64
|
||||
* self
|
||||
.target_bandwidths
|
||||
.last()
|
||||
.expect("empty target_bandwidths");
|
||||
(num as usize) / (self.frame_rate() * 10)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340
|
||||
#[derive(Debug)]
|
||||
struct EncodecEuclideanCodebook {
|
||||
inited: Tensor,
|
||||
cluster_size: Tensor,
|
||||
embed: Tensor,
|
||||
embed_avg: Tensor,
|
||||
}
|
||||
|
||||
impl EncodecEuclideanCodebook {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let inited = vb.get(1, "inited")?;
|
||||
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
|
||||
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
|
||||
let embed = vb.get(e_shape, "embed")?;
|
||||
let embed_avg = vb.get(e_shape, "embed_avg")?;
|
||||
Ok(Self {
|
||||
inited,
|
||||
cluster_size,
|
||||
embed,
|
||||
embed_avg,
|
||||
})
|
||||
}
|
||||
|
||||
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||
let quantize = self.embed.embedding(embed_ind)?;
|
||||
Ok(quantize)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecVectorQuantization {
|
||||
codebook: EncodecEuclideanCodebook,
|
||||
}
|
||||
|
||||
impl EncodecVectorQuantization {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?;
|
||||
Ok(Self { codebook })
|
||||
}
|
||||
|
||||
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||
let quantize = self.codebook.decode(embed_ind)?;
|
||||
let quantize = quantize.transpose(1, 2)?;
|
||||
Ok(quantize)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecResidualVectorQuantizer {
|
||||
layers: Vec<EncodecVectorQuantization>,
|
||||
}
|
||||
|
||||
impl EncodecResidualVectorQuantizer {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = &vb.pp("layers");
|
||||
let layers = (0..cfg.num_quantizers())
|
||||
.map(|i| EncodecVectorQuantization::load(vb.pp(&i.to_string()), cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self { layers })
|
||||
}
|
||||
|
||||
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
|
||||
if codes.dim(0)? != self.layers.len() {
|
||||
candle::bail!(
|
||||
"codes shape {:?} does not match the number of quantization layers {}",
|
||||
codes.shape(),
|
||||
self.layers.len()
|
||||
)
|
||||
}
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let quantized = layer.decode(&codes.i(i)?)?;
|
||||
quantized_out = quantized.broadcast_add(&quantized_out)?;
|
||||
}
|
||||
Ok(quantized_out)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
||||
#[derive(Debug)]
|
||||
struct EncodecLSTM {
|
||||
layers: Vec<candle_nn::LSTM>,
|
||||
}
|
||||
|
||||
impl EncodecLSTM {
|
||||
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = &vb.pp("lstm");
|
||||
let mut layers = vec![];
|
||||
for layer_idx in 0..cfg.num_lstm_layers {
|
||||
let config = candle_nn::LSTMConfig {
|
||||
layer_idx,
|
||||
..Default::default()
|
||||
};
|
||||
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
|
||||
layers.push(lstm)
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecLSTM {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
use candle_nn::RNN;
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
let states = layer.seq(&xs)?;
|
||||
xs = layer.states_to_tensor(&states)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecConvTranspose1d {
|
||||
weight_g: Tensor,
|
||||
weight_v: Tensor,
|
||||
bias: Tensor,
|
||||
}
|
||||
|
||||
impl EncodecConvTranspose1d {
|
||||
fn load(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
k: usize,
|
||||
_stride: usize,
|
||||
vb: VarBuilder,
|
||||
_cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let vb = &vb.pp("conv");
|
||||
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((in_c, out_c, k), "weight_v")?;
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Self {
|
||||
weight_g,
|
||||
weight_v,
|
||||
bias,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConvTranspose1d {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecConv1d {
|
||||
causal: bool,
|
||||
conv: Conv1d,
|
||||
norm: Option<candle_nn::GroupNorm>,
|
||||
}
|
||||
|
||||
impl EncodecConv1d {
|
||||
fn load(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
vb: VarBuilder,
|
||||
cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let conv = match cfg.norm_type {
|
||||
NormType::WeightNorm => conv1d_weight_norm(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
Conv1dConfig {
|
||||
padding: 0,
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
NormType::None | NormType::TimeGroupNorm => conv1d(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
Conv1dConfig {
|
||||
padding: 0,
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
};
|
||||
let norm = match cfg.norm_type {
|
||||
NormType::None | NormType::WeightNorm => None,
|
||||
NormType::TimeGroupNorm => {
|
||||
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||
Some(gn)
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
causal: cfg.use_causal_conv,
|
||||
conv,
|
||||
norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: padding, depending on causal.
|
||||
let xs = self.conv.forward(xs)?;
|
||||
match &self.norm {
|
||||
None => Ok(xs),
|
||||
Some(norm) => xs.apply(norm),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecResnetBlock {
|
||||
block_conv1: EncodecConv1d,
|
||||
block_conv2: EncodecConv1d,
|
||||
shortcut: Option<EncodecConv1d>,
|
||||
}
|
||||
|
||||
impl EncodecResnetBlock {
|
||||
fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = dim / cfg.compress;
|
||||
let mut layer = Layer::new(vb.pp("block"));
|
||||
if dilations.len() != 2 {
|
||||
candle::bail!("expected dilations of size 2")
|
||||
}
|
||||
// TODO: Apply dilations!
|
||||
layer.inc();
|
||||
let block_conv1 =
|
||||
EncodecConv1d::load(dim, h, cfg.residual_kernel_size, 1, layer.next(), cfg)?;
|
||||
layer.inc();
|
||||
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, layer.next(), cfg)?;
|
||||
let shortcut = if cfg.use_conv_shortcut {
|
||||
let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?;
|
||||
Some(conv)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
block_conv1,
|
||||
block_conv2,
|
||||
shortcut,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecResnetBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs.clone();
|
||||
let xs = xs.elu(1.)?;
|
||||
let xs = self.block_conv1.forward(&xs)?;
|
||||
let xs = xs.elu(1.)?;
|
||||
let xs = self.block_conv2.forward(&xs)?;
|
||||
let xs = match &self.shortcut {
|
||||
None => (xs + residual)?,
|
||||
Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
|
||||
};
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
struct Layer<'a> {
|
||||
vb: VarBuilder<'a>,
|
||||
cnt: usize,
|
||||
}
|
||||
|
||||
impl<'a> Layer<'a> {
|
||||
fn new(vb: VarBuilder<'a>) -> Self {
|
||||
Self { vb, cnt: 0 }
|
||||
}
|
||||
|
||||
fn inc(&mut self) {
|
||||
self.cnt += 1;
|
||||
}
|
||||
|
||||
fn next(&mut self) -> VarBuilder {
|
||||
let vb = self.vb.pp(&self.cnt.to_string());
|
||||
self.cnt += 1;
|
||||
vb
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecEncoder {
|
||||
init_conv: EncodecConv1d,
|
||||
sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
|
||||
final_lstm: EncodecLSTM,
|
||||
final_conv: EncodecConv1d,
|
||||
}
|
||||
|
||||
impl EncodecEncoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mut layer = Layer::new(vb.pp("layers"));
|
||||
let init_conv = EncodecConv1d::load(
|
||||
cfg.audio_channels,
|
||||
cfg.num_filters,
|
||||
cfg.kernel_size,
|
||||
1,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
let mut sampling_layers = vec![];
|
||||
let mut scaling = 1;
|
||||
for &ratio in cfg.upsampling_ratios.iter().rev() {
|
||||
let current_scale = scaling * cfg.num_filters;
|
||||
let mut resnets = vec![];
|
||||
for j in 0..(cfg.num_residual_layers as u32) {
|
||||
let resnet = EncodecResnetBlock::load(
|
||||
current_scale,
|
||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
resnets.push(resnet)
|
||||
}
|
||||
layer.inc(); // ELU
|
||||
let conv1d = EncodecConv1d::load(
|
||||
current_scale,
|
||||
current_scale * 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
sampling_layers.push((resnets, conv1d));
|
||||
scaling *= 2;
|
||||
}
|
||||
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
|
||||
layer.inc(); // ELU
|
||||
let final_conv = EncodecConv1d::load(
|
||||
cfg.num_filters * scaling,
|
||||
cfg.hidden_size,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
Ok(Self {
|
||||
init_conv,
|
||||
sampling_layers,
|
||||
final_conv,
|
||||
final_lstm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?;
|
||||
for (resnets, conv) in self.sampling_layers.iter() {
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?;
|
||||
}
|
||||
xs = xs.elu(1.0)?.apply(conv)?;
|
||||
}
|
||||
xs.apply(&self.final_lstm)?
|
||||
.elu(1.0)?
|
||||
.apply(&self.final_conv)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecDecoder {
|
||||
init_conv: EncodecConv1d,
|
||||
init_lstm: EncodecLSTM,
|
||||
sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
|
||||
final_conv: EncodecConv1d,
|
||||
}
|
||||
|
||||
impl EncodecDecoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mut layer = Layer::new(vb.pp("layers"));
|
||||
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
|
||||
let init_conv = EncodecConv1d::load(
|
||||
cfg.hidden_size,
|
||||
cfg.num_filters * scaling,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
|
||||
let mut sampling_layers = vec![];
|
||||
for &ratio in cfg.upsampling_ratios.iter() {
|
||||
let current_scale = scaling * cfg.num_filters;
|
||||
layer.inc(); // ELU
|
||||
let conv1d = EncodecConvTranspose1d::load(
|
||||
current_scale,
|
||||
current_scale / 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
let mut resnets = vec![];
|
||||
for j in 0..(cfg.num_residual_layers as u32) {
|
||||
let resnet = EncodecResnetBlock::load(
|
||||
current_scale / 2,
|
||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
resnets.push(resnet)
|
||||
}
|
||||
sampling_layers.push((conv1d, resnets));
|
||||
scaling /= 2;
|
||||
}
|
||||
layer.inc(); // ELU
|
||||
let final_conv = EncodecConv1d::load(
|
||||
cfg.num_filters,
|
||||
cfg.audio_channels,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
Ok(Self {
|
||||
init_conv,
|
||||
init_lstm,
|
||||
sampling_layers,
|
||||
final_conv,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
|
||||
for (conv, resnets) in self.sampling_layers.iter() {
|
||||
xs = xs.elu(1.)?.apply(conv)?;
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?
|
||||
}
|
||||
}
|
||||
xs.elu(1.)?.apply(&self.final_conv)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EncodecModel {
|
||||
encoder: EncodecEncoder,
|
||||
decoder: EncodecDecoder,
|
||||
quantizer: EncodecResidualVectorQuantizer,
|
||||
}
|
||||
|
||||
impl EncodecModel {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let encoder = EncodecEncoder::load(vb.pp("encoder"), cfg)?;
|
||||
let decoder = EncodecDecoder::load(vb.pp("decoder"), cfg)?;
|
||||
let quantizer = EncodecResidualVectorQuantizer::load(vb.pp("quantizer"), cfg)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
quantizer,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
@ -10,7 +10,9 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
mod encodec_model;
|
||||
mod musicgen_model;
|
||||
mod nn;
|
||||
|
||||
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
use crate::encodec_model;
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
|
||||
VarBuilder,
|
||||
};
|
||||
use candle_transformers::models::{encodec, t5};
|
||||
use candle_transformers::models::t5;
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@ -371,7 +372,7 @@ impl MusicgenForCausalLM {
|
||||
#[derive(Debug)]
|
||||
pub struct MusicgenForConditionalGeneration {
|
||||
pub text_encoder: t5::T5EncoderModel,
|
||||
pub audio_encoder: encodec::Model,
|
||||
pub audio_encoder: crate::encodec_model::EncodecModel,
|
||||
pub decoder: MusicgenForCausalLM,
|
||||
cfg: GenConfig,
|
||||
}
|
||||
@ -380,42 +381,15 @@ pub struct MusicgenForConditionalGeneration {
|
||||
pub struct GenConfig {
|
||||
musicgen: Config,
|
||||
t5: t5::Config,
|
||||
encodec: encodec::Config,
|
||||
encodec: crate::encodec_model::Config,
|
||||
}
|
||||
|
||||
impl GenConfig {
|
||||
pub fn small() -> Self {
|
||||
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
|
||||
let encodec = encodec::Config {
|
||||
audio_channels: 1,
|
||||
chunk_length_s: None,
|
||||
codebook_dim: Some(128),
|
||||
codebook_size: 2048,
|
||||
compress: 2,
|
||||
dilation_growth_rate: 2,
|
||||
hidden_size: 128,
|
||||
kernel_size: 7,
|
||||
last_kernel_size: 7,
|
||||
norm_type: encodec::NormType::WeightNorm,
|
||||
normalize: false,
|
||||
num_filters: 64,
|
||||
num_lstm_layers: 2,
|
||||
num_residual_layers: 1,
|
||||
overlap: None,
|
||||
// This should be Reflect and not Replicate but Reflect does not work yet.
|
||||
pad_mode: encodec::PadMode::Replicate,
|
||||
residual_kernel_size: 3,
|
||||
sampling_rate: 32_000,
|
||||
target_bandwidths: vec![2.2],
|
||||
trim_right_ratio: 1.0,
|
||||
upsampling_ratios: vec![8, 5, 4, 4],
|
||||
use_causal_conv: false,
|
||||
use_conv_shortcut: false,
|
||||
};
|
||||
Self {
|
||||
musicgen: Config::musicgen_small(),
|
||||
t5: t5::Config::musicgen_small(),
|
||||
encodec,
|
||||
encodec: encodec_model::Config::musicgen_small(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -427,7 +401,8 @@ impl MusicgenForConditionalGeneration {
|
||||
|
||||
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
|
||||
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
|
||||
let audio_encoder = encodec::Model::new(&cfg.encodec, vb.pp("audio_encoder"))?;
|
||||
let audio_encoder =
|
||||
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
|
||||
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
|
||||
Ok(Self {
|
||||
text_encoder,
|
||||
|
20
candle-examples/examples/musicgen/nn.rs
Normal file
20
candle-examples/examples/musicgen/nn.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use candle::Result;
|
||||
use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
|
||||
|
||||
// Applies weight norm for inference by recomputing the weight tensor. This
|
||||
// does not apply to training.
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
|
||||
pub fn conv1d_weight_norm(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
@ -1,39 +1,10 @@
|
||||
## Using ONNX models in Candle
|
||||
|
||||
This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based models in Candle.
|
||||
This example demonstrates how to run ONNX based models in Candle, the model
|
||||
being used here is a small sequeezenet variant.
|
||||
|
||||
It contains small variants of two models, [SqueezeNet](https://arxiv.org/pdf/1602.07360.pdf) (default) and [EfficientNet](https://arxiv.org/pdf/1905.11946.pdf).
|
||||
|
||||
You can run the examples with following commands:
|
||||
You can run the example with the following command:
|
||||
|
||||
```bash
|
||||
cargo run --example onnx --features=onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
```
|
||||
|
||||
Use the `--which` flag to specify explicitly which network to use, i.e.
|
||||
|
||||
```bash
|
||||
$ cargo run --example onnx --features=onnx --release -- --which squeeze-net --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
Finished release [optimized] target(s) in 0.21s
|
||||
Running `target/release/examples/onnx --which squeeze-net --image candle-examples/examples/yolo-v8/assets/bike.jpg`
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
unicycle, monocycle : 83.23%
|
||||
ballplayer, baseball player : 3.68%
|
||||
bearskin, busby, shako : 1.54%
|
||||
military uniform : 0.78%
|
||||
cowboy hat, ten-gallon hat : 0.76%
|
||||
```
|
||||
|
||||
```bash
|
||||
$ cargo run --example onnx --features=onnx --release -- --which efficient-net --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
Finished release [optimized] target(s) in 0.20s
|
||||
Running `target/release/examples/onnx --which efficient-net --image candle-examples/examples/yolo-v8/assets/bike.jpg`
|
||||
loaded image Tensor[dims 224, 224, 3; f32]
|
||||
bicycle-built-for-two, tandem bicycle, tandem : 99.16%
|
||||
mountain bike, all-terrain bike, off-roader : 0.60%
|
||||
unicycle, monocycle : 0.17%
|
||||
crash helmet : 0.02%
|
||||
alp : 0.02%
|
||||
cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
```
|
||||
|
@ -212,14 +212,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
/// Process prompt elements separately.
|
||||
#[arg(long)]
|
||||
split_prompt: bool,
|
||||
|
||||
/// Run on CPU rather than GPU even if a GPU is available.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
@ -369,7 +361,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let device = candle_examples::device(false)?;
|
||||
|
||||
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
|
||||
Some("gguf") => {
|
||||
@ -495,20 +487,11 @@ fn main() -> anyhow::Result<()> {
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = if !args.split_prompt {
|
||||
let mut next_token = {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
} else {
|
||||
let mut next_token = 0;
|
||||
for (pos, token) in prompt_tokens.iter().enumerate() {
|
||||
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
next_token = logits_processor.sample(&logits)?
|
||||
}
|
||||
next_token
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
all_tokens.push(next_token);
|
||||
|
@ -1,281 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::qwen2::{Config, Model};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
|
||||
enum WhichModel {
|
||||
#[value(name = "0.5b")]
|
||||
W0_5b,
|
||||
#[value(name = "1.8b")]
|
||||
W1_8b,
|
||||
#[value(name = "4b")]
|
||||
W4b,
|
||||
#[value(name = "7b")]
|
||||
W7b,
|
||||
#[value(name = "14b")]
|
||||
W14b,
|
||||
#[value(name = "72b")]
|
||||
W72b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
#[arg(long, default_value = "0.5b")]
|
||||
model: WhichModel,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => {
|
||||
let size = match args.model {
|
||||
WhichModel::W0_5b => "0.5B",
|
||||
WhichModel::W1_8b => "1.8B",
|
||||
WhichModel::W4b => "4B",
|
||||
WhichModel::W7b => "7B",
|
||||
WhichModel::W14b => "14B",
|
||||
WhichModel::W72b => "72B",
|
||||
};
|
||||
format!("Qwen/Qwen1.5-{size}")
|
||||
}
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.model {
|
||||
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
|
||||
WhichModel::W4b | WhichModel::W7b | WhichModel::W14b | WhichModel::W72b => {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
},
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config_file = repo.get("config.json")?;
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -411,7 +411,7 @@ impl DDPG<'_> {
|
||||
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
|
||||
let actions = self
|
||||
.actor
|
||||
.forward(&state.detach().unsqueeze(0)?)?
|
||||
.forward(&state.detach()?.unsqueeze(0)?)?
|
||||
.squeeze(0)?;
|
||||
let actions = if self.train {
|
||||
(actions + self.ou_noise.sample()?)?
|
||||
|
@ -74,7 +74,7 @@ pub fn run() -> Result<()> {
|
||||
loop {
|
||||
let action = {
|
||||
let action_probs: Vec<f32> =
|
||||
softmax(&model.forward(&state.detach().unsqueeze(0)?)?, 1)?
|
||||
softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)?
|
||||
.squeeze(0)?
|
||||
.to_vec1()?;
|
||||
weighted_sample(action_probs, &mut rng)? as i64
|
||||
@ -109,7 +109,7 @@ pub fn run() -> Result<()> {
|
||||
|
||||
let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
|
||||
.to_dtype(DType::F32)?
|
||||
.detach();
|
||||
.detach()?;
|
||||
|
||||
let actions_mask = {
|
||||
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
|
||||
@ -126,12 +126,12 @@ pub fn run() -> Result<()> {
|
||||
.unwrap()
|
||||
})
|
||||
.collect();
|
||||
Tensor::stack(&actions_mask, 0)?.detach()
|
||||
Tensor::stack(&actions_mask, 0)?.detach()?
|
||||
};
|
||||
|
||||
let states = {
|
||||
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
|
||||
Tensor::stack(&states, 0)?.detach()
|
||||
Tensor::stack(&states, 0)?.detach()?
|
||||
};
|
||||
|
||||
let log_probs = actions_mask
|
||||
|
@ -1,17 +0,0 @@
|
||||
## candle-rwkv
|
||||
|
||||
The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
|
||||
with performance on par with transformer architectures. Several variants are
|
||||
available, candle implements the v5 and v6 versions and can be used with
|
||||
Eagle 7B([blog post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
|
||||
|
||||
```bash
|
||||
$ cargo run --example rwkv --release -- --prompt "The smallest prime is "
|
||||
avx: true, neon: false, simd128: false, f16c: true
|
||||
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
The smallest prime is ϕ(2) = 2.
|
||||
The smallest composite is ϕ(3) = 3.
|
||||
The smallest perfect number is ϕ(5) = 5.
|
||||
The smallest perfect square is ϕ(4) = 4.
|
||||
The smallest perfect cube is ϕ(6) = 6.
|
||||
```
|
@ -1,330 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::quantized_rwkv_v5::Model as Q5;
|
||||
use candle_transformers::models::quantized_rwkv_v6::Model as Q6;
|
||||
use candle_transformers::models::rwkv_v5::{Config, Model as M5, State, Tokenizer};
|
||||
use candle_transformers::models::rwkv_v6::Model as M6;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
const EOS_TOKEN_ID: u32 = 261;
|
||||
|
||||
enum Model {
|
||||
M5(M5),
|
||||
Q5(Q5),
|
||||
M6(M6),
|
||||
Q6(Q6),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, xs: &Tensor, state: &mut State) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::M5(m) => m.forward(xs, state),
|
||||
Self::Q5(m) => m.forward(xs, state),
|
||||
Self::M6(m) => m.forward(xs, state),
|
||||
Self::Q6(m) => m.forward(xs, state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
config: Config,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
config: Config,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
config,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
let mut tokens = self.tokenizer.encode(prompt)?;
|
||||
let mut generated_tokens = 0usize;
|
||||
let mut state = State::new(1, &self.config, &self.device)?;
|
||||
let mut next_logits = None;
|
||||
for &t in tokens.iter() {
|
||||
let input = Tensor::new(&[[t]], &self.device)?;
|
||||
let logits = self.model.forward(&input, &mut state)?;
|
||||
next_logits = Some(logits);
|
||||
print!("{}", self.tokenizer.decode(&[t])?)
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for _ in 0..sample_len {
|
||||
let logits = match next_logits.as_ref() {
|
||||
Some(logits) => logits,
|
||||
None => anyhow::bail!("cannot work on an empty prompt"),
|
||||
};
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == EOS_TOKEN_ID || next_token == 0 {
|
||||
break;
|
||||
}
|
||||
print!("{}", self.tokenizer.decode(&[next_token])?);
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let input = Tensor::new(&[[next_token]], &self.device)?;
|
||||
next_logits = Some(self.model.forward(&input, &mut state)?)
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
|
||||
enum Which {
|
||||
Eagle7b,
|
||||
World1b5,
|
||||
World3b,
|
||||
World6_1b6,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Which {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_id(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
|
||||
Self::World1b5 => "RWKV/rwkv-5-world-1b5",
|
||||
Self::World3b => "RWKV/rwkv-5-world-3b",
|
||||
Self::World6_1b6 => "paperfun/rwkv",
|
||||
}
|
||||
}
|
||||
|
||||
fn revision(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Eagle7b => "refs/pr/1",
|
||||
Self::World1b5 | Self::World3b => "refs/pr/2",
|
||||
Self::World6_1b6 => "main",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "world1b5")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id
|
||||
.unwrap_or_else(|| args.which.model_id().to_string()),
|
||||
RepoType::Model,
|
||||
args.revision
|
||||
.unwrap_or_else(|| args.which.revision().to_string()),
|
||||
));
|
||||
let tokenizer = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => api
|
||||
.model("lmz/candle-rwkv".to_string())
|
||||
.get("rwkv_vocab_v20230424.json")?,
|
||||
};
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
vec![match args.which {
|
||||
Which::World1b5 => api
|
||||
.model("lmz/candle-rwkv".to_string())
|
||||
.get("world1b5-q4k.gguf")?,
|
||||
Which::World3b => api
|
||||
.model("lmz/candle-rwkv".to_string())
|
||||
.get("world3b-q4k.gguf")?,
|
||||
Which::Eagle7b => api
|
||||
.model("lmz/candle-rwkv".to_string())
|
||||
.get("eagle7b-q4k.gguf")?,
|
||||
Which::World6_1b6 => repo.get("rwkv-6-world-1b6-q4k.gguf")?,
|
||||
}]
|
||||
} else {
|
||||
vec![match args.which {
|
||||
Which::World1b5 | Which::World3b | Which::Eagle7b => {
|
||||
repo.get("model.safetensors")?
|
||||
}
|
||||
Which::World6_1b6 => repo.get("rwkv-6-world-1b6.safetensors")?,
|
||||
}]
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::new(tokenizer)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||
match args.which {
|
||||
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::Q5(Q5::new(&config, vb)?),
|
||||
Which::World6_1b6 => Model::Q6(Q6::new(&config, vb)?),
|
||||
}
|
||||
} else {
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
match args.which {
|
||||
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::M5(M5::new(&config, vb)?),
|
||||
Which::World6_1b6 => Model::M6(M6::new(&config, vb)?),
|
||||
}
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
config,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,28 +0,0 @@
|
||||
# candle-segformer
|
||||
|
||||
- [HuggingFace Segformer Model Card][segformer]
|
||||
- [`mit-b0` - An encoder only pretrained model][encoder]
|
||||
- [`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation][ade512]
|
||||
|
||||
## How to run the example
|
||||
|
||||
If you want you can use the example images from this [pull request][pr], download them and supply the path to the image as an argument to the example.
|
||||
|
||||
```bash
|
||||
# run the image classification task
|
||||
cargo run --example segformer classify <path-to-image>
|
||||
# run the segmentation task
|
||||
cargo run --example segformer segment <path-to-image>
|
||||
```
|
||||
|
||||
Example output for classification:
|
||||
|
||||
```text
|
||||
classification logits [3.275261e-5, 0.0008562019, 0.0008868563, 0.9977506, 0.0002465068, 0.0002241473, 2.846596e-6]
|
||||
label: hamburger
|
||||
```
|
||||
|
||||
[pr]: https://github.com/huggingface/candle/pull/1617
|
||||
[segformer]: https://huggingface.co/docs/transformers/model_doc/segformer
|
||||
[encoder]: https://huggingface.co/nvidia/mit-b0
|
||||
[ade512]: https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512
|
@ -1,752 +0,0 @@
|
||||
[
|
||||
{
|
||||
"index": 1,
|
||||
"color": "#787878",
|
||||
"label": "wall"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"color": "#B47878",
|
||||
"label": "building;edifice"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"color": "#06E6E6",
|
||||
"label": "sky"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"color": "#503232",
|
||||
"label": "floor;flooring"
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"color": "#04C803",
|
||||
"label": "tree"
|
||||
},
|
||||
{
|
||||
"index": 6,
|
||||
"color": "#787850",
|
||||
"label": "ceiling"
|
||||
},
|
||||
{
|
||||
"index": 7,
|
||||
"color": "#8C8C8C",
|
||||
"label": "road;route"
|
||||
},
|
||||
{
|
||||
"index": 8,
|
||||
"color": "#CC05FF",
|
||||
"label": "bed"
|
||||
},
|
||||
{
|
||||
"index": 9,
|
||||
"color": "#E6E6E6",
|
||||
"label": "windowpane;window"
|
||||
},
|
||||
{
|
||||
"index": 10,
|
||||
"color": "#04FA07",
|
||||
"label": "grass"
|
||||
},
|
||||
{
|
||||
"index": 11,
|
||||
"color": "#E005FF",
|
||||
"label": "cabinet"
|
||||
},
|
||||
{
|
||||
"index": 12,
|
||||
"color": "#EBFF07",
|
||||
"label": "sidewalk;pavement"
|
||||
},
|
||||
{
|
||||
"index": 13,
|
||||
"color": "#96053D",
|
||||
"label": "person;individual;someone;somebody;mortal;soul"
|
||||
},
|
||||
{
|
||||
"index": 14,
|
||||
"color": "#787846",
|
||||
"label": "earth;ground"
|
||||
},
|
||||
{
|
||||
"index": 15,
|
||||
"color": "#08FF33",
|
||||
"label": "door;double;door"
|
||||
},
|
||||
{
|
||||
"index": 16,
|
||||
"color": "#FF0652",
|
||||
"label": "table"
|
||||
},
|
||||
{
|
||||
"index": 17,
|
||||
"color": "#8FFF8C",
|
||||
"label": "mountain;mount"
|
||||
},
|
||||
{
|
||||
"index": 18,
|
||||
"color": "#CCFF04",
|
||||
"label": "plant;flora;plant;life"
|
||||
},
|
||||
{
|
||||
"index": 19,
|
||||
"color": "#FF3307",
|
||||
"label": "curtain;drape;drapery;mantle;pall"
|
||||
},
|
||||
{
|
||||
"index": 20,
|
||||
"color": "#CC4603",
|
||||
"label": "chair"
|
||||
},
|
||||
{
|
||||
"index": 21,
|
||||
"color": "#0066C8",
|
||||
"label": "car;auto;automobile;machine;motorcar"
|
||||
},
|
||||
{
|
||||
"index": 22,
|
||||
"color": "#3DE6FA",
|
||||
"label": "water"
|
||||
},
|
||||
{
|
||||
"index": 23,
|
||||
"color": "#FF0633",
|
||||
"label": "painting;picture"
|
||||
},
|
||||
{
|
||||
"index": 24,
|
||||
"color": "#0B66FF",
|
||||
"label": "sofa;couch;lounge"
|
||||
},
|
||||
{
|
||||
"index": 25,
|
||||
"color": "#FF0747",
|
||||
"label": "shelf"
|
||||
},
|
||||
{
|
||||
"index": 26,
|
||||
"color": "#FF09E0",
|
||||
"label": "house"
|
||||
},
|
||||
{
|
||||
"index": 27,
|
||||
"color": "#0907E6",
|
||||
"label": "sea"
|
||||
},
|
||||
{
|
||||
"index": 28,
|
||||
"color": "#DCDCDC",
|
||||
"label": "mirror"
|
||||
},
|
||||
{
|
||||
"index": 29,
|
||||
"color": "#FF095C",
|
||||
"label": "rug;carpet;carpeting"
|
||||
},
|
||||
{
|
||||
"index": 30,
|
||||
"color": "#7009FF",
|
||||
"label": "field"
|
||||
},
|
||||
{
|
||||
"index": 31,
|
||||
"color": "#08FFD6",
|
||||
"label": "armchair"
|
||||
},
|
||||
{
|
||||
"index": 32,
|
||||
"color": "#07FFE0",
|
||||
"label": "seat"
|
||||
},
|
||||
{
|
||||
"index": 33,
|
||||
"color": "#FFB806",
|
||||
"label": "fence;fencing"
|
||||
},
|
||||
{
|
||||
"index": 34,
|
||||
"color": "#0AFF47",
|
||||
"label": "desk"
|
||||
},
|
||||
{
|
||||
"index": 35,
|
||||
"color": "#FF290A",
|
||||
"label": "rock;stone"
|
||||
},
|
||||
{
|
||||
"index": 36,
|
||||
"color": "#07FFFF",
|
||||
"label": "wardrobe;closet;press"
|
||||
},
|
||||
{
|
||||
"index": 37,
|
||||
"color": "#E0FF08",
|
||||
"label": "lamp"
|
||||
},
|
||||
{
|
||||
"index": 38,
|
||||
"color": "#6608FF",
|
||||
"label": "bathtub;bathing;tub;bath;tub"
|
||||
},
|
||||
{
|
||||
"index": 39,
|
||||
"color": "#FF3D06",
|
||||
"label": "railing;rail"
|
||||
},
|
||||
{
|
||||
"index": 40,
|
||||
"color": "#FFC207",
|
||||
"label": "cushion"
|
||||
},
|
||||
{
|
||||
"index": 41,
|
||||
"color": "#FF7A08",
|
||||
"label": "base;pedestal;stand"
|
||||
},
|
||||
{
|
||||
"index": 42,
|
||||
"color": "#00FF14",
|
||||
"label": "box"
|
||||
},
|
||||
{
|
||||
"index": 43,
|
||||
"color": "#FF0829",
|
||||
"label": "column;pillar"
|
||||
},
|
||||
{
|
||||
"index": 44,
|
||||
"color": "#FF0599",
|
||||
"label": "signboard;sign"
|
||||
},
|
||||
{
|
||||
"index": 45,
|
||||
"color": "#0633FF",
|
||||
"label": "chest;of;drawers;chest;bureau;dresser"
|
||||
},
|
||||
{
|
||||
"index": 46,
|
||||
"color": "#EB0CFF",
|
||||
"label": "counter"
|
||||
},
|
||||
{
|
||||
"index": 47,
|
||||
"color": "#A09614",
|
||||
"label": "sand"
|
||||
},
|
||||
{
|
||||
"index": 48,
|
||||
"color": "#00A3FF",
|
||||
"label": "sink"
|
||||
},
|
||||
{
|
||||
"index": 49,
|
||||
"color": "#8C8C8C",
|
||||
"label": "skyscraper"
|
||||
},
|
||||
{
|
||||
"index": 50,
|
||||
"color": "#FA0A0F",
|
||||
"label": "fireplace;hearth;open;fireplace"
|
||||
},
|
||||
{
|
||||
"index": 51,
|
||||
"color": "#14FF00",
|
||||
"label": "refrigerator;icebox"
|
||||
},
|
||||
{
|
||||
"index": 52,
|
||||
"color": "#1FFF00",
|
||||
"label": "grandstand;covered;stand"
|
||||
},
|
||||
{
|
||||
"index": 53,
|
||||
"color": "#FF1F00",
|
||||
"label": "path"
|
||||
},
|
||||
{
|
||||
"index": 54,
|
||||
"color": "#FFE000",
|
||||
"label": "stairs;steps"
|
||||
},
|
||||
{
|
||||
"index": 55,
|
||||
"color": "#99FF00",
|
||||
"label": "runway"
|
||||
},
|
||||
{
|
||||
"index": 56,
|
||||
"color": "#0000FF",
|
||||
"label": "case;display;case;showcase;vitrine"
|
||||
},
|
||||
{
|
||||
"index": 57,
|
||||
"color": "#FF4700",
|
||||
"label": "pool;table;billiard;table;snooker;table"
|
||||
},
|
||||
{
|
||||
"index": 58,
|
||||
"color": "#00EBFF",
|
||||
"label": "pillow"
|
||||
},
|
||||
{
|
||||
"index": 59,
|
||||
"color": "#00ADFF",
|
||||
"label": "screen;door;screen"
|
||||
},
|
||||
{
|
||||
"index": 60,
|
||||
"color": "#1F00FF",
|
||||
"label": "stairway;staircase"
|
||||
},
|
||||
{
|
||||
"index": 61,
|
||||
"color": "#0BC8C8",
|
||||
"label": "river"
|
||||
},
|
||||
{
|
||||
"index": 62,
|
||||
"color": "#FF5200",
|
||||
"label": "bridge;span"
|
||||
},
|
||||
{
|
||||
"index": 63,
|
||||
"color": "#00FFF5",
|
||||
"label": "bookcase"
|
||||
},
|
||||
{
|
||||
"index": 64,
|
||||
"color": "#003DFF",
|
||||
"label": "blind;screen"
|
||||
},
|
||||
{
|
||||
"index": 65,
|
||||
"color": "#00FF70",
|
||||
"label": "coffee;table;cocktail;table"
|
||||
},
|
||||
{
|
||||
"index": 66,
|
||||
"color": "#00FF85",
|
||||
"label": "toilet;can;commode;crapper;pot;potty;stool;throne"
|
||||
},
|
||||
{
|
||||
"index": 67,
|
||||
"color": "#FF0000",
|
||||
"label": "flower"
|
||||
},
|
||||
{
|
||||
"index": 68,
|
||||
"color": "#FFA300",
|
||||
"label": "book"
|
||||
},
|
||||
{
|
||||
"index": 69,
|
||||
"color": "#FF6600",
|
||||
"label": "hill"
|
||||
},
|
||||
{
|
||||
"index": 70,
|
||||
"color": "#C2FF00",
|
||||
"label": "bench"
|
||||
},
|
||||
{
|
||||
"index": 71,
|
||||
"color": "#008FFF",
|
||||
"label": "countertop"
|
||||
},
|
||||
{
|
||||
"index": 72,
|
||||
"color": "#33FF00",
|
||||
"label": "stove;kitchen;stove;range;kitchen;range;cooking;stove"
|
||||
},
|
||||
{
|
||||
"index": 73,
|
||||
"color": "#0052FF",
|
||||
"label": "palm;palm;tree"
|
||||
},
|
||||
{
|
||||
"index": 74,
|
||||
"color": "#00FF29",
|
||||
"label": "kitchen;island"
|
||||
},
|
||||
{
|
||||
"index": 75,
|
||||
"color": "#00FFAD",
|
||||
"label": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system"
|
||||
},
|
||||
{
|
||||
"index": 76,
|
||||
"color": "#0A00FF",
|
||||
"label": "swivel;chair"
|
||||
},
|
||||
{
|
||||
"index": 77,
|
||||
"color": "#ADFF00",
|
||||
"label": "boat"
|
||||
},
|
||||
{
|
||||
"index": 78,
|
||||
"color": "#00FF99",
|
||||
"label": "bar"
|
||||
},
|
||||
{
|
||||
"index": 79,
|
||||
"color": "#FF5C00",
|
||||
"label": "arcade;machine"
|
||||
},
|
||||
{
|
||||
"index": 80,
|
||||
"color": "#FF00FF",
|
||||
"label": "hovel;hut;hutch;shack;shanty"
|
||||
},
|
||||
{
|
||||
"index": 81,
|
||||
"color": "#FF00F5",
|
||||
"label": "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle"
|
||||
},
|
||||
{
|
||||
"index": 82,
|
||||
"color": "#FF0066",
|
||||
"label": "towel"
|
||||
},
|
||||
{
|
||||
"index": 83,
|
||||
"color": "#FFAD00",
|
||||
"label": "light;light;source"
|
||||
},
|
||||
{
|
||||
"index": 84,
|
||||
"color": "#FF0014",
|
||||
"label": "truck;motortruck"
|
||||
},
|
||||
{
|
||||
"index": 85,
|
||||
"color": "#FFB8B8",
|
||||
"label": "tower"
|
||||
},
|
||||
{
|
||||
"index": 86,
|
||||
"color": "#001FFF",
|
||||
"label": "chandelier;pendant;pendent"
|
||||
},
|
||||
{
|
||||
"index": 87,
|
||||
"color": "#00FF3D",
|
||||
"label": "awning;sunshade;sunblind"
|
||||
},
|
||||
{
|
||||
"index": 88,
|
||||
"color": "#0047FF",
|
||||
"label": "streetlight;street;lamp"
|
||||
},
|
||||
{
|
||||
"index": 89,
|
||||
"color": "#FF00CC",
|
||||
"label": "booth;cubicle;stall;kiosk"
|
||||
},
|
||||
{
|
||||
"index": 90,
|
||||
"color": "#00FFC2",
|
||||
"label": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box"
|
||||
},
|
||||
{
|
||||
"index": 91,
|
||||
"color": "#00FF52",
|
||||
"label": "airplane;aeroplane;plane"
|
||||
},
|
||||
{
|
||||
"index": 92,
|
||||
"color": "#000AFF",
|
||||
"label": "dirt;track"
|
||||
},
|
||||
{
|
||||
"index": 93,
|
||||
"color": "#0070FF",
|
||||
"label": "apparel;wearing;apparel;dress;clothes"
|
||||
},
|
||||
{
|
||||
"index": 94,
|
||||
"color": "#3300FF",
|
||||
"label": "pole"
|
||||
},
|
||||
{
|
||||
"index": 95,
|
||||
"color": "#00C2FF",
|
||||
"label": "land;ground;soil"
|
||||
},
|
||||
{
|
||||
"index": 96,
|
||||
"color": "#007AFF",
|
||||
"label": "bannister;banister;balustrade;balusters;handrail"
|
||||
},
|
||||
{
|
||||
"index": 97,
|
||||
"color": "#00FFA3",
|
||||
"label": "escalator;moving;staircase;moving;stairway"
|
||||
},
|
||||
{
|
||||
"index": 98,
|
||||
"color": "#FF9900",
|
||||
"label": "ottoman;pouf;pouffe;puff;hassock"
|
||||
},
|
||||
{
|
||||
"index": 99,
|
||||
"color": "#00FF0A",
|
||||
"label": "bottle"
|
||||
},
|
||||
{
|
||||
"index": 100,
|
||||
"color": "#FF7000",
|
||||
"label": "buffet;counter;sideboard"
|
||||
},
|
||||
{
|
||||
"index": 101,
|
||||
"color": "#8FFF00",
|
||||
"label": "poster;posting;placard;notice;bill;card"
|
||||
},
|
||||
{
|
||||
"index": 102,
|
||||
"color": "#5200FF",
|
||||
"label": "stage"
|
||||
},
|
||||
{
|
||||
"index": 103,
|
||||
"color": "#A3FF00",
|
||||
"label": "van"
|
||||
},
|
||||
{
|
||||
"index": 104,
|
||||
"color": "#FFEB00",
|
||||
"label": "ship"
|
||||
},
|
||||
{
|
||||
"index": 105,
|
||||
"color": "#08B8AA",
|
||||
"label": "fountain"
|
||||
},
|
||||
{
|
||||
"index": 106,
|
||||
"color": "#8500FF",
|
||||
"label": "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter"
|
||||
},
|
||||
{
|
||||
"index": 107,
|
||||
"color": "#00FF5C",
|
||||
"label": "canopy"
|
||||
},
|
||||
{
|
||||
"index": 108,
|
||||
"color": "#B800FF",
|
||||
"label": "washer;automatic;washer;washing;machine"
|
||||
},
|
||||
{
|
||||
"index": 109,
|
||||
"color": "#FF001F",
|
||||
"label": "plaything;toy"
|
||||
},
|
||||
{
|
||||
"index": 110,
|
||||
"color": "#00B8FF",
|
||||
"label": "swimming;pool;swimming;bath;natatorium"
|
||||
},
|
||||
{
|
||||
"index": 111,
|
||||
"color": "#00D6FF",
|
||||
"label": "stool"
|
||||
},
|
||||
{
|
||||
"index": 112,
|
||||
"color": "#FF0070",
|
||||
"label": "barrel;cask"
|
||||
},
|
||||
{
|
||||
"index": 113,
|
||||
"color": "#5CFF00",
|
||||
"label": "basket;handbasket"
|
||||
},
|
||||
{
|
||||
"index": 114,
|
||||
"color": "#00E0FF",
|
||||
"label": "waterfall;falls"
|
||||
},
|
||||
{
|
||||
"index": 115,
|
||||
"color": "#70E0FF",
|
||||
"label": "tent;collapsible;shelter"
|
||||
},
|
||||
{
|
||||
"index": 116,
|
||||
"color": "#46B8A0",
|
||||
"label": "bag"
|
||||
},
|
||||
{
|
||||
"index": 117,
|
||||
"color": "#A300FF",
|
||||
"label": "minibike;motorbike"
|
||||
},
|
||||
{
|
||||
"index": 118,
|
||||
"color": "#9900FF",
|
||||
"label": "cradle"
|
||||
},
|
||||
{
|
||||
"index": 119,
|
||||
"color": "#47FF00",
|
||||
"label": "oven"
|
||||
},
|
||||
{
|
||||
"index": 120,
|
||||
"color": "#FF00A3",
|
||||
"label": "ball"
|
||||
},
|
||||
{
|
||||
"index": 121,
|
||||
"color": "#FFCC00",
|
||||
"label": "food;solid;food"
|
||||
},
|
||||
{
|
||||
"index": 122,
|
||||
"color": "#FF008F",
|
||||
"label": "step;stair"
|
||||
},
|
||||
{
|
||||
"index": 123,
|
||||
"color": "#00FFEB",
|
||||
"label": "tank;storage;tank"
|
||||
},
|
||||
{
|
||||
"index": 124,
|
||||
"color": "#85FF00",
|
||||
"label": "trade;name;brand;name;brand;marque"
|
||||
},
|
||||
{
|
||||
"index": 125,
|
||||
"color": "#FF00EB",
|
||||
"label": "microwave;microwave;oven"
|
||||
},
|
||||
{
|
||||
"index": 126,
|
||||
"color": "#F500FF",
|
||||
"label": "pot;flowerpot"
|
||||
},
|
||||
{
|
||||
"index": 127,
|
||||
"color": "#FF007A",
|
||||
"label": "animal;animate;being;beast;brute;creature;fauna"
|
||||
},
|
||||
{
|
||||
"index": 128,
|
||||
"color": "#FFF500",
|
||||
"label": "bicycle;bike;wheel;cycle"
|
||||
},
|
||||
{
|
||||
"index": 129,
|
||||
"color": "#0ABED4",
|
||||
"label": "lake"
|
||||
},
|
||||
{
|
||||
"index": 130,
|
||||
"color": "#D6FF00",
|
||||
"label": "dishwasher;dish;washer;dishwashing;machine"
|
||||
},
|
||||
{
|
||||
"index": 131,
|
||||
"color": "#00CCFF",
|
||||
"label": "screen;silver;screen;projection;screen"
|
||||
},
|
||||
{
|
||||
"index": 132,
|
||||
"color": "#1400FF",
|
||||
"label": "blanket;cover"
|
||||
},
|
||||
{
|
||||
"index": 133,
|
||||
"color": "#FFFF00",
|
||||
"label": "sculpture"
|
||||
},
|
||||
{
|
||||
"index": 134,
|
||||
"color": "#0099FF",
|
||||
"label": "hood;exhaust;hood"
|
||||
},
|
||||
{
|
||||
"index": 135,
|
||||
"color": "#0029FF",
|
||||
"label": "sconce"
|
||||
},
|
||||
{
|
||||
"index": 136,
|
||||
"color": "#00FFCC",
|
||||
"label": "vase"
|
||||
},
|
||||
{
|
||||
"index": 137,
|
||||
"color": "#2900FF",
|
||||
"label": "traffic;light;traffic;signal;stoplight"
|
||||
},
|
||||
{
|
||||
"index": 138,
|
||||
"color": "#29FF00",
|
||||
"label": "tray"
|
||||
},
|
||||
{
|
||||
"index": 139,
|
||||
"color": "#AD00FF",
|
||||
"label": "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin"
|
||||
},
|
||||
{
|
||||
"index": 140,
|
||||
"color": "#00F5FF",
|
||||
"label": "fan"
|
||||
},
|
||||
{
|
||||
"index": 141,
|
||||
"color": "#4700FF",
|
||||
"label": "pier;wharf;wharfage;dock"
|
||||
},
|
||||
{
|
||||
"index": 142,
|
||||
"color": "#7A00FF",
|
||||
"label": "crt;screen"
|
||||
},
|
||||
{
|
||||
"index": 143,
|
||||
"color": "#00FFB8",
|
||||
"label": "plate"
|
||||
},
|
||||
{
|
||||
"index": 144,
|
||||
"color": "#005CFF",
|
||||
"label": "monitor;monitoring;device"
|
||||
},
|
||||
{
|
||||
"index": 145,
|
||||
"color": "#B8FF00",
|
||||
"label": "bulletin;board;notice;board"
|
||||
},
|
||||
{
|
||||
"index": 146,
|
||||
"color": "#0085FF",
|
||||
"label": "shower"
|
||||
},
|
||||
{
|
||||
"index": 147,
|
||||
"color": "#FFD600",
|
||||
"label": "radiator"
|
||||
},
|
||||
{
|
||||
"index": 148,
|
||||
"color": "#19C2C2",
|
||||
"label": "glass;drinking;glass"
|
||||
},
|
||||
{
|
||||
"index": 149,
|
||||
"color": "#66FF00",
|
||||
"label": "clock"
|
||||
},
|
||||
{
|
||||
"index": 150,
|
||||
"color": "#5C00FF",
|
||||
"label": "flag"
|
||||
}
|
||||
]
|
@ -1,155 +0,0 @@
|
||||
use candle::Device;
|
||||
use candle::Module;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::segformer::{
|
||||
Config, ImageClassificationModel, SemanticSegmentationModel,
|
||||
};
|
||||
use clap::{Args, Parser, Subcommand};
|
||||
use image::Rgb;
|
||||
use imageproc::integral_image::ArrayData;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[clap(about, version, long_about = None)]
|
||||
struct CliArgs {
|
||||
#[arg(long, help = "use cpu")]
|
||||
cpu: bool,
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
#[derive(Args, Debug)]
|
||||
struct SegmentationArgs {
|
||||
#[arg(
|
||||
long,
|
||||
help = "name of the huggingface hub model",
|
||||
default_value = "nvidia/segformer-b0-finetuned-ade-512-512"
|
||||
)]
|
||||
model_name: String,
|
||||
#[arg(
|
||||
long,
|
||||
help = "path to the label file in json format",
|
||||
default_value = "candle-examples/examples/segformer/assets/labels.json"
|
||||
)]
|
||||
label_path: PathBuf,
|
||||
#[arg(long, help = "path to for the output mask image")]
|
||||
output_path: PathBuf,
|
||||
#[arg(help = "path to image as input")]
|
||||
image: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug)]
|
||||
struct ClassificationArgs {
|
||||
#[arg(
|
||||
long,
|
||||
help = "name of the huggingface hub model",
|
||||
default_value = "paolinox/segformer-finetuned-food101"
|
||||
)]
|
||||
model_name: String,
|
||||
#[arg(help = "path to image as input")]
|
||||
image: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum Commands {
|
||||
Segment(SegmentationArgs),
|
||||
Classify(ClassificationArgs),
|
||||
}
|
||||
|
||||
fn get_vb_and_config(model_name: String, device: &Device) -> anyhow::Result<(VarBuilder, Config)> {
|
||||
println!("loading model {} via huggingface hub", model_name);
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name.clone());
|
||||
let model_file = api.get("model.safetensors")?;
|
||||
println!("model {} downloaded and loaded", model_name);
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, device)? };
|
||||
let config = std::fs::read_to_string(api.get("config.json")?)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
println!("{:?}", config);
|
||||
Ok((vb, config))
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct LabelItem {
|
||||
index: u32,
|
||||
color: String,
|
||||
}
|
||||
|
||||
fn segmentation_task(args: SegmentationArgs, device: &Device) -> anyhow::Result<()> {
|
||||
let label_file = std::fs::read_to_string(&args.label_path)?;
|
||||
let label_items: Vec<LabelItem> = serde_json::from_str(&label_file)?;
|
||||
let label_colors: HashMap<u32, Rgb<u8>> = label_items
|
||||
.iter()
|
||||
.map(|x| {
|
||||
(x.index - 1, {
|
||||
let color = x.color.trim_start_matches('#');
|
||||
let r = u8::from_str_radix(&color[0..2], 16).unwrap();
|
||||
let g = u8::from_str_radix(&color[2..4], 16).unwrap();
|
||||
let b = u8::from_str_radix(&color[4..6], 16).unwrap();
|
||||
Rgb([r, g, b])
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?
|
||||
.unsqueeze(0)?
|
||||
.to_device(device)?;
|
||||
let (vb, config) = get_vb_and_config(args.model_name, device)?;
|
||||
let num_labels = label_items.len();
|
||||
|
||||
let model = SemanticSegmentationModel::new(&config, num_labels, vb)?;
|
||||
let segmentations = model.forward(&image)?;
|
||||
|
||||
// generate a mask image
|
||||
let mask = &segmentations.squeeze(0)?.argmax(0)?;
|
||||
let (h, w) = mask.dims2()?;
|
||||
let mask = mask.flatten_all()?.to_vec1::<u32>()?;
|
||||
let mask = mask
|
||||
.iter()
|
||||
.flat_map(|x| label_colors[x].data())
|
||||
.collect::<Vec<u8>>();
|
||||
let mask: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
image::ImageBuffer::from_raw(w as u32, h as u32, mask).unwrap();
|
||||
// resize
|
||||
let mask = image::DynamicImage::from(mask);
|
||||
let mask = mask.resize_to_fill(
|
||||
w as u32 * 4,
|
||||
h as u32 * 4,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
mask.save(args.output_path.clone())?;
|
||||
println!("mask image saved to {:?}", args.output_path);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn classification_task(args: ClassificationArgs, device: &Device) -> anyhow::Result<()> {
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?
|
||||
.unsqueeze(0)?
|
||||
.to_device(device)?;
|
||||
let (vb, config) = get_vb_and_config(args.model_name, device)?;
|
||||
let num_labels = 7;
|
||||
let model = ImageClassificationModel::new(&config, num_labels, vb)?;
|
||||
let classification = model.forward(&image)?;
|
||||
let classification = candle_nn::ops::softmax_last_dim(&classification)?;
|
||||
let classification = classification.squeeze(0)?;
|
||||
println!(
|
||||
"classification logits {:?}",
|
||||
classification.to_vec1::<f32>()?
|
||||
);
|
||||
let label_id = classification.argmax(0)?.to_scalar::<u32>()?;
|
||||
let label_id = format!("{}", label_id);
|
||||
println!("label: {}", config.id2label[&label_id]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = CliArgs::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
if let Commands::Segment(args) = args.command {
|
||||
segmentation_task(args, &device)?
|
||||
} else if let Commands::Classify(args) = args.command {
|
||||
classification_task(args, &device)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -57,7 +57,7 @@ The downside is some long compilation time. You can set the
|
||||
`/home/user/.candle` to ensures that the compilation artifacts are properly
|
||||
cached.
|
||||
|
||||
Enabling flash-attention requires both a feature flag, `--features flash-attn`
|
||||
Enabling flash-attention requires both a feature flag, `--feature flash-attn`
|
||||
and using the command line flag `--use-flash-attn`.
|
||||
|
||||
Note that flash-attention-v2 is only compatible with Ampere, Ada, or Hopper GPUs
|
||||
|
@ -8,13 +8,6 @@ Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t).
|
||||
Note that this model is gated so you will have to request access on the Hub in
|
||||
order to be able to use it.
|
||||
|
||||
Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants.
|
||||
|
||||
StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by
|
||||
Candle, so to run it you can download a somewhat compatible
|
||||
[tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true)
|
||||
and pass it via the --tokenizer-file argument.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
|
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
|
||||
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
|
||||
@ -122,16 +122,6 @@ impl TextGeneration {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
|
||||
enum Which {
|
||||
V1Orig,
|
||||
V1,
|
||||
V1Zephyr,
|
||||
V2,
|
||||
V2Zephyr,
|
||||
Code,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -162,18 +152,15 @@ struct Args {
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 1000)]
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
#[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long, default_value = "v2")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
@ -220,80 +207,33 @@ fn main() -> Result<()> {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => match args.which {
|
||||
Which::V1Orig => "lmz/candle-stablelm-3b-4e1t".to_string(),
|
||||
Which::V1 => "stabilityai/stablelm-3b-4e1t".to_string(),
|
||||
Which::V1Zephyr => "stabilityai/stablelm-zephyr-3b".to_string(),
|
||||
Which::Code => "stabilityai/stable-code-3b".to_string(),
|
||||
Which::V2 => "stabilityai/stablelm-2-1_6b".to_string(),
|
||||
Which::V2Zephyr => "stabilityai/stablelm-2-zephyr-1_6b".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => match args.which {
|
||||
Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::Code => {
|
||||
repo.get("tokenizer.json")?
|
||||
}
|
||||
Which::V2 | Which::V2Zephyr => api
|
||||
.model("lmz/candle-stablelm".to_string())
|
||||
.get("tokenizer-gpt4.json")?,
|
||||
},
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => match (args.which, args.quantized) {
|
||||
(Which::V1Orig | Which::V1, true) => vec![repo.get("model-q4k.gguf")?],
|
||||
(Which::V2, true) => {
|
||||
let gguf = api
|
||||
.model("lmz/candle-stablelm".to_string())
|
||||
.get("stablelm-2-1_6b-q4k.gguf")?;
|
||||
vec![gguf]
|
||||
}
|
||||
(Which::V2Zephyr, true) => {
|
||||
let gguf = api
|
||||
.model("lmz/candle-stablelm".to_string())
|
||||
.get("stablelm-2-zephyr-1_6b-q4k.gguf")?;
|
||||
vec![gguf]
|
||||
}
|
||||
(Which::V1Zephyr | Which::Code, true) => {
|
||||
anyhow::bail!("Quantized {:?} variant not supported.", args.which)
|
||||
}
|
||||
(Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => {
|
||||
None => {
|
||||
if args.quantized {
|
||||
vec![repo.get("model-q4k.gguf")?]
|
||||
} else {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
(Which::Code, false) => {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = match args.which {
|
||||
Which::V1Orig => Config::stablelm_3b_4e1t(args.use_flash_attn),
|
||||
Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code => {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: Config = serde_json::from_str(&config)?;
|
||||
config.set_use_flash_attn(args.use_flash_attn);
|
||||
config
|
||||
}
|
||||
};
|
||||
|
||||
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
|
@ -1,253 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::starcoder2::Model;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => "bigcode/starcoder2-3b".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let config_file = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let tokenizer_file = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => vec![repo.get("model.safetensors")?],
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
Binary file not shown.
Before Width: | Height: | Size: 7.5 KiB |
@ -10,36 +10,15 @@ use clap::{Parser, ValueEnum};
|
||||
use candle::{DType, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::{trocr, vit};
|
||||
use candle_transformers::models::trocr;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
mod image_processor;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "base")]
|
||||
BaseHandwritten,
|
||||
#[value(name = "large")]
|
||||
LargeHandwritten,
|
||||
BasePrinted,
|
||||
LargePrinted,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn repo_and_branch_name(&self) -> (&str, &str) {
|
||||
match self {
|
||||
Self::BaseHandwritten => ("microsoft/trocr-base-handwritten", "refs/pr/3"),
|
||||
Self::LargeHandwritten => ("microsoft/trocr-large-handwritten", "refs/pr/6"),
|
||||
Self::BasePrinted => ("microsoft/trocr-base-printed", "refs/pr/7"),
|
||||
Self::LargePrinted => ("microsoft/trocr-large-printed", "main"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
struct Config {
|
||||
encoder: vit::Config,
|
||||
decoder: trocr::TrOCRConfig,
|
||||
Base,
|
||||
Large,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -55,64 +34,63 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The image file to be processed.
|
||||
/// Text to be translated
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Tokenization config.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use hf_hub::api::sync::Api;
|
||||
let args = Args::parse();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
|
||||
let mut tokenizer_dec = {
|
||||
let tokenizer_file = match args.tokenizer {
|
||||
None => api
|
||||
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
||||
.get("tokenizer.json")?,
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(&tokenizer_file).map_err(E::msg)?;
|
||||
TokenOutputStream::new(tokenizer)
|
||||
let tokenizer_dec = {
|
||||
let tokenizer = Api::new()?
|
||||
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
||||
.get("tokenizer.json")?;
|
||||
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
};
|
||||
|
||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let vb = {
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => {
|
||||
let (repo, branch) = args.which.repo_and_branch_name();
|
||||
api.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
branch.to_string(),
|
||||
))
|
||||
.get("model.safetensors")?
|
||||
}
|
||||
None => match args.which {
|
||||
Which::Base => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"microsoft/trocr-base-handwritten".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/3".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
Which::Large => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"microsoft/trocr-large-handwritten".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/6".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
},
|
||||
};
|
||||
println!("model: {:?}", model);
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
|
||||
};
|
||||
|
||||
let (encoder_config, decoder_config) = {
|
||||
let (repo, branch) = args.which.repo_and_branch_name();
|
||||
let config_filename = api
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
branch.to_string(),
|
||||
))
|
||||
.get("config.json")?;
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
(config.encoder, config.decoder)
|
||||
let encoder_config = match args.which {
|
||||
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
|
||||
Which::Large => {
|
||||
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
|
||||
}
|
||||
};
|
||||
|
||||
let decoder_config = trocr::TrOCRConfig::default();
|
||||
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
|
||||
|
||||
let processor_config = image_processor::ProcessorConfig::default();
|
||||
let processor = image_processor::ViTImageProcessor::new(&processor_config);
|
||||
let config = image_processor::ProcessorConfig::default();
|
||||
let processor = image_processor::ViTImageProcessor::new(&config);
|
||||
|
||||
let image = vec![args.image.as_str()];
|
||||
let image = processor.preprocess(image)?;
|
||||
|
@ -5,27 +5,12 @@ transcribe image text. See the associated [model
|
||||
card](https://huggingface.co/microsoft/trocr-base-printed) for details on
|
||||
the model itself.
|
||||
|
||||
Supported models include:
|
||||
|
||||
- `--which base`: small handwritten OCR model.
|
||||
- `--which large`: large handwritten OCR model.
|
||||
- `--which base-printed`: small printed OCR model.
|
||||
- `--which large-printed`: large printed OCR model.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
cargo run --example trocr --release -- --image candle-examples/examples/trocr/assets/trocr.png
|
||||
cargo run --example trocr --release -- --which large --image candle-examples/examples/trocr/assets/trocr.png
|
||||
cargo run --example trocr --release -- --which base-printed --image candle-examples/examples/trocr/assets/noto.png
|
||||
cargo run --example trocr --release -- --which large-printed --image candle-examples/examples/trocr/assets/noto.png
|
||||
cargo run --example trocr --release -- --which base --cpu --image candle-examples/examples/trocr/assets/trocr.png
|
||||
```
|
||||
|
||||
### Outputs
|
||||
|
||||
```
|
||||
industry , Mr. Brown commented icily . " Let us have a
|
||||
industry , " Mr. Brown commented icily . " Let us have a
|
||||
THE QUICK BROWN FOR JUMPS OVER THE LAY DOG
|
||||
THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG
|
||||
<s> industry , Mr. Brown commented icily . " Let us have a</s>
|
||||
```
|
||||
|
@ -1,673 +0,0 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, IndexOp, Tensor};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use std::iter;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod multilingual;
|
||||
|
||||
use candle_transformers::models::whisper::{self as m, audio, Config};
|
||||
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub enum Model {
|
||||
Normal(m::model::Whisper),
|
||||
Quantized(m::quantized_model::Whisper),
|
||||
}
|
||||
|
||||
// Maybe we should use some traits rather than doing the dispatch for all these.
|
||||
impl Model {
|
||||
pub fn config(&self) -> &Config {
|
||||
match self {
|
||||
Self::Normal(m) => &m.config,
|
||||
Self::Quantized(m) => &m.config,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.encoder.forward(x, flush),
|
||||
Self::Quantized(m) => m.encoder.forward(x, flush),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decoder_forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: &Tensor,
|
||||
flush: bool,
|
||||
) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.decoder.forward(x, xa, flush),
|
||||
Self::Quantized(m) => m.decoder.forward(x, xa, flush),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.decoder.final_linear(x),
|
||||
Self::Quantized(m) => m.decoder.final_linear(x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecodingResult {
|
||||
tokens: Vec<u32>,
|
||||
text: String,
|
||||
avg_logprob: f64,
|
||||
no_speech_prob: f64,
|
||||
temperature: f64,
|
||||
compression_ratio: f64,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct Segment {
|
||||
start: f64,
|
||||
duration: f64,
|
||||
dr: DecodingResult,
|
||||
}
|
||||
|
||||
struct Decoder {
|
||||
model: Model,
|
||||
rng: rand::rngs::StdRng,
|
||||
task: Option<Task>,
|
||||
timestamps: bool,
|
||||
verbose: bool,
|
||||
tokenizer: Tokenizer,
|
||||
suppress_tokens: Tensor,
|
||||
sot_token: u32,
|
||||
transcribe_token: u32,
|
||||
translate_token: u32,
|
||||
eot_token: u32,
|
||||
no_speech_token: u32,
|
||||
no_timestamps_token: u32,
|
||||
language_token: Option<u32>,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
device: &Device,
|
||||
language_token: Option<u32>,
|
||||
task: Option<Task>,
|
||||
timestamps: bool,
|
||||
verbose: bool,
|
||||
) -> Result<Self> {
|
||||
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
|
||||
// Suppress the notimestamps token when in timestamps mode.
|
||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
|
||||
let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
|
||||
.map(|i| {
|
||||
if model.config().suppress_tokens.contains(&i)
|
||||
|| timestamps && i == no_timestamps_token
|
||||
{
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0f32
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||
let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
|
||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||
let no_speech_token = m::NO_SPEECH_TOKENS
|
||||
.iter()
|
||||
.find_map(|token| token_id(&tokenizer, token).ok());
|
||||
let no_speech_token = match no_speech_token {
|
||||
None => anyhow::bail!("unable to find any non-speech token"),
|
||||
Some(n) => n,
|
||||
};
|
||||
Ok(Self {
|
||||
model,
|
||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||
tokenizer,
|
||||
task,
|
||||
timestamps,
|
||||
verbose,
|
||||
suppress_tokens,
|
||||
sot_token,
|
||||
transcribe_token,
|
||||
translate_token,
|
||||
eot_token,
|
||||
no_speech_token,
|
||||
language_token,
|
||||
no_timestamps_token,
|
||||
})
|
||||
}
|
||||
|
||||
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
||||
let model = &mut self.model;
|
||||
let audio_features = model.encoder_forward(mel, true)?;
|
||||
if self.verbose {
|
||||
println!("audio features: {:?}", audio_features.dims());
|
||||
}
|
||||
let sample_len = model.config().max_target_positions / 2;
|
||||
let mut sum_logprob = 0f64;
|
||||
let mut no_speech_prob = f64::NAN;
|
||||
let mut tokens = vec![self.sot_token];
|
||||
if let Some(language_token) = self.language_token {
|
||||
tokens.push(language_token);
|
||||
}
|
||||
match self.task {
|
||||
None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
|
||||
Some(Task::Translate) => tokens.push(self.translate_token),
|
||||
}
|
||||
if !self.timestamps {
|
||||
tokens.push(self.no_timestamps_token);
|
||||
}
|
||||
for i in 0..sample_len {
|
||||
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
||||
|
||||
// The model expects a batch dim but this inference loop does not handle
|
||||
// it so we add it at this point.
|
||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
|
||||
|
||||
// Extract the no speech probability on the first iteration by looking at the first
|
||||
// token logits and the probability for the according token.
|
||||
if i == 0 {
|
||||
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
no_speech_prob = softmax(&logits, 0)?
|
||||
.i(self.no_speech_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
}
|
||||
|
||||
let (_, seq_len, _) = ys.dims3()?;
|
||||
let logits = model
|
||||
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
|
||||
.i(0)?
|
||||
.i(0)?;
|
||||
// TODO: Besides suppress tokens, we should apply the heuristics from
|
||||
// ApplyTimestampRules, i.e.:
|
||||
// - Timestamps come in pairs, except before EOT.
|
||||
// - Timestamps should be non-decreasing.
|
||||
// - If the sum of the probabilities of timestamps is higher than any other tokens,
|
||||
// only consider timestamps when sampling.
|
||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439
|
||||
let logits = logits.broadcast_add(&self.suppress_tokens)?;
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = softmax(&(&logits / t)?, 0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
distr.sample(&mut self.rng) as u32
|
||||
} else {
|
||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||
logits_v
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, u), (_, v)| u.total_cmp(v))
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
};
|
||||
tokens.push(next_token);
|
||||
let prob = softmax(&logits, candle::D::Minus1)?
|
||||
.i(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
|
||||
break;
|
||||
}
|
||||
sum_logprob += prob.ln();
|
||||
}
|
||||
let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
|
||||
let avg_logprob = sum_logprob / tokens.len() as f64;
|
||||
|
||||
Ok(DecodingResult {
|
||||
tokens,
|
||||
text,
|
||||
avg_logprob,
|
||||
no_speech_prob,
|
||||
temperature: t,
|
||||
compression_ratio: f64::NAN,
|
||||
})
|
||||
}
|
||||
|
||||
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
|
||||
for (i, &t) in m::TEMPERATURES.iter().enumerate() {
|
||||
let dr: Result<DecodingResult> = self.decode(segment, t);
|
||||
if i == m::TEMPERATURES.len() - 1 {
|
||||
return dr;
|
||||
}
|
||||
// On errors, we try again with a different temperature.
|
||||
match dr {
|
||||
Ok(dr) => {
|
||||
let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|
||||
|| dr.avg_logprob < m::LOGPROB_THRESHOLD;
|
||||
if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
|
||||
return Ok(dr);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
println!("Error running at {t}: {err}")
|
||||
}
|
||||
}
|
||||
}
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn run(&mut self, mel: &Tensor, times: Option<(f64, f64)>) -> Result<Vec<Segment>> {
|
||||
let (_, _, content_frames) = mel.dims3()?;
|
||||
let mut seek = 0;
|
||||
let mut segments = vec![];
|
||||
while seek < content_frames {
|
||||
let start = std::time::Instant::now();
|
||||
let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
||||
let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
|
||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
||||
let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
||||
let dr = self.decode_with_fallback(&mel_segment)?;
|
||||
seek += segment_size;
|
||||
if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
|
||||
println!("no speech detected, skipping {seek} {dr:?}");
|
||||
continue;
|
||||
}
|
||||
let segment = Segment {
|
||||
start: time_offset,
|
||||
duration: segment_duration,
|
||||
dr,
|
||||
};
|
||||
if self.timestamps {
|
||||
println!(
|
||||
"{:.1}s -- {:.1}s",
|
||||
segment.start,
|
||||
segment.start + segment.duration,
|
||||
);
|
||||
let mut tokens_to_decode = vec![];
|
||||
let mut prev_timestamp_s = 0f32;
|
||||
for &token in segment.dr.tokens.iter() {
|
||||
if token == self.sot_token || token == self.eot_token {
|
||||
continue;
|
||||
}
|
||||
// The no_timestamp_token is the last before the timestamp ones.
|
||||
if token > self.no_timestamps_token {
|
||||
let timestamp_s = (token - self.no_timestamps_token + 1) as f32 / 50.;
|
||||
if !tokens_to_decode.is_empty() {
|
||||
let text = self
|
||||
.tokenizer
|
||||
.decode(&tokens_to_decode, true)
|
||||
.map_err(E::msg)?;
|
||||
println!(" {:.1}s-{:.1}s: {}", prev_timestamp_s, timestamp_s, text);
|
||||
tokens_to_decode.clear()
|
||||
}
|
||||
prev_timestamp_s = timestamp_s;
|
||||
} else {
|
||||
tokens_to_decode.push(token)
|
||||
}
|
||||
}
|
||||
if !tokens_to_decode.is_empty() {
|
||||
let text = self
|
||||
.tokenizer
|
||||
.decode(&tokens_to_decode, true)
|
||||
.map_err(E::msg)?;
|
||||
if !text.is_empty() {
|
||||
println!(" {:.1}s-...: {}", prev_timestamp_s, text);
|
||||
}
|
||||
tokens_to_decode.clear()
|
||||
}
|
||||
} else {
|
||||
match times {
|
||||
Some((start, end)) => {
|
||||
println!("{:.1}s -- {:.1}s: {}", start, end, segment.dr.text)
|
||||
}
|
||||
None => {
|
||||
println!(
|
||||
"{:.1}s -- {:.1}s: {}",
|
||||
segment.start,
|
||||
segment.start + segment.duration,
|
||||
segment.dr.text,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
if self.verbose {
|
||||
println!("{seek}: {segment:?}, in {:?}", start.elapsed());
|
||||
}
|
||||
segments.push(segment)
|
||||
}
|
||||
Ok(segments)
|
||||
}
|
||||
|
||||
fn set_language_token(&mut self, language_token: Option<u32>) {
|
||||
self.language_token = language_token;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn reset_kv_cache(&mut self) {
|
||||
match &mut self.model {
|
||||
Model::Normal(m) => m.reset_kv_cache(),
|
||||
Model::Quantized(m) => m.reset_kv_cache(),
|
||||
}
|
||||
}
|
||||
|
||||
fn model(&mut self) -> &mut Model {
|
||||
&mut self.model
|
||||
}
|
||||
}
|
||||
|
||||
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
|
||||
match tokenizer.token_to_id(token) {
|
||||
None => candle::bail!("no token-id for {token}"),
|
||||
Some(id) => Ok(id),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Task {
|
||||
Transcribe,
|
||||
Translate,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
|
||||
enum WhichModel {
|
||||
Tiny,
|
||||
#[value(name = "tiny.en")]
|
||||
TinyEn,
|
||||
Base,
|
||||
#[value(name = "base.en")]
|
||||
BaseEn,
|
||||
Small,
|
||||
#[value(name = "small.en")]
|
||||
SmallEn,
|
||||
Medium,
|
||||
#[value(name = "medium.en")]
|
||||
MediumEn,
|
||||
Large,
|
||||
LargeV2,
|
||||
LargeV3,
|
||||
#[value(name = "distil-medium.en")]
|
||||
DistilMediumEn,
|
||||
#[value(name = "distil-large-v2")]
|
||||
DistilLargeV2,
|
||||
}
|
||||
|
||||
impl WhichModel {
|
||||
fn is_multilingual(&self) -> bool {
|
||||
match self {
|
||||
Self::Tiny
|
||||
| Self::Base
|
||||
| Self::Small
|
||||
| Self::Medium
|
||||
| Self::Large
|
||||
| Self::LargeV2
|
||||
| Self::LargeV3
|
||||
| Self::DistilLargeV2 => true,
|
||||
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn model_and_revision(&self) -> (&'static str, &'static str) {
|
||||
match self {
|
||||
Self::Tiny => ("openai/whisper-tiny", "main"),
|
||||
Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"),
|
||||
Self::Base => ("openai/whisper-base", "refs/pr/22"),
|
||||
Self::BaseEn => ("openai/whisper-base.en", "refs/pr/13"),
|
||||
Self::Small => ("openai/whisper-small", "main"),
|
||||
Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"),
|
||||
Self::Medium => ("openai/whisper-medium", "main"),
|
||||
Self::MediumEn => ("openai/whisper-medium.en", "main"),
|
||||
Self::Large => ("openai/whisper-large", "refs/pr/36"),
|
||||
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
|
||||
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
|
||||
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
|
||||
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
/// The model to use, check out available models:
|
||||
/// https://huggingface.co/models?search=whisper
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// The model to be used, can be tiny, small, medium.
|
||||
#[arg(long, default_value = "tiny.en")]
|
||||
model: WhichModel,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Language.
|
||||
#[arg(long)]
|
||||
language: Option<String>,
|
||||
|
||||
/// Task, when no task is specified, the input tokens contain only the sot token which can
|
||||
/// improve things when in no-timestamp mode.
|
||||
#[arg(long)]
|
||||
task: Option<Task>,
|
||||
|
||||
/// Timestamps mode, this is not fully implemented yet.
|
||||
#[arg(long)]
|
||||
timestamps: bool,
|
||||
|
||||
/// Print the full DecodingResult structure rather than just the text.
|
||||
#[arg(long)]
|
||||
verbose: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (default_model, default_revision) = if args.quantized {
|
||||
("lmz/candle-whisper", "main")
|
||||
} else {
|
||||
args.model.model_and_revision()
|
||||
};
|
||||
let default_model = default_model.to_string();
|
||||
let default_revision = default_revision.to_string();
|
||||
let (model_id, revision) = match (args.model_id, args.revision) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let (config, tokenizer, model) = if args.quantized {
|
||||
let ext = match args.model {
|
||||
WhichModel::TinyEn => "tiny-en",
|
||||
WhichModel::Tiny => "tiny",
|
||||
_ => unimplemented!("no quantized support for {:?}", args.model),
|
||||
};
|
||||
(
|
||||
repo.get(&format!("config-{ext}.json"))?,
|
||||
repo.get(&format!("tokenizer-{ext}.json"))?,
|
||||
repo.get(&format!("model-{ext}-q80.gguf"))?,
|
||||
)
|
||||
} else {
|
||||
let config = repo.get("config.json")?;
|
||||
let tokenizer = repo.get("tokenizer.json")?;
|
||||
let model = repo.get("model.safetensors")?;
|
||||
(config, tokenizer, model)
|
||||
};
|
||||
(config, tokenizer, model)
|
||||
};
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let model = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&weights_filename,
|
||||
&device,
|
||||
)?;
|
||||
Model::Quantized(m::quantized_model::Whisper::load(&vb, config.clone())?)
|
||||
} else {
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
|
||||
Model::Normal(m::model::Whisper::load(&vb, config.clone())?)
|
||||
};
|
||||
let language_token = None;
|
||||
let mut dc = Decoder::new(
|
||||
model,
|
||||
tokenizer.clone(),
|
||||
args.seed,
|
||||
&device,
|
||||
language_token,
|
||||
args.task,
|
||||
args.timestamps,
|
||||
args.verbose,
|
||||
)?;
|
||||
|
||||
let mel_bytes = match config.num_mel_bins {
|
||||
80 => include_bytes!("../whisper/melfilters.bytes").as_slice(),
|
||||
128 => include_bytes!("../whisper/melfilters128.bytes").as_slice(),
|
||||
nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"),
|
||||
};
|
||||
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
||||
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
||||
|
||||
// Set up the input device and stream with the default input config.
|
||||
let host = cpal::default_host();
|
||||
let _device = "default";
|
||||
let _device = if _device == "default" {
|
||||
host.default_input_device()
|
||||
} else {
|
||||
host.input_devices()?
|
||||
.find(|x| x.name().map(|y| y == _device).unwrap_or(false))
|
||||
}
|
||||
.expect("failed to find input device");
|
||||
|
||||
let _config = _device
|
||||
.default_input_config()
|
||||
.expect("Failed to get default input config");
|
||||
|
||||
let channel_count = _config.channels() as usize;
|
||||
|
||||
let audio_ring_buffer = Arc::new(Mutex::new(Vec::new()));
|
||||
let audio_ring_buffer_2 = audio_ring_buffer.clone();
|
||||
|
||||
std::thread::spawn(move || loop {
|
||||
let data = record_audio(&_device, &_config, 300).unwrap();
|
||||
audio_ring_buffer.lock().unwrap().extend_from_slice(&data);
|
||||
let max_len = data.len() * 16;
|
||||
let data_len = data.len();
|
||||
let len = audio_ring_buffer.lock().unwrap().len();
|
||||
if len > max_len {
|
||||
let mut data = audio_ring_buffer.lock().unwrap();
|
||||
let new_data = data[data_len..].to_vec();
|
||||
*data = new_data;
|
||||
}
|
||||
});
|
||||
|
||||
// loop to process the audio data forever (until the user stops the program)
|
||||
println!("Transcribing audio...");
|
||||
for (i, _) in iter::repeat(()).enumerate() {
|
||||
std::thread::sleep(std::time::Duration::from_millis(1000));
|
||||
let data = audio_ring_buffer_2.lock().unwrap().clone();
|
||||
let pcm_data: Vec<_> = data[..data.len() / channel_count as usize]
|
||||
.iter()
|
||||
.map(|v| *v as f32 / 32768.)
|
||||
.collect();
|
||||
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
|
||||
let mel_len = mel.len();
|
||||
let mel = Tensor::from_vec(
|
||||
mel,
|
||||
(1, config.num_mel_bins, mel_len / config.num_mel_bins),
|
||||
&device,
|
||||
)?;
|
||||
|
||||
// on the first iteration, we detect the language and set the language token.
|
||||
if i == 0 {
|
||||
let language_token = match (args.model.is_multilingual(), args.language.clone()) {
|
||||
(true, None) => Some(multilingual::detect_language(dc.model(), &tokenizer, &mel)?),
|
||||
(false, None) => None,
|
||||
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
|
||||
Ok(token_id) => Some(token_id),
|
||||
Err(_) => anyhow::bail!("language {language} is not supported"),
|
||||
},
|
||||
(false, Some(_)) => {
|
||||
anyhow::bail!("a language cannot be set for non-multilingual models")
|
||||
}
|
||||
};
|
||||
println!("language_token: {:?}", language_token);
|
||||
dc.set_language_token(language_token);
|
||||
}
|
||||
dc.run(
|
||||
&mel,
|
||||
Some((
|
||||
i as f64,
|
||||
i as f64 + data.len() as f64 / m::SAMPLE_RATE as f64,
|
||||
)),
|
||||
)?;
|
||||
dc.reset_kv_cache();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn record_audio(
|
||||
device: &cpal::Device,
|
||||
config: &cpal::SupportedStreamConfig,
|
||||
milliseconds: u64,
|
||||
) -> Result<Vec<i16>> {
|
||||
let writer = Arc::new(Mutex::new(Vec::new()));
|
||||
let writer_2 = writer.clone();
|
||||
let stream = device.build_input_stream(
|
||||
&config.config(),
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
let processed = data
|
||||
.iter()
|
||||
.map(|v| (v * 32768.0) as i16)
|
||||
.collect::<Vec<i16>>();
|
||||
writer_2.lock().unwrap().extend_from_slice(&processed);
|
||||
},
|
||||
move |err| {
|
||||
eprintln!("an error occurred on stream: {}", err);
|
||||
},
|
||||
None,
|
||||
)?;
|
||||
stream.play()?;
|
||||
std::thread::sleep(std::time::Duration::from_millis(milliseconds));
|
||||
drop(stream);
|
||||
let data = writer.lock().unwrap().clone();
|
||||
let step = 3;
|
||||
let data: Vec<i16> = data.iter().step_by(step).copied().collect();
|
||||
Ok(data)
|
||||
}
|
@ -1,137 +0,0 @@
|
||||
use crate::{token_id, Model};
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_transformers::models::whisper::{self as m};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const LANGUAGES: [(&str, &str); 99] = [
|
||||
("en", "english"),
|
||||
("zh", "chinese"),
|
||||
("de", "german"),
|
||||
("es", "spanish"),
|
||||
("ru", "russian"),
|
||||
("ko", "korean"),
|
||||
("fr", "french"),
|
||||
("ja", "japanese"),
|
||||
("pt", "portuguese"),
|
||||
("tr", "turkish"),
|
||||
("pl", "polish"),
|
||||
("ca", "catalan"),
|
||||
("nl", "dutch"),
|
||||
("ar", "arabic"),
|
||||
("sv", "swedish"),
|
||||
("it", "italian"),
|
||||
("id", "indonesian"),
|
||||
("hi", "hindi"),
|
||||
("fi", "finnish"),
|
||||
("vi", "vietnamese"),
|
||||
("he", "hebrew"),
|
||||
("uk", "ukrainian"),
|
||||
("el", "greek"),
|
||||
("ms", "malay"),
|
||||
("cs", "czech"),
|
||||
("ro", "romanian"),
|
||||
("da", "danish"),
|
||||
("hu", "hungarian"),
|
||||
("ta", "tamil"),
|
||||
("no", "norwegian"),
|
||||
("th", "thai"),
|
||||
("ur", "urdu"),
|
||||
("hr", "croatian"),
|
||||
("bg", "bulgarian"),
|
||||
("lt", "lithuanian"),
|
||||
("la", "latin"),
|
||||
("mi", "maori"),
|
||||
("ml", "malayalam"),
|
||||
("cy", "welsh"),
|
||||
("sk", "slovak"),
|
||||
("te", "telugu"),
|
||||
("fa", "persian"),
|
||||
("lv", "latvian"),
|
||||
("bn", "bengali"),
|
||||
("sr", "serbian"),
|
||||
("az", "azerbaijani"),
|
||||
("sl", "slovenian"),
|
||||
("kn", "kannada"),
|
||||
("et", "estonian"),
|
||||
("mk", "macedonian"),
|
||||
("br", "breton"),
|
||||
("eu", "basque"),
|
||||
("is", "icelandic"),
|
||||
("hy", "armenian"),
|
||||
("ne", "nepali"),
|
||||
("mn", "mongolian"),
|
||||
("bs", "bosnian"),
|
||||
("kk", "kazakh"),
|
||||
("sq", "albanian"),
|
||||
("sw", "swahili"),
|
||||
("gl", "galician"),
|
||||
("mr", "marathi"),
|
||||
("pa", "punjabi"),
|
||||
("si", "sinhala"),
|
||||
("km", "khmer"),
|
||||
("sn", "shona"),
|
||||
("yo", "yoruba"),
|
||||
("so", "somali"),
|
||||
("af", "afrikaans"),
|
||||
("oc", "occitan"),
|
||||
("ka", "georgian"),
|
||||
("be", "belarusian"),
|
||||
("tg", "tajik"),
|
||||
("sd", "sindhi"),
|
||||
("gu", "gujarati"),
|
||||
("am", "amharic"),
|
||||
("yi", "yiddish"),
|
||||
("lo", "lao"),
|
||||
("uz", "uzbek"),
|
||||
("fo", "faroese"),
|
||||
("ht", "haitian creole"),
|
||||
("ps", "pashto"),
|
||||
("tk", "turkmen"),
|
||||
("nn", "nynorsk"),
|
||||
("mt", "maltese"),
|
||||
("sa", "sanskrit"),
|
||||
("lb", "luxembourgish"),
|
||||
("my", "myanmar"),
|
||||
("bo", "tibetan"),
|
||||
("tl", "tagalog"),
|
||||
("mg", "malagasy"),
|
||||
("as", "assamese"),
|
||||
("tt", "tatar"),
|
||||
("haw", "hawaiian"),
|
||||
("ln", "lingala"),
|
||||
("ha", "hausa"),
|
||||
("ba", "bashkir"),
|
||||
("jw", "javanese"),
|
||||
("su", "sundanese"),
|
||||
];
|
||||
|
||||
/// Returns the token id for the selected language.
|
||||
pub fn detect_language(model: &mut Model, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
|
||||
let (_bsize, _, seq_len) = mel.dims3()?;
|
||||
let mel = mel.narrow(
|
||||
2,
|
||||
0,
|
||||
usize::min(seq_len, model.config().max_source_positions),
|
||||
)?;
|
||||
let device = mel.device();
|
||||
let language_token_ids = LANGUAGES
|
||||
.iter()
|
||||
.map(|(t, _)| token_id(tokenizer, &format!("<|{t}|>")))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let sot_token = token_id(tokenizer, m::SOT_TOKEN)?;
|
||||
let audio_features = model.encoder_forward(&mel, true)?;
|
||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
|
||||
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
let logits = logits.index_select(&language_token_ids, 0)?;
|
||||
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
||||
let probs = probs.to_vec1::<f32>()?;
|
||||
let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::<Vec<_>>();
|
||||
probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for ((_, language), p) in probs.iter().take(5) {
|
||||
println!("{language}: {p}")
|
||||
}
|
||||
let language = token_id(tokenizer, &format!("<|{}|>", probs[0].0 .0))?;
|
||||
Ok(language)
|
||||
}
|
@ -18,8 +18,6 @@ use rand::{distributions::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod multilingual;
|
||||
mod pcm_decode;
|
||||
|
||||
use candle_transformers::models::whisper::{self as m, audio, Config};
|
||||
|
||||
pub enum Model {
|
||||
@ -537,10 +535,17 @@ fn main() -> Result<()> {
|
||||
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
||||
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
||||
|
||||
let (pcm_data, sample_rate) = pcm_decode::pcm_decode(input)?;
|
||||
if sample_rate != m::SAMPLE_RATE as u32 {
|
||||
anyhow::bail!("input file must have a {} sampling rate", m::SAMPLE_RATE)
|
||||
let mut input = std::fs::File::open(input)?;
|
||||
let (header, data) = wav::read(&mut input)?;
|
||||
println!("loaded wav data: {header:?}");
|
||||
if header.sampling_rate != m::SAMPLE_RATE as u32 {
|
||||
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
|
||||
}
|
||||
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
||||
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
||||
.iter()
|
||||
.map(|v| *v as f32 / 32768.)
|
||||
.collect();
|
||||
println!("pcm data loaded {}", pcm_data.len());
|
||||
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
|
||||
let mel_len = mel.len();
|
||||
|
@ -1,74 +0,0 @@
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
|
||||
use symphonia::core::conv::FromSample;
|
||||
|
||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||
where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
|
||||
// Open the media source.
|
||||
let src = std::fs::File::open(path)?;
|
||||
|
||||
// Create the media source stream.
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
|
||||
// Create a probe hint using the file's extension. [Optional]
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
|
||||
// Use the default options for metadata and format readers.
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
|
||||
// Probe the media source.
|
||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||
// Get the instantiated format reader.
|
||||
let mut format = probed.format;
|
||||
|
||||
// Find the first audio track with a known (decodeable) codec.
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
|
||||
.expect("no supported audio tracks");
|
||||
|
||||
// Use the default options for the decoder.
|
||||
let dec_opts: DecoderOptions = Default::default();
|
||||
|
||||
// Create a decoder for the track.
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &dec_opts)
|
||||
.expect("unsupported codec");
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
// The decode loop.
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
// Consume any new metadata that has been read since the last packet.
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
|
||||
// If the packet does not belong to the selected track, skip over it.
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
@ -104,7 +104,6 @@ impl TextGeneration {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
let t = t.replace("<|im_end|>", "\n");
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
@ -216,7 +216,7 @@ fn detect(
|
||||
xs: &Tensor,
|
||||
image_height: usize,
|
||||
classes: usize,
|
||||
anchors: &[(usize, usize)],
|
||||
anchors: &Vec<(usize, usize)>,
|
||||
) -> Result<Tensor> {
|
||||
let (bsize, _channels, height, _width) = xs.dims4()?;
|
||||
let stride = image_height / height;
|
||||
|
@ -1,29 +0,0 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
// https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/data/audio_utils.py#L57
|
||||
pub fn normalize_loudness(
|
||||
wav: &Tensor,
|
||||
sample_rate: u32,
|
||||
loudness_compressor: bool,
|
||||
) -> Result<Tensor> {
|
||||
let energy = wav.sqr()?.mean_all()?.sqrt()?.to_vec0::<f32>()?;
|
||||
if energy < 2e-3 {
|
||||
return Ok(wav.clone());
|
||||
}
|
||||
let wav_array = wav.to_vec1::<f32>()?;
|
||||
let mut meter = crate::bs1770::ChannelLoudnessMeter::new(sample_rate);
|
||||
meter.push(wav_array.into_iter());
|
||||
let power = meter.as_100ms_windows();
|
||||
let loudness = match crate::bs1770::gated_mean(power) {
|
||||
None => return Ok(wav.clone()),
|
||||
Some(gp) => gp.loudness_lkfs() as f64,
|
||||
};
|
||||
let delta_loudness = -14. - loudness;
|
||||
let gain = 10f64.powf(delta_loudness / 20.);
|
||||
let wav = (wav * gain)?;
|
||||
if loudness_compressor {
|
||||
wav.tanh()
|
||||
} else {
|
||||
Ok(wav)
|
||||
}
|
||||
}
|
@ -1,506 +0,0 @@
|
||||
// Copied from https://github.com/ruuda/bs1770/blob/master/src/lib.rs
|
||||
// BS1770 -- Loudness analysis library conforming to ITU-R BS.1770
|
||||
// Copyright 2020 Ruud van Asseldonk
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// A copy of the License has been included in the root of the repository.
|
||||
|
||||
//! Loudness analysis conforming to [ITU-R BS.1770-4][bs17704].
|
||||
//!
|
||||
//! This library offers the building blocks to perform BS.1770 loudness
|
||||
//! measurements, but you need to put the pieces together yourself.
|
||||
//!
|
||||
//! [bs17704]: https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
|
||||
//!
|
||||
//! # Stereo integrated loudness example
|
||||
//!
|
||||
//! ```ignore
|
||||
//! # fn load_stereo_audio() -> [Vec<i16>; 2] {
|
||||
//! # [vec![0; 48_000], vec![0; 48_000]]
|
||||
//! # }
|
||||
//! #
|
||||
//! let sample_rate_hz = 44_100;
|
||||
//! let bits_per_sample = 16;
|
||||
//! let channel_samples: [Vec<i16>; 2] = load_stereo_audio();
|
||||
//!
|
||||
//! // When converting integer samples to float, note that the maximum amplitude
|
||||
//! // is `1 << (bits_per_sample - 1)`, one bit is the sign bit.
|
||||
//! let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
|
||||
//!
|
||||
//! let channel_power: Vec<_> = channel_samples.iter().map(|samples| {
|
||||
//! let mut meter = bs1770::ChannelLoudnessMeter::new(sample_rate_hz);
|
||||
//! meter.push(samples.iter().map(|&s| s as f32 * normalizer));
|
||||
//! meter.into_100ms_windows()
|
||||
//! }).collect();
|
||||
//!
|
||||
//! let stereo_power = bs1770::reduce_stereo(
|
||||
//! channel_power[0].as_ref(),
|
||||
//! channel_power[1].as_ref(),
|
||||
//! );
|
||||
//!
|
||||
//! let gated_power = bs1770::gated_mean(
|
||||
//! stereo_power.as_ref()
|
||||
//! ).unwrap_or(bs1770::Power(0.0));
|
||||
//! println!("Integrated loudness: {:.1} LUFS", gated_power.loudness_lkfs());
|
||||
//! ```
|
||||
|
||||
use std::f32;
|
||||
|
||||
/// Coefficients for a 2nd-degree infinite impulse response filter.
|
||||
///
|
||||
/// Coefficient a0 is implicitly 1.0.
|
||||
#[derive(Clone)]
|
||||
struct Filter {
|
||||
a1: f32,
|
||||
a2: f32,
|
||||
b0: f32,
|
||||
b1: f32,
|
||||
b2: f32,
|
||||
|
||||
// The past two input and output samples.
|
||||
x1: f32,
|
||||
x2: f32,
|
||||
y1: f32,
|
||||
y2: f32,
|
||||
}
|
||||
|
||||
impl Filter {
|
||||
/// Stage 1 of th BS.1770-4 pre-filter.
|
||||
pub fn high_shelf(sample_rate_hz: f32) -> Filter {
|
||||
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
|
||||
let gain_db = 3.999_843_8;
|
||||
let q = 0.707_175_25;
|
||||
let center_hz = 1_681.974_5;
|
||||
|
||||
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L134-L143.
|
||||
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
|
||||
let vh = 10.0_f32.powf(gain_db / 20.0);
|
||||
let vb = vh.powf(0.499_666_78);
|
||||
let a0 = 1.0 + k / q + k * k;
|
||||
Filter {
|
||||
b0: (vh + vb * k / q + k * k) / a0,
|
||||
b1: 2.0 * (k * k - vh) / a0,
|
||||
b2: (vh - vb * k / q + k * k) / a0,
|
||||
a1: 2.0 * (k * k - 1.0) / a0,
|
||||
a2: (1.0 - k / q + k * k) / a0,
|
||||
|
||||
x1: 0.0,
|
||||
x2: 0.0,
|
||||
y1: 0.0,
|
||||
y2: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Stage 2 of th BS.1770-4 pre-filter.
|
||||
pub fn high_pass(sample_rate_hz: f32) -> Filter {
|
||||
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
|
||||
let q = 0.500_327_05;
|
||||
let center_hz = 38.135_47;
|
||||
|
||||
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L145-L151
|
||||
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
|
||||
Filter {
|
||||
a1: 2.0 * (k * k - 1.0) / (1.0 + k / q + k * k),
|
||||
a2: (1.0 - k / q + k * k) / (1.0 + k / q + k * k),
|
||||
b0: 1.0,
|
||||
b1: -2.0,
|
||||
b2: 1.0,
|
||||
|
||||
x1: 0.0,
|
||||
x2: 0.0,
|
||||
y1: 0.0,
|
||||
y2: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed the next input sample, get the next output sample.
|
||||
#[inline(always)]
|
||||
pub fn apply(&mut self, x0: f32) -> f32 {
|
||||
let y0 = 0.0 + self.b0 * x0 + self.b1 * self.x1 + self.b2 * self.x2
|
||||
- self.a1 * self.y1
|
||||
- self.a2 * self.y2;
|
||||
|
||||
self.x2 = self.x1;
|
||||
self.x1 = x0;
|
||||
self.y2 = self.y1;
|
||||
self.y1 = y0;
|
||||
|
||||
y0
|
||||
}
|
||||
}
|
||||
|
||||
/// Compensated sum, for summing many values of different orders of magnitude
|
||||
/// accurately.
|
||||
#[derive(Copy, Clone, PartialEq)]
|
||||
struct Sum {
|
||||
sum: f32,
|
||||
residue: f32,
|
||||
}
|
||||
|
||||
impl Sum {
|
||||
#[inline(always)]
|
||||
fn zero() -> Sum {
|
||||
Sum {
|
||||
sum: 0.0,
|
||||
residue: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn add(&mut self, x: f32) {
|
||||
let sum = self.sum + (self.residue + x);
|
||||
self.residue = (self.residue + x) - (sum - self.sum);
|
||||
self.sum = sum;
|
||||
}
|
||||
}
|
||||
|
||||
/// The mean of the squares of the K-weighted samples in a window of time.
|
||||
///
|
||||
/// K-weighted power is equivalent to K-weighted loudness, the only difference
|
||||
/// is one of scale: power is quadratic in sample amplitudes, whereas loudness
|
||||
/// units are logarithmic. `loudness_lkfs` and `from_lkfs` convert between power,
|
||||
/// and K-weighted Loudness Units relative to nominal Full Scale (LKFS).
|
||||
///
|
||||
/// The term “LKFS” (Loudness Units, K-Weighted, relative to nominal Full Scale)
|
||||
/// is used in BS.1770-4 to emphasize K-weighting, but the term is otherwise
|
||||
/// interchangeable with the more widespread term “LUFS” (Loudness Units,
|
||||
/// relative to Full Scale). Loudness units are related to decibels in the
|
||||
/// following sense: boosting a signal that has a loudness of
|
||||
/// -<var>L<sub>K</sub></var> LUFS by <var>L<sub>K</sub></var> dB (by
|
||||
/// multiplying the amplitude by 10<sup><var>L<sub>K</sub></var>/20</sup>) will
|
||||
/// bring the loudness to 0 LUFS.
|
||||
///
|
||||
/// K-weighting refers to a high-shelf and high-pass filter that model the
|
||||
/// effect that humans perceive a certain amount of power in low frequencies to
|
||||
/// be less loud than the same amount of power in higher frequencies. In this
|
||||
/// library the `Power` type is used exclusively to refer to power after applying K-weighting.
|
||||
///
|
||||
/// The nominal “full scale” is the range [-1.0, 1.0]. Because the power is the
|
||||
/// mean square of the samples, if no input samples exceeded the full scale, the
|
||||
/// power will be in the range [0.0, 1.0]. However, the power delivered by
|
||||
/// multiple channels, which is a weighted sum over individual channel powers,
|
||||
/// can exceed this range, because the weighted sum is not normalized.
|
||||
#[derive(Copy, Clone, PartialEq, PartialOrd)]
|
||||
pub struct Power(pub f32);
|
||||
|
||||
impl Power {
|
||||
/// Convert Loudness Units relative to Full Scale into a squared sample amplitude.
|
||||
///
|
||||
/// This is the inverse of `loudness_lkfs`.
|
||||
pub fn from_lkfs(lkfs: f32) -> Power {
|
||||
// The inverse of the formula below.
|
||||
Power(10.0_f32.powf((lkfs + 0.691) * 0.1))
|
||||
}
|
||||
|
||||
/// Return the loudness of this window in Loudness Units, K-weighted, relative to Full Scale.
|
||||
///
|
||||
/// This is the inverse of `from_lkfs`.
|
||||
pub fn loudness_lkfs(&self) -> f32 {
|
||||
// Equation 2 (p.5) of BS.1770-4.
|
||||
-0.691 + 10.0 * self.0.log10()
|
||||
}
|
||||
}
|
||||
|
||||
/// A `T` value for non-overlapping windows of audio, 100ms in length.
|
||||
///
|
||||
/// The `ChannelLoudnessMeter` applies K-weighting and then produces the power
|
||||
/// for non-overlapping windows of 100ms duration.
|
||||
///
|
||||
/// These non-overlapping 100ms windows can later be combined into overlapping
|
||||
/// windows of 400ms, spaced 100ms apart, to compute instantaneous loudness or
|
||||
/// to perform a gated measurement, or they can be combined into even larger
|
||||
/// windows for a momentary loudness measurement.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Windows100ms<T> {
|
||||
pub inner: T,
|
||||
}
|
||||
|
||||
impl<T> Windows100ms<T> {
|
||||
/// Wrap a new empty vector.
|
||||
pub fn new() -> Windows100ms<Vec<T>> {
|
||||
Windows100ms { inner: Vec::new() }
|
||||
}
|
||||
|
||||
/// Apply `as_ref` to the inner value.
|
||||
pub fn as_ref(&self) -> Windows100ms<&[Power]>
|
||||
where
|
||||
T: AsRef<[Power]>,
|
||||
{
|
||||
Windows100ms {
|
||||
inner: self.inner.as_ref(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply `as_mut` to the inner value.
|
||||
pub fn as_mut(&mut self) -> Windows100ms<&mut [Power]>
|
||||
where
|
||||
T: AsMut<[Power]>,
|
||||
{
|
||||
Windows100ms {
|
||||
inner: self.inner.as_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
/// Apply `len` to the inner value.
|
||||
pub fn len(&self) -> usize
|
||||
where
|
||||
T: AsRef<[Power]>,
|
||||
{
|
||||
self.inner.as_ref().len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Measures K-weighted power of non-overlapping 100ms windows of a single channel of audio.
|
||||
///
|
||||
/// # Output
|
||||
///
|
||||
/// The output of the meter is an intermediate result in the form of power for
|
||||
/// 100ms non-overlapping windows. The windows need to be processed further to
|
||||
/// get one of the instantaneous, momentary, and integrated loudness
|
||||
/// measurements defined in BS.1770.
|
||||
///
|
||||
/// The windows can also be inspected directly; the data is meaningful
|
||||
/// on its own (the K-weighted power delivered in that window of time), but it
|
||||
/// is not something that BS.1770 defines a term for.
|
||||
///
|
||||
/// # Multichannel audio
|
||||
///
|
||||
/// To perform a loudness measurement of multichannel audio, construct a
|
||||
/// `ChannelLoudnessMeter` per channel, and later combine the measured power
|
||||
/// with e.g. `reduce_stereo`.
|
||||
///
|
||||
/// # Instantaneous loudness
|
||||
///
|
||||
/// The instantaneous loudness is the power over a 400ms window, so you can
|
||||
/// average four 100ms windows. No special functionality is implemented to help
|
||||
/// with that at this time. ([Pull requests would be accepted.][contribute])
|
||||
///
|
||||
/// # Momentary loudness
|
||||
///
|
||||
/// The momentary loudness is the power over a 3-second window, so you can
|
||||
/// average thirty 100ms windows. No special functionality is implemented to
|
||||
/// help with that at this time. ([Pull requests would be accepted.][contribute])
|
||||
///
|
||||
/// # Integrated loudness
|
||||
///
|
||||
/// Use `gated_mean` to perform an integrated loudness measurement:
|
||||
///
|
||||
/// ```ignore
|
||||
/// # use std::iter;
|
||||
/// # use bs1770::{ChannelLoudnessMeter, gated_mean};
|
||||
/// # let sample_rate_hz = 44_100;
|
||||
/// # let samples_per_100ms = sample_rate_hz / 10;
|
||||
/// # let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
|
||||
/// # meter.push((0..44_100).map(|i| (i as f32 * 0.01).sin()));
|
||||
/// let integrated_loudness_lkfs = gated_mean(meter.as_100ms_windows())
|
||||
/// .unwrap_or(bs1770::Power(0.0))
|
||||
/// .loudness_lkfs();
|
||||
/// ```
|
||||
///
|
||||
/// [contribute]: https://github.com/ruuda/bs1770/blob/master/CONTRIBUTING.md
|
||||
#[derive(Clone)]
|
||||
pub struct ChannelLoudnessMeter {
|
||||
/// The number of samples that fit in 100ms of audio.
|
||||
samples_per_100ms: u32,
|
||||
|
||||
/// Stage 1 filter (head effects, high shelf).
|
||||
filter_stage1: Filter,
|
||||
|
||||
/// Stage 2 filter (high-pass).
|
||||
filter_stage2: Filter,
|
||||
|
||||
/// Sum of the squares over non-overlapping windows of 100ms.
|
||||
windows: Windows100ms<Vec<Power>>,
|
||||
|
||||
/// The number of samples in the current unfinished window.
|
||||
count: u32,
|
||||
|
||||
/// The sum of the squares of the samples in the current unfinished window.
|
||||
square_sum: Sum,
|
||||
}
|
||||
|
||||
impl ChannelLoudnessMeter {
|
||||
/// Construct a new loudness meter for the given sample rate.
|
||||
pub fn new(sample_rate_hz: u32) -> ChannelLoudnessMeter {
|
||||
ChannelLoudnessMeter {
|
||||
samples_per_100ms: sample_rate_hz / 10,
|
||||
filter_stage1: Filter::high_shelf(sample_rate_hz as f32),
|
||||
filter_stage2: Filter::high_pass(sample_rate_hz as f32),
|
||||
windows: Windows100ms::new(),
|
||||
count: 0,
|
||||
square_sum: Sum::zero(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed input samples for loudness analysis.
|
||||
///
|
||||
/// # Full scale
|
||||
///
|
||||
/// Full scale for the input samples is the interval [-1.0, 1.0]. If your
|
||||
/// input consists of signed integer samples, you can convert as follows:
|
||||
///
|
||||
/// ```ignore
|
||||
/// # let mut meter = bs1770::ChannelLoudnessMeter::new(44_100);
|
||||
/// # let bits_per_sample = 16_usize;
|
||||
/// # let samples = &[0_i16];
|
||||
/// // Note that the maximum amplitude is `1 << (bits_per_sample - 1)`,
|
||||
/// // one bit is the sign bit.
|
||||
/// let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
|
||||
/// meter.push(samples.iter().map(|&s| s as f32 * normalizer));
|
||||
/// ```
|
||||
///
|
||||
/// # Repeated calls
|
||||
///
|
||||
/// You can call `push` multiple times to feed multiple batches of samples.
|
||||
/// This is equivalent to feeding a single chained iterator. The leftover of
|
||||
/// samples that did not fill a full 100ms window is not discarded:
|
||||
///
|
||||
/// ```ignore
|
||||
/// # use std::iter;
|
||||
/// # use bs1770::ChannelLoudnessMeter;
|
||||
/// let sample_rate_hz = 44_100;
|
||||
/// let samples_per_100ms = sample_rate_hz / 10;
|
||||
/// let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
|
||||
///
|
||||
/// meter.push(iter::repeat(0.0).take(samples_per_100ms as usize - 1));
|
||||
/// assert_eq!(meter.as_100ms_windows().len(), 0);
|
||||
///
|
||||
/// meter.push(iter::once(0.0));
|
||||
/// assert_eq!(meter.as_100ms_windows().len(), 1);
|
||||
/// ```
|
||||
pub fn push<I: Iterator<Item = f32>>(&mut self, samples: I) {
|
||||
let normalizer = 1.0 / self.samples_per_100ms as f32;
|
||||
|
||||
// LLVM, if you could go ahead and inline those apply calls, and then
|
||||
// unroll and vectorize the loop, that'd be terrific.
|
||||
for x in samples {
|
||||
let y = self.filter_stage1.apply(x);
|
||||
let z = self.filter_stage2.apply(y);
|
||||
|
||||
self.square_sum.add(z * z);
|
||||
self.count += 1;
|
||||
|
||||
// TODO: Should this branch be marked cold?
|
||||
if self.count == self.samples_per_100ms {
|
||||
let mean_squares = Power(self.square_sum.sum * normalizer);
|
||||
self.windows.inner.push(mean_squares);
|
||||
// We intentionally do not reset the residue. That way, leftover
|
||||
// energy from this window is not lost, so for the file overall,
|
||||
// the sum remains more accurate.
|
||||
self.square_sum.sum = 0.0;
|
||||
self.count = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a reference to the 100ms windows analyzed so far.
|
||||
pub fn as_100ms_windows(&self) -> Windows100ms<&[Power]> {
|
||||
self.windows.as_ref()
|
||||
}
|
||||
|
||||
/// Return all 100ms windows analyzed so far.
|
||||
pub fn into_100ms_windows(self) -> Windows100ms<Vec<Power>> {
|
||||
self.windows
|
||||
}
|
||||
}
|
||||
|
||||
/// Combine power for multiple channels by taking a weighted sum.
|
||||
///
|
||||
/// Note that BS.1770-4 defines power for a multi-channel signal as a weighted
|
||||
/// sum over channels which is not normalized. This means that a stereo signal
|
||||
/// is inherently louder than a mono signal. For a mono signal played back on
|
||||
/// stereo speakers, you should therefore still apply `reduce_stereo`, passing
|
||||
/// in the same signal for both channels.
|
||||
pub fn reduce_stereo(
|
||||
left: Windows100ms<&[Power]>,
|
||||
right: Windows100ms<&[Power]>,
|
||||
) -> Windows100ms<Vec<Power>> {
|
||||
assert_eq!(
|
||||
left.len(),
|
||||
right.len(),
|
||||
"Channels must have the same length."
|
||||
);
|
||||
let mut result = Vec::with_capacity(left.len());
|
||||
for (l, r) in left.inner.iter().zip(right.inner) {
|
||||
result.push(Power(l.0 + r.0));
|
||||
}
|
||||
Windows100ms { inner: result }
|
||||
}
|
||||
|
||||
/// In-place version of `reduce_stereo` that stores the result in the former left channel.
|
||||
pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows100ms<&[Power]>) {
|
||||
assert_eq!(
|
||||
left.len(),
|
||||
right.len(),
|
||||
"Channels must have the same length."
|
||||
);
|
||||
for (l, r) in left.inner.iter_mut().zip(right.inner) {
|
||||
l.0 += r.0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement.
|
||||
///
|
||||
/// The integrated loudness measurement is not just the average power over the
|
||||
/// entire signal. BS.1770-4 defines defines two stages of gating that exclude
|
||||
/// parts of the signal, to ensure that silent parts do not contribute to the
|
||||
/// loudness measurment. This function performs that gating, and returns the
|
||||
/// average power over the windows that were not excluded.
|
||||
///
|
||||
/// The result of this function is the integrated loudness measurement.
|
||||
///
|
||||
/// When no signal remains after applying the gate, this function returns
|
||||
/// `None`. In particular, this happens when all of the signal is softer than
|
||||
/// -70 LKFS, including a signal that consists of pure silence.
|
||||
pub fn gated_mean(windows_100ms: Windows100ms<&[Power]>) -> Option<Power> {
|
||||
let mut gating_blocks = Vec::with_capacity(windows_100ms.len());
|
||||
|
||||
// Stage 1: an absolute threshold of -70 LKFS. (Equation 6, p.6.)
|
||||
let absolute_threshold = Power::from_lkfs(-70.0);
|
||||
|
||||
// Iterate over all 400ms windows.
|
||||
for window in windows_100ms.inner.windows(4) {
|
||||
// Note that the sum over channels has already been performed at this point.
|
||||
let gating_block_power = Power(0.25 * window.iter().map(|mean| mean.0).sum::<f32>());
|
||||
|
||||
if gating_block_power > absolute_threshold {
|
||||
gating_blocks.push(gating_block_power);
|
||||
}
|
||||
}
|
||||
|
||||
if gating_blocks.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Compute the loudness after applying the absolute gate, in order to
|
||||
// determine the threshold for the relative gate.
|
||||
let mut sum_power = Sum::zero();
|
||||
for &gating_block_power in &gating_blocks {
|
||||
sum_power.add(gating_block_power.0);
|
||||
}
|
||||
let absolute_gated_power = Power(sum_power.sum / (gating_blocks.len() as f32));
|
||||
|
||||
// Stage 2: Apply the relative gate.
|
||||
let relative_threshold = Power::from_lkfs(absolute_gated_power.loudness_lkfs() - 10.0);
|
||||
let mut sum_power = Sum::zero();
|
||||
let mut n_blocks = 0_usize;
|
||||
for &gating_block_power in &gating_blocks {
|
||||
if gating_block_power > relative_threshold {
|
||||
sum_power.add(gating_block_power.0);
|
||||
n_blocks += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if n_blocks == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let relative_gated_power = Power(sum_power.sum / n_blocks as f32);
|
||||
Some(relative_gated_power)
|
||||
}
|
@ -1,9 +1,6 @@
|
||||
pub mod audio;
|
||||
pub mod bs1770;
|
||||
pub mod coco_classes;
|
||||
pub mod imagenet;
|
||||
pub mod token_output_stream;
|
||||
pub mod wav;
|
||||
|
||||
use candle::utils::{cuda_is_available, metal_is_available};
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
@ -40,7 +40,7 @@ impl TokenOutputStream {
|
||||
};
|
||||
self.tokens.push(token);
|
||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
|
||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_ascii() {
|
||||
let text = text.split_at(prev_text.len());
|
||||
self.prev_index = self.current_index;
|
||||
self.current_index = self.tokens.len();
|
||||
|
@ -1,56 +0,0 @@
|
||||
use std::io::prelude::*;
|
||||
|
||||
pub trait Sample {
|
||||
fn to_i16(&self) -> i16;
|
||||
}
|
||||
|
||||
impl Sample for f32 {
|
||||
fn to_i16(&self) -> i16 {
|
||||
(self.clamp(-1.0, 1.0) * 32767.0) as i16
|
||||
}
|
||||
}
|
||||
|
||||
impl Sample for f64 {
|
||||
fn to_i16(&self) -> i16 {
|
||||
(self.clamp(-1.0, 1.0) * 32767.0) as i16
|
||||
}
|
||||
}
|
||||
|
||||
impl Sample for i16 {
|
||||
fn to_i16(&self) -> i16 {
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_pcm_as_wav<W: Write, S: Sample>(
|
||||
w: &mut W,
|
||||
samples: &[S],
|
||||
sample_rate: u32,
|
||||
) -> std::io::Result<()> {
|
||||
let len = 12u32; // header
|
||||
let len = len + 24u32; // fmt
|
||||
let len = len + samples.len() as u32 * 2 + 8; // data
|
||||
let n_channels = 1u16;
|
||||
let bytes_per_second = sample_rate * 2 * n_channels as u32;
|
||||
w.write_all(b"RIFF")?;
|
||||
w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes
|
||||
w.write_all(b"WAVE")?;
|
||||
|
||||
// Format block
|
||||
w.write_all(b"fmt ")?;
|
||||
w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes
|
||||
w.write_all(&1u16.to_le_bytes())?; // PCM
|
||||
w.write_all(&n_channels.to_le_bytes())?; // one channel
|
||||
w.write_all(&sample_rate.to_le_bytes())?;
|
||||
w.write_all(&bytes_per_second.to_le_bytes())?;
|
||||
w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample
|
||||
w.write_all(&16u16.to_le_bytes())?; // bits per sample
|
||||
|
||||
// Data block
|
||||
w.write_all(b"data")?;
|
||||
w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?;
|
||||
for sample in samples.iter() {
|
||||
w.write_all(&sample.to_i16().to_le_bytes())?
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.4.1"
|
||||
version = "0.3.3"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.1" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.4.1"
|
||||
version = "0.3.3"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -71,6 +71,7 @@ __device__ void im2col1d(
|
||||
}
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 3;
|
||||
const size_t b_in = src_dims[0];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t l_in = src_dims[2];
|
||||
|
||||
@ -119,6 +120,7 @@ __device__ void im2col(
|
||||
}
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 4;
|
||||
const size_t b_in = src_dims[0];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t h_in = src_dims[2];
|
||||
const size_t w_in = src_dims[3];
|
||||
@ -223,60 +225,6 @@ __device__ void conv2d(
|
||||
dst[dst_i] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
// Naive implementation of conv_transpose1d.
|
||||
template <typename T, typename A>
|
||||
__device__ void conv_transpose1d(
|
||||
const size_t src_numel,
|
||||
const size_t l_out,
|
||||
const size_t stride,
|
||||
const size_t padding,
|
||||
const size_t out_padding,
|
||||
const size_t dilation,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
const T *kernel,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// src: (b_size, c_in, l_in)
|
||||
// k: (c_in, c_out, l_k)
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 3;
|
||||
const size_t *k_dims = info + 6;
|
||||
const size_t *k_s = info + 9;
|
||||
const size_t l_k = k_dims[2];
|
||||
const size_t c_out = k_dims[1];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t l_in = src_dims[2];
|
||||
if (dst_i >= src_dims[0] * c_out * l_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO
|
||||
const size_t b_idx = dst_i / (l_out * c_out);
|
||||
const size_t dst_c_idx = (dst_i / l_out) % c_out;
|
||||
// NCL layout.
|
||||
const size_t out_x = dst_i % l_out;
|
||||
|
||||
const size_t src_idx0 = b_idx * src_s[0];
|
||||
A d = 0;
|
||||
for (int k_x = 0; k_x < (int)l_k; ++k_x) {
|
||||
// let out_x = inp_x * p.stride + k_x * p.dilation - p.padding;
|
||||
int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
|
||||
if (inp_x_stride < 0 || inp_x_stride % stride) {
|
||||
continue;
|
||||
}
|
||||
int inp_x = inp_x_stride / stride;
|
||||
if (inp_x >= l_in) continue;
|
||||
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
||||
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_x * src_s[2];
|
||||
const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_x * k_s[2];
|
||||
d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
|
||||
}
|
||||
}
|
||||
dst[dst_i] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
// Naive implementation of conv_transpose2d.
|
||||
template <typename T, typename A>
|
||||
__device__ void conv_transpose2d(
|
||||
@ -559,22 +507,6 @@ extern "C" __global__ void FN_NAME( \
|
||||
im2col<TYPENAME>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, info, src, dst); \
|
||||
} \
|
||||
|
||||
#define CONVT1D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t src_numel, \
|
||||
const size_t l_out, \
|
||||
const size_t stride, \
|
||||
const size_t padding, \
|
||||
const size_t out_padding, \
|
||||
const size_t dilation, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *kernel, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
conv_transpose1d<TYPENAME, TYPEACC>(src_numel, l_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t src_numel, \
|
||||
@ -636,7 +568,6 @@ extern "C" __global__ void FN_NAME( \
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
|
||||
CONV2D_OP(__nv_bfloat16, float, conv2d_bf16)
|
||||
CONVT1D_OP(__nv_bfloat16, float, conv_transpose1d_bf16)
|
||||
CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16)
|
||||
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
|
||||
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
||||
@ -648,7 +579,6 @@ IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
CONV1D_OP(__half, float, conv1d_f16)
|
||||
CONV2D_OP(__half, float, conv2d_f16)
|
||||
CONVT1D_OP(__half, float, conv_transpose1d_f16)
|
||||
CONVT2D_OP(__half, float, conv_transpose2d_f16)
|
||||
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
|
||||
MAX_POOL2D_OP(__half, max_pool2d_f16)
|
||||
@ -667,11 +597,6 @@ CONV2D_OP(double, double, conv2d_f64)
|
||||
CONV2D_OP(uint8_t, uint8_t, conv2d_u8)
|
||||
CONV2D_OP(uint32_t, uint32_t, conv2d_u32)
|
||||
|
||||
CONVT1D_OP(float, float, conv_transpose1d_f32)
|
||||
CONVT1D_OP(double, double, conv_transpose1d_f64)
|
||||
CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8)
|
||||
CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32)
|
||||
|
||||
CONVT2D_OP(float, float, conv_transpose2d_f32)
|
||||
CONVT2D_OP(double, double, conv_transpose2d_f64)
|
||||
CONVT2D_OP(uint8_t, uint8_t, conv_transpose2d_u8)
|
||||
|
@ -4,7 +4,6 @@ pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -55,11 +55,6 @@ __device__ __forceinline__ T relu_fwd(T x) {
|
||||
return maxg(x, zero);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T silu_fwd(T x) {
|
||||
return x / (static_cast<T>(1) + expg(-x));
|
||||
}
|
||||
|
||||
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
@ -108,7 +103,6 @@ UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
|
||||
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
|
||||
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
||||
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
||||
#endif
|
||||
|
||||
@ -133,7 +127,6 @@ UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
|
||||
UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x))
|
||||
UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
||||
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
||||
UNARY_OP(__half, usilu_f16, silu_fwd(x))
|
||||
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
||||
#endif
|
||||
|
||||
@ -180,7 +173,5 @@ UNARY_OP(float, urelu_f32, relu_fwd(x))
|
||||
UNARY_OP(double, urelu_f64, relu_fwd(x))
|
||||
UNARY_OP1(float, uelu_f32, elu_fwd(x, param))
|
||||
UNARY_OP1(double, uelu_f64, elu_fwd(x, param))
|
||||
UNARY_OP(float, usilu_f32, silu_fwd(x))
|
||||
UNARY_OP(double, usilu_f64, silu_fwd(x))
|
||||
UNARY_OP1(float, upowf_f32, powg(x, param))
|
||||
UNARY_OP1(double, upowf_f64, powg(x, param))
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.4.1"
|
||||
version = "0.3.3"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
|
@ -183,7 +183,7 @@ macro_rules! ops{
|
||||
pub mod unary {
|
||||
ops!(
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||
tanh, recip, silu
|
||||
tanh, recip
|
||||
);
|
||||
}
|
||||
pub mod binary {
|
||||
@ -623,7 +623,8 @@ pub fn call_reduce_strided(
|
||||
strides,
|
||||
elements_to_sum,
|
||||
(input, input_offset),
|
||||
output
|
||||
output,
|
||||
out_length
|
||||
)
|
||||
);
|
||||
|
||||
@ -1364,12 +1365,13 @@ pub fn call_gemm(
|
||||
// TODO byte_stride_d
|
||||
let byte_stride_d = 0;
|
||||
|
||||
let buffer: Vec<u64> = vec![
|
||||
byte_stride_a as _,
|
||||
byte_stride_b as _,
|
||||
byte_stride_c as _,
|
||||
byte_stride_d as _,
|
||||
];
|
||||
let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
|
||||
for i in 0..b {
|
||||
buffer.push((i * byte_stride_a) as u64);
|
||||
buffer.push((i * byte_stride_b) as u64);
|
||||
buffer.push((i * byte_stride_c) as u64);
|
||||
buffer.push((i * byte_stride_d) as u64);
|
||||
}
|
||||
encoder.set_bytes(
|
||||
10,
|
||||
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
||||
|
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user