Compare commits

..

4 Commits

Author SHA1 Message Date
c2261d0222 Merge. 2024-01-07 20:27:33 +01:00
06d186355b Change more consitently the test. 2024-01-06 15:20:55 +01:00
2bbd544832 Non random for better quantization quality 2024-01-06 15:16:01 +01:00
504d0b9ac7 Potential bug on q4k. 2024-01-05 14:15:47 +01:00
151 changed files with 1104 additions and 14535 deletions

View File

@ -1,7 +0,0 @@
version: 2
updates:
- package-ecosystem: "cargo"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 5

View File

@ -5,15 +5,49 @@ on:
pull_request:
jobs:
start-runner:
name: Start self-hosted EC2 runner
runs-on: ubuntu-latest
# Don't run on forks, they won't have access to secrets anyway.
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
env:
AWS_REGION: us-east-1
EC2_AMI_ID: ami-03cfed9ea28f4b002
EC2_INSTANCE_TYPE: g5.xlarge
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
EC2_SECURITY_GROUP: sg-030175c435ac141d6
outputs:
label: ${{ steps.start-ec2-runner.outputs.label }}
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Start EC2 runner
id: start-ec2-runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: start
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
ec2-image-id: ${{ env.EC2_AMI_ID }}
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
subnet-id: ${{ env.EC2_SUBNET_ID }}
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
aws-resource-tags: > # optional, requires additional permissions
[
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
]
test-cuda:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on: [single-gpu, nvidia-gpu, t4, ci]
container:
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
options: --gpus 0
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
needs: start-runner # required to start the main job when the runner is ready
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
permissions:
contents: write
packages: write
@ -24,10 +58,32 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Install dependencies
run: apt-get update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
- name: Install Rust Stable
uses: actions-rust-lang/setup-rust-toolchain@v1
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
- uses: Swatinem/rust-cache@v2
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
- name: Test (cuda)
run: cargo test --features cuda
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
stop-runner:
name: Stop self-hosted EC2 runner
needs:
- start-runner
- test-cuda
runs-on: ubuntu-latest
env:
AWS_REGION: us-east-1
if: ${{ (success() || failure()) && github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }} # required to stop the runner even if the error happened in the previous jobs
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Stop EC2 runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: stop
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
label: ${{ needs.start-runner.outputs.label }}
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}

View File

@ -19,7 +19,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.4.0"
version = "0.3.3"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -31,18 +31,18 @@ license = "MIT OR Apache-2.0"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.4.0" }
candle-datasets = { path = "./candle-datasets", version = "0.4.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.0" }
candle-kernels = { path = "./candle-kernels", version = "0.4.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.0" }
candle-nn = { path = "./candle-nn", version = "0.4.0" }
candle-onnx = { path = "./candle-onnx", version = "0.4.0" }
candle-transformers = { path = "./candle-transformers", version = "0.4.0" }
candle = { path = "./candle-core", package = "candle-core" }
candle-datasets = { path = "./candle-datasets" }
candle-flash-attn = { path = "./candle-flash-attn" }
candle-kernels = { path = "./candle-kernels" }
candle-metal-kernels = { path = "./candle-metal-kernels" }
candle-nn = { path = "./candle-nn" }
candle-onnx = { path = "./candle-onnx" }
candle-transformers = { path = "./candle-transformers" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.10.0", features = ["f16"] }
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
cudarc = { version = "0.9.14", features = ["f16"] }
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
@ -50,20 +50,20 @@ imageproc = { version = "0.23.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
log = "0.4"
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
num_cpus = "1.15.0"
num-traits = "0.2.15"
parquet = { version = "50.0.0" }
parquet = { version = "45.0.0" }
rand = "0.8.5"
rand_distr = "0.4.3"
rayon = "1.7.0"
rusttype = { version = "0.9", default-features = false }
safetensors = "0.4.1"
safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"
serde_json = "1.0.99"
thiserror = "1"
tokenizers = { version = "0.15.0", default-features = false }
tokenizers = { version = "0.13.4", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"

View File

@ -65,9 +65,8 @@ We also provide a some command line based examples using state of the art models
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets. Also supports
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
- [Mamba](./candle-examples/examples/mamba/): an inference only
pre-trained on 1T tokens of English and code datasets.
- [Minimal Mamba](./candle-examples/examples/minimal-mamba/): a minimal
implementation of the Mamba state space model.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
better performance than all publicly available 13b models as of 2023-09-28.
@ -75,9 +74,6 @@ We also provide a some command line based examples using state of the art models
experts 8x7b general LLM with better performance than a Llama 2 70B model with
much faster inference.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
- [RWKV v5](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
performance.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
(English/Chinese) general LLMs with 6b and 34b parameters.
@ -113,12 +109,8 @@ We also provide a some command line based examples using state of the art models
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation).
- [VGG](./candle-examples/examples/vgg/),
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
generate captions for an image.
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
dedicated submodels for hand-writing and printed recognition.
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
model, generates the translated text from the input text.
@ -189,15 +181,13 @@ If you have an addition to this list, please submit a pull request.
- Falcon.
- StarCoder.
- Phi 1, 1.5, and 2.
- Mamba, Minimal Mamba
- Minimal Mamba
- Mistral 7b v0.1.
- Mixtral 8x7b v0.1.
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
- StableLM-3B-4E1T.
- Replit-code-v1.5-3B.
- Bert.
- Yi-6B and Yi-34B.
- Qwen1.5.
- RWKV.
- Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants.
- Mistral 7b, and 7b instruct.
@ -213,10 +203,8 @@ If you have an addition to this list, please submit a pull request.
- Wurstchen v2.
- Image to text.
- BLIP.
- TrOCR.
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
ConvNeXTv2.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
- yolo-v3, yolo-v8.
- Segment-Anything Model (SAM).
- File formats: load models from safetensors, npz, ggml, or PyTorch files.

View File

@ -46,5 +46,6 @@ accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"]
[[bench]]
name = "bench_main"
name = "matmul"
harness = false

View File

@ -1,9 +0,0 @@
mod benchmarks;
use criterion::criterion_main;
criterion_main!(
benchmarks::affine::benches,
benchmarks::matmul::benches,
benchmarks::random::benches,
benchmarks::where_cond::benches
);

View File

@ -1,43 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor) {
a.affine(12.34, 56.78).unwrap();
}
fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let b = 1;
let m = 1024;
let k = 1024;
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
let flops = b * m * k * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
run_affine_benchmark(c, &device, DType::F16, "affine_f16");
run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,66 +0,0 @@
pub(crate) mod affine;
pub(crate) mod matmul;
pub(crate) mod random;
pub(crate) mod where_cond;
use candle_core::{Device, Result};
pub(crate) trait BenchDevice {
fn sync(&self) -> Result<()>;
fn bench_name<S: Into<String>>(&self, name: S) -> String;
}
impl BenchDevice for Device {
fn sync(&self) -> Result<()> {
match self {
Device::Cpu => Ok(()),
Device::Cuda(device) => {
#[cfg(feature = "cuda")]
return Ok(device.synchronize()?);
#[cfg(not(feature = "cuda"))]
panic!("Cuda device without cuda feature enabled: {:?}", device)
}
Device::Metal(device) => {
#[cfg(feature = "metal")]
return Ok(device.wait_until_completed()?);
#[cfg(not(feature = "metal"))]
panic!("Metal device without metal feature enabled: {:?}", device)
}
}
}
fn bench_name<S: Into<String>>(&self, name: S) -> String {
match self {
Device::Cpu => {
let cpu_type = if cfg!(feature = "accelerate") {
"accelerate"
} else if cfg!(feature = "mkl") {
"mkl"
} else {
"cpu"
};
format!("{}_{}", cpu_type, name.into())
}
Device::Cuda(_) => format!("cuda_{}", name.into()),
Device::Metal(_) => format!("metal_{}", name.into()),
}
}
}
struct BenchDeviceHandler {
devices: Vec<Device>,
}
impl BenchDeviceHandler {
pub fn new() -> Result<Self> {
let mut devices = Vec::new();
if cfg!(feature = "metal") {
devices.push(Device::new_metal(0)?);
} else if cfg!(feature = "cuda") {
devices.push(Device::new_cuda(0)?);
}
devices.push(Device::Cpu);
Ok(Self { devices })
}
}

View File

@ -1,63 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn rand_uniform(a: &Tensor) {
a.rand_like(-1.0, 123.0).unwrap();
}
fn rand_normal(a: &Tensor) {
a.randn_like(100.0, 15.0).unwrap();
}
fn run_random_bench(c: &mut Criterion, device: &Device) {
let b = 1;
let rows = 2048;
let cols = 2048;
let dtype = DType::F32;
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let flops = b * rows * cols * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_uniform(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let mut group = c.benchmark_group(device.bench_name("random_normal"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_normal(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_random_bench(c, &device);
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,64 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
a.where_cond(b, c).unwrap();
}
const fn create_cond_arr<const N: usize>() -> [u8; N] {
let mut arr = [0u8; N];
let mut i = 0;
while i < N {
arr[i] = (i % 2) as u8;
i += 1;
}
arr
}
const B: usize = 1;
const M: usize = 1024;
const K: usize = 1024;
const SIZE: usize = B * M * K;
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
let elements = B * M * K;
// E.g. 2 f32 tensors + 1 u8 tensor
let flops = (2 * elements * dtype.size_in_bytes()) + elements;
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(
black_box(&tensor),
black_box(&on_true),
black_box(&on_false),
);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let device = BenchDeviceHandler::new().unwrap();
for d in device.devices {
run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,25 +1,25 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor, b: &Tensor) {
a.matmul(&b.t().unwrap()).unwrap();
}
fn run_bench(c: &mut Criterion, device: &Device) {
fn criterion_benchmark(c: &mut Criterion) {
let b = 1;
let m = 1;
let n = 2048;
let k = 2048;
let device = Device::new_metal(0).unwrap();
let dtype = DType::F32;
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap();
let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap();
let flops = b * m * n * k;
let mut group = c.benchmark_group(device.bench_name("matmul"));
let mut group = c.benchmark_group("matmul_metal");
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
@ -27,18 +27,16 @@ fn run_bench(c: &mut Criterion, device: &Device) {
for _i in 0..iters {
run(black_box(&lhs), black_box(&rhs));
}
device.sync().unwrap();
if let Device::Metal(device) = &device {
device.wait_until_completed().unwrap();
} else {
panic!("Expected metal device");
}
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_bench(c, &device);
}
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

View File

@ -1,5 +1,5 @@
use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
use candle_core::{Device, Result};
use candle_core::quantized::{gguf_file, k_quants, QTensor};
use candle_core::{Device, Result, Tensor};
use clap::{Parser, Subcommand, ValueEnum};
use rayon::prelude::*;
@ -11,7 +11,12 @@ enum QuantizationMode {
}
impl QuantizationMode {
fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result<QTensor> {
fn quantize(
&self,
name: &str,
tensor: QTensor,
default: fn(&Tensor) -> Result<QTensor>,
) -> Result<QTensor> {
match self {
Self::Llama => {
// Same behavior as the llama.cpp quantization.
@ -19,9 +24,9 @@ impl QuantizationMode {
if should_quantize {
let tensor = tensor.dequantize(&Device::Cpu)?;
if name == "output.weight" {
QTensor::quantize(&tensor, GgmlDType::Q6K)
QTensor::quantize::<k_quants::BlockQ6K>(&tensor)
} else {
QTensor::quantize(&tensor, dtype)
default(&tensor)
}
} else {
Ok(tensor)
@ -55,27 +60,6 @@ enum Quantization {
F32,
}
impl Quantization {
fn dtype(&self) -> GgmlDType {
match self {
Quantization::Q4_0 => GgmlDType::Q4_0,
Quantization::Q4_1 => GgmlDType::Q4_1,
Quantization::Q5_0 => GgmlDType::Q5_0,
Quantization::Q5_1 => GgmlDType::Q5_1,
Quantization::Q8_0 => GgmlDType::Q8_0,
Quantization::Q8_1 => GgmlDType::Q8_1,
Quantization::Q2k => GgmlDType::Q2K,
Quantization::Q3k => GgmlDType::Q3K,
Quantization::Q4k => GgmlDType::Q4K,
Quantization::Q5k => GgmlDType::Q5K,
Quantization::Q6k => GgmlDType::Q6K,
Quantization::Q8k => GgmlDType::Q8K,
Quantization::F16 => GgmlDType::F16,
Quantization::F32 => GgmlDType::F32,
}
}
}
#[derive(ValueEnum, Debug, Clone)]
enum Format {
Safetensors,
@ -118,7 +102,7 @@ enum Command {
},
Quantize {
/// The input file(s), in safetensors format.
/// The input file, in gguf format.
in_file: Vec<std::path::PathBuf>,
/// The output file, in gguf format.
@ -133,15 +117,6 @@ enum Command {
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
mode: QuantizationMode,
},
Dequantize {
/// The input file, in gguf format.
in_file: std::path::PathBuf,
/// The output file, in safetensors format.
#[arg(long)]
out_file: std::path::PathBuf,
},
}
#[derive(Parser, Debug, Clone)]
@ -150,12 +125,7 @@ struct Args {
command: Command,
}
fn run_ls(
file: &std::path::PathBuf,
format: Option<Format>,
verbose: bool,
device: &Device,
) -> Result<()> {
fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
let format = match format {
Some(format) => format,
None => match Format::infer(file) {
@ -196,7 +166,7 @@ fn run_ls(
}
}
Format::Pth => {
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?;
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?;
tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() {
println!(
@ -221,7 +191,7 @@ fn run_ls(
}
Format::Ggml => {
let mut file = std::fs::File::open(file)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, qtensor) in tensors.iter() {
@ -262,8 +232,37 @@ fn run_quantize_safetensors(
}
println!("tensors: {}", tensors.len());
let dtype = q.dtype();
let block_size = dtype.block_size();
let quantize_fn = match q {
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
Quantization::F16 => QTensor::quantize::<half::f16>,
Quantization::F32 => QTensor::quantize::<f32>,
};
let block_size = match q {
Quantization::Q4_0 => k_quants::QK4_0,
Quantization::Q4_1 => k_quants::QK4_1,
Quantization::Q5_0 => k_quants::QK5_0,
Quantization::Q5_1 => k_quants::QK5_1,
Quantization::Q8_0 => k_quants::QK8_0,
Quantization::Q8_1 => k_quants::QK8_1,
Quantization::Q2k
| Quantization::Q3k
| Quantization::Q4k
| Quantization::Q5k
| Quantization::Q6k
| Quantization::Q8k => k_quants::QK_K,
Quantization::F16 | Quantization::F32 => 1,
};
let qtensors = tensors
.into_par_iter()
@ -271,9 +270,9 @@ fn run_quantize_safetensors(
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
println!(" quantizing {name} {tensor:?} {should_quantize}");
let tensor = if should_quantize {
QTensor::quantize(&tensor, dtype)?
quantize_fn(&tensor)?
} else {
QTensor::quantize(&tensor, GgmlDType::F32)?
QTensor::quantize::<f32>(&tensor)?
};
Ok((name, tensor))
})
@ -286,29 +285,11 @@ fn run_quantize_safetensors(
Ok(())
}
fn run_dequantize(
in_file: std::path::PathBuf,
out_file: std::path::PathBuf,
device: &Device,
) -> Result<()> {
let mut in_file = std::fs::File::open(in_file)?;
let content = gguf_file::Content::read(&mut in_file)?;
let mut tensors = std::collections::HashMap::new();
for (tensor_name, _) in content.tensor_infos.iter() {
let tensor = content.tensor(&mut in_file, tensor_name, device)?;
let tensor = tensor.dequantize(device)?;
tensors.insert(tensor_name.to_string(), tensor);
}
candle_core::safetensors::save(&tensors, out_file)?;
Ok(())
}
fn run_quantize(
in_files: &[std::path::PathBuf],
out_file: std::path::PathBuf,
q: Quantization,
qmode: QuantizationMode,
device: &Device,
) -> Result<()> {
if in_files.is_empty() {
candle_core::bail!("no specified input files")
@ -334,15 +315,31 @@ fn run_quantize(
let content = gguf_file::Content::read(&mut in_)?;
println!("tensors: {}", content.tensor_infos.len());
let dtype = q.dtype();
let quantize_fn = match q {
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
Quantization::F16 => QTensor::quantize::<half::f16>,
Quantization::F32 => QTensor::quantize::<f32>,
};
let qtensors = content
.tensor_infos
.par_iter()
.map(|(name, _)| {
println!(" quantizing {name}");
let mut in_file = std::fs::File::open(&in_files[0])?;
let tensor = content.tensor(&mut in_file, name, device)?;
let tensor = qmode.quantize(name, tensor, dtype)?;
let tensor = content.tensor(&mut in_file, name)?;
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
Ok((name, tensor))
})
.collect::<Result<Vec<_>>>()?;
@ -362,7 +359,6 @@ fn run_quantize(
fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = Device::Cpu;
match args.command {
Command::Ls {
files,
@ -374,7 +370,7 @@ fn main() -> anyhow::Result<()> {
if multiple_files {
println!("--- {file:?} ---");
}
run_ls(file, format.clone(), verbose, &device)?
run_ls(file, format.clone(), verbose)?
}
}
Command::Quantize {
@ -382,8 +378,7 @@ fn main() -> anyhow::Result<()> {
out_file,
quantization,
mode,
} => run_quantize(&in_file, out_file, quantization, mode, &device)?,
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?,
} => run_quantize(&in_file, out_file, quantization, mode)?,
}
Ok(())
}

View File

@ -380,16 +380,6 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vs_exp_inplace(y: &mut [f32]) {
unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vd_exp_inplace(y: &mut [f64]) {
unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
@ -412,28 +402,6 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
}
}
#[inline]
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vs_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
#[inline]
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vd_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
macro_rules! binary_op {
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
#[inline]

View File

@ -175,7 +175,7 @@ impl Tensor {
// the backprop graph of the backprop itself. This would be an issue for second order
// derivatives but these are out of scope at the moment.
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
let grad = if do_not_detach { grad } else { grad.detach() };
let grad = if do_not_detach { grad } else { grad.detach()? };
if let Some(op) = node.op() {
match op {
Op::Binary(lhs, rhs, BinaryOp::Add) => {
@ -589,13 +589,6 @@ impl Tensor {
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
}
Op::Unary(arg, UnaryOp::Silu) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
let sigmoid_arg = (*node / arg)?;
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
}
Op::Elu(arg, alpha) => {
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
let sum_grad = grads.or_insert(arg)?;

View File

@ -1149,55 +1149,6 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
inp_l: &Layout,
k: &CudaSlice<T>,
k_l: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
// Kernel shape: (c_in_k, c_out, l_k)
// Input shape: (b_size, c_in, l_in)
let p = &self.0;
let l_out = p.l_out();
let dst_el = p.c_out * l_out * p.b_size;
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(k_l.start_offset()..);
let shape = inp_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), kernels::CONV)?;
let ds = if dims.len() == 3 {
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
} else {
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
let params = (
el,
l_out,
p.stride,
p.padding,
p.output_padding,
p.dilation,
&ds,
inp,
k,
&out,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
impl<'a> Map2 for ConvTranspose2D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@ -1859,15 +1810,12 @@ impl BackendStorage for CudaStorage {
fn conv_transpose1d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
let device = self.device().clone();
let slice =
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
Ok(Self { slice, device })
todo!()
}
#[cfg(not(feature = "cudnn"))]

View File

@ -72,7 +72,7 @@ pub mod utils;
mod variable;
pub use cpu_backend::CpuStorage;
pub use device::{Device, DeviceLocation, NdArray};
pub use device::{Device, DeviceLocation};
pub use dtype::{DType, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
pub use indexer::IndexOp;

View File

@ -7,9 +7,8 @@ use candle_metal_kernels::Kernels;
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock, TryLockError};
use std::sync::{Arc, RwLock, TryLockError};
/// Simple way to catch lock error without
/// depending on T
@ -85,8 +84,13 @@ pub struct MetalDevice {
command_buffer_index: Arc<RwLock<usize>>,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
compute_per_buffer: usize,
/// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the
/// execution order to be linear.
/// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
/// compute graph.
fence: metal::Fence,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`]
/// Heavily used by [`candle_metal_kernels`], both fences need to match
kernels: Arc<candle_metal_kernels::Kernels>,
/// Simple allocator struct.
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
@ -102,8 +106,6 @@ pub struct MetalDevice {
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
/// (strong_count = 1).
buffers: AllocatedBuffers,
/// Seed for random number generation.
seed: Arc<Mutex<Buffer>>,
}
impl std::fmt::Debug for MetalDevice {
@ -219,8 +221,10 @@ impl MetalDevice {
let command_buffer = self.command_buffer()?;
command_buffer.set_label("with_data");
let blit = command_buffer.new_blit_command_encoder();
blit.wait_for_fence(&self.fence);
blit.set_label("with_data_blit");
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
blit.update_fence(&self.fence);
blit.end_encoding();
// This is necessary, for mmaped safetensors
@ -228,33 +232,12 @@ impl MetalDevice {
// The slice might not live long enough for metal
// To actually fill the GPU buffer.
// Putting this wait forces the GPU buffer to be filled
// with the actual data allowing the CPU storage to do
// with the actual data allowing the CPU storage todo
// deallocate properly.
self.wait_until_completed()?;
Ok(real)
}
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
let buffer = self.allocate_buffer(
size_in_bytes as NSUInteger,
MTLResourceOptions::StorageModePrivate,
"allocate_zeros",
)?;
let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder();
blit.fill_buffer(
&buffer,
metal::NSRange {
location: 0,
length: buffer.length(),
},
0,
);
blit.end_encoding();
Ok(buffer)
}
/// The critical allocator algorithm
fn allocate_buffer(
&self,
@ -325,14 +308,35 @@ impl BackendStorage for MetalStorage {
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
let length = self.buffer.length() as usize;
let size = self.dtype.size_in_bytes();
if length % size != 0 {
crate::bail!(
"The Metal buffer length is not aligned with dtype {:?}",
self.dtype
);
}
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
{
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.wait_for_fence(&self.device.fence);
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.update_fence(&self.device.fence);
blit.end_encoding();
}
self.device.wait_until_completed()?;
match self.dtype {
DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)),
DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)),
DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)),
DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)),
DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)),
DType::F32 => Ok(CpuStorage::F32(self.to_cpu()?)),
DType::F64 => Ok(CpuStorage::F64(self.to_cpu()?)),
DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))),
DType::U32 => Ok(CpuStorage::U32(read_to_vec(&buffer, length / size))),
DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))),
DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))),
DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))),
DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))),
DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))),
}
}
@ -349,7 +353,6 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype {
DType::F32 => "affine_f32",
DType::F16 => "affine_f16",
DType::BF16 => "affine_bf16",
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
};
candle_metal_kernels::call_affine(
@ -368,7 +371,6 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype {
DType::F32 => "affine_f32_strided",
DType::F16 => "affine_f16_strided",
DType::BF16 => "affine_bf16_strided",
dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"),
};
candle_metal_kernels::call_affine_strided(
@ -588,27 +590,14 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::U8) => "cast_u32_u8",
(DType::U32, DType::I64) => "cast_u32_i64",
(DType::U32, DType::F16) => "cast_u32_f16",
(DType::U32, DType::BF16) => "cast_u32_bf16",
(DType::U8, DType::U32) => "cast_u8_u32",
(DType::U8, DType::F32) => "cast_u8_f32",
(DType::U8, DType::I64) => "cast_u8_i64",
(DType::U8, DType::BF16) => "cast_u8_bf16",
(DType::F32, DType::F16) => "cast_f32_f16",
(DType::F32, DType::BF16) => "cast_f32_bf16",
(DType::I64, DType::F32) => "cast_i64_f32",
(DType::F16, DType::BF16) => "cast_f16_bf16",
(DType::F16, DType::F32) => "cast_f16_f32",
(DType::BF16, DType::U8) => "cast_bf16_u8",
(DType::BF16, DType::U32) => "cast_bf16_u32",
(DType::BF16, DType::F16) => "cast_bf16_f16",
(DType::I64, DType::F32) => "cast_i64_f32",
(DType::F32, DType::BF16) => "cast_f32_bf16",
(DType::BF16, DType::F32) => "cast_bf16_f32",
(left, right) => {
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
}
@ -680,14 +669,12 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
("uerf", DType::F32) => contiguous::erf::FLOAT,
("usilu", DType::F32) => contiguous::silu::FLOAT,
("uabs", DType::F32) => contiguous::abs::FLOAT,
("uceil", DType::F32) => contiguous::ceil::FLOAT,
("ufloor", DType::F32) => contiguous::floor::FLOAT,
("uround", DType::F32) => contiguous::round::FLOAT,
("urecip", DType::F32) => contiguous::recip::FLOAT,
("utanh", DType::F32) => contiguous::tanh::FLOAT,
("urelu", DType::F32) => contiguous::relu::FLOAT,
("ucos", DType::F16) => contiguous::cos::HALF,
("usin", DType::F16) => contiguous::sin::HALF,
("usqr", DType::F16) => contiguous::sqr::HALF,
@ -698,14 +685,12 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F16) => contiguous::gelu::HALF,
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
("uerf", DType::F16) => contiguous::erf::HALF,
("usilu", DType::F16) => contiguous::silu::HALF,
("uabs", DType::F16) => contiguous::abs::HALF,
("uceil", DType::F16) => contiguous::ceil::HALF,
("ufloor", DType::F16) => contiguous::floor::HALF,
("uround", DType::F16) => contiguous::round::HALF,
("urecip", DType::F16) => contiguous::recip::HALF,
("utanh", DType::F16) => contiguous::tanh::HALF,
("urelu", DType::F16) => contiguous::relu::HALF,
(name, dtype) => {
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
}
@ -733,11 +718,9 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F32) => strided::gelu::FLOAT,
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
("uerf", DType::F32) => strided::erf::FLOAT,
("usilu", DType::F32) => strided::silu::FLOAT,
("uabs", DType::F32) => strided::abs::FLOAT,
("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT,
("urelu", DType::F32) => strided::relu::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF,
@ -749,11 +732,9 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F16) => strided::gelu::HALF,
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
("uerf", DType::F16) => strided::erf::HALF,
("usilu", DType::F16) => strided::silu::HALF,
("uabs", DType::F16) => strided::abs::HALF,
("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF,
("urelu", DType::F16) => strided::relu::HALF,
("uround", DType::F16) => strided::round::HALF,
(name, dtype) => {
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
@ -809,7 +790,6 @@ impl BackendStorage for MetalStorage {
}
let name = match (self.dtype, t.dtype()) {
(DType::U8, DType::F32) => "where_u8_f32",
(DType::U8, DType::BF16) => "where_u8_bf16",
(DType::U8, DType::F16) => "where_u8_f16",
(DType::U8, DType::I64) => "where_u8_i64",
(DType::U8, DType::U32) => "where_u8_u32",
@ -1147,12 +1127,8 @@ impl BackendStorage for MetalStorage {
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::BF16) => "is_u8_bf16",
(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
(DType::U32, DType::BF16) => "is_u32_bf16",
(left, right) => {
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
}
@ -1265,7 +1241,7 @@ impl BackendStorage for MetalStorage {
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length);
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
blit.end_encoding();
} else {
let src_shape = src_l.shape();
@ -1342,7 +1318,6 @@ impl MetalStorage {
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
("add", DType::F16) => (contiguous::add::HALF, self.dtype),
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
@ -1353,18 +1328,6 @@ impl MetalStorage {
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype),
("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype),
("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype),
("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype),
("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8),
("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8),
("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8),
("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8),
("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8),
("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8),
("add", DType::I64) => (contiguous::add::I64, self.dtype),
("sub", DType::I64) => (contiguous::sub::I64, self.dtype),
("mul", DType::I64) => (contiguous::mul::I64, self.dtype),
@ -1375,7 +1338,6 @@ impl MetalStorage {
("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
("add", DType::U32) => (contiguous::add::U32, self.dtype),
("sub", DType::U32) => (contiguous::sub::U32, self.dtype),
("mul", DType::U32) => (contiguous::mul::U32, self.dtype),
@ -1386,7 +1348,6 @@ impl MetalStorage {
("lt", DType::U32) => (contiguous::lt::U32, DType::U8),
("ge", DType::U32) => (contiguous::ge::U32, DType::U8),
("gt", DType::U32) => (contiguous::gt::U32, DType::U8),
("add", DType::U8) => (contiguous::add::U8, self.dtype),
("sub", DType::U8) => (contiguous::sub::U8, self.dtype),
("mul", DType::U8) => (contiguous::mul::U8, self.dtype),
@ -1397,7 +1358,6 @@ impl MetalStorage {
("lt", DType::U8) => (contiguous::lt::U8, DType::U8),
("ge", DType::U8) => (contiguous::ge::U8, DType::U8),
("gt", DType::U8) => (contiguous::gt::U8, DType::U8),
(name, dtype) => {
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
}
@ -1431,7 +1391,6 @@ impl MetalStorage {
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
("badd", DType::F16) => (strided::add::HALF, self.dtype),
("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
@ -1444,20 +1403,6 @@ impl MetalStorage {
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype),
("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype),
("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype),
("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype),
("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype),
("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype),
("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8),
("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8),
("le", DType::BF16) => (strided::le::BFLOAT, DType::U8),
("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8),
("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8),
("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8),
("badd", DType::I64) => (strided::add::I64, self.dtype),
("bsub", DType::I64) => (strided::sub::I64, self.dtype),
("bmul", DType::I64) => (strided::mul::I64, self.dtype),
@ -1470,7 +1415,6 @@ impl MetalStorage {
("lt", DType::I64) => (strided::lt::I64, DType::U8),
("ge", DType::I64) => (strided::ge::I64, DType::U8),
("gt", DType::I64) => (strided::gt::I64, DType::U8),
("badd", DType::U32) => (strided::add::U32, self.dtype),
("bsub", DType::U32) => (strided::sub::U32, self.dtype),
("bmul", DType::U32) => (strided::mul::U32, self.dtype),
@ -1483,7 +1427,6 @@ impl MetalStorage {
("lt", DType::U32) => (strided::lt::U32, DType::U8),
("ge", DType::U32) => (strided::ge::U32, DType::U8),
("gt", DType::U32) => (strided::gt::U32, DType::U8),
("badd", DType::U8) => (strided::add::U8, self.dtype),
("bsub", DType::U8) => (strided::sub::U8, self.dtype),
("bmul", DType::U8) => (strided::mul::U8, self.dtype),
@ -1496,7 +1439,6 @@ impl MetalStorage {
("lt", DType::U8) => (strided::lt::U8, DType::U8),
("ge", DType::U8) => (strided::ge::U8, DType::U8),
("gt", DType::U8) => (strided::gt::U8, DType::U8),
(name, dtype) => {
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
}
@ -1522,28 +1464,6 @@ impl MetalStorage {
command_buffer.set_label("binary");
Ok(Self::new(buffer, device.clone(), dtype))
}
pub(crate) fn to_cpu<T: Clone>(&self) -> Result<Vec<T>> {
let length = self.buffer.length() as usize;
let size = self.dtype.size_in_bytes();
if length % size != 0 {
crate::bail!(
"The Metal buffer length is not aligned with dtype {:?}",
self.dtype
);
}
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
{
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding();
}
self.device.wait_until_completed()?;
Ok(read_to_vec(&buffer, length / size))
}
}
impl BackendDevice for MetalDevice {
@ -1556,29 +1476,29 @@ impl BackendDevice for MetalDevice {
command_buffer.enqueue();
let command_buffer = Arc::new(RwLock::new(command_buffer));
let command_buffer_index = Arc::new(RwLock::new(0));
let kernels = Arc::new(Kernels::new());
let fence = device.new_fence();
let kernels = Arc::new(Kernels::new(fence.clone()));
let buffers = Arc::new(RwLock::new(HashMap::new()));
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?,
_ => 10,
_ => 20,
};
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
4,
MTLResourceOptions::StorageModeManaged,
)));
Ok(Self {
device,
fence,
command_queue,
command_buffer,
command_buffer_index,
compute_per_buffer,
buffers,
kernels,
seed,
})
}
fn set_seed(&self, _seed: u64) -> Result<()> {
crate::bail!("Metal set_seed not implemented")
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Metal {
gpu_id: self.registry_id() as usize,
@ -1590,8 +1510,21 @@ impl BackendDevice for MetalDevice {
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let size = shape.elem_count() * dtype.size_in_bytes();
let buffer = self.allocate_zeros(size)?;
let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?;
let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder();
blit.wait_for_fence(&self.fence);
blit.fill_buffer(
&buffer,
metal::NSRange {
location: 0,
length: buffer.length(),
},
0,
);
blit.update_fence(&self.fence);
blit.end_encoding();
Ok(MetalStorage::new(buffer, self.clone(), dtype))
}
@ -1618,31 +1551,12 @@ impl BackendDevice for MetalDevice {
&self,
shape: &Shape,
dtype: DType,
min: f64,
max: f64,
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
let name = match dtype {
DType::F32 => "rand_uniform_f32",
DType::F16 => "rand_uniform_f16",
DType::BF16 => "rand_uniform_bf16",
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_uniform")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_random_uniform(
&self.device,
&command_buffer,
&self.kernels,
name,
min as f32,
max as f32,
shape.elem_count(),
&*self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::Storage::new(buffer, self.clone(), dtype))
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)
}
fn rand_normal(
@ -1652,43 +1566,9 @@ impl BackendDevice for MetalDevice {
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
let name = match dtype {
DType::F32 => "rand_normal_f32",
DType::F16 => "rand_normal_f16",
DType::BF16 => "rand_normal_bf16",
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_random_normal(
&self.device,
&command_buffer,
&self.kernels,
name,
mean as f32,
stddev as f32,
shape.elem_count(),
&*self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::Storage::new(buffer, self.clone(), dtype))
}
fn set_seed(&self, seed: u64) -> Result<()> {
let seed: u32 = seed.try_into().map_err(|_| {
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())
})?;
let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?;
let contents = seed_buffer.contents();
unsafe {
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4);
}
seed_buffer.did_modify_range(metal::NSRange::new(0, 4));
Ok(())
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)
}
}

View File

@ -333,16 +333,6 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_exp_inplace(y: &mut [f32]) {
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_exp_inplace(y: &mut [f64]) {
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
@ -365,28 +355,6 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
}
}
#[inline]
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vs_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
#[inline]
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vd_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
macro_rules! binary_op {
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
#[inline]

View File

@ -61,7 +61,6 @@ pub enum UnaryOp {
GeluErf,
Erf,
Relu,
Silu,
Tanh,
Floor,
Ceil,
@ -391,7 +390,6 @@ pub(crate) struct Gelu;
pub(crate) struct GeluErf;
pub(crate) struct Erf;
pub(crate) struct Relu;
pub(crate) struct Silu;
pub(crate) struct Tanh;
pub(crate) struct Floor;
pub(crate) struct Ceil;
@ -726,77 +724,6 @@ impl UnaryOpT for Erf {
}
}
/// Silu operation
impl UnaryOpT for Silu {
const NAME: &'static str = "silu";
const V: Self = Silu;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v / (bf16::ONE + (-v).exp())
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v / (f16::ONE + (-v).exp())
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v / (1.0 + (-v).exp())
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v / (1.0 + (-v).exp())
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
const KERNEL: &'static str = "usilu";
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::vs_silu(xs, ys)
}
#[cfg(feature = "mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_silu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::accelerate::vs_silu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::accelerate::vd_silu(xs, ys)
}
}
impl UnaryOpT for Abs {
const NAME: &'static str = "abs";
const KERNEL: &'static str = "uabs";

View File

@ -217,13 +217,6 @@ impl Object {
let args = args.remove(1);
(callable, args)
}
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
let mut args = args.tuple()?;
args.remove(0).reduce()?
}
_ => (callable, args),
};
match callable {
@ -234,11 +227,13 @@ impl Object {
_ => return Ok(None),
};
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
let mut path = dir_name.to_path_buf();
path.push(file_path);
Ok(Some(TensorInfo {
name,
dtype,
layout,
path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
path: path.to_string_lossy().into_owned(),
storage_size,
}))
}
@ -350,10 +345,8 @@ impl Stack {
module_name,
class_name,
} => {
if module_name == "collections"
&& (class_name == "OrderedDict" || class_name == "defaultdict")
{
// TODO: have a separate ordered dict and a separate default dict.
if module_name == "collections" && class_name == "OrderedDict" {
// TODO: have a separate ordered dict.
Some(Object::Dict(vec![]))
} else {
None
@ -634,16 +627,9 @@ pub struct TensorInfo {
pub storage_size: usize,
}
/// Read the tensor info from a .pth file.
///
/// # Arguments
/// * `file` - The path to the .pth file.
/// * `verbose` - Whether to print debug information.
/// * `key` - Optional key to retrieve `state_dict` from the pth file.
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
file: P,
verbose: bool,
key: Option<&str>,
) -> Result<Vec<TensorInfo>> {
let file = std::fs::File::open(file)?;
let zip_reader = std::io::BufReader::new(file);
@ -665,9 +651,8 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
stack.read_loop(&mut reader)?;
let obj = stack.finalize()?;
if VERBOSE || verbose {
println!("{obj:#?}");
println!("{obj:?}");
}
let obj = match obj {
Object::Build { callable, args } => match *callable {
Object::Reduce { callable, args: _ } => match *callable {
@ -681,24 +666,6 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
},
obj => obj,
};
// If key is provided, then we need to extract the state_dict from the object.
let obj = if let Some(key) = key {
if let Object::Dict(key_values) = obj {
key_values
.into_iter()
.find(|(k, _)| *k == Object::Unicode(key.to_owned()))
.map(|(_, v)| v)
.ok_or_else(|| E::Msg(format!("key {key} not found")))?
} else {
obj
}
} else {
obj
};
// If the object is a dict, then we can extract the tensor info from it.
// NOTE: We are assuming that the `obj` is state_dict by this stage.
if let Object::Dict(key_values) = obj {
for (name, value) in key_values.into_iter() {
match value.into_tensor_info(name, &dir_name) {
@ -721,8 +688,8 @@ pub struct PthTensors {
}
impl PthTensors {
pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
let tensor_infos = tensor_infos
.into_iter()
.map(|ti| (ti.name.to_string(), ti))
@ -736,7 +703,6 @@ impl PthTensors {
}
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
use std::io::Read;
let tensor_info = match self.tensor_infos.get(name) {
None => return Ok(None),
Some(tensor_info) => tensor_info,
@ -745,56 +711,27 @@ impl PthTensors {
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_name(&tensor_info.path)?;
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
let rank = tensor_info.layout.shape().rank();
// Reading the data is a bit tricky as it can be strided, for now only support the basic
// case and when the tensor is fortran contiguous.
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
// For now only support the basic case.
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
crate::bail!(
"cannot retrieve non-contiguous tensors {:?}",
tensor_info.layout
)
}
let start_offset = tensor_info.layout.start_offset();
if start_offset > 0 {
std::io::copy(
&mut reader.by_ref().take(start_offset as u64),
&mut std::io::sink(),
)?;
}
let tensor = Tensor::from_reader(
tensor_info.layout.shape().clone(),
tensor_info.dtype,
&mut reader,
)?;
if rank > 1 && is_fortran_contiguous {
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
let tensor = tensor.reshape(shape_reversed)?;
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
let tensor = tensor.permute(dim_indeces_reversed)?;
Ok(Some(tensor))
} else {
Ok(Some(tensor))
}
Ok(Some(tensor))
}
}
/// Read all the tensors from a PyTorch pth file with a given key.
///
/// # Arguments
/// * `path` - Path to the pth file.
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
/// contains multiple objects and the state_dict is the one we are interested in.
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
path: P,
key: Option<&str>,
) -> Result<Vec<(String, Tensor)>> {
let pth = PthTensors::new(path, key)?;
/// Read all the tensors from a PyTorch pth file.
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
let pth = PthTensors::new(path)?;
let tensor_names = pth.tensor_infos.keys();
let mut tensors = Vec::with_capacity(tensor_names.len());
for name in tensor_names {
@ -804,11 +741,3 @@ pub fn read_all_with_key<P: AsRef<std::path::Path>>(
}
Ok(tensors)
}
/// Read all the tensors from a PyTorch pth file.
///
/// # Arguments
/// * `path` - Path to the pth file.
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
read_all_with_key(path, None)
}

View File

@ -1,43 +0,0 @@
#![allow(unused)]
use super::GgmlDType;
use crate::{Error, MetalDevice, MetalStorage, Result};
pub struct QMetalStorage {
dtype: GgmlDType,
device: MetalDevice,
}
impl QMetalStorage {
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &MetalDevice {
&self.device
}
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn storage_size_in_bytes(&self) -> usize {
0
}
pub fn fwd(
&self,
_self_shape: &crate::Shape,
_storage: &MetalStorage,
_layout: &crate::Layout,
) -> Result<(MetalStorage, crate::Shape)> {
Err(Error::NotCompiledWithMetalSupport)
}
}

View File

@ -1,9 +1,7 @@
//! Support for the GGML file format.
#[cfg(feature = "metal")]
use super::metal::load_quantized_metal;
use super::{k_quants, GgmlDType, QStorage};
use crate::{Device, Result};
use super::{k_quants, GgmlDType};
use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
@ -123,22 +121,11 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
raw_data: &[u8],
size_in_bytes: usize,
dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> {
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
let data: QStorage = match device {
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
#[cfg(feature = "metal")]
Device::Metal(metal) => load_quantized_metal(metal, data)?,
#[cfg(not(feature = "metal"))]
Device::Metal(_metal) => {
crate::bail!("Metal backend requires `metal` feature")
}
device => unimplemented!("Implement quantized tensor for device {device:?}"),
};
super::QTensor::new(data, dims)
super::QTensor::new(data.to_vec(), dims)
}
/// Creates a [Tensor] from a raw GGML tensor.
@ -146,50 +133,29 @@ pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType,
raw_data: &[u8],
dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>();
let block_size = ggml_dtype.block_size();
if tensor_elems % block_size != 0 {
let blck_size = ggml_dtype.blck_size();
if tensor_elems % blck_size != 0 {
crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
)
}
let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size();
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4_0 => {
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4_1 => {
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_0 => {
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_1 => {
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q8_0 => {
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q2K => {
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q3K => {
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4K => {
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5K => {
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q6K => {
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
}
}
@ -197,7 +163,6 @@ pub fn qtensor_from_ggml(
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R,
magic: VersionedMagic,
device: &Device,
) -> Result<(String, super::QTensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?;
@ -218,11 +183,11 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
}
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
let tensor_elems = dims.iter().product::<usize>();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
// TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?;
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
Ok(tensor) => Ok((name, tensor)),
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
}
@ -233,14 +198,10 @@ pub struct Content {
pub hparams: HParams,
pub vocab: Vocab,
pub tensors: HashMap<String, super::QTensor>,
pub device: Device,
}
impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(
reader: &mut R,
device: &Device,
) -> Result<Content> {
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?;
@ -250,16 +211,14 @@ impl Content {
let mut tensors = HashMap::new();
while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic, device)?;
let (name, tensor) = read_one_tensor(reader, magic)?;
tensors.insert(name, tensor);
}
let device = device.clone();
Ok(Self {
magic,
hparams,
vocab,
tensors,
device,
})
}

View File

@ -3,7 +3,7 @@
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor};
use crate::{Device, Result};
use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;
@ -59,25 +59,19 @@ impl TensorInfo {
&self,
reader: &mut R,
tensor_data_offset: u64,
device: &Device,
) -> Result<QTensor> {
let tensor_elems = self.shape.elem_count();
let block_size = self.ggml_dtype.block_size();
if tensor_elems % block_size != 0 {
let blck_size = self.ggml_dtype.blck_size();
if tensor_elems % blck_size != 0 {
crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
)
}
let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size();
let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?;
super::ggml_file::qtensor_from_ggml(
self.ggml_dtype,
&raw_data,
self.shape.dims().to_vec(),
device,
)
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
}
}
@ -466,13 +460,12 @@ impl Content {
&self,
reader: &mut R,
name: &str,
device: &Device,
) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info,
None => crate::bail!("cannot find tensor info for {name}"),
};
tensor_info.read(reader, self.tensor_data_offset, device)
tensor_info.read(reader, self.tensor_data_offset)
}
}
@ -524,9 +517,10 @@ pub fn write<W: std::io::Seek + std::io::Write>(
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
)
}
let data = tensor.data()?;
let size_in_bytes = data.len();
w.write_all(&data)?;
let data_ptr = tensor.as_ptr();
let size_in_bytes = tensor.storage_size_in_bytes();
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
w.write_all(data)?;
let padding = 31 - (31 + size_in_bytes) % 32;
w.write_all(&vec![0u8; padding])?;
}

View File

@ -1545,13 +1545,13 @@ impl GgmlType for BlockQ5K {
let d2 = d * sc as f32;
let m2 = min * m as f32;
for (ql, qh) in ql.iter().zip(qh) {
let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 };
y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1;
let to_add = if qh & u1 != 0 { 16 } else { 1 };
y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1;
ys_index += 1;
}
for (ql, qh) in ql.iter().zip(qh) {
let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 };
y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2;
let to_add = if qh & u2 != 0 { 16 } else { 1 };
y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2;
ys_index += 1;
}
is += 2;

View File

@ -1,234 +0,0 @@
use super::{GgmlDType, QStorage};
use crate::backend::BackendStorage;
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
use metal::Buffer;
use std::sync::Arc;
pub struct QMetalStorage {
dtype: GgmlDType,
device: MetalDevice,
buffer: Arc<Buffer>,
}
impl QMetalStorage {
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
let size = elem_count * dtype.type_size() / dtype.block_size();
let buffer = device.allocate_zeros(size)?;
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &MetalDevice {
&self.device
}
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding();
self.device.wait_until_completed()?;
let mut out = vec![0.0; elem_count];
match self.dtype {
GgmlDType::F32 => {
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
f32::to_float(&vec, &mut out)?;
}
GgmlDType::F16 => {
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
half::f16::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_0 => {
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_1 => {
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_0 => {
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_1 => {
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_0 => {
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_1 => {
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q2K => {
let vec: Vec<crate::quantized::BlockQ2K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
}
GgmlDType::Q3K => {
let vec: Vec<crate::quantized::BlockQ3K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
}
GgmlDType::Q4K => {
let vec: Vec<crate::quantized::BlockQ4K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
}
GgmlDType::Q5K => {
let vec: Vec<crate::quantized::BlockQ5K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
}
GgmlDType::Q6K => {
let vec: Vec<crate::quantized::BlockQ6K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
}
GgmlDType::Q8K => {
let vec: Vec<crate::quantized::BlockQ8K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
}
}
let buffer = self.device.new_buffer_with_data(&out)?;
Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
}
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
// Quantization only happens on CPU for now.
let src = src.to_cpu::<f32>()?;
let elem_count = src.len();
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;
qcpu_storage.quantize(&src)?;
let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
self.buffer = buffer;
Ok(())
}
pub fn storage_size_in_bytes(&self) -> usize {
self.buffer.length() as usize
}
pub fn fwd(
&self,
self_shape: &Shape,
storage: &MetalStorage,
layout: &crate::Layout,
) -> Result<(MetalStorage, Shape)> {
use crate::MetalError;
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let (n, k) = self_shape.dims2()?;
let mut dst_shape = src_shape.dims().to_vec();
let (b, m) = match dst_shape.len() {
3 => (dst_shape[0], dst_shape[1]),
2 => (1, dst_shape[0]),
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
let device = storage.device().clone();
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
let command_buffer = device.command_buffer()?;
candle_metal_kernels::call_quantized_matmul_t(
device.device(),
&command_buffer,
device.kernels(),
self.dtype.into(),
(b, m, n, k),
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
&self.buffer,
&dst,
)
.map_err(MetalError::from)?;
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
Ok((dst_storage, dst_shape))
}
}
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
device: &MetalDevice,
data: &[T],
) -> Result<QStorage> {
let buffer = device.new_buffer_with_data(data)?;
let device = device.clone();
Ok(QStorage::Metal(QMetalStorage {
dtype: T::DTYPE,
device,
buffer,
}))
}
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
assert!(!ptr.is_null());
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
slice.to_vec()
}
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
fn from(value: GgmlDType) -> Self {
match value {
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
}
}
}

View File

@ -1,118 +1,23 @@
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;
use crate::{Device, Result, Shape, Tensor};
#[cfg(target_feature = "avx")]
pub mod avx;
mod dummy_metal;
pub mod ggml_file;
pub mod gguf_file;
pub mod k_quants;
#[cfg(feature = "metal")]
pub mod metal;
#[cfg(not(feature = "metal"))]
mod metal {
pub use super::dummy_metal::*;
}
#[cfg(target_feature = "neon")]
pub mod neon;
#[cfg(target_feature = "simd128")]
pub mod simd128;
pub mod utils;
use half::f16;
pub use k_quants::GgmlType;
pub struct QTensor {
storage: QStorage,
data: Box<dyn QuantizedType>,
shape: Shape,
}
impl Device {
fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
match self {
Device::Cpu => {
let storage = dtype.cpu_zeros(elem_count);
Ok(QStorage::Cpu(storage))
}
Device::Metal(metal) => {
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
Ok(QStorage::Metal(storage))
}
Device::Cuda(_cuda) => {
crate::bail!("Cuda ggml quantization not supported");
}
}
}
}
pub enum QStorage {
Cpu(Box<dyn QuantizedType>),
Metal(metal::QMetalStorage),
}
impl QStorage {
fn block_size(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.block_size(),
QStorage::Metal(storage) => storage.dtype().block_size(),
}
}
fn dtype(&self) -> GgmlDType {
match self {
QStorage::Cpu(storage) => storage.dtype(),
QStorage::Metal(storage) => storage.dtype(),
}
}
fn device(&self) -> Device {
match self {
QStorage::Cpu(_storage) => Device::Cpu,
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
}
}
fn size_in_bytes(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
}
}
fn quantize(&mut self, src: &Storage) -> Result<()> {
match (self, src) {
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
storage.from_float(src.as_slice::<f32>()?)?;
}
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
_ => crate::bail!("Invalid dequantize storage locations do not match"),
}
Ok(())
}
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
match self {
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
}
}
fn data(&self) -> Result<Cow<[u8]>> {
match self {
QStorage::Cpu(storage) => {
let data_ptr = storage.as_ptr();
let size_in_bytes = storage.storage_size_in_bytes();
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
Ok(Cow::from(data))
}
QStorage::Metal(_storage) => {
crate::bail!("not implemented");
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GgmlDType {
F32,
@ -172,25 +77,6 @@ impl GgmlDType {
}
}
/// The block dtype
pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
match self {
Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
}
}
/// The type size for blocks in bytes.
pub fn type_size(&self) -> usize {
use k_quants::*;
@ -214,7 +100,7 @@ impl GgmlDType {
}
/// The block size, i.e. the number of elements stored in each block.
pub fn block_size(&self) -> usize {
pub fn blck_size(&self) -> usize {
match self {
Self::F32 => 1,
Self::F16 => 1,
@ -233,13 +119,9 @@ impl GgmlDType {
pub trait QuantizedType: Send + Sync {
fn dtype(&self) -> GgmlDType;
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
fn storage_size_in_bytes(&self) -> usize;
fn as_ptr(&self) -> *const u8;
fn block_size(&self) -> usize;
#[allow(clippy::wrong_self_convention)]
fn from_float(&mut self, xs: &[f32]) -> Result<()>;
fn size(&self) -> usize;
}
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
@ -247,26 +129,12 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
}
fn size(&self) -> usize {
self.len() * core::mem::size_of::<T>()
}
fn from_float(&mut self, xs: &[f32]) -> Result<()> {
T::from_float(xs, self)
}
fn dtype(&self) -> GgmlDType {
T::DTYPE
}
fn block_size(&self) -> usize {
T::BLCK_SIZE
}
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
let mut ys = vec![0.0f32; elem_count];
T::to_float(self.as_slice(), &mut ys)?;
Ok(CpuStorage::F32(ys))
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
T::to_float(self.as_slice(), ys)
}
fn storage_size_in_bytes(&self) -> usize {
@ -284,53 +152,56 @@ impl std::fmt::Debug for QTensor {
}
}
fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
let dims = shape.dims();
if dims.is_empty() {
crate::bail!("scalar tensor cannot be quantized {shape:?}")
}
if dims[dims.len() - 1] % block_size != 0 {
if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
crate::bail!(
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
block_size
T::BLCK_SIZE
)
}
Ok(())
}
impl QTensor {
pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
data: Vec<T>,
shape: S,
) -> Result<Self> {
let shape = shape.into();
check_shape(&shape, storage.block_size())?;
Ok(Self { storage, shape })
check_shape::<T>(&shape)?;
Ok(Self {
data: Box::new(data),
shape,
})
}
pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
let shape = src.shape();
let block_size = dtype.block_size();
check_shape(shape, block_size)?;
let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
let elem_count = shape.elem_count();
if elem_count % block_size != 0 {
check_shape::<T>(shape)?;
let src = src
.to_dtype(crate::DType::F32)?
.flatten_all()?
.to_vec1::<f32>()?;
if src.len() % T::BLCK_SIZE != 0 {
crate::bail!(
"tensor size ({shape:?}) is not divisible by block size {}",
block_size
T::BLCK_SIZE
)
}
let mut storage = src.device().qzeros(elem_count, dtype)?;
storage.quantize(&src.storage())?;
let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
T::from_float(&src, &mut data)?;
Ok(Self {
storage,
data: Box::new(data),
shape: shape.clone(),
})
}
pub fn dtype(&self) -> GgmlDType {
self.storage.dtype()
}
pub fn device(&self) -> Device {
self.storage.device()
self.data.dtype()
}
pub fn rank(&self) -> usize {
@ -342,19 +213,21 @@ impl QTensor {
}
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
let storage = self.storage.dequantize(self.shape.elem_count())?;
let none = crate::op::BackpropOp::none();
let is_variable = false;
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
.to_device(device)
let mut f32_data = vec![0f32; self.shape.elem_count()];
self.data.to_float(&mut f32_data)?;
Tensor::from_vec(f32_data, &self.shape, device)
}
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
self.data.matmul_t(mkn, lhs, dst)
}
pub fn storage_size_in_bytes(&self) -> usize {
self.storage.size_in_bytes()
self.data.storage_size_in_bytes()
}
pub fn data(&self) -> Result<Cow<'_, [u8]>> {
self.storage.data()
pub fn as_ptr(&self) -> *const u8 {
self.data.as_ptr()
}
}
@ -421,29 +294,17 @@ impl crate::CustomOp1 for QTensor {
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
#[allow(clippy::infallible_destructuring_match)]
let self_storage = match &self.storage {
QStorage::Cpu(storage) => storage,
QStorage::Metal(_) => crate::bail!("Invalid storage"),
};
let slice = storage.as_slice::<f32>()?;
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let storage = storage.as_slice::<f32>()?;
let storage =
&storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?;
self.matmul_t(
(dst_shape.elem_count() / n, k, n),
storage,
&mut dst_storage,
)?;
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
}
fn metal_fwd(
&self,
storage: &crate::MetalStorage,
layout: &crate::Layout,
) -> Result<(crate::MetalStorage, Shape)> {
let self_storage = match &self.storage {
QStorage::Metal(metal) => metal,
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
};
self_storage.fwd(&self.shape, storage, layout)
}
}
impl crate::Module for QMatMul {

View File

@ -508,7 +508,6 @@ impl Tensor {
unary_op!(gelu_erf, GeluErf);
unary_op!(erf, Erf);
unary_op!(relu, Relu);
unary_op!(silu, Silu);
unary_op!(ceil, Ceil);
unary_op!(floor, Floor);
unary_op!(round, Round);
@ -805,35 +804,6 @@ impl Tensor {
}
}
/// Roll the tensor input along the given dimension.
/// Elements that are shifted beyond the last position are re-introduced at the first position.
///
/// ```rust
/// # use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.roll(1, 0)?;
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.roll(-1, 0)?;
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
where
D: Dim + Clone,
{
let dim = dim.to_index(self.shape(), "roll")?;
let dim_size = self.dim(dim)?;
let shift = shift.rem_euclid(dim_size as i32) as usize;
if shift == 0 {
Ok(self.clone())
} else {
let a = self.narrow(dim, 0, dim_size - shift)?;
let b = self.narrow(dim, dim_size - shift, shift)?;
Tensor::cat(&[&b, &a], dim)
}
}
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
/// input dimensions.
///
@ -1883,9 +1853,9 @@ impl Tensor {
/// this new node. The storage of this tensor is shared with the initial tensor.
///
/// If the tensor is already detached from the computation graph, the same tensor is returned.
pub fn detach(&self) -> Tensor {
pub fn detach(&self) -> Result<Tensor> {
if self.op.is_none() && !self.is_variable {
self.clone()
Ok(self.clone())
} else {
let tensor_ = Tensor_ {
id: TensorId::new(),
@ -1896,7 +1866,7 @@ impl Tensor {
dtype: self.dtype,
device: self.device.clone(),
};
Tensor(Arc::new(tensor_))
Ok(Tensor(Arc::new(tensor_)))
}
}
@ -2608,21 +2578,11 @@ impl Tensor {
}
/// Returns log(sum(exp(tensor), dim)).
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
let exp = self.exp()?;
let sum = exp.sum(sum_dims)?;
sum.log()
}
/// Pointwise pow operation.
pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.mul(&self.log()?)?.exp()
}
/// Broadcasting version of `pow`.
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.broadcast_mul(&self.log()?)?.exp()
}
}
macro_rules! bin_trait {

View File

@ -107,10 +107,6 @@ impl Var {
Ok(Self(inner))
}
pub fn as_detached_tensor(&self) -> Tensor {
self.0.detach()
}
pub fn as_tensor(&self) -> &Tensor {
&self.0
}

View File

@ -50,15 +50,17 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
4.7076, -5.9745, -0.8276, 1.621
],
);
if dev.is_cpu() {
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
4.7076, -5.9745, -0.8276, 1.621
],
);
}
Ok(())
}

View File

@ -270,19 +270,6 @@ fn unary_grad(device: &Device) -> Result<()> {
[0.7358, 2.0000, 0.2707, 1.0000]
);
// testing compared to pytorch nn.Silu()
let y = x.silu()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.8577, 0.7311, 3.9281, 0.0806]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[1.0881, 0.9277, 1.0527, 0.5747],
);
// manually checked: see comments
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?;

View File

@ -1,37 +0,0 @@
import torch
from collections import OrderedDict
# Write a trivial tensor to a pt file
a= torch.tensor([[1,2,3,4], [5,6,7,8]])
o = OrderedDict()
o["test"] = a
# Write a trivial tensor to a pt file
torch.save(o, "test.pt")
############################################################################################################
# Write a trivial tensor to a pt file with a key
torch.save({"model_state_dict": o}, "test_with_key.pt")
############################################################################################################
# Create a tensor with fortran contiguous memory layout
import numpy as np
# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers
# For example, creating a 2x3x4 array
array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4))
# Verify the memory order
print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True
print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False
# Step 2: Convert the NumPy array to a PyTorch tensor
tensor_fortran = torch.from_numpy(array_fortran)
# Verify the tensor layout
print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout
# Step 3: Save the PyTorch tensor to a .pth file
torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth')
print("3D Tensor saved with Fortran layout.")

View File

@ -1,31 +0,0 @@
/// Regression test for pth files not loading on Windows.
#[test]
fn test_pth() {
let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
tensors.get("test").unwrap().unwrap();
}
#[test]
fn test_pth_with_key() {
let tensors =
candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
.unwrap();
tensors.get("test").unwrap().unwrap();
}
#[test]
fn test_pth_fortran_congiguous() {
let tensors =
candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap();
let tensor = tensors.get("tensor_fortran").unwrap().unwrap();
assert_eq!(tensor.dims3().unwrap(), (2, 3, 4));
assert_eq!(
tensor.to_vec3::<i64>().unwrap(),
[
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
]
);
}

View File

@ -1,7 +1,6 @@
use candle_core::{
bail,
quantized::{self, GgmlDType},
test_device,
test_utils::to_vec2_round,
Device, Module, Result, Tensor,
};
@ -15,48 +14,16 @@ const GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS: f32 = 0.0075;
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040;
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
fn test_matmul(
device: &Device,
(b, m, n, k): (usize, usize, usize, usize),
dtype: GgmlDType,
) -> Result<()> {
let lhs = (0..(m * k))
.map(|v| v as f32 / (m * k) as f32)
.collect::<Vec<_>>();
let rhs = (0..(k * n))
.map(|v| v as f32 / (n * k) as f32)
.collect::<Vec<_>>();
let lhs = Tensor::from_slice(&lhs, (m, k), device)?;
let rhs = Tensor::from_slice(&rhs, (k, n), device)?;
let mm = lhs.matmul(&rhs)?;
let qtensor = quantized::QTensor::quantize(&rhs.t()?, dtype)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&lhs)?;
let error: f32 = ((&mm - &res)?.abs()? / &mm.abs()?)?
.sum_all()?
.to_scalar()?;
let error = error / (b * m * n) as f32;
assert!(
error <= 0.02,
"Error {error} is too big. \nExpected:\n {mm} \nFound:\n {res}\n for {dtype:?}"
);
Ok(())
}
fn quantized_matmul(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
#[test]
fn quantized_matmul() -> Result<()> {
let cpu = &Device::Cpu;
let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!(
@ -66,7 +33,6 @@ fn quantized_matmul(device: &Device) -> Result<()> {
341876.0, 994283.0, 1655709.0, 2301518.0
]
);
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
let mm = tensor_lhs.matmul(&tensor_rhs)?;
assert_eq!(
mm.to_vec2::<f32>()?,
@ -77,49 +43,35 @@ fn quantized_matmul(device: &Device) -> Result<()> {
]
);
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
match device {
Device::Metal(_) => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[84946.0, 214126.0, 344757.0, 473798.0],
[213458.0, 604350.0, 1000469.0, 1387990.0],
[341970.0, 994574.0, 1656181.0, 2302182.0]
]
),
_ => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[85120.0, 214562.0, 345455.0, 474748.0],
[213475.0, 604465.0, 1000686.0, 1388317.0],
[341876.0, 994283.0, 1655709.0, 2301518.0]
]
),
}
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
assert_eq!(
to_vec2_round(&res, 0)?,
&[
[85120.0, 214562.0, 345455.0, 474748.0],
[213475.0, 604465.0, 1000686.0, 1388317.0],
[341876.0, 994283.0, 1655709.0, 2301518.0]
]
);
Ok(())
}
fn quantized_matmul_neg(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
#[test]
fn quantized_matmul_neg() -> Result<()> {
let cpu = &Device::Cpu;
let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k))
.map(|v| v as f32 - (m * k) as f32 / 2.0)
.collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..k * n)
.map(|v| v as f32 - (k * n) as f32 / 3.0)
.collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!(
@ -139,56 +91,32 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
]
);
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
match device {
Device::Metal(_) => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243666.0, -19714.0, -285433.0, -550453.0],
[23782.0, 21654.0, 19400.0, 18369.0],
[-196102.0, 63022.0, 324233.0, 587191.0]
]
),
_ => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243524.0, -19596.0, -285051.0, -549815.0],
[23777.0, 21651.0, 19398.0, 18367.0],
[-196472.0, 63012.0, 324585.0, 587902.0]
]
),
}
assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243524.0, -19596.0, -285051.0, -549815.0],
[23777.0, 21651.0, 19398.0, 18367.0],
[-196472.0, 63012.0, 324585.0, 587902.0]
]
);
Ok(())
}
test_device!(
quantized_matmul,
quantized_matmul_cpu,
quantized_matmul_cuda,
quantized_matmul_metal
);
test_device!(
quantized_matmul_neg,
quantized_matmul_neg_cpu,
quantized_matmul_neg_cuda,
quantized_matmul_neg_metal
);
#[test]
fn quantize_q4_0() -> Result<()> {
use k_quants::BlockQ4_0;
fn quantize_q4_0(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
let dst = quant.dequantize(device)?;
let mut dst = vec![0f32; 32 * 4];
let mut quant = vec![BlockQ4_0::zeros(); 4];
BlockQ4_0::from_float(&src, &mut quant)?;
BlockQ4_0::to_float(&quant, dst.as_mut_slice())?;
assert_eq!(
dst.to_vec1::<f32>()?,
dst,
&[
-0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625,
11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25,
@ -204,21 +132,21 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
127.0, 127.0
]
);
ggml_quantization_error_test(GgmlDType::Q4_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
ggml_quantization_error_test::<BlockQ4_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
fn quantize_q4_1(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
#[test]
fn quantize_q4_1() -> Result<()> {
use k_quants::BlockQ4_1;
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
let dst = quant.dequantize(device)?;
let mut dst = vec![0f32; 32 * 4];
let mut quant = vec![BlockQ4_1::zeros(); 4];
BlockQ4_1::from_float(&src, &mut quant)?;
BlockQ4_1::to_float(&quant, dst.as_mut_slice())?;
assert_eq!(
round_vector(&dst.to_vec1::<f32>()?),
round_vector(&dst),
&[
0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332,
12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73,
@ -234,21 +162,21 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996
]
);
ggml_quantization_error_test(GgmlDType::Q4_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
ggml_quantization_error_test::<BlockQ4_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
fn quantize_q5_0(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
#[test]
fn quantize_q5_0() -> Result<()> {
use k_quants::BlockQ5_0;
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
let dst = quant.dequantize(device)?;
let mut dst = vec![0f32; 32 * 4];
let mut quant = vec![BlockQ5_0::zeros(); 4];
BlockQ5_0::from_float(&src, &mut quant)?;
BlockQ5_0::to_float(&quant, dst.as_mut_slice())?;
assert_eq!(
round_vector(&dst.to_vec1::<f32>()?),
round_vector(&dst),
&[
-0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625,
11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313,
@ -264,21 +192,21 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0
]
);
ggml_quantization_error_test(GgmlDType::Q5_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
ggml_quantization_error_test::<BlockQ5_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
fn quantize_q5_1(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
#[test]
fn quantize_q5_1() -> Result<()> {
use k_quants::BlockQ5_1;
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
let dst = quant.dequantize(device)?;
let mut dst = vec![0f32; 32 * 4];
let mut quant = vec![BlockQ5_1::zeros(); 4];
BlockQ5_1::from_float(&src, &mut quant)?;
BlockQ5_1::to_float(&quant, dst.as_mut_slice())?;
assert_eq!(
round_vector(&dst.to_vec1::<f32>()?),
dst,
&[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
@ -292,11 +220,13 @@ fn quantize_q5_1(device: &Device) -> Result<()> {
124.0, 125.0, 126.0, 127.0
]
);
ggml_quantization_error_test(GgmlDType::Q5_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
ggml_quantization_error_test::<BlockQ5_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result<Tensor> {
/// Generates a small test vector ranging from -`bound` to `bound` with `size` steps
fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
assert!(
size % crate::quantized::k_quants::QK_K == 0,
"size must be a multiple of {}",
@ -306,8 +236,10 @@ fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result<Tensor>
let src = (0..size)
.map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.))
.collect::<Vec<_>>();
let dst = vec![0f32; size];
assert_eq!([src[0], src[size / 2]], [-bound, 0.0]);
Tensor::from_vec(src, (size,), device)
(src, dst)
}
/// Round a vector
@ -356,12 +288,11 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
/// Similar to the GGML quantization unit test:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f32) -> Result<()> {
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
let src = create_ggml_like_vector(0.0);
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
let mut dst = vec![0.0; GGML_TEST_SIZE];
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
let error = calculate_rmse(src.as_slice(), dst.as_slice());
if error > max_error {
bail!(
"Quantization error {} exceeds max error {}",
@ -372,19 +303,19 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
Ok(())
}
fn quantize_q2k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q2K;
fn quantize_roundtrip<T: GgmlType>(src: &[f32], dst: &mut [f32]) -> Result<Vec<T>> {
let mut quant = vec![T::zeros(); src.len() / T::BLCK_SIZE];
T::from_float(src, &mut quant)?;
T::to_float(&quant, dst)?;
Ok(quant)
}
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
#[test]
fn quantize_q2k() -> Result<()> {
use k_quants::BlockQ2K;
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
let (src, mut dst) = get_test_vector(0.5, 1024);
let _quant = quantize_roundtrip::<BlockQ2K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.1);
// Test some specific values
@ -398,30 +329,20 @@ fn quantize_q2k(device: &Device) -> Result<()> {
[-0.499, -0.366, -0.249, 0.0, 0.295, 0.492]
);
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let _quant_big = quantize_roundtrip::<BlockQ2K>(src_big.as_slice(), dst_big.as_mut_slice())?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
ggml_quantization_error_test::<BlockQ2K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
Ok(())
}
fn quantize_q3k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q3K;
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
#[test]
fn quantize_q3k() -> Result<()> {
use k_quants::BlockQ3K;
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
let (src, mut dst) = get_test_vector(0.5, 1024);
let _quant = quantize_roundtrip::<BlockQ3K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.03);
// Test some specific values
@ -435,30 +356,20 @@ fn quantize_q3k(device: &Device) -> Result<()> {
[-0.493, -0.37, -0.243, -0.0, 0.292, 0.492]
);
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let _quant_big = quantize_roundtrip::<BlockQ3K>(src_big.as_slice(), dst_big.as_mut_slice())?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
ggml_quantization_error_test::<BlockQ3K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
Ok(())
}
fn quantize_q4k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q4K;
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
#[test]
fn quantize_q4k() -> Result<()> {
use k_quants::BlockQ4K;
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
let (src, mut dst) = get_test_vector(0.5, 1024);
let _quant = quantize_roundtrip::<BlockQ4K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.017);
// Test some specific values
@ -472,31 +383,21 @@ fn quantize_q4k(device: &Device) -> Result<()> {
[-0.5, -0.373, -0.25, 0.0, 0.288, 0.498]
);
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let _quant_big = quantize_roundtrip::<BlockQ4K>(src_big.as_slice(), dst_big.as_mut_slice())?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
ggml_quantization_error_test::<BlockQ4K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
fn quantize_q5k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q5K;
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
#[test]
fn quantize_q5k() -> Result<()> {
use k_quants::BlockQ5K;
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.009);
let (src, mut dst) = get_test_vector(0.5, 1024);
let _quant = quantize_roundtrip::<BlockQ5K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
// Test some specific values
assert_eq!(
@ -506,33 +407,24 @@ fn quantize_q5k(device: &Device) -> Result<()> {
let dst = round_vector(&dst);
assert_eq!(
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
[-0.5, -0.373, -0.25, 0.0, 0.279, 0.499]
[-0.499, -0.372, -0.249, 0.001, 0.279, 0.499]
);
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let _quant_big = quantize_roundtrip::<BlockQ5K>(src_big.as_slice(), dst_big.as_mut_slice())?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
ggml_quantization_error_test::<BlockQ5K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
fn quantize_q6k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q6K;
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
#[test]
fn quantize_q6k() -> Result<()> {
use k_quants::BlockQ6K;
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
let (src, mut dst) = get_test_vector(0.5, 1024);
let _quant = quantize_roundtrip::<BlockQ6K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
// Test some specific values
@ -546,31 +438,22 @@ fn quantize_q6k(device: &Device) -> Result<()> {
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
);
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let _quant_big = quantize_roundtrip::<BlockQ6K>(src_big.as_slice(), dst_big.as_mut_slice())?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
ggml_quantization_error_test::<BlockQ6K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
fn quantize_q8k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q8K;
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
#[test]
fn quantize_q8k() -> Result<()> {
use k_quants::BlockQ8K;
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
let (src, mut dst) = get_test_vector(0.5, 1024);
let _quant = quantize_roundtrip::<BlockQ8K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.003);
// Test some specific values
assert_eq!(
@ -583,79 +466,15 @@ fn quantize_q8k(device: &Device) -> Result<()> {
[-0.5, -0.375, -0.25, -0.0, 0.281, 0.499]
);
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let _quant_big = quantize_roundtrip::<BlockQ8K>(src_big.as_slice(), dst_big.as_mut_slice())?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
ggml_quantization_error_test::<BlockQ8K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
test_device!(
quantize_q4_0,
quantize_q4_0_cpu,
quantize_q4_0_cuda,
quantize_q4_0_metal
);
test_device!(
quantize_q4_1,
quantize_q4_1_cpu,
quantize_q4_1_cuda,
quantize_q4_1_metal
);
test_device!(
quantize_q5_0,
quantize_q5_0_cpu,
quantize_q5_0_cuda,
quantize_q5_0_metal
);
test_device!(
quantize_q5_1,
quantize_q5_1_cpu,
quantize_q5_1_cuda,
quantize_q5_1_metal
);
test_device!(
quantize_q2k,
quantize_q2k_cpu,
quantize_q2k_cuda,
quantize_q2k_metal
);
test_device!(
quantize_q3k,
quantize_q3k_cpu,
quantize_q3k_cuda,
quantize_q3k_metal
);
test_device!(
quantize_q4k,
quantize_q4k_cpu,
quantize_q4k_cuda,
quantize_q4k_metal
);
test_device!(
quantize_q5k,
quantize_q5k_cpu,
quantize_q5k_cuda,
quantize_q5k_metal
);
test_device!(
quantize_q6k,
quantize_q6k_cpu,
quantize_q6k_cuda,
quantize_q6k_metal
);
test_device!(
quantize_q8k,
quantize_q8k_cpu,
quantize_q8k_cuda,
quantize_q8k_metal
);
/// Very simple dot product implementation
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(a, b)| a * b).sum()
@ -739,6 +558,26 @@ fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Res
Ok(())
}
fn get_small_tensors(
m: usize,
k: usize,
n: usize,
device: &Device,
) -> Result<(Tensor, Tensor, Tensor)> {
let lhs = (0..m * k)
.map(|i| i as f32 / (m * k) as f32)
.collect::<Vec<_>>();
let rhs = (0..n * k)
.map(|i| i as f32 / (n * k) as f32)
.collect::<Vec<_>>();
let lhs = Tensor::from_vec(lhs, (m, k), device)?;
let rhs = Tensor::from_vec(rhs, (n, k), device)?;
let mm = lhs.matmul(&rhs.t()?)?;
Ok((lhs, rhs, mm))
}
#[test]
fn quantized_mm() -> Result<()> {
ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
@ -772,112 +611,6 @@ fn get_random_tensors(
Ok((lhs, rhs, mm))
}
#[macro_export]
macro_rules! quantized_matmul {
// TODO: Switch to generating the two last arguments automatically once concat_idents is
// stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
fn $fn_name(device: &Device) -> Result<()> {
if device.is_cuda() {
// TODO Enable Cuda GGML sometime maybe.
return Ok(());
}
test_matmul(device, (1, 3, 4, 256), $dtype)?;
Ok(())
}
test_device!($fn_name, $fn_name_cpu, $fn_name_cuda, $fn_name_metal);
};
}
quantized_matmul!(
quantized_matmul_q4_0_bis,
quantized_matmul_q4_0_cpu,
quantized_matmul_q4_0_cuda,
quantized_matmul_q4_0_metal,
GgmlDType::Q4_0
);
quantized_matmul!(
quantized_matmul_q4_1_bis,
quantized_matmul_q4_1_cpu,
quantized_matmul_q4_1_cuda,
quantized_matmul_q4_1_metal,
GgmlDType::Q4_1
);
quantized_matmul!(
quantized_matmul_q5_0_bis,
quantized_matmul_q5_0_cpu,
quantized_matmul_q5_0_cuda,
quantized_matmul_q5_0_metal,
GgmlDType::Q5_0
);
quantized_matmul!(
quantized_matmul_q5_1_bis,
quantized_matmul_q5_1_cpu,
quantized_matmul_q5_1_cuda,
quantized_matmul_q5_1_metal,
GgmlDType::Q5_1
);
quantized_matmul!(
quantized_matmul_q8_0_bis,
quantized_matmul_q8_0_cpu,
quantized_matmul_q8_0_cuda,
quantized_matmul_q8_0_metal,
GgmlDType::Q8_0
);
// Not implemented in Ggml
// quantized_matmul!(
// quantized_matmul_q8_1_bis,
// quantized_matmul_q8_1_cpu,
// quantized_matmul_q8_1_cuda,
// quantized_matmul_q8_1_metal,
// GgmlDType::Q8_1
// );
// TODO This is bugged (also bugged in GGML
quantized_matmul!(
quantized_matmul_q2k_bis,
quantized_matmul_q2k_cpu,
quantized_matmul_q2k_cuda,
quantized_matmul_q2k_metal,
GgmlDType::Q2K
);
quantized_matmul!(
quantized_matmul_q3k_bis,
quantized_matmul_q3k_cpu,
quantized_matmul_q3k_cuda,
quantized_matmul_q3k_metal,
GgmlDType::Q3K
);
quantized_matmul!(
quantized_matmul_q4k_bis,
quantized_matmul_q4k_cpu,
quantized_matmul_q4k_cuda,
quantized_matmul_q4k_metal,
GgmlDType::Q4K
);
quantized_matmul!(
quantized_matmul_q5k_bis,
quantized_matmul_q5k_cpu,
quantized_matmul_q5k_cuda,
quantized_matmul_q5k_metal,
GgmlDType::Q5K
);
quantized_matmul!(
quantized_matmul_q6k_bis,
quantized_matmul_q6k_cpu,
quantized_matmul_q6k_cuda,
quantized_matmul_q6k_metal,
GgmlDType::Q6K
);
// Not implemented on metal
// quantized_matmul!(
// quantized_matmul_q8k_bis,
// quantized_matmul_q8k_cpu,
// quantized_matmul_q8k_cuda,
// quantized_matmul_q8k_metal,
// GgmlDType::Q8K
// );
#[test]
fn quantized_matmul_q2k() -> Result<()> {
use k_quants::BlockQ2K;
@ -890,7 +623,7 @@ fn quantized_matmul_q2k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
@ -910,20 +643,30 @@ fn quantized_matmul_q3k() -> Result<()> {
let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let (lhs, rhs, mm) = get_small_tensors(m, k, n, cpu)?;
// assert_eq!(mm.dims(), [m, n]);
// let dst = mm.flatten_all()?.to_vec1::<f32>()?;
// let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
// assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q3K)?;
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
let qmm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]);
let error: f32 = ((&mm - &qmm)?.abs()? / &mm.abs()?)?
.sum_all()?
.to_scalar()?;
let error = error / (m * n) as f32;
// assert_eq!(qmm.dims(), [m, n]);
// let dst = qmm.flatten_all()?.to_vec1::<f32>()?;
// let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
// assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]);
assert!(
error < 0.01,
"{error} is too big, shouldn't exceed a few percent. \nGot:{qmm}\nExpected:\n{mm} "
);
ggml_matmul_error_test::<BlockQ3K>()?;
@ -936,20 +679,30 @@ fn quantized_matmul_q4k() -> Result<()> {
let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let (lhs, rhs, mm) = get_small_tensors(m, k, n, cpu)?;
// assert_eq!(mm.dims(), [m, n]);
// let dst = mm.flatten_all()?.to_vec1::<f32>()?;
// let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
// assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q4K)?;
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
let qmm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.125, 1.435, -0.201, 1.589]);
let error: f32 = ((&mm - &qmm)?.abs()? / &mm.abs()?)?
.sum_all()?
.to_scalar()?;
let error = error / (m * n) as f32;
assert!(
error < 0.01,
"{error} is too big, shouldn't exceed a few percent. \nGot:{qmm}\nExpected:\n{mm} "
);
// assert_eq!(mm.dims(), [m, n]);
// let dst = mm.flatten_all()?.to_vec1::<f32>()?;
// let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
// assert_eq!(dst, [1.125, 1.435, -0.201, 1.589]);
ggml_matmul_error_test::<BlockQ4K>()?;
@ -968,7 +721,7 @@ fn quantized_matmul_q5k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q5K)?;
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
@ -995,7 +748,7 @@ fn quantized_matmul_q6k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q6K)?;
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
@ -1020,7 +773,7 @@ fn quantized_matmul_q8k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q8K)?;
let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;

View File

@ -120,13 +120,6 @@ fn unary_op(device: &Device) -> Result<()> {
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.silu()?, 4)?,
[
[-0.1423, 0.7311, 3.9281, -0.0475, 0.3112],
[2.53, -0.2553, -0.1205, 1.5447, 2.6395]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
@ -1252,23 +1245,11 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
}
#[test]
fn log_sum_exp() -> Result<()> {
fn logsumexp() -> Result<()> {
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let output = input.log_sum_exp(D::Minus1)?;
let output = input.logsumexp(D::Minus1)?;
// The expectations obtained from pytorch.
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
assert_close(&output, &expected, 0.00001)?;
Ok(())
}
#[test]
fn pow() -> Result<()> {
let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let rhs = (&lhs - 2.)?;
let res = lhs.pow(&rhs)?;
assert_eq!(
test_utils::to_vec2_round(&res, 4)?,
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
);
Ok(())
}

Binary file not shown.

Binary file not shown.

View File

@ -21,7 +21,7 @@ candle-onnx = { workspace = true, optional = true }
csv = "1.3.0"
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
hf-hub = { workspace = true, features = ["tokio"] }
hf-hub = { workspace = true, features=["tokio"]}
image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
@ -30,9 +30,7 @@ rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"] }
tokenizers = { workspace = true, features = ["onig"] }
cpal= { version = "0.15.2", optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -45,6 +43,7 @@ rusttype = { workspace = true }
tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.29.1"
@ -62,7 +61,6 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"]
[[example]]
name = "llama_multiprocess"
@ -79,7 +77,3 @@ required-features = ["onnx"]
[[example]]
name = "onnx_basics"
required-features = ["onnx"]
[[example]]
name = "whisper-microphone"
required-features = ["microphone"]

View File

@ -27,5 +27,11 @@ fn main() -> Result<()> {
bindings.write(kdir.rust_target).unwrap()
}
}
#[cfg(not(feature = "cuda"))]
{
for kdir in KERNEL_DIRS.iter() {
let _file = std::fs::File::create(kdir.rust_target)?;
}
}
Ok(())
}

View File

@ -106,17 +106,17 @@ pub fn main() -> anyhow::Result<()> {
let config = blip::Config::image_captioning_large();
let device = candle_examples::device(args.cpu)?;
let (image_embeds, device, mut model) = if args.quantized {
let device = Device::Cpu;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?;
let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
(image_embeds, device, Model::Q(model))
} else {
let device = candle_examples::device(args.cpu)?;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");

View File

@ -1,237 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::chatglm::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
if tokens.is_empty() {
anyhow::bail!("Empty prompts are not supported in the chatglm model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("</s>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the endoftext token"),
};
print!("{prompt}");
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
weight_file: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => "THUDM/chatglm3-6b".to_string(),
};
let revision = match args.revision {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("lmz/candle-chatglm".to_string())
.get("chatglm-tokenizer.json")?,
};
let filenames = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::glm3_6b();
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,23 +0,0 @@
# candle-convnext
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808).
This candle implementation uses a pre-trained ConvNeXt network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example convnext --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which tiny
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 84.09%
bicycle-built-for-two, tandem bicycle, tandem: 4.15%
maillot : 0.74%
crash helmet : 0.54%
unicycle, monocycle : 0.44%
```

View File

@ -1,126 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::convnext;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
Atto,
Femto,
Pico,
Nano,
Tiny,
Small,
Base,
Large,
AttoV2,
FemtoV2,
PicoV2,
NanoV2,
TinyV2,
BaseV2,
LargeV2,
XLarge,
Huge,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::Atto => "convnext_atto.d2_in1k",
Self::Femto => "convnext_femto.d1_in1k",
Self::Pico => "convnext_pico.d1_in1k",
Self::Nano => "convnext_nano.d1h_in1k",
Self::Tiny => "convnext_tiny.fb_in1k",
Self::Small => "convnext_small.fb_in1k",
Self::Base => "convnext_base.fb_in1k",
Self::Large => "convnext_large.fb_in1k",
Self::AttoV2 => "convnextv2_atto.fcmae_ft_in1k",
Self::FemtoV2 => "convnextv2_femto.fcmae_ft_in1k",
Self::PicoV2 => "convnextv2_pico.fcmae_ft_in1k",
Self::NanoV2 => "convnextv2_nano.fcmae_ft_in1k",
Self::TinyV2 => "convnextv2_tiny.fcmae_ft_in1k",
Self::BaseV2 => "convnextv2_base.fcmae_ft_in1k",
Self::LargeV2 => "convnextv2_large.fcmae_ft_in1k",
Self::XLarge => "convnext_xlarge.fb_in22k_ft_in1k",
Self::Huge => "convnextv2_huge.fcmae_ft_in1k",
};
format!("timm/{name}")
}
fn config(&self) -> convnext::Config {
match self {
Self::Atto | Self::AttoV2 => convnext::Config::atto(),
Self::Femto | Self::FemtoV2 => convnext::Config::femto(),
Self::Pico | Self::PicoV2 => convnext::Config::pico(),
Self::Nano | Self::NanoV2 => convnext::Config::nano(),
Self::Tiny | Self::TinyV2 => convnext::Config::tiny(),
Self::Small => convnext::Config::small(),
Self::Base | Self::BaseV2 => convnext::Config::base(),
Self::Large | Self::LargeV2 => convnext::Config::large(),
Self::XLarge => convnext::Config::xlarge(),
Self::Huge => convnext::Config::huge(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::Tiny)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = convnext::convnext(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -0,0 +1 @@
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));

View File

@ -57,7 +57,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 10000)]
#[arg(long, default_value_t = 100)]
sample_len: usize,
/// Disable the key-value cache.
@ -143,6 +143,7 @@ fn main() -> Result<()> {
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
};
println!("building the model");
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
@ -156,7 +157,6 @@ fn main() -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
println!("starting the inference loop");
print!("{prompt}");
@ -190,16 +190,18 @@ fn main() -> Result<()> {
token_generated += 1;
tokens.push(next_token);
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
if Some(next_token) == eos_token_id {
break;
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
let dt = start_gen.elapsed();
println!(

View File

@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.extension()
.map_or(false, |v| v == "safetensors");
let (model, config) = if is_gguf {
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")?
.shape()
@ -279,13 +279,13 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_real",
)?
.dequantize(&device)?;
.dequantize(&candle::Device::Cpu)?;
let freq_cis_imag = vb
.get(
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_imag",
)?
.dequantize(&device)?;
.dequantize(&candle::Device::Cpu)?;
let fake_vb = candle_nn::VarBuilder::from_tensors(
[
@ -295,7 +295,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.into_iter()
.collect(),
candle::DType::F32,
&device,
&candle::Device::Cpu,
);
let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
@ -328,7 +328,6 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
let start_gen = std::time::Instant::now();
for index in 0.. {
@ -354,14 +353,16 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
let dt = start_gen.elapsed();
println!(
"\n{} tokens generated ({:.2} token/s)\n",

View File

@ -2,9 +2,6 @@
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
Compared to the mamba example, this version can handle training but is much
slower.
## Running the example
```bash

View File

@ -1,17 +0,0 @@
# candle-mamba: Mamba implementation
Candle implementation of *Mamba* [1] inference only. Mamba is an alternative to
the transformer architecture. It leverages State Space Models (SSMs) with the
goal of being computationally efficient on long sequences. The implementation is
based on [mamba.rs](https://github.com/LaurentMazare/mamba.rs).
- [1]. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752).
Compared to the mamba-minimal example, this version is far more efficient but
would only work for inference.
## Running the example
```bash
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
```

View File

@ -1,299 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::mamba::{Config, Model, State};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
config: Config,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
config: Config,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
config,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let mut state = State::new(1, &self.config, &self.device)?;
let mut next_logits = None;
for &t in tokens.iter() {
let input = Tensor::new(&[t], &self.device)?;
let logits = self.model.forward(&input, &mut state)?;
next_logits = Some(logits);
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
for _ in 0..sample_len {
let logits = match next_logits.as_ref() {
Some(logits) => logits,
None => anyhow::bail!("cannot work on an empty prompt"),
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
let input = Tensor::new(&[next_token], &self.device)?;
next_logits = Some(self.model.forward(&input, &mut state)?)
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
enum Which {
Mamba130m,
Mamba370m,
Mamba790m,
Mamba1_4b,
Mamba2_8b,
Mamba2_8bSlimPj,
}
impl std::fmt::Display for Which {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl Which {
fn model_id(&self) -> &'static str {
match self {
Self::Mamba130m => "state-spaces/mamba-130m",
Self::Mamba370m => "state-spaces/mamba-370m",
Self::Mamba790m => "state-spaces/mamba-790m",
Self::Mamba1_4b => "state-spaces/mamba-1.4b",
Self::Mamba2_8b => "state-spaces/mamba-2.8b",
Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
}
}
fn revision(&self) -> &'static str {
match self {
Self::Mamba130m
| Self::Mamba370m
| Self::Mamba790m
| Self::Mamba1_4b
| Self::Mamba2_8bSlimPj => "refs/pr/1",
Self::Mamba2_8b => "refs/pr/4",
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long, default_value = "mamba130m")]
which: Which,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id
.unwrap_or_else(|| args.which.model_id().to_string()),
RepoType::Model,
args.revision
.unwrap_or_else(|| args.which.revision().to_string()),
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("EleutherAI/gpt-neox-20b".to_string())
.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
vec![repo.get("model.safetensors")?]
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = Model::new(&config, vb.pp("backbone"))?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
config,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -152,7 +152,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long)]
@ -244,14 +244,13 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config = Config::config_7b_v0_1(args.use_flash_attn);
let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized {
let filename = &filenames[0];
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
let model = QMistral::new(&config, vb)?;
(Model::Quantized(model), device)
(Model::Quantized(model), Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {

View File

@ -143,7 +143,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]

View File

@ -1,22 +0,0 @@
# candle-mobileone
[MobileOne: An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040).
This candle implementation uses a pre-trained MobileOne network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example mobileone --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which s2
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 79.33%
bicycle-built-for-two, tandem bicycle, tandem: 15.32%
crash helmet : 2.58%
unicycle, monocycle : 1.70%
alp : 0.21%
```

View File

@ -1,96 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::mobileone;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
S0,
S1,
S2,
S3,
S4,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::S0 => "s0",
Self::S1 => "s1",
Self::S2 => "s2",
Self::S3 => "s3",
Self::S4 => "s4",
};
format!("timm/mobileone_{}.apple_in1k", name)
}
fn config(&self) -> mobileone::Config {
match self {
Self::S0 => mobileone::Config::s0(),
Self::S1 => mobileone::Config::s1(),
Self::S2 => mobileone::Config::s2(),
Self::S3 => mobileone::Config::s3(),
Self::S4 => mobileone::Config::s4(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::S0)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = mobileone::mobileone(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -1,39 +1,10 @@
## Using ONNX models in Candle
This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based models in Candle.
This example demonstrates how to run ONNX based models in Candle, the model
being used here is a small sequeezenet variant.
It contains small variants of two models, [SqueezeNet](https://arxiv.org/pdf/1602.07360.pdf) (default) and [EfficientNet](https://arxiv.org/pdf/1905.11946.pdf).
You can run the examples with following commands:
You can run the example with the following command:
```bash
cargo run --example onnx --features=onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
```
Use the `--which` flag to specify explicitly which network to use, i.e.
```bash
$ cargo run --example onnx --features=onnx --release -- --which squeeze-net --image candle-examples/examples/yolo-v8/assets/bike.jpg
Finished release [optimized] target(s) in 0.21s
Running `target/release/examples/onnx --which squeeze-net --image candle-examples/examples/yolo-v8/assets/bike.jpg`
loaded image Tensor[dims 3, 224, 224; f32]
unicycle, monocycle : 83.23%
ballplayer, baseball player : 3.68%
bearskin, busby, shako : 1.54%
military uniform : 0.78%
cowboy hat, ten-gallon hat : 0.76%
```
```bash
$ cargo run --example onnx --features=onnx --release -- --which efficient-net --image candle-examples/examples/yolo-v8/assets/bike.jpg
Finished release [optimized] target(s) in 0.20s
Running `target/release/examples/onnx --which efficient-net --image candle-examples/examples/yolo-v8/assets/bike.jpg`
loaded image Tensor[dims 224, 224, 3; f32]
bicycle-built-for-two, tandem bicycle, tandem : 99.16%
mountain bike, all-terrain bike, off-roader : 0.60%
unicycle, monocycle : 0.17%
crash helmet : 0.02%
alp : 0.02%
cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
```

View File

@ -8,7 +8,6 @@ use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
use candle::{DType, Device, Tensor};
@ -19,7 +18,6 @@ use tokenizers::Tokenizer;
enum Model {
MixFormer(MixFormer),
Phi(Phi),
Quantized(QMixFormer),
}
@ -86,7 +84,6 @@ impl TextGeneration {
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = match &mut self.model {
Model::MixFormer(m) => m.forward(&input)?,
Model::Phi(m) => m.forward(&input)?,
Model::Quantized(m) => m.forward(&input)?,
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
@ -120,7 +117,7 @@ impl TextGeneration {
}
}
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, ValueEnum)]
enum WhichModel {
#[value(name = "1")]
V1,
@ -128,8 +125,6 @@ enum WhichModel {
V1_5,
#[value(name = "2")]
V2,
#[value(name = "2-old")]
V2Old,
PuffinPhiV2,
PhiHermes,
}
@ -174,7 +169,7 @@ struct Args {
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "2")]
#[arg(long, default_value = "1.5")]
model: WhichModel,
#[arg(long)]
@ -235,7 +230,7 @@ fn main() -> Result<()> {
match args.model {
WhichModel::V1 => "microsoft/phi-1".to_string(),
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
WhichModel::V2 => "microsoft/phi-2".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string()
}
@ -250,9 +245,8 @@ fn main() -> Result<()> {
"main".to_string()
} else {
match args.model {
WhichModel::V1 => "refs/pr/8".to_string(),
WhichModel::V1_5 => "refs/pr/73".to_string(),
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
WhichModel::V1 => "refs/pr/2".to_string(),
WhichModel::V1_5 => "refs/pr/18".to_string(),
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"main".to_string()
}
@ -264,9 +258,7 @@ fn main() -> Result<()> {
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old => {
repo.get("tokenizer.json")?
}
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?,
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")?
}
@ -279,14 +271,14 @@ fn main() -> Result<()> {
match args.model {
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?],
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
}
} else {
match args.model {
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
WhichModel::V2 | WhichModel::V2Old => candle_examples::hub_load_safetensors(
WhichModel::V2 => candle_examples::hub_load_safetensors(
&repo,
"model.safetensors.index.json",
)?,
@ -300,44 +292,28 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = || match args.model {
let config = match args.model {
WhichModel::V1 => Config::v1(),
WhichModel::V1_5 => Config::v1_5(),
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
WhichModel::V2 => Config::v2(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
};
let device = candle_examples::device(args.cpu)?;
let model = if args.quantized {
let config = config();
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&filenames[0],
&device,
)?;
let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
let model = match args.model {
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
WhichModel::V2 => QMixFormer::new_v2(&config, vb)?,
_ => QMixFormer::new(&config, vb)?,
};
Model::Quantized(model)
(Model::Quantized(model), Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let config: PhiConfig = serde_json::from_str(&config)?;
let phi = Phi::new(&config, vb)?;
Model::Phi(phi)
}
WhichModel::V2Old => {
let config = config();
Model::MixFormer(MixFormer::new_v2(&config, vb)?)
}
WhichModel::PhiHermes | WhichModel::PuffinPhiV2 => {
let config = config();
Model::MixFormer(MixFormer::new(&config, vb)?)
}
}
let model = match args.model {
WhichModel::V2 => MixFormer::new_v2(&config, vb)?,
_ => MixFormer::new(&config, vb)?,
};
(Model::MixFormer(model), device)
};
println!("loaded the model in {:?}", start.elapsed());
@ -417,10 +393,6 @@ fn mmlu<P: AsRef<std::path::Path>>(
m.clear_kv_cache();
m.forward(&input)?
}
Model::Phi(m) => {
m.clear_kv_cache();
m.forward(&input)?
}
Model::Quantized(m) => {
m.clear_kv_cache();
m.forward(&input)?

View File

@ -132,8 +132,7 @@ impl T5ModelBuilder {
}
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
let device = Device::Cpu;
let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &device)?;
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}

View File

@ -9,7 +9,7 @@ use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file};
use candle::Tensor;
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use candle_examples::token_output_stream::TokenOutputStream;
@ -361,7 +361,6 @@ fn main() -> anyhow::Result<()> {
let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
let device = candle_examples::device(false)?;
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => {
@ -370,7 +369,7 @@ fn main() -> anyhow::Result<()> {
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
@ -378,16 +377,15 @@ fn main() -> anyhow::Result<()> {
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file, &device)?
ModelWeights::from_gguf(model, &mut file)?
}
Some("ggml" | "bin") | Some(_) | None => {
let model = ggml_file::Content::read(&mut file, &device)
.map_err(|e| e.with_path(model_path))?;
let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count();
total_size_in_bytes +=
elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
@ -488,7 +486,7 @@ fn main() -> anyhow::Result<()> {
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
@ -509,7 +507,7 @@ fn main() -> anyhow::Result<()> {
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {

View File

@ -1,281 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::qwen2::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
enum WhichModel {
#[value(name = "0.5b")]
W0_5b,
#[value(name = "1.8b")]
W1_8b,
#[value(name = "4b")]
W4b,
#[value(name = "7b")]
W7b,
#[value(name = "14b")]
W14b,
#[value(name = "72b")]
W72b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long, default_value = "0.5b")]
model: WhichModel,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => {
let size = match args.model {
WhichModel::W0_5b => "0.5B",
WhichModel::W1_8b => "1.8B",
WhichModel::W4b => "4B",
WhichModel::W7b => "7B",
WhichModel::W14b => "14B",
WhichModel::W72b => "72B",
};
format!("Qwen/Qwen1.5-{size}")
}
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.model {
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
WhichModel::W4b | WhichModel::W7b | WhichModel::W14b | WhichModel::W72b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config_file = repo.get("config.json")?;
let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -411,7 +411,7 @@ impl DDPG<'_> {
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
let actions = self
.actor
.forward(&state.detach().unsqueeze(0)?)?
.forward(&state.detach()?.unsqueeze(0)?)?
.squeeze(0)?;
let actions = if self.train {
(actions + self.ou_noise.sample()?)?

View File

@ -74,7 +74,7 @@ pub fn run() -> Result<()> {
loop {
let action = {
let action_probs: Vec<f32> =
softmax(&model.forward(&state.detach().unsqueeze(0)?)?, 1)?
softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)?
.squeeze(0)?
.to_vec1()?;
weighted_sample(action_probs, &mut rng)? as i64
@ -109,7 +109,7 @@ pub fn run() -> Result<()> {
let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
.to_dtype(DType::F32)?
.detach();
.detach()?;
let actions_mask = {
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
@ -126,12 +126,12 @@ pub fn run() -> Result<()> {
.unwrap()
})
.collect();
Tensor::stack(&actions_mask, 0)?.detach()
Tensor::stack(&actions_mask, 0)?.detach()?
};
let states = {
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
Tensor::stack(&states, 0)?.detach()
Tensor::stack(&states, 0)?.detach()?
};
let log_probs = actions_mask

View File

@ -236,15 +236,16 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let config = Config::replit_code_v1_5_3b();
let model = if args.quantized {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
Model::Q(Q::new(&config, vb.pp("transformer"))?)
let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
(model, Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
Model::M(M::new(&config, vb.pp("transformer"))?)
let model = Model::M(M::new(&config, vb.pp("transformer"))?);
(model, device)
};
println!("loaded the model in {:?}", start.elapsed());

View File

@ -1,22 +0,0 @@
# candle-repvgg
[RepVGG: Making VGG-style ConvNets Great Again](https://arxiv.org/abs/2101.03697).
This candle implementation uses a pre-trained RepVGG network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example repvgg --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 61.70%
bicycle-built-for-two, tandem bicycle, tandem: 33.14%
unicycle, monocycle : 4.88%
crash helmet : 0.15%
moped : 0.04%
```

View File

@ -1,111 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::repvgg;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
A0,
A1,
A2,
B0,
B1,
B2,
B3,
B1G4,
B2G4,
B3G4,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::A0 => "a0",
Self::A1 => "a1",
Self::A2 => "a2",
Self::B0 => "b0",
Self::B1 => "b1",
Self::B2 => "b2",
Self::B3 => "b3",
Self::B1G4 => "b1g4",
Self::B2G4 => "b2g4",
Self::B3G4 => "b3g4",
};
format!("timm/repvgg_{}.rvgg_in1k", name)
}
fn config(&self) -> repvgg::Config {
match self {
Self::A0 => repvgg::Config::a0(),
Self::A1 => repvgg::Config::a1(),
Self::A2 => repvgg::Config::a2(),
Self::B0 => repvgg::Config::b0(),
Self::B1 => repvgg::Config::b1(),
Self::B2 => repvgg::Config::b2(),
Self::B3 => repvgg::Config::b3(),
Self::B1G4 => repvgg::Config::b1g4(),
Self::B2G4 => repvgg::Config::b2g4(),
Self::B3G4 => repvgg::Config::b3g4(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::A0)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = repvgg::repvgg(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -1,17 +0,0 @@
## candle-rwkv
The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
with performance on par with transformer architectures. Several variants are
available, candle implements the v5 version and can be used with Eagle 7B([blog
post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
```bash
$ cargo run --example rwkv --release -- --prompt "The smallest prime is "
avx: true, neon: false, simd128: false, f16c: true
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
The smallest prime is ϕ(2) = 2.
The smallest composite is ϕ(3) = 3.
The smallest perfect number is ϕ(5) = 5.
The smallest perfect square is ϕ(4) = 4.
The smallest perfect cube is ϕ(6) = 6.
```

View File

@ -1,265 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use clap::{Parser, ValueEnum};
use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
struct TextGeneration {
model: Model,
config: Config,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
config: Config,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
config,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
let mut tokens = self.tokenizer.encode(prompt)?;
let mut generated_tokens = 0usize;
let mut state = State::new(1, &self.config, &self.device)?;
let mut next_logits = None;
for &t in tokens.iter() {
let input = Tensor::new(&[[t]], &self.device)?;
let logits = self.model.forward(&input, &mut state)?;
next_logits = Some(logits);
print!("{}", self.tokenizer.decode(&[t])?)
}
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
for _ in 0..sample_len {
let logits = match next_logits.as_ref() {
Some(logits) => logits,
None => anyhow::bail!("cannot work on an empty prompt"),
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
print!("{}", self.tokenizer.decode(&[next_token])?);
std::io::stdout().flush()?;
let input = Tensor::new(&[[next_token]], &self.device)?;
next_logits = Some(self.model.forward(&input, &mut state)?)
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
enum Which {
Eagle7b,
World1b5,
World3b,
}
impl std::fmt::Display for Which {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl Which {
fn model_id(&self) -> &'static str {
match self {
Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
Self::World1b5 => "RWKV/rwkv-5-world-1b5",
Self::World3b => "RWKV/rwkv-5-world-3b",
}
}
fn revision(&self) -> &'static str {
match self {
Self::Eagle7b => "refs/pr/1",
Self::World1b5 | Self::World3b => "refs/pr/2",
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long, default_value = "world1b5")]
which: Which,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id
.unwrap_or_else(|| args.which.model_id().to_string()),
RepoType::Model,
args.revision
.unwrap_or_else(|| args.which.revision().to_string()),
));
let tokenizer = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("lmz/candle-rwkv".to_string())
.get("rwkv_vocab_v20230424.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
vec![repo.get("model.safetensors")?]
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::new(tokenizer)?;
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
config,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -8,13 +8,6 @@ Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t).
Note that this model is gated so you will have to request access on the Hub in
order to be able to use it.
Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants.
StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by
Candle, so to run it you can download a somewhat compatible
[tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true)
and pass it via the --tokenizer-file argument.
## Running some example
```bash

View File

@ -5,7 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use clap::Parser;
use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
@ -122,16 +122,6 @@ impl TextGeneration {
}
}
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
enum Which {
V1Orig,
V1,
V1Zephyr,
V2,
V2Zephyr,
Code,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -162,18 +152,15 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 1000)]
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")]
model_id: String,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long, default_value = "v2")]
which: Which,
#[arg(long)]
tokenizer_file: Option<String>,
@ -220,88 +207,40 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => match args.which {
Which::V1Orig => "lmz/candle-stablelm-3b-4e1t".to_string(),
Which::V1 => "stabilityai/stablelm-3b-4e1t".to_string(),
Which::V1Zephyr => "stabilityai/stablelm-zephyr-3b".to_string(),
Which::Code => "stabilityai/stable-code-3b".to_string(),
Which::V2 => "stabilityai/stablelm-2-1_6b".to_string(),
Which::V2Zephyr => "stabilityai/stablelm-2-zephyr-1_6b".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
model_id,
args.model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => match args.which {
Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::Code => {
repo.get("tokenizer.json")?
}
Which::V2 | Which::V2Zephyr => api
.model("lmz/candle-stablelm".to_string())
.get("tokenizer-gpt4.json")?,
},
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match (args.which, args.quantized) {
(Which::V1Orig | Which::V1, true) => vec![repo.get("model-q4k.gguf")?],
(Which::V2, true) => {
let gguf = api
.model("lmz/candle-stablelm".to_string())
.get("stablelm-2-1_6b-q4k.gguf")?;
vec![gguf]
}
(Which::V2Zephyr, true) => {
let gguf = api
.model("lmz/candle-stablelm".to_string())
.get("stablelm-2-zephyr-1_6b-q4k.gguf")?;
vec![gguf]
}
(Which::V1Zephyr | Which::Code, true) => {
anyhow::bail!("Quantized {:?} variant not supported.", args.which)
}
(Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => {
None => {
if args.quantized {
vec![repo.get("model-q4k.gguf")?]
} else {
vec![repo.get("model.safetensors")?]
}
(Which::Code, false) => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = match args.which {
Which::V1Orig => Config::stablelm_3b_4e1t(args.use_flash_attn),
Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let mut config: Config = serde_json::from_str(&config)?;
config.set_use_flash_attn(args.use_flash_attn);
config
}
};
let device = candle_examples::device(args.cpu)?;
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
let (model, device) = if args.quantized {
let filename = &filenames[0];
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
let model = QStableLM::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.5 KiB

View File

@ -10,36 +10,15 @@ use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::{trocr, vit};
use candle_transformers::models::trocr;
use tokenizers::Tokenizer;
mod image_processor;
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
#[value(name = "base")]
BaseHandwritten,
#[value(name = "large")]
LargeHandwritten,
BasePrinted,
LargePrinted,
}
impl Which {
fn repo_and_branch_name(&self) -> (&str, &str) {
match self {
Self::BaseHandwritten => ("microsoft/trocr-base-handwritten", "refs/pr/3"),
Self::LargeHandwritten => ("microsoft/trocr-large-handwritten", "refs/pr/6"),
Self::BasePrinted => ("microsoft/trocr-base-printed", "refs/pr/7"),
Self::LargePrinted => ("microsoft/trocr-large-printed", "main"),
}
}
}
#[derive(Debug, Clone, serde::Deserialize)]
struct Config {
encoder: vit::Config,
decoder: trocr::TrOCRConfig,
Base,
Large,
}
#[derive(Parser, Debug)]
@ -55,64 +34,63 @@ struct Args {
#[arg(long)]
cpu: bool,
/// The image file to be processed.
/// Text to be translated
#[arg(long)]
image: String,
/// Tokenization config.
#[arg(long)]
tokenizer: Option<String>,
}
pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();
let api = hf_hub::api::sync::Api::new()?;
let mut tokenizer_dec = {
let tokenizer_file = match args.tokenizer {
None => api
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
.get("tokenizer.json")?,
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
};
let tokenizer = Tokenizer::from_file(&tokenizer_file).map_err(E::msg)?;
TokenOutputStream::new(tokenizer)
let tokenizer_dec = {
let tokenizer = Api::new()?
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
.get("tokenizer.json")?;
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
let device = candle_examples::device(args.cpu)?;
let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => {
let (repo, branch) = args.which.repo_and_branch_name();
api.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
branch.to_string(),
))
.get("model.safetensors")?
}
None => match args.which {
Which::Base => Api::new()?
.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-base-handwritten".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
))
.get("model.safetensors")?,
Which::Large => Api::new()?
.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-large-handwritten".to_string(),
hf_hub::RepoType::Model,
"refs/pr/6".to_string(),
))
.get("model.safetensors")?,
},
};
println!("model: {:?}", model);
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
};
let (encoder_config, decoder_config) = {
let (repo, branch) = args.which.repo_and_branch_name();
let config_filename = api
.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
branch.to_string(),
))
.get("config.json")?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
(config.encoder, config.decoder)
let encoder_config = match args.which {
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
Which::Large => {
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
}
};
let decoder_config = trocr::TrOCRConfig::default();
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
let processor_config = image_processor::ProcessorConfig::default();
let processor = image_processor::ViTImageProcessor::new(&processor_config);
let config = image_processor::ProcessorConfig::default();
let processor = image_processor::ViTImageProcessor::new(&config);
let image = vec![args.image.as_str()];
let image = processor.preprocess(image)?;

View File

@ -5,27 +5,12 @@ transcribe image text. See the associated [model
card](https://huggingface.co/microsoft/trocr-base-printed) for details on
the model itself.
Supported models include:
- `--which base`: small handwritten OCR model.
- `--which large`: large handwritten OCR model.
- `--which base-printed`: small printed OCR model.
- `--which large-printed`: large printed OCR model.
## Running an example
```bash
cargo run --example trocr --release -- --image candle-examples/examples/trocr/assets/trocr.png
cargo run --example trocr --release -- --which large --image candle-examples/examples/trocr/assets/trocr.png
cargo run --example trocr --release -- --which base-printed --image candle-examples/examples/trocr/assets/noto.png
cargo run --example trocr --release -- --which large-printed --image candle-examples/examples/trocr/assets/noto.png
cargo run --example trocr --release -- --which base --cpu --image candle-examples/examples/trocr/assets/trocr.png
```
### Outputs
```
industry , Mr. Brown commented icily . " Let us have a
industry , " Mr. Brown commented icily . " Let us have a
THE QUICK BROWN FOR JUMPS OVER THE LAY DOG
THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG
<s> industry , Mr. Brown commented icily . " Let us have a</s>
```

View File

@ -1,673 +0,0 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{Device, IndexOp, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use std::iter;
use tokenizers::Tokenizer;
mod multilingual;
use candle_transformers::models::whisper::{self as m, audio, Config};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use std::sync::{Arc, Mutex};
pub enum Model {
Normal(m::model::Whisper),
Quantized(m::quantized_model::Whisper),
}
// Maybe we should use some traits rather than doing the dispatch for all these.
impl Model {
pub fn config(&self) -> &Config {
match self {
Self::Normal(m) => &m.config,
Self::Quantized(m) => &m.config,
}
}
pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.encoder.forward(x, flush),
Self::Quantized(m) => m.encoder.forward(x, flush),
}
}
pub fn decoder_forward(
&mut self,
x: &Tensor,
xa: &Tensor,
flush: bool,
) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.decoder.forward(x, xa, flush),
Self::Quantized(m) => m.decoder.forward(x, xa, flush),
}
}
pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.decoder.final_linear(x),
Self::Quantized(m) => m.decoder.final_linear(x),
}
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
struct DecodingResult {
tokens: Vec<u32>,
text: String,
avg_logprob: f64,
no_speech_prob: f64,
temperature: f64,
compression_ratio: f64,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
struct Segment {
start: f64,
duration: f64,
dr: DecodingResult,
}
struct Decoder {
model: Model,
rng: rand::rngs::StdRng,
task: Option<Task>,
timestamps: bool,
verbose: bool,
tokenizer: Tokenizer,
suppress_tokens: Tensor,
sot_token: u32,
transcribe_token: u32,
translate_token: u32,
eot_token: u32,
no_speech_token: u32,
no_timestamps_token: u32,
language_token: Option<u32>,
}
impl Decoder {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
device: &Device,
language_token: Option<u32>,
task: Option<Task>,
timestamps: bool,
verbose: bool,
) -> Result<Self> {
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
// Suppress the notimestamps token when in timestamps mode.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
.map(|i| {
if model.config().suppress_tokens.contains(&i)
|| timestamps && i == no_timestamps_token
{
f32::NEG_INFINITY
} else {
0f32
}
})
.collect();
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
let no_speech_token = m::NO_SPEECH_TOKENS
.iter()
.find_map(|token| token_id(&tokenizer, token).ok());
let no_speech_token = match no_speech_token {
None => anyhow::bail!("unable to find any non-speech token"),
Some(n) => n,
};
Ok(Self {
model,
rng: rand::rngs::StdRng::seed_from_u64(seed),
tokenizer,
task,
timestamps,
verbose,
suppress_tokens,
sot_token,
transcribe_token,
translate_token,
eot_token,
no_speech_token,
language_token,
no_timestamps_token,
})
}
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
let model = &mut self.model;
let audio_features = model.encoder_forward(mel, true)?;
if self.verbose {
println!("audio features: {:?}", audio_features.dims());
}
let sample_len = model.config().max_target_positions / 2;
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![self.sot_token];
if let Some(language_token) = self.language_token {
tokens.push(language_token);
}
match self.task {
None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
Some(Task::Translate) => tokens.push(self.translate_token),
}
if !self.timestamps {
tokens.push(self.no_timestamps_token);
}
for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
// The model expects a batch dim but this inference loop does not handle
// it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?;
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
no_speech_prob = softmax(&logits, 0)?
.i(self.no_speech_token as usize)?
.to_scalar::<f32>()? as f64;
}
let (_, seq_len, _) = ys.dims3()?;
let logits = model
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
.i(0)?
.i(0)?;
// TODO: Besides suppress tokens, we should apply the heuristics from
// ApplyTimestampRules, i.e.:
// - Timestamps come in pairs, except before EOT.
// - Timestamps should be non-decreasing.
// - If the sum of the probabilities of timestamps is higher than any other tokens,
// only consider timestamps when sampling.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439
let logits = logits.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v
.iter()
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
.unwrap()
};
tokens.push(next_token);
let prob = softmax(&logits, candle::D::Minus1)?
.i(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
break;
}
sum_logprob += prob.ln();
}
let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
let avg_logprob = sum_logprob / tokens.len() as f64;
Ok(DecodingResult {
tokens,
text,
avg_logprob,
no_speech_prob,
temperature: t,
compression_ratio: f64::NAN,
})
}
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
for (i, &t) in m::TEMPERATURES.iter().enumerate() {
let dr: Result<DecodingResult> = self.decode(segment, t);
if i == m::TEMPERATURES.len() - 1 {
return dr;
}
// On errors, we try again with a different temperature.
match dr {
Ok(dr) => {
let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|| dr.avg_logprob < m::LOGPROB_THRESHOLD;
if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
return Ok(dr);
}
}
Err(err) => {
println!("Error running at {t}: {err}")
}
}
}
unreachable!()
}
fn run(&mut self, mel: &Tensor, times: Option<(f64, f64)>) -> Result<Vec<Segment>> {
let (_, _, content_frames) = mel.dims3()?;
let mut seek = 0;
let mut segments = vec![];
while seek < content_frames {
let start = std::time::Instant::now();
let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?;
let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let dr = self.decode_with_fallback(&mel_segment)?;
seek += segment_size;
if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
println!("no speech detected, skipping {seek} {dr:?}");
continue;
}
let segment = Segment {
start: time_offset,
duration: segment_duration,
dr,
};
if self.timestamps {
println!(
"{:.1}s -- {:.1}s",
segment.start,
segment.start + segment.duration,
);
let mut tokens_to_decode = vec![];
let mut prev_timestamp_s = 0f32;
for &token in segment.dr.tokens.iter() {
if token == self.sot_token || token == self.eot_token {
continue;
}
// The no_timestamp_token is the last before the timestamp ones.
if token > self.no_timestamps_token {
let timestamp_s = (token - self.no_timestamps_token + 1) as f32 / 50.;
if !tokens_to_decode.is_empty() {
let text = self
.tokenizer
.decode(&tokens_to_decode, true)
.map_err(E::msg)?;
println!(" {:.1}s-{:.1}s: {}", prev_timestamp_s, timestamp_s, text);
tokens_to_decode.clear()
}
prev_timestamp_s = timestamp_s;
} else {
tokens_to_decode.push(token)
}
}
if !tokens_to_decode.is_empty() {
let text = self
.tokenizer
.decode(&tokens_to_decode, true)
.map_err(E::msg)?;
if !text.is_empty() {
println!(" {:.1}s-...: {}", prev_timestamp_s, text);
}
tokens_to_decode.clear()
}
} else {
match times {
Some((start, end)) => {
println!("{:.1}s -- {:.1}s: {}", start, end, segment.dr.text)
}
None => {
println!(
"{:.1}s -- {:.1}s: {}",
segment.start,
segment.start + segment.duration,
segment.dr.text,
)
}
}
}
if self.verbose {
println!("{seek}: {segment:?}, in {:?}", start.elapsed());
}
segments.push(segment)
}
Ok(segments)
}
fn set_language_token(&mut self, language_token: Option<u32>) {
self.language_token = language_token;
}
#[allow(dead_code)]
fn reset_kv_cache(&mut self) {
match &mut self.model {
Model::Normal(m) => m.reset_kv_cache(),
Model::Quantized(m) => m.reset_kv_cache(),
}
}
fn model(&mut self) -> &mut Model {
&mut self.model
}
}
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
match tokenizer.token_to_id(token) {
None => candle::bail!("no token-id for {token}"),
Some(id) => Ok(id),
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Task {
Transcribe,
Translate,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
enum WhichModel {
Tiny,
#[value(name = "tiny.en")]
TinyEn,
Base,
#[value(name = "base.en")]
BaseEn,
Small,
#[value(name = "small.en")]
SmallEn,
Medium,
#[value(name = "medium.en")]
MediumEn,
Large,
LargeV2,
LargeV3,
#[value(name = "distil-medium.en")]
DistilMediumEn,
#[value(name = "distil-large-v2")]
DistilLargeV2,
}
impl WhichModel {
fn is_multilingual(&self) -> bool {
match self {
Self::Tiny
| Self::Base
| Self::Small
| Self::Medium
| Self::Large
| Self::LargeV2
| Self::LargeV3
| Self::DistilLargeV2 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
false
}
}
}
fn model_and_revision(&self) -> (&'static str, &'static str) {
match self {
Self::Tiny => ("openai/whisper-tiny", "main"),
Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"),
Self::Base => ("openai/whisper-base", "refs/pr/22"),
Self::BaseEn => ("openai/whisper-base.en", "refs/pr/13"),
Self::Small => ("openai/whisper-small", "main"),
Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"),
Self::Medium => ("openai/whisper-medium", "main"),
Self::MediumEn => ("openai/whisper-medium.en", "main"),
Self::Large => ("openai/whisper-large", "refs/pr/36"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(long)]
model_id: Option<String>,
/// The model to use, check out available models:
/// https://huggingface.co/models?search=whisper
#[arg(long)]
revision: Option<String>,
/// The model to be used, can be tiny, small, medium.
#[arg(long, default_value = "tiny.en")]
model: WhichModel,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
quantized: bool,
/// Language.
#[arg(long)]
language: Option<String>,
/// Task, when no task is specified, the input tokens contain only the sot token which can
/// improve things when in no-timestamp mode.
#[arg(long)]
task: Option<Task>,
/// Timestamps mode, this is not fully implemented yet.
#[arg(long)]
timestamps: bool,
/// Print the full DecodingResult structure rather than just the text.
#[arg(long)]
verbose: bool,
}
pub fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(args.cpu)?;
let (default_model, default_revision) = if args.quantized {
("lmz/candle-whisper", "main")
} else {
args.model.model_and_revision()
};
let default_model = default_model.to_string();
let default_revision = default_revision.to_string();
let (model_id, revision) = match (args.model_id, args.revision) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let (config, tokenizer, model) = if args.quantized {
let ext = match args.model {
WhichModel::TinyEn => "tiny-en",
WhichModel::Tiny => "tiny",
_ => unimplemented!("no quantized support for {:?}", args.model),
};
(
repo.get(&format!("config-{ext}.json"))?,
repo.get(&format!("tokenizer-{ext}.json"))?,
repo.get(&format!("model-{ext}-q80.gguf"))?,
)
} else {
let config = repo.get("config.json")?;
let tokenizer = repo.get("tokenizer.json")?;
let model = repo.get("model.safetensors")?;
(config, tokenizer, model)
};
(config, tokenizer, model)
};
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let model = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&weights_filename,
&device,
)?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config.clone())?)
} else {
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
Model::Normal(m::model::Whisper::load(&vb, config.clone())?)
};
let language_token = None;
let mut dc = Decoder::new(
model,
tokenizer.clone(),
args.seed,
&device,
language_token,
args.task,
args.timestamps,
args.verbose,
)?;
let mel_bytes = match config.num_mel_bins {
80 => include_bytes!("../whisper/melfilters.bytes").as_slice(),
128 => include_bytes!("../whisper/melfilters128.bytes").as_slice(),
nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"),
};
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
// Set up the input device and stream with the default input config.
let host = cpal::default_host();
let _device = "default";
let _device = if _device == "default" {
host.default_input_device()
} else {
host.input_devices()?
.find(|x| x.name().map(|y| y == _device).unwrap_or(false))
}
.expect("failed to find input device");
let _config = _device
.default_input_config()
.expect("Failed to get default input config");
let channel_count = _config.channels() as usize;
let audio_ring_buffer = Arc::new(Mutex::new(Vec::new()));
let audio_ring_buffer_2 = audio_ring_buffer.clone();
std::thread::spawn(move || loop {
let data = record_audio(&_device, &_config, 300).unwrap();
audio_ring_buffer.lock().unwrap().extend_from_slice(&data);
let max_len = data.len() * 16;
let data_len = data.len();
let len = audio_ring_buffer.lock().unwrap().len();
if len > max_len {
let mut data = audio_ring_buffer.lock().unwrap();
let new_data = data[data_len..].to_vec();
*data = new_data;
}
});
// loop to process the audio data forever (until the user stops the program)
println!("Transcribing audio...");
for (i, _) in iter::repeat(()).enumerate() {
std::thread::sleep(std::time::Duration::from_millis(1000));
let data = audio_ring_buffer_2.lock().unwrap().clone();
let pcm_data: Vec<_> = data[..data.len() / channel_count as usize]
.iter()
.map(|v| *v as f32 / 32768.)
.collect();
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
let mel_len = mel.len();
let mel = Tensor::from_vec(
mel,
(1, config.num_mel_bins, mel_len / config.num_mel_bins),
&device,
)?;
// on the first iteration, we detect the language and set the language token.
if i == 0 {
let language_token = match (args.model.is_multilingual(), args.language.clone()) {
(true, None) => Some(multilingual::detect_language(dc.model(), &tokenizer, &mel)?),
(false, None) => None,
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
Ok(token_id) => Some(token_id),
Err(_) => anyhow::bail!("language {language} is not supported"),
},
(false, Some(_)) => {
anyhow::bail!("a language cannot be set for non-multilingual models")
}
};
println!("language_token: {:?}", language_token);
dc.set_language_token(language_token);
}
dc.run(
&mel,
Some((
i as f64,
i as f64 + data.len() as f64 / m::SAMPLE_RATE as f64,
)),
)?;
dc.reset_kv_cache();
}
Ok(())
}
fn record_audio(
device: &cpal::Device,
config: &cpal::SupportedStreamConfig,
milliseconds: u64,
) -> Result<Vec<i16>> {
let writer = Arc::new(Mutex::new(Vec::new()));
let writer_2 = writer.clone();
let stream = device.build_input_stream(
&config.config(),
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let processed = data
.iter()
.map(|v| (v * 32768.0) as i16)
.collect::<Vec<i16>>();
writer_2.lock().unwrap().extend_from_slice(&processed);
},
move |err| {
eprintln!("an error occurred on stream: {}", err);
},
None,
)?;
stream.play()?;
std::thread::sleep(std::time::Duration::from_millis(milliseconds));
drop(stream);
let data = writer.lock().unwrap().clone();
let step = 3;
let data: Vec<i16> = data.iter().step_by(step).copied().collect();
Ok(data)
}

View File

@ -1,137 +0,0 @@
use crate::{token_id, Model};
use candle::{IndexOp, Result, Tensor, D};
use candle_transformers::models::whisper::{self as m};
use tokenizers::Tokenizer;
const LANGUAGES: [(&str, &str); 99] = [
("en", "english"),
("zh", "chinese"),
("de", "german"),
("es", "spanish"),
("ru", "russian"),
("ko", "korean"),
("fr", "french"),
("ja", "japanese"),
("pt", "portuguese"),
("tr", "turkish"),
("pl", "polish"),
("ca", "catalan"),
("nl", "dutch"),
("ar", "arabic"),
("sv", "swedish"),
("it", "italian"),
("id", "indonesian"),
("hi", "hindi"),
("fi", "finnish"),
("vi", "vietnamese"),
("he", "hebrew"),
("uk", "ukrainian"),
("el", "greek"),
("ms", "malay"),
("cs", "czech"),
("ro", "romanian"),
("da", "danish"),
("hu", "hungarian"),
("ta", "tamil"),
("no", "norwegian"),
("th", "thai"),
("ur", "urdu"),
("hr", "croatian"),
("bg", "bulgarian"),
("lt", "lithuanian"),
("la", "latin"),
("mi", "maori"),
("ml", "malayalam"),
("cy", "welsh"),
("sk", "slovak"),
("te", "telugu"),
("fa", "persian"),
("lv", "latvian"),
("bn", "bengali"),
("sr", "serbian"),
("az", "azerbaijani"),
("sl", "slovenian"),
("kn", "kannada"),
("et", "estonian"),
("mk", "macedonian"),
("br", "breton"),
("eu", "basque"),
("is", "icelandic"),
("hy", "armenian"),
("ne", "nepali"),
("mn", "mongolian"),
("bs", "bosnian"),
("kk", "kazakh"),
("sq", "albanian"),
("sw", "swahili"),
("gl", "galician"),
("mr", "marathi"),
("pa", "punjabi"),
("si", "sinhala"),
("km", "khmer"),
("sn", "shona"),
("yo", "yoruba"),
("so", "somali"),
("af", "afrikaans"),
("oc", "occitan"),
("ka", "georgian"),
("be", "belarusian"),
("tg", "tajik"),
("sd", "sindhi"),
("gu", "gujarati"),
("am", "amharic"),
("yi", "yiddish"),
("lo", "lao"),
("uz", "uzbek"),
("fo", "faroese"),
("ht", "haitian creole"),
("ps", "pashto"),
("tk", "turkmen"),
("nn", "nynorsk"),
("mt", "maltese"),
("sa", "sanskrit"),
("lb", "luxembourgish"),
("my", "myanmar"),
("bo", "tibetan"),
("tl", "tagalog"),
("mg", "malagasy"),
("as", "assamese"),
("tt", "tatar"),
("haw", "hawaiian"),
("ln", "lingala"),
("ha", "hausa"),
("ba", "bashkir"),
("jw", "javanese"),
("su", "sundanese"),
];
/// Returns the token id for the selected language.
pub fn detect_language(model: &mut Model, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
let (_bsize, _, seq_len) = mel.dims3()?;
let mel = mel.narrow(
2,
0,
usize::min(seq_len, model.config().max_source_positions),
)?;
let device = mel.device();
let language_token_ids = LANGUAGES
.iter()
.map(|(t, _)| token_id(tokenizer, &format!("<|{t}|>")))
.collect::<Result<Vec<_>>>()?;
let sot_token = token_id(tokenizer, m::SOT_TOKEN)?;
let audio_features = model.encoder_forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let logits = logits.index_select(&language_token_ids, 0)?;
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
let probs = probs.to_vec1::<f32>()?;
let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::<Vec<_>>();
probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for ((_, language), p) in probs.iter().take(5) {
println!("{language}: {p}")
}
let language = token_id(tokenizer, &format!("<|{}|>", probs[0].0 .0))?;
Ok(language)
}

View File

@ -18,8 +18,6 @@ use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
mod multilingual;
mod pcm_decode;
use candle_transformers::models::whisper::{self as m, audio, Config};
pub enum Model {
@ -537,10 +535,17 @@ fn main() -> Result<()> {
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
let (pcm_data, sample_rate) = pcm_decode::pcm_decode(input)?;
if sample_rate != m::SAMPLE_RATE as u32 {
anyhow::bail!("input file must have a {} sampling rate", m::SAMPLE_RATE)
let mut input = std::fs::File::open(input)?;
let (header, data) = wav::read(&mut input)?;
println!("loaded wav data: {header:?}");
if header.sampling_rate != m::SAMPLE_RATE as u32 {
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
}
let data = data.as_sixteen().expect("expected 16 bit wav file");
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
.iter()
.map(|v| *v as f32 / 32768.)
.collect();
println!("pcm data loaded {}", pcm_data.len());
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
let mel_len = mel.len();
@ -552,10 +557,8 @@ fn main() -> Result<()> {
println!("loaded mel: {:?}", mel.dims());
let mut model = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&weights_filename,
&device,
)?;
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else {
let vb =

View File

@ -1,74 +0,0 @@
use symphonia::core::audio::{AudioBufferRef, Signal};
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
use symphonia::core::conv::FromSample;
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
// Open the media source.
let src = std::fs::File::open(path)?;
// Create the media source stream.
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
// Create a probe hint using the file's extension. [Optional]
let hint = symphonia::core::probe::Hint::new();
// Use the default options for metadata and format readers.
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
// Probe the media source.
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
// Get the instantiated format reader.
let mut format = probed.format;
// Find the first audio track with a known (decodeable) codec.
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
.expect("no supported audio tracks");
// Use the default options for the decoder.
let dec_opts: DecoderOptions = Default::default();
// Create a decoder for the track.
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &dec_opts)
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
// The decode loop.
while let Ok(packet) = format.next_packet() {
// Consume any new metadata that has been read since the last packet.
while !format.metadata().is_latest() {
format.metadata().pop();
}
// If the packet does not belong to the selected track, skip over it.
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}

View File

@ -104,7 +104,6 @@ impl TextGeneration {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
let t = t.replace("<|im_end|>", "\n");
print!("{t}");
std::io::stdout().flush()?;
}

View File

@ -216,7 +216,7 @@ fn detect(
xs: &Tensor,
image_height: usize,
classes: usize,
anchors: &[(usize, usize)],
anchors: &Vec<(usize, usize)>,
) -> Result<Tensor> {
let (bsize, _channels, height, _width) = xs.dims4()?;
let stride = image_height / height;

View File

@ -40,7 +40,7 @@ impl TokenOutputStream {
};
self.tokens.push(token);
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphabetic() {
if text.len() > prev_text.len() && text.chars().last().unwrap().is_ascii() {
let text = text.split_at(prev_text.len());
self.prev_index = self.current_index;
self.current_index = self.tokens.len();

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.4.0"
version = "0.3.3"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.0" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.4.0"
version = "0.3.3"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -71,6 +71,7 @@ __device__ void im2col1d(
}
const size_t *src_dims = info;
const size_t *src_s = info + 3;
const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];
@ -119,6 +120,7 @@ __device__ void im2col(
}
const size_t *src_dims = info;
const size_t *src_s = info + 4;
const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1];
const size_t h_in = src_dims[2];
const size_t w_in = src_dims[3];
@ -223,60 +225,6 @@ __device__ void conv2d(
dst[dst_i] = static_cast<T>(d);
}
// Naive implementation of conv_transpose1d.
template <typename T, typename A>
__device__ void conv_transpose1d(
const size_t src_numel,
const size_t l_out,
const size_t stride,
const size_t padding,
const size_t out_padding,
const size_t dilation,
const size_t *info,
const T *src,
const T *kernel,
T *dst
) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
// src: (b_size, c_in, l_in)
// k: (c_in, c_out, l_k)
const size_t *src_dims = info;
const size_t *src_s = info + 3;
const size_t *k_dims = info + 6;
const size_t *k_s = info + 9;
const size_t l_k = k_dims[2];
const size_t c_out = k_dims[1];
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];
if (dst_i >= src_dims[0] * c_out * l_out) {
return;
}
// TODO
const size_t b_idx = dst_i / (l_out * c_out);
const size_t dst_c_idx = (dst_i / l_out) % c_out;
// NCL layout.
const size_t out_x = dst_i % l_out;
const size_t src_idx0 = b_idx * src_s[0];
A d = 0;
for (int k_x = 0; k_x < (int)l_k; ++k_x) {
// let out_x = inp_x * p.stride + k_x * p.dilation - p.padding;
int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
if (inp_x_stride < 0 || inp_x_stride % stride) {
continue;
}
int inp_x = inp_x_stride / stride;
if (inp_x >= l_in) continue;
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_x * src_s[2];
const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_x * k_s[2];
d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
}
}
dst[dst_i] = static_cast<T>(d);
}
// Naive implementation of conv_transpose2d.
template <typename T, typename A>
__device__ void conv_transpose2d(
@ -559,22 +507,6 @@ extern "C" __global__ void FN_NAME( \
im2col<TYPENAME>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, info, src, dst); \
} \
#define CONVT1D_OP(TYPENAME, TYPEACC, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t src_numel, \
const size_t l_out, \
const size_t stride, \
const size_t padding, \
const size_t out_padding, \
const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
const TYPENAME *kernel, \
TYPENAME *dst \
) { \
conv_transpose1d<TYPENAME, TYPEACC>(src_numel, l_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \
} \
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t src_numel, \
@ -636,7 +568,6 @@ extern "C" __global__ void FN_NAME( \
#if __CUDA_ARCH__ >= 800
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
CONV2D_OP(__nv_bfloat16, float, conv2d_bf16)
CONVT1D_OP(__nv_bfloat16, float, conv_transpose1d_bf16)
CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16)
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
@ -648,7 +579,6 @@ IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
#if __CUDA_ARCH__ >= 530
CONV1D_OP(__half, float, conv1d_f16)
CONV2D_OP(__half, float, conv2d_f16)
CONVT1D_OP(__half, float, conv_transpose1d_f16)
CONVT2D_OP(__half, float, conv_transpose2d_f16)
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
MAX_POOL2D_OP(__half, max_pool2d_f16)
@ -667,11 +597,6 @@ CONV2D_OP(double, double, conv2d_f64)
CONV2D_OP(uint8_t, uint8_t, conv2d_u8)
CONV2D_OP(uint32_t, uint32_t, conv2d_u32)
CONVT1D_OP(float, float, conv_transpose1d_f32)
CONVT1D_OP(double, double, conv_transpose1d_f64)
CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8)
CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32)
CONVT2D_OP(float, float, conv_transpose2d_f32)
CONVT2D_OP(double, double, conv_transpose2d_f64)
CONVT2D_OP(uint8_t, uint8_t, conv_transpose2d_u8)

View File

@ -55,11 +55,6 @@ __device__ __forceinline__ T relu_fwd(T x) {
return maxg(x, zero);
}
template<typename T>
__device__ __forceinline__ T silu_fwd(T x) {
return x / (static_cast<T>(1) + expg(-x));
}
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
@ -108,7 +103,6 @@ UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
#endif
@ -133,7 +127,6 @@ UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x))
UNARY_OP(__half, urelu_f16, relu_fwd(x))
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
UNARY_OP(__half, usilu_f16, silu_fwd(x))
UNARY_OP1(__half, upowf_f16, powg(x, param))
#endif
@ -180,7 +173,5 @@ UNARY_OP(float, urelu_f32, relu_fwd(x))
UNARY_OP(double, urelu_f64, relu_fwd(x))
UNARY_OP1(float, uelu_f32, elu_fwd(x, param))
UNARY_OP1(double, uelu_f64, elu_fwd(x, param))
UNARY_OP(float, usilu_f32, silu_fwd(x))
UNARY_OP(double, usilu_f64, silu_fwd(x))
UNARY_OP1(float, upowf_f32, powg(x, param))
UNARY_OP1(double, upowf_f64, powg(x, param))

View File

@ -1,2 +0,0 @@
src/compiled/

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.4.0"
version = "0.3.3"
edition = "2021"
description = "Metal kernels for Candle"
@ -9,17 +9,12 @@ keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
metal = { version = "0.27.0", features = ["mps"] }
metal = { version = "0.27.0", features = ["mps"]}
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"
[dev-dependencies]
half = { version = "2.3.1", features = [
"num-traits",
"use-intrinsics",
"rand_distr",
] }
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
rand = "0.8.5"

View File

@ -1,45 +0,0 @@
use std::path::Path;
use std::process::Command;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let files: std::fs::ReadDir = std::fs::read_dir("src/").unwrap();
for file in files {
let file = file?;
let path = file.path();
if let Some(extension) = path.extension() {
if extension == "metal" {
build_kernel(&path)?;
}
println!("cargo:warning=output {:?}", path.file_stem());
}
}
Ok(())
}
fn build_kernel(path: &Path) -> Result<(), Box<dyn std::error::Error>> {
let stem = path
.file_stem()
.expect("expect real filename")
.to_str()
.expect("expect real stem");
Command::new("xcrun")
.args([
"metal",
"-c",
path.as_os_str().to_str().expect("Expect a real filename"),
"-I",
"src/",
"-o",
&format!("src/compiled/{stem}.air"),
])
.output()?;
Command::new("xcrun")
.args([
"metallib",
&format!("src/compiled/{stem}.air"),
"-o",
&format!("src/compiled/{stem}.metallib"),
])
.output()?;
Ok(())
}

View File

@ -17,19 +17,19 @@ METAL_FUNC uint get_strided_index(
using namespace metal;
#define AFFINE(FN_NAME, T) \
#define AFFINE(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
constant float &add, \
device const T *input, \
device T *output, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
output[id] = T(fma(float(input[id]), mul, add)); \
output[id] = TYPENAME(float(input[id]) * mul + add); \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
@ -38,14 +38,14 @@ kernel void FN_NAME##_strided( \
constant size_t *strides, \
constant float &mul, \
constant float &add, \
device const T *input, \
device T *output, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
output[id] = T(fma(float(input[get_strided_index(id, num_dims, dims, strides)]), mul, add)); \
output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \
}
#define POWF(FN_NAME, TYPENAME) \
@ -117,7 +117,7 @@ ELU(elu_f32, float)
ELU(elu_f16, half)
#if defined(__HAVE_BFLOAT__)
#if __METAL_VERSION__ >= 310
AFFINE(affine_bf16, bfloat);
POWF(powf_bf16, bfloat);
ELU(elu_bf16, bfloat);

View File

@ -73,7 +73,7 @@ BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
#define INT64_BINARY_OP_OUT(NAME, FN) \
BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided);
BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided);
BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
@ -105,7 +105,7 @@ INT64_BINARY_OP_OUT(ge, x >= y)
INT64_BINARY_OP_OUT(gt, x > y)
#endif
#if defined(__HAVE_BFLOAT__)
#if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul)

View File

@ -28,7 +28,7 @@ kernel void FN_NAME( \
if (tid >= dim) { \
return; \
} \
output[tid] = static_cast<RIGHT_TYPENAME>(input[tid]); \
output[tid] = RIGHT_TYPENAME(input[tid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -42,38 +42,10 @@ kernel void FN_NAME_STRIDED( \
if (tid >= dim) { \
return; \
} \
output[tid] = static_cast<RIGHT_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)]); \
} \
#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[tid])); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
} \
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
@ -86,14 +58,7 @@ CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
#endif
#if defined(__HAVE_BFLOAT__)
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
#if __METAL_VERSION__ >= 310
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
#endif

View File

@ -173,10 +173,7 @@ SCATTER_ADD_OP(sa_u32_f32, uint, float)
SCATTER_ADD_OP(sa_u32_f16, uint, half)
#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
#if __METAL_VERSION__ >= 310
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)

View File

@ -1,22 +1,20 @@
use metal::{
Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device, Function,
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
};
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
const AFFINE: &[u8] = include_bytes!("compiled/affine.metallib");
const INDEXING: &[u8] = include_bytes!("compiled/indexing.metallib");
const UNARY: &[u8] = include_bytes!("compiled/unary.metallib");
const BINARY: &[u8] = include_bytes!("compiled/binary.metallib");
const TERNARY: &[u8] = include_bytes!("compiled/ternary.metallib");
const CAST: &[u8] = include_bytes!("compiled/cast.metallib");
const CONV: &[u8] = include_bytes!("compiled/conv.metallib");
const REDUCE: &[u8] = include_bytes!("compiled/reduce.metallib");
const RANDOM: &[u8] = include_bytes!("compiled/random.metallib");
const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal");
const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal");
const CONV: &str = include_str!("conv.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &[u8] = include_bytes!("compiled/quantized.metallib");
/// Most kernels apply similarly across the tensors
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
@ -63,12 +61,8 @@ macro_rules! primitive {
}
};
}
primitive!(bool);
primitive!(usize);
primitive!(i32);
primitive!(i64);
primitive!(u32);
primitive!(u64);
primitive!(f32);
impl<T> EncoderParam for &[T] {
@ -123,8 +117,6 @@ pub enum Source {
Reduce,
Mfa,
Conv,
Random,
Quantized,
}
macro_rules! ops{
@ -182,8 +174,8 @@ macro_rules! ops{
pub mod unary {
ops!(
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
tanh, recip, silu
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
recip
);
}
pub mod binary {
@ -223,19 +215,21 @@ type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipeline
pub struct Kernels {
libraries: RwLock<Libraries>,
pipelines: RwLock<Pipelines>,
fence: metal::Fence,
}
impl Kernels {
pub fn new() -> Self {
pub fn new(fence: metal::Fence) -> Self {
let libraries = RwLock::new(Libraries::new());
let pipelines = RwLock::new(Pipelines::new());
Self {
libraries,
pipelines,
fence,
}
}
fn get_library_source(&self, source: Source) -> &'static [u8] {
fn get_library_source(&self, source: Source) -> &'static str {
match source {
Source::Affine => AFFINE,
Source::Unary => UNARY,
@ -245,9 +239,7 @@ impl Kernels {
Source::Cast => CAST,
Source::Reduce => REDUCE,
Source::Conv => CONV,
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
Source::Mfa => MFA,
Source::Mfa => panic!("Invalid lib"),
}
}
@ -262,12 +254,22 @@ impl Kernels {
if let Some(lib) = libraries.get(&source) {
Ok(lib.clone())
} else {
let source_data = self.get_library_source(source);
let lib = device.new_library_with_data(source_data).map_err(|e| {
MetalKernelError::LoadLibraryError(format!(
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
))
})?;
let lib = match source {
Source::Mfa => {
let source_data = MFA;
device.new_library_with_data(source_data).map_err(|e| {
MetalKernelError::LoadLibraryError(format!(
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
))
})?
}
source => {
let source_content = self.get_library_source(source);
device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
}
};
libraries.insert(source, lib.clone());
Ok(lib)
}
@ -343,6 +345,7 @@ pub fn call_unary_contiguous(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, output));
@ -351,6 +354,7 @@ pub fn call_unary_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -372,6 +376,7 @@ pub fn call_unary_strided(
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@ -393,6 +398,7 @@ pub fn call_unary_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -411,6 +417,7 @@ pub fn call_binary_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, left, right, output));
@ -421,6 +428,7 @@ pub fn call_binary_contiguous(
encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -445,6 +453,7 @@ pub fn call_binary_strided(
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
let width: usize = shape.iter().product();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@ -469,6 +478,7 @@ pub fn call_binary_strided(
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -487,6 +497,7 @@ pub fn call_cast_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, (input, input_offset), output));
@ -495,6 +506,7 @@ pub fn call_cast_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -514,6 +526,7 @@ pub fn call_cast_strided(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@ -535,6 +548,7 @@ pub fn call_cast_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -554,6 +568,7 @@ pub fn call_reduce_contiguous(
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -582,6 +597,7 @@ pub fn call_reduce_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -603,6 +619,7 @@ pub fn call_reduce_strided(
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -638,6 +655,7 @@ pub fn call_reduce_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -656,6 +674,7 @@ pub fn call_last_softmax(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -686,6 +705,7 @@ pub fn call_last_softmax(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -705,6 +725,7 @@ pub fn call_affine(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, add, input, output));
@ -713,6 +734,7 @@ pub fn call_affine(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -735,6 +757,7 @@ pub fn call_affine_strided(
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -755,6 +778,7 @@ pub fn call_affine_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -773,6 +797,7 @@ pub fn call_powf(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output));
@ -781,6 +806,7 @@ pub fn call_powf(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -802,6 +828,7 @@ pub fn call_powf_strided(
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -821,6 +848,7 @@ pub fn call_powf_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -839,6 +867,7 @@ pub fn call_elu(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output));
@ -847,6 +876,7 @@ pub fn call_elu(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -868,6 +898,7 @@ pub fn call_elu_strided(
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -887,6 +918,7 @@ pub fn call_elu_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -908,6 +940,7 @@ pub fn call_where_cond_strided(
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let size: usize = shape.iter().product();
@ -936,6 +969,7 @@ pub fn call_where_cond_strided(
encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -962,6 +996,7 @@ pub fn call_index_select(
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -984,6 +1019,7 @@ pub fn call_index_select(
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -1012,6 +1048,7 @@ pub fn call_gather(
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -1034,6 +1071,7 @@ pub fn call_gather(
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -1062,6 +1100,7 @@ pub fn call_scatter_add(
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -1084,6 +1123,7 @@ pub fn call_scatter_add(
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -1113,6 +1153,7 @@ pub fn call_index_add(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -1136,6 +1177,7 @@ pub fn call_index_add(
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -1339,6 +1381,7 @@ pub fn call_gemm(
let block_bytes = block_elements * bytes;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
@ -1354,12 +1397,13 @@ pub fn call_gemm(
// TODO byte_stride_d
let byte_stride_d = 0;
let buffer: Vec<u64> = vec![
byte_stride_a as _,
byte_stride_b as _,
byte_stride_c as _,
byte_stride_d as _,
];
let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
for i in 0..b {
buffer.push((i * byte_stride_a) as u64);
buffer.push((i * byte_stride_b) as u64);
buffer.push((i * byte_stride_c) as u64);
buffer.push((i * byte_stride_d) as u64);
}
encoder.set_bytes(
10,
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
@ -1377,10 +1421,12 @@ pub fn call_gemm(
height: 1,
depth: 1,
};
// println!("grid size {grid_size:?} group size {group_size:?}");
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
@ -1405,6 +1451,7 @@ pub fn call_im2col1d_strided(
let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@ -1424,6 +1471,7 @@ pub fn call_im2col1d_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
@ -1453,6 +1501,7 @@ pub fn call_im2col_strided(
let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@ -1474,6 +1523,7 @@ pub fn call_im2col_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
@ -1499,6 +1549,7 @@ pub fn call_upsample_nearest_2d(
let scale_h = shape[3] as f32 / out_h as f32;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@ -1516,243 +1567,7 @@ pub fn call_upsample_nearest_2d(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_random_uniform(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
min: f32,
max: f32,
length: usize,
seed: &Buffer,
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
if min >= max {
return Err(MetalKernelError::LoadLibraryError(
"min must be less than max".to_string(),
));
}
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, min, max, seed, buffer));
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_random_normal(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
mean: f32,
stddev: f32,
length: usize,
seed: &Buffer,
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, mean, stddev, seed, buffer));
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[derive(Debug, Clone, Copy)]
pub enum GgmlDType {
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2K,
Q3K,
Q4K,
Q5K,
Q6K,
Q8K,
F16,
F32,
}
pub fn call_quantized_matmul_t(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
dtype: GgmlDType,
(b, m, n, k): (usize, usize, usize, usize),
lhs: &Buffer,
lhs_offset: usize,
rhs: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
// Everything is in reverse
let ne00 = k as i64;
let ne01 = n as i64;
let ne02 = b as i64;
let ne03 = 1 as i64;
let nb00 = 0i64;
let nb01 = 0 as i64;
let nb02 = 0 as i64;
let ne10 = k as i64;
let ne11 = m as i64;
let ne12 = b as i64;
let ne13 = 1 as i64;
let nb10 = 0i64;
let nb11 = 0i64;
let nb12 = 0i64;
let ne0 = n as i64;
let ne1 = m as i64;
let r2: u32 = (ne12 / ne02) as u32;
let r3: u32 = (ne13 / ne03) as u32;
let (nth0, nth1, align) = match dtype {
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q8_1 => {
let nth0 = 8;
let nth1 = 8;
let align = 8;
(nth0, nth1, align)
}
GgmlDType::Q2K => {
// Fixing a bug in Metal for GGML
let nth0 = 4;
let nth1 = 8;
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q4K => {
let nth0 = 4;
let nth1 = 8;
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q3K | GgmlDType::Q5K => {
let nth0 = 2;
let nth1 = 32;
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q6K => {
let nth0 = 2;
let nth1 = 32;
let align = 2;
(nth0, nth1, align)
}
GgmlDType::F16 | GgmlDType::Q8K => {
// Original implem uses rows
let nth0 = 32;
let nth1 = 1;
let align = 8;
(nth0, nth1, align)
}
GgmlDType::F32 => {
let nth0 = 32;
let nth1 = 1;
let align = 8;
(nth0, nth1, align)
}
};
let thread_groups_count = MTLSize {
width: divide(ne01 as usize, align),
height: ne11 as u64,
depth: (ne12 * ne13) as u64,
};
let threads_per_threadgroup = MTLSize {
width: nth0,
height: nth1,
depth: 1,
};
let name = match dtype {
GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32",
GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32",
GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32",
GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32",
GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32",
GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32",
GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32",
GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32",
GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32",
GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32",
GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32",
GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32",
GgmlDType::F16 => "kernel_mul_mv_f16_f32",
GgmlDType::F32 => "kernel_mul_mv_f32_f32",
};
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
rhs,
(lhs, lhs_offset),
output,
ne00,
ne01,
ne02,
nb00,
nb01,
nb02,
ne10,
ne11,
ne12,
nb10,
nb11,
nb12,
ne0,
ne1,
r2,
r3
)
);
encoder.set_threadgroup_memory_length(0, 8192);
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())

File diff suppressed because it is too large Load Diff

View File

@ -1,206 +0,0 @@
#include <metal_stdlib>
#include <metal_integer>
#include <metal_atomic>
using namespace metal;
// Constants
// 2^32 and 1/2^32. Useful for converting between float and uint.
static constexpr constant ulong UNIF01_NORM32 = 4294967296;
static constexpr constant float UNIF01_INV32 = 2.328306436538696289e-10;
// 2 * pi
static constexpr constant float TWO_PI = 2.0 * M_PI_F;
static constexpr constant int3 S1 = {13, 19, 12};
static constexpr constant int3 S2 = {2, 25, 4};
static constexpr constant int3 S3 = {3, 11, 17};
// Used to prevent bad seeds.
static constexpr constant uint64_t PHI[16] = {
0x9E3779B97F4A7C15,
0xF39CC0605CEDC834,
0x1082276BF3A27251,
0xF86C6A11D0C18E95,
0x2767F0B153D27B7F,
0x0347045B5BF1827F,
0x01886F0928403002,
0xC1D64BA40F335E36,
0xF06AD7AE9717877E,
0x85839D6EFFBD7DC6,
0x64D325D1C5371682,
0xCADD0CCCFDFFBBE1,
0x626E33B8D04B4331,
0xBBF73C790D94F79D,
0x471C4AB3ED3D82A5,
0xFEC507705E4AE6E5,
};
// Combined Tausworthe and LCG Random Number Generator.
// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application
// https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf
struct HybridTaus {
float state;
HybridTaus() thread = default;
HybridTaus() threadgroup = default;
HybridTaus() device = default;
HybridTaus() constant = default;
// Generate seeds for each thread.
METAL_FUNC static uint4 seed_per_thread(const ulong4 seeds) {
return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL));
}
// Tausworthe generator.
METAL_FUNC static uint taus(const uint z, const int3 s, const uint M) {
uint b = (((z << s.x) ^ z) >> s.y);
return (((z & M) << s.z) ^ b);
}
// LCG generator.
METAL_FUNC static uint lcg(const uint z) {
return (1664525 * z + 1013904223UL);
}
// Initialize the RNG state.
METAL_FUNC static HybridTaus init(const ulong4 seeds) {
uint4 seed = seed_per_thread(seeds);
// Seed #1
uint z1 = taus(seed.x, S1, 4294967294UL);
uint z2 = taus(seed.y, S2, 4294967288UL);
uint z3 = taus(seed.z, S3, 4294967280UL);
uint z4 = lcg(seed.x);
// Seed #2
uint r1 = (z1^z2^z3^z4^seed.y);
z1 = taus(r1, S1, 429496729UL);
z2 = taus(r1, S2, 4294967288UL);
z3 = taus(r1, S3, 429496280UL);
z4 = lcg(r1);
// Seed #3
r1 = (z1^z2^z3^z4^seed.z);
z1 = taus(r1, S1, 429496729UL);
z2 = taus(r1, S2, 4294967288UL);
z3 = taus(r1, S3, 429496280UL);
z4 = lcg(r1);
// Seed #4
r1 = (z1^z2^z3^z4^seed.w);
z1 = taus(r1, S1, 429496729UL);
z2 = taus(r1, S2, 4294967288UL);
z3 = taus(r1, S3, 429496280UL);
z4 = lcg(r1);
HybridTaus rng;
rng.state = (z1^z2^z3^z4) * UNIF01_INV32;
return rng;
}
METAL_FUNC float rand() {
uint seed = this->state * UNIF01_NORM32;
uint z1 = taus(seed, S1, 429496729UL);
uint z2 = taus(seed, S2, 4294967288UL);
uint z3 = taus(seed, S3, 429496280UL);
uint z4 = lcg(seed);
thread float result = this->state;
this->state = (z1^z2^z3^z4) * UNIF01_INV32;
return result;
}
};
template<typename T> METAL_FUNC void rand_uniform(
constant size_t &size,
constant float &min,
constant float &max,
device atomic_uint *seed,
device T *out,
uint tid [[thread_position_in_grid]]
) {
if (tid >= size) {
return;
}
float diff = abs(min - max);
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
out[tid] = static_cast<T>(rng.rand() * diff + min);
if (tid == 0) {
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
// Return early if tid == 0, otherwise we will write to out[size].
return;
}
// Use symmetry to fill the other half of the array.
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
}
// Create Gaussian normal distribution using Box-Muller transform:
// https://en.wikipedia.org/wiki/BoxMuller_transform
template<typename T> METAL_FUNC void normal(
constant size_t &size,
constant float &mean,
constant float &stddev,
device atomic_uint *seed,
device T *out,
uint tid [[thread_position_in_grid]]
) {
if (tid >= size) {
return;
}
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
float u1 = rng.rand();
float u2 = rng.rand();
float cosval;
float sinval = sincos(TWO_PI * u2, cosval);
float mag = stddev * sqrt(-2.0 * log(u1));
float z0 = mag * cosval + mean;
float z1 = mag * sinval + mean;
out[tid] = static_cast<T>(z0);
if (tid == 0) {
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
// Return early if tid == 0, otherwise we will write to out[size].
return;
}
// Use symmetry to fill the other half of the array.
out[size - tid] = static_cast<T>(z1);
}
#define UNIFORM_OP(NAME, T) \
kernel void rand_uniform_##NAME( \
constant size_t &size, \
constant float &min, \
constant float &max, \
device atomic_uint *seed, \
device T *out, \
uint tid [[thread_position_in_grid]] \
) { \
rand_uniform<T>(size, min, max, seed, out, tid); \
} \
#define NORMAL_OP(NAME, T) \
kernel void rand_normal_##NAME( \
constant size_t &size, \
constant float &mean, \
constant float &stddev, \
device atomic_uint *seed, \
device T *out, \
uint tid [[thread_position_in_grid]] \
) { \
normal<T>(size, mean, stddev, seed, out, tid); \
} \
#define RANDOM_OPS(NAME, T) \
UNIFORM_OP(NAME, T) \
NORMAL_OP(NAME, T) \
RANDOM_OPS(f32, float)
RANDOM_OPS(f16, half)
#if __METAL_VERSION__ >= 310
RANDOM_OPS(bf16, bfloat)
#endif

View File

@ -295,7 +295,7 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
#endif
#if defined(__HAVE_BFLOAT__)
#if __METAL_VERSION__ >= 310
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)

View File

@ -17,45 +17,29 @@ METAL_FUNC uint get_strided_index(
return strided_i;
}
template<typename T, typename ID>
METAL_FUNC void where_cond(
constant size_t &numel,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides,
constant size_t *strides_t,
constant size_t *strides_f,
device const ID *ids,
device const T *t,
device const T *f,
device T *out,
uint i [[ thread_position_in_grid ]]
) {
if (i >= numel){
return;
}
uint strided_i = get_strided_index(i, num_dims, dims, strides);
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t);
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f);
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f];
}
#define WHERE_OP(T, ID, FN_NAME) \
kernel void FN_NAME( \
constant size_t &numel, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t *strides_t, \
constant size_t *strides_f, \
device const ID *ids, \
device const T *t, \
device const T *f, \
device T *out, \
uint i [[ thread_position_in_grid ]] \
) { \
where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \
} \
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
kernel void FN_NAME( \
constant size_t &numel, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t *strides_t, \
constant size_t *strides_f, \
device const ID_TYPENAME *ids, \
device const TYPENAME *t, \
device const TYPENAME *f, \
device TYPENAME *out ,\
uint i [[ thread_position_in_grid ]] \
) { \
if (i >= numel){ \
return; \
} \
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
} \
// WHERE_OP(float, int64_t, where_i64_f32)
// WHERE_OP(double, int64_t, where_i64_f64)
@ -70,14 +54,10 @@ kernel void FN_NAME(
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
WHERE_OP(float, uint8_t, where_u8_f32)
WHERE_OP(half, uint8_t, where_u8_f16)
// WHERE_OP(double, uint8_t, where_u8_f64)
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
WHERE_OP(uint32_t, uint8_t, where_u8_u32)
#if __METAL_VERSION__ >= 220
WHERE_OP(int64_t, uint8_t, where_u8_i64)
#endif
#if defined(__HAVE_BFLOAT__)
WHERE_OP(bfloat, uint8_t, where_u8_bf16)
#endif

View File

@ -1,6 +1,6 @@
use super::*;
use half::{bf16, f16};
use metal::{Buffer, Device, MTLResourceOptions};
use metal::{Device, MTLResourceOptions};
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
@ -11,7 +11,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged;
let ptr = data.as_ptr() as *const c_void;
let ptr = data.as_ptr() as *const core::ffi::c_void;
let size = (data.len() * std::mem::size_of::<T>()) as u64;
device.new_buffer_with_data(ptr, size, options)
}
@ -37,7 +37,8 @@ fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
@ -59,7 +60,8 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
@ -94,7 +96,8 @@ fn run_strided<T: Clone>(
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
call_unary_strided(
&device,
command_buffer,
@ -231,25 +234,6 @@ fn gelu_f32() {
assert_eq!(approx(results, 3), expected);
}
#[test]
fn silu_f16() {
let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
let expected: Vec<f32> = vec![-0.0, -0.27, 0.0, 0.73, 1.76, 2.86, 10.0, 20.0];
let results = run(&v, unary::contiguous::silu::HALF);
assert_eq!(approx_f16(results, 2), expected);
}
#[test]
fn silu_f32() {
let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
let expected: Vec<f32> = vec![-0.0, -0.269, 0.0, 0.731, 1.762, 2.858, 10.0, 20.0];
let results = run(&v, unary::contiguous::silu::FLOAT);
assert_eq!(approx(results, 3), expected);
}
#[test]
fn binary_add_f32() {
let left = vec![1.0f32, 2.0, 3.0];
@ -264,37 +248,10 @@ fn binary_add_f32() {
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
}
#[test]
fn binary_ops_bf16() {
let lhs: Vec<bf16> = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect();
let rhs: Vec<bf16> = [4.2f32, 5.5f32, 6.91f32]
.into_iter()
.map(bf16::from_f32)
.collect();
macro_rules! binary_op {
($opname:ident, $opexpr:expr) => {{
let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT);
let expected: Vec<bf16> = lhs
.iter()
.zip(rhs.iter())
.map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y))
.collect();
assert_eq!(results, expected);
}};
}
binary_op!(add, |x, y| x + y);
binary_op!(sub, |x, y| x - y);
binary_op!(mul, |x, y| x * y);
binary_op!(div, |x, y| x / y);
binary_op!(min, |x: bf16, y| x.min(y));
binary_op!(max, |x: bf16, y| x.max(y));
}
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
@ -339,92 +296,10 @@ fn cast_u32_f32() {
assert_eq!(results, vec![1.0f32; 10_000]);
}
#[test]
fn it_cast_bf16_u32() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
let output: Vec<u32> = cast(&input, "cast_bf16_u32");
let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_bf16_f32() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
let output: Vec<f32> = cast(&input, "cast_bf16_f32");
let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_u8_bf16() {
let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
let expected: Vec<bf16> = input
.iter()
.map(|v| bf16::from_f32(*v as f32))
.collect::<Vec<_>>();
assert_eq!(output, expected);
}
#[test]
fn it_cast_u32_bf16() {
let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_f32_bf16() {
let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_bf16_u8() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
let output: Vec<u8> = cast(&input, "cast_bf16_u8");
let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_bf16_f16() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
let output: Vec<f16> = cast(&input, "cast_bf16_f16");
let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_f16_bf16() {
let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
assert_eq!(output, expected);
}
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
@ -459,7 +334,8 @@ fn run_affine_strided<T: Clone>(
add: f64,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
@ -520,14 +396,14 @@ fn index_select() {
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [2, 5];
let ids = [0u32, 1, 0];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
@ -543,46 +419,20 @@ fn index_select_f16() {
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16");
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
approx_f16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
);
}
#[test]
fn index_select_is_u32_bf16() {
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16");
assert_eq!(
approx_bf16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
);
}
#[test]
fn index_select_is_u8_bf16() {
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
let shape = [5, 2];
let ids = [0u8, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16");
assert_eq!(
approx_bf16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
);
}
#[test]
fn index_select_dim1() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 1, 0];
let dim = 1;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
result,
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
@ -594,7 +444,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
shape: &[usize],
ids: &[I],
dim: usize,
name: &'static str,
) -> Vec<T> {
let device = Device::system_default().expect("no device found");
@ -608,7 +457,14 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let dst_el = ids.len() * left_size * right_size;
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
let kernels = Kernels::new();
let name = match core::mem::size_of::<T>() {
4 => "is_u32_f32",
2 => "is_u32_f16",
_ => unimplemented!(),
};
let fence = device.new_fence();
let kernels = Kernels::new(fence);
call_index_select(
&device,
&command_buffer,
@ -643,7 +499,8 @@ fn cos_f16() {
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
@ -673,7 +530,8 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
@ -732,6 +590,7 @@ fn softmax() {
}
let results = run_softmax(&v, last_dim, "softmax_f32");
let results = approx(results, 4);
println!("{results:?}");
assert_eq!(
results.iter().map(|&s| s.round() as usize).sum::<usize>(),
n
@ -791,7 +650,8 @@ fn run_where_cond<I: Clone, T: Clone>(
name: &'static str,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
@ -867,7 +727,8 @@ fn run_gemm<T: Clone>(
rhs_offset: usize,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
@ -945,124 +806,3 @@ fn gemm() {
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
);
}
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options);
let seed = device.new_buffer_with_data(
&seed as *const u32 as *const core::ffi::c_void,
std::mem::size_of::<u32>() as NSUInteger,
options,
);
if name.starts_with("rand_uniform") {
call_random_uniform(
&device,
command_buffer,
&kernels,
name,
a,
b,
length,
&seed,
&output,
)
.unwrap();
} else {
call_random_normal(
&device,
command_buffer,
&kernels,
name,
a,
b,
length,
&seed,
&output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, length)
}
#[test]
fn random() {
fn calc_mean(data: &[f32]) -> f32 {
let sum = data.iter().sum::<f32>() as f32;
let count = data.len();
assert!(count > 0);
sum / count as f32
}
fn calc_stddev(data: &[f32]) -> f32 {
let mean = calc_mean(data);
let count = data.len();
assert!(count > 0);
let variance = data
.iter()
.map(|value| {
let diff = mean - (*value as f32);
diff * diff
})
.sum::<f32>()
/ count as f32;
variance.sqrt()
}
let shape = vec![1024, 10];
let length = shape.iter().product::<usize>();
let seed = 299792458;
let min = -30.0;
let max = 30.0;
let mean = 100.0;
let stddev = 50.0;
macro_rules! validate_random {
($type:ty) => {
let results: Vec<f32> = run_random::<$type>(
concat!("rand_uniform_", stringify!($type)),
seed,
length,
min,
max,
)
.into_iter()
.map(f32::from)
.collect();
results.iter().for_each(|v| {
assert!(*v >= min && *v <= max);
});
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
let results: Vec<f32> = run_random::<$type>(
concat!("rand_normal_", stringify!($type)),
seed,
length,
mean,
stddev,
)
.into_iter()
.map(f32::from)
.collect();
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
};
}
validate_random!(f32);
validate_random!(f16);
validate_random!(bf16);
}

View File

@ -58,15 +58,6 @@ template <typename T> METAL_FUNC T gelu(T x) {
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
}
template <typename T> METAL_FUNC T relu(T in){
if (in < 0) {
return 0;
}
return in;
}
template <typename T> METAL_FUNC T silu(T in){
return in / (static_cast<T>(1) + exp(-in));
}
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
@ -111,7 +102,6 @@ UNARY_OP(neg)
UNARY_OP(exp)
UNARY_OP(log)
UNARY_OP(gelu)
UNARY_OP(silu)
UNARY_OP(abs)
UNARY_OP(ceil)
UNARY_OP(floor)
@ -120,7 +110,7 @@ UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY_OP(tanh)
UNARY_OP(recip)
UNARY_OP(relu)
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
@ -130,7 +120,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided)
UNARY(id, int64_t, copy_i64, copy_i64_strided)
#endif
#if defined(__HAVE_BFLOAT__)
#if __METAL_VERSION__ >= 310
BFLOAT_UNARY_OP(cos)
BFLOAT_UNARY_OP(sin)
BFLOAT_UNARY_OP(sqr)
@ -139,8 +129,6 @@ BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
BFLOAT_UNARY_OP(log)
BFLOAT_UNARY_OP(gelu)
BFLOAT_UNARY_OP(silu)
BFLOAT_UNARY_OP(abs)
BFLOAT_UNARY_OP(ceil)
BFLOAT_UNARY_OP(floor)
BFLOAT_UNARY_OP(round)
@ -148,7 +136,6 @@ BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip)
BFLOAT_UNARY_OP(relu)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
#endif

View File

@ -222,10 +222,7 @@ impl Benchmark for QMatMul {
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
let mm = candle::quantized::QTensor::new(
candle::quantized::QStorage::Cpu(Box::new(zeros)),
(4096, 11008),
)?;
let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
let mm = candle::quantized::QMatMul::from_qtensor(mm)?;
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
Ok((mm, arg))

View File

@ -6,7 +6,6 @@ use serde::Deserialize;
pub enum Activation {
#[default]
Gelu,
#[serde(alias = "gelu_new")]
NewGelu,
Relu,
Relu2,
@ -30,7 +29,7 @@ impl super::Module for Activation {
Self::Relu => xs.relu(),
Self::Relu2 => xs.relu()?.sqr(),
Self::Relu6 => xs.clamp(0f32, 6f32),
Self::Silu => xs.silu(),
Self::Silu => crate::ops::silu(xs),
Self::Sigmoid => crate::ops::sigmoid(xs),
Self::HardSigmoid => crate::ops::hard_sigmoid(xs),
Self::Swiglu => crate::ops::swiglu(xs),

View File

@ -262,19 +262,9 @@ impl BatchNorm {
let target_shape = target_shape.as_slice();
let x = x
.broadcast_sub(
&self
.running_mean
.as_detached_tensor()
.reshape(target_shape)?,
)?
.broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)?
.broadcast_div(
&(self
.running_var
.as_detached_tensor()
.reshape(target_shape)?
+ self.eps)?
.sqrt()?,
&(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?,
)?;
match &self.weight_and_bias {

View File

@ -302,22 +302,6 @@ pub fn conv1d(
Ok(Conv1d::new(ws, Some(bs), cfg))
}
pub fn conv1d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: Conv1dConfig,
vb: crate::VarBuilder,
) -> Result<Conv1d> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints(
(out_channels, in_channels / cfg.groups, kernel_size),
"weight",
init_ws,
)?;
Ok(Conv1d::new(ws, None, cfg))
}
pub fn conv_transpose1d(
in_channels: usize,
out_channels: usize,

Some files were not shown because too many files have changed in this diff Show More