Compare commits

..

4 Commits

Author SHA1 Message Date
c2261d0222 Merge. 2024-01-07 20:27:33 +01:00
06d186355b Change more consitently the test. 2024-01-06 15:20:55 +01:00
2bbd544832 Non random for better quantization quality 2024-01-06 15:16:01 +01:00
504d0b9ac7 Potential bug on q4k. 2024-01-05 14:15:47 +01:00
78 changed files with 1029 additions and 10295 deletions

View File

@ -1,7 +0,0 @@
version: 2
updates:
- package-ecosystem: "cargo"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 5

View File

@ -41,8 +41,8 @@ candle-onnx = { path = "./candle-onnx" }
candle-transformers = { path = "./candle-transformers" } candle-transformers = { path = "./candle-transformers" }
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false } criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.10.0", features = ["f16"] } cudarc = { version = "0.9.14", features = ["f16"] }
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0" hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] } image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
@ -50,20 +50,20 @@ imageproc = { version = "0.23.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] } intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" } libc = { version = "0.2.147" }
log = "0.4" log = "0.4"
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] } memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
num_cpus = "1.15.0" num_cpus = "1.15.0"
num-traits = "0.2.15" num-traits = "0.2.15"
parquet = { version = "50.0.0" } parquet = { version = "45.0.0" }
rand = "0.8.5" rand = "0.8.5"
rand_distr = "0.4.3" rand_distr = "0.4.3"
rayon = "1.7.0" rayon = "1.7.0"
rusttype = { version = "0.9", default-features = false } rusttype = { version = "0.9", default-features = false }
safetensors = "0.4.1" safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] } serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2" serde_plain = "1.0.2"
serde_json = "1.0.99" serde_json = "1.0.99"
thiserror = "1" thiserror = "1"
tokenizers = { version = "0.15.0", default-features = false } tokenizers = { version = "0.13.4", default-features = false }
tracing = "0.1.37" tracing = "0.1.37"
tracing-chrome = "0.7.1" tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7" tracing-subscriber = "0.3.7"

View File

@ -66,7 +66,7 @@ We also provide a some command line based examples using state of the art models
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b. - [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM - [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets. pre-trained on 1T tokens of English and code datasets.
- [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal - [Minimal Mamba](./candle-examples/examples/minimal-mamba/): a minimal
implementation of the Mamba state space model. implementation of the Mamba state space model.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with - [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
better performance than all publicly available 13b models as of 2023-09-28. better performance than all publicly available 13b models as of 2023-09-28.
@ -109,9 +109,6 @@ We also provide a some command line based examples using state of the art models
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained - [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation). evaluation, segmentation).
- [VGG](./candle-examples/examples/vgg/),
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to - [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
generate captions for an image. generate captions for an image.
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation - [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
@ -207,7 +204,7 @@ If you have an addition to this list, please submit a pull request.
- Image to text. - Image to text.
- BLIP. - BLIP.
- Computer Vision Models. - Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG. - DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
- yolo-v3, yolo-v8. - yolo-v3, yolo-v8.
- Segment-Anything Model (SAM). - Segment-Anything Model (SAM).
- File formats: load models from safetensors, npz, ggml, or PyTorch files. - File formats: load models from safetensors, npz, ggml, or PyTorch files.

View File

@ -46,5 +46,6 @@ accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"] metal = ["dep:metal", "dep:candle-metal-kernels"]
[[bench]] [[bench]]
name = "bench_main" name = "matmul"
harness = false harness = false

View File

@ -1,11 +0,0 @@
mod benchmarks;
use criterion::criterion_main;
criterion_main!(
//benchmarks::affine::benches,
//benchmarks::matmul::benches,
//benchmarks::random::benches,
benchmarks::reduce::benches,
//benchmarks::where_cond::benches
);

View File

@ -1,43 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor) {
a.affine(12.34, 56.78).unwrap();
}
fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let b = 1;
let m = 1024;
let k = 1024;
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
let flops = b * m * k * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
run_affine_benchmark(c, &device, DType::F16, "affine_f16");
run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,67 +0,0 @@
pub(crate) mod affine;
pub(crate) mod matmul;
pub(crate) mod random;
pub(crate) mod reduce;
pub(crate) mod where_cond;
use candle_core::{Device, Result};
pub(crate) trait BenchDevice {
fn sync(&self) -> Result<()>;
fn bench_name<S: Into<String>>(&self, name: S) -> String;
}
impl BenchDevice for Device {
fn sync(&self) -> Result<()> {
match self {
Device::Cpu => Ok(()),
Device::Cuda(device) => {
#[cfg(feature = "cuda")]
return Ok(device.synchronize()?);
#[cfg(not(feature = "cuda"))]
panic!("Cuda device without cuda feature enabled: {:?}", device)
}
Device::Metal(device) => {
#[cfg(feature = "metal")]
return Ok(device.wait_until_completed()?);
#[cfg(not(feature = "metal"))]
panic!("Metal device without metal feature enabled: {:?}", device)
}
}
}
fn bench_name<S: Into<String>>(&self, name: S) -> String {
match self {
Device::Cpu => {
let cpu_type = if cfg!(feature = "accelerate") {
"accelerate"
} else if cfg!(feature = "mkl") {
"mkl"
} else {
"cpu"
};
format!("{}_{}", cpu_type, name.into())
}
Device::Cuda(_) => format!("cuda_{}", name.into()),
Device::Metal(_) => format!("metal_{}", name.into()),
}
}
}
struct BenchDeviceHandler {
devices: Vec<Device>,
}
impl BenchDeviceHandler {
pub fn new() -> Result<Self> {
let mut devices = Vec::new();
if cfg!(feature = "metal") {
devices.push(Device::new_metal(0)?);
} else if cfg!(feature = "cuda") {
devices.push(Device::new_cuda(0)?);
}
devices.push(Device::Cpu);
Ok(Self { devices })
}
}

View File

@ -1,63 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn rand_uniform(a: &Tensor) {
a.rand_like(-1.0, 123.0).unwrap();
}
fn rand_normal(a: &Tensor) {
a.randn_like(100.0, 15.0).unwrap();
}
fn run_random_bench(c: &mut Criterion, device: &Device) {
let b = 1;
let rows = 2048;
let cols = 2048;
let dtype = DType::F32;
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let flops = b * rows * cols * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_uniform(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let mut group = c.benchmark_group(device.bench_name("random_normal"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_normal(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_random_bench(c, &device);
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,239 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Storage, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use half::{bf16, f16};
use std::ops::Deref;
use std::time::Instant;
fn run_sum(a: &Tensor) {
a.sum(2).unwrap();
}
fn run_arg_min(a: &Tensor) {
a.argmin(2).unwrap();
}
// TODO: Remove before merging. Softmax impls live in candle-nn, so this is a temporary workaround.
fn softmax(a: &Tensor) -> candle_core::Result<()> {
use candle_core::{backend::BackendStorage, DType};
let (storage, layout) = a.storage_and_layout();
let device = a.device();
if let (Device::Metal(device), Storage::Metal(storage)) = (device, storage.deref()) {
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
let name = match a.dtype() {
DType::F32 => "softmax_f32",
DType::F16 => "softmax_f16",
DType::BF16 => "softmax_bf16",
dtype => candle_core::bail!("softmax-last-dim is not implemented for {dtype:?}"),
};
let n = layout.stride().len();
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
candle_core::bail!("Non contiguous softmax-last-dim is not implemented");
}
let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count();
let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
candle_metal_kernels::call_last_softmax(
device.metal_device(),
&command_buffer,
kernels,
name,
elem_count,
last_dim,
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
&output,
)
.unwrap();
}
Ok(())
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
let (lo, up) = (-1000.0f32, 1000.0f32);
for device in handler.devices {
run_softmax(c, &device, (lo, up));
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
run_reduce(c, &device, (lo, up), false);
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
run_arg_reduce(c, &device, (lo, up), false);
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
run_reduce(c, &device, (lo, up), true);
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
run_arg_reduce(c, &device, (lo, up), true);
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
}
}
fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
if !device.is_metal() {
return;
}
let b = 1;
let m = 1024;
let k = 1024;
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
let flops = b * m * k * T::DTYPE.size_in_bytes();
let name = match T::DTYPE {
DType::F32 => "softmax_f32",
DType::F16 => "softmax_f16",
DType::BF16 => "softmax_bf16",
_ => "softmax",
};
softmax(&a).unwrap();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
softmax(black_box(&a)).unwrap();
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn run_reduce<T: candle_core::FloatDType>(
c: &mut Criterion,
device: &Device,
(lo, up): (T, T),
strided: bool,
) {
let b = 1;
let m = 1024;
let k = 1024;
let a = if strided {
Tensor::rand(lo, up, (b, m, k), &device)
.unwrap()
.transpose(0, 2)
.unwrap()
} else {
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
};
let flops = b * m * k * T::DTYPE.size_in_bytes();
let name = match T::DTYPE {
DType::F32 => {
if strided {
"reduce_f32_strided"
} else {
"reduce_f32"
}
}
DType::F16 => {
if strided {
"reduce_f16_strided"
} else {
"reduce_f16"
}
}
DType::BF16 => {
if strided {
"reduce_bf16_strided"
} else {
"reduce_bf16"
}
}
_ => "reduce",
};
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run_sum(black_box(&a));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn run_arg_reduce<T: candle_core::FloatDType>(
c: &mut Criterion,
device: &Device,
(lo, up): (T, T),
strided: bool,
) {
let b = 1;
let m = 1024;
let k = 1024;
let a = if strided {
Tensor::rand(lo, up, (b, m, k), &device)
.unwrap()
.transpose(0, 2)
.unwrap()
} else {
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
};
let flops = b * m * k * (DType::U32.size_in_bytes() + T::DTYPE.size_in_bytes());
let name = match T::DTYPE {
DType::F32 => {
if strided {
"arg_reduce_f32_strided"
} else {
"arg_reduce_f32"
}
}
DType::F16 => {
if strided {
"arg_reduce_f16_strided"
} else {
"arg_reduce_f16"
}
}
DType::BF16 => {
if strided {
"arg_reduce_bf16_strided"
} else {
"arg_reduce_bf16"
}
}
_ => "unknown",
};
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run_arg_min(black_box(&a));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,64 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
a.where_cond(b, c).unwrap();
}
const fn create_cond_arr<const N: usize>() -> [u8; N] {
let mut arr = [0u8; N];
let mut i = 0;
while i < N {
arr[i] = (i % 2) as u8;
i += 1;
}
arr
}
const B: usize = 1;
const M: usize = 1024;
const K: usize = 1024;
const SIZE: usize = B * M * K;
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
let elements = B * M * K;
// E.g. 2 f32 tensors + 1 u8 tensor
let flops = (2 * elements * dtype.size_in_bytes()) + elements;
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(
black_box(&tensor),
black_box(&on_true),
black_box(&on_false),
);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let device = BenchDeviceHandler::new().unwrap();
for d in device.devices {
run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,25 +1,25 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor}; use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput}; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
use std::time::Instant; use std::time::Instant;
fn run(a: &Tensor, b: &Tensor) { fn run(a: &Tensor, b: &Tensor) {
a.matmul(&b.t().unwrap()).unwrap(); a.matmul(&b.t().unwrap()).unwrap();
} }
fn run_bench(c: &mut Criterion, device: &Device) { fn criterion_benchmark(c: &mut Criterion) {
let b = 1; let b = 1;
let m = 1; let m = 1;
let n = 2048; let n = 2048;
let k = 2048; let k = 2048;
let device = Device::new_metal(0).unwrap();
let dtype = DType::F32; let dtype = DType::F32;
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap(); let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap();
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap(); let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap();
let flops = b * m * n * k; let flops = b * m * n * k;
let mut group = c.benchmark_group(device.bench_name("matmul")); let mut group = c.benchmark_group("matmul_metal");
group.throughput(Throughput::Bytes(flops as u64)); group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| { group.bench_function("iter", move |b| {
b.iter_custom(|iters| { b.iter_custom(|iters| {
@ -27,18 +27,16 @@ fn run_bench(c: &mut Criterion, device: &Device) {
for _i in 0..iters { for _i in 0..iters {
run(black_box(&lhs), black_box(&rhs)); run(black_box(&lhs), black_box(&rhs));
} }
device.sync().unwrap(); if let Device::Metal(device) = &device {
device.wait_until_completed().unwrap();
} else {
panic!("Expected metal device");
}
start.elapsed() start.elapsed()
}) })
}); });
group.finish(); group.finish();
} }
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_bench(c, &device);
}
}
criterion_group!(benches, criterion_benchmark); criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

View File

@ -1,5 +1,5 @@
use candle_core::quantized::{gguf_file, GgmlDType, QTensor}; use candle_core::quantized::{gguf_file, k_quants, QTensor};
use candle_core::{Device, Result}; use candle_core::{Device, Result, Tensor};
use clap::{Parser, Subcommand, ValueEnum}; use clap::{Parser, Subcommand, ValueEnum};
use rayon::prelude::*; use rayon::prelude::*;
@ -11,7 +11,12 @@ enum QuantizationMode {
} }
impl QuantizationMode { impl QuantizationMode {
fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result<QTensor> { fn quantize(
&self,
name: &str,
tensor: QTensor,
default: fn(&Tensor) -> Result<QTensor>,
) -> Result<QTensor> {
match self { match self {
Self::Llama => { Self::Llama => {
// Same behavior as the llama.cpp quantization. // Same behavior as the llama.cpp quantization.
@ -19,9 +24,9 @@ impl QuantizationMode {
if should_quantize { if should_quantize {
let tensor = tensor.dequantize(&Device::Cpu)?; let tensor = tensor.dequantize(&Device::Cpu)?;
if name == "output.weight" { if name == "output.weight" {
QTensor::quantize(&tensor, GgmlDType::Q6K) QTensor::quantize::<k_quants::BlockQ6K>(&tensor)
} else { } else {
QTensor::quantize(&tensor, dtype) default(&tensor)
} }
} else { } else {
Ok(tensor) Ok(tensor)
@ -55,27 +60,6 @@ enum Quantization {
F32, F32,
} }
impl Quantization {
fn dtype(&self) -> GgmlDType {
match self {
Quantization::Q4_0 => GgmlDType::Q4_0,
Quantization::Q4_1 => GgmlDType::Q4_1,
Quantization::Q5_0 => GgmlDType::Q5_0,
Quantization::Q5_1 => GgmlDType::Q5_1,
Quantization::Q8_0 => GgmlDType::Q8_0,
Quantization::Q8_1 => GgmlDType::Q8_1,
Quantization::Q2k => GgmlDType::Q2K,
Quantization::Q3k => GgmlDType::Q3K,
Quantization::Q4k => GgmlDType::Q4K,
Quantization::Q5k => GgmlDType::Q5K,
Quantization::Q6k => GgmlDType::Q6K,
Quantization::Q8k => GgmlDType::Q8K,
Quantization::F16 => GgmlDType::F16,
Quantization::F32 => GgmlDType::F32,
}
}
}
#[derive(ValueEnum, Debug, Clone)] #[derive(ValueEnum, Debug, Clone)]
enum Format { enum Format {
Safetensors, Safetensors,
@ -118,7 +102,7 @@ enum Command {
}, },
Quantize { Quantize {
/// The input file(s), in safetensors format. /// The input file, in gguf format.
in_file: Vec<std::path::PathBuf>, in_file: Vec<std::path::PathBuf>,
/// The output file, in gguf format. /// The output file, in gguf format.
@ -133,15 +117,6 @@ enum Command {
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)] #[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
mode: QuantizationMode, mode: QuantizationMode,
}, },
Dequantize {
/// The input file, in gguf format.
in_file: std::path::PathBuf,
/// The output file, in safetensors format.
#[arg(long)]
out_file: std::path::PathBuf,
},
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
@ -150,12 +125,7 @@ struct Args {
command: Command, command: Command,
} }
fn run_ls( fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
file: &std::path::PathBuf,
format: Option<Format>,
verbose: bool,
device: &Device,
) -> Result<()> {
let format = match format { let format = match format {
Some(format) => format, Some(format) => format,
None => match Format::infer(file) { None => match Format::infer(file) {
@ -221,7 +191,7 @@ fn run_ls(
} }
Format::Ggml => { Format::Ggml => {
let mut file = std::fs::File::open(file)?; let mut file = std::fs::File::open(file)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?; let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>(); let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
tensors.sort_by(|a, b| a.0.cmp(&b.0)); tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, qtensor) in tensors.iter() { for (name, qtensor) in tensors.iter() {
@ -262,8 +232,37 @@ fn run_quantize_safetensors(
} }
println!("tensors: {}", tensors.len()); println!("tensors: {}", tensors.len());
let dtype = q.dtype(); let quantize_fn = match q {
let block_size = dtype.block_size(); Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
Quantization::F16 => QTensor::quantize::<half::f16>,
Quantization::F32 => QTensor::quantize::<f32>,
};
let block_size = match q {
Quantization::Q4_0 => k_quants::QK4_0,
Quantization::Q4_1 => k_quants::QK4_1,
Quantization::Q5_0 => k_quants::QK5_0,
Quantization::Q5_1 => k_quants::QK5_1,
Quantization::Q8_0 => k_quants::QK8_0,
Quantization::Q8_1 => k_quants::QK8_1,
Quantization::Q2k
| Quantization::Q3k
| Quantization::Q4k
| Quantization::Q5k
| Quantization::Q6k
| Quantization::Q8k => k_quants::QK_K,
Quantization::F16 | Quantization::F32 => 1,
};
let qtensors = tensors let qtensors = tensors
.into_par_iter() .into_par_iter()
@ -271,9 +270,9 @@ fn run_quantize_safetensors(
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0; let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
println!(" quantizing {name} {tensor:?} {should_quantize}"); println!(" quantizing {name} {tensor:?} {should_quantize}");
let tensor = if should_quantize { let tensor = if should_quantize {
QTensor::quantize(&tensor, dtype)? quantize_fn(&tensor)?
} else { } else {
QTensor::quantize(&tensor, GgmlDType::F32)? QTensor::quantize::<f32>(&tensor)?
}; };
Ok((name, tensor)) Ok((name, tensor))
}) })
@ -286,29 +285,11 @@ fn run_quantize_safetensors(
Ok(()) Ok(())
} }
fn run_dequantize(
in_file: std::path::PathBuf,
out_file: std::path::PathBuf,
device: &Device,
) -> Result<()> {
let mut in_file = std::fs::File::open(in_file)?;
let content = gguf_file::Content::read(&mut in_file)?;
let mut tensors = std::collections::HashMap::new();
for (tensor_name, _) in content.tensor_infos.iter() {
let tensor = content.tensor(&mut in_file, tensor_name, device)?;
let tensor = tensor.dequantize(device)?;
tensors.insert(tensor_name.to_string(), tensor);
}
candle_core::safetensors::save(&tensors, out_file)?;
Ok(())
}
fn run_quantize( fn run_quantize(
in_files: &[std::path::PathBuf], in_files: &[std::path::PathBuf],
out_file: std::path::PathBuf, out_file: std::path::PathBuf,
q: Quantization, q: Quantization,
qmode: QuantizationMode, qmode: QuantizationMode,
device: &Device,
) -> Result<()> { ) -> Result<()> {
if in_files.is_empty() { if in_files.is_empty() {
candle_core::bail!("no specified input files") candle_core::bail!("no specified input files")
@ -334,15 +315,31 @@ fn run_quantize(
let content = gguf_file::Content::read(&mut in_)?; let content = gguf_file::Content::read(&mut in_)?;
println!("tensors: {}", content.tensor_infos.len()); println!("tensors: {}", content.tensor_infos.len());
let dtype = q.dtype(); let quantize_fn = match q {
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
Quantization::F16 => QTensor::quantize::<half::f16>,
Quantization::F32 => QTensor::quantize::<f32>,
};
let qtensors = content let qtensors = content
.tensor_infos .tensor_infos
.par_iter() .par_iter()
.map(|(name, _)| { .map(|(name, _)| {
println!(" quantizing {name}"); println!(" quantizing {name}");
let mut in_file = std::fs::File::open(&in_files[0])?; let mut in_file = std::fs::File::open(&in_files[0])?;
let tensor = content.tensor(&mut in_file, name, device)?; let tensor = content.tensor(&mut in_file, name)?;
let tensor = qmode.quantize(name, tensor, dtype)?; let tensor = qmode.quantize(name, tensor, quantize_fn)?;
Ok((name, tensor)) Ok((name, tensor))
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
@ -362,7 +359,6 @@ fn run_quantize(
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let args = Args::parse(); let args = Args::parse();
let device = Device::Cpu;
match args.command { match args.command {
Command::Ls { Command::Ls {
files, files,
@ -374,7 +370,7 @@ fn main() -> anyhow::Result<()> {
if multiple_files { if multiple_files {
println!("--- {file:?} ---"); println!("--- {file:?} ---");
} }
run_ls(file, format.clone(), verbose, &device)? run_ls(file, format.clone(), verbose)?
} }
} }
Command::Quantize { Command::Quantize {
@ -382,8 +378,7 @@ fn main() -> anyhow::Result<()> {
out_file, out_file,
quantization, quantization,
mode, mode,
} => run_quantize(&in_file, out_file, quantization, mode, &device)?, } => run_quantize(&in_file, out_file, quantization, mode)?,
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?,
} }
Ok(()) Ok(())
} }

View File

@ -72,7 +72,7 @@ pub mod utils;
mod variable; mod variable;
pub use cpu_backend::CpuStorage; pub use cpu_backend::CpuStorage;
pub use device::{Device, DeviceLocation, NdArray}; pub use device::{Device, DeviceLocation};
pub use dtype::{DType, FloatDType, IntDType, WithDType}; pub use dtype::{DType, FloatDType, IntDType, WithDType};
pub use error::{Error, Result}; pub use error::{Error, Result};
pub use indexer::IndexOp; pub use indexer::IndexOp;

View File

@ -7,9 +7,8 @@ use candle_metal_kernels::Kernels;
use metal; use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path; use std::path::Path;
use std::sync::{Arc, Mutex, RwLock, TryLockError}; use std::sync::{Arc, RwLock, TryLockError};
/// Simple way to catch lock error without /// Simple way to catch lock error without
/// depending on T /// depending on T
@ -85,8 +84,13 @@ pub struct MetalDevice {
command_buffer_index: Arc<RwLock<usize>>, command_buffer_index: Arc<RwLock<usize>>,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
compute_per_buffer: usize, compute_per_buffer: usize,
/// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the
/// execution order to be linear.
/// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
/// compute graph.
fence: metal::Fence,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`] /// Heavily used by [`candle_metal_kernels`], both fences need to match
kernels: Arc<candle_metal_kernels::Kernels>, kernels: Arc<candle_metal_kernels::Kernels>,
/// Simple allocator struct. /// Simple allocator struct.
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over. /// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
@ -102,8 +106,6 @@ pub struct MetalDevice {
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers /// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
/// (strong_count = 1). /// (strong_count = 1).
buffers: AllocatedBuffers, buffers: AllocatedBuffers,
/// Seed for random number generation.
seed: Arc<Mutex<Buffer>>,
} }
impl std::fmt::Debug for MetalDevice { impl std::fmt::Debug for MetalDevice {
@ -219,8 +221,10 @@ impl MetalDevice {
let command_buffer = self.command_buffer()?; let command_buffer = self.command_buffer()?;
command_buffer.set_label("with_data"); command_buffer.set_label("with_data");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
blit.wait_for_fence(&self.fence);
blit.set_label("with_data_blit"); blit.set_label("with_data_blit");
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
blit.update_fence(&self.fence);
blit.end_encoding(); blit.end_encoding();
// This is necessary, for mmaped safetensors // This is necessary, for mmaped safetensors
@ -234,27 +238,6 @@ impl MetalDevice {
Ok(real) Ok(real)
} }
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
let buffer = self.allocate_buffer(
size_in_bytes as NSUInteger,
MTLResourceOptions::StorageModePrivate,
"allocate_zeros",
)?;
let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder();
blit.fill_buffer(
&buffer,
metal::NSRange {
location: 0,
length: buffer.length(),
},
0,
);
blit.end_encoding();
Ok(buffer)
}
/// The critical allocator algorithm /// The critical allocator algorithm
fn allocate_buffer( fn allocate_buffer(
&self, &self,
@ -325,14 +308,35 @@ impl BackendStorage for MetalStorage {
} }
fn to_cpu_storage(&self) -> Result<CpuStorage> { fn to_cpu_storage(&self) -> Result<CpuStorage> {
let length = self.buffer.length() as usize;
let size = self.dtype.size_in_bytes();
if length % size != 0 {
crate::bail!(
"The Metal buffer length is not aligned with dtype {:?}",
self.dtype
);
}
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
{
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.wait_for_fence(&self.device.fence);
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.update_fence(&self.device.fence);
blit.end_encoding();
}
self.device.wait_until_completed()?;
match self.dtype { match self.dtype {
DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)), DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))),
DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)), DType::U32 => Ok(CpuStorage::U32(read_to_vec(&buffer, length / size))),
DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)), DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))),
DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)), DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))),
DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)), DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))),
DType::F32 => Ok(CpuStorage::F32(self.to_cpu()?)), DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))),
DType::F64 => Ok(CpuStorage::F64(self.to_cpu()?)), DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))),
} }
} }
@ -349,7 +353,6 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "affine_f32", DType::F32 => "affine_f32",
DType::F16 => "affine_f16", DType::F16 => "affine_f16",
DType::BF16 => "affine_bf16",
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"), dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
}; };
candle_metal_kernels::call_affine( candle_metal_kernels::call_affine(
@ -368,7 +371,6 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "affine_f32_strided", DType::F32 => "affine_f32_strided",
DType::F16 => "affine_f16_strided", DType::F16 => "affine_f16_strided",
DType::BF16 => "affine_bf16_strided",
dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"), dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"),
}; };
candle_metal_kernels::call_affine_strided( candle_metal_kernels::call_affine_strided(
@ -489,7 +491,6 @@ impl BackendStorage for MetalStorage {
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> { fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let device = self.device.clone(); let device = self.device.clone();
let src_stride = layout.stride(); let src_stride = layout.stride();
let src_dims = layout.shape().dims(); let src_dims = layout.shape().dims();
// Source dims and strides with the sum dims at the end. // Source dims and strides with the sum dims at the end.
@ -503,69 +504,13 @@ impl BackendStorage for MetalStorage {
stride.push(src_stride[dim_idx]); stride.push(src_stride[dim_idx]);
} }
} }
if layout.is_contiguous() {
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
(k, dtype) => {
crate::bail!("Metal contiguous reduce op {k:?} {dtype:?} not implemented")
}
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let dtype = if return_index { DType::U32 } else { self.dtype };
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_reduce_contiguous(
&device.device,
&command_buffer,
&device.kernels,
name,
layout.shape().elem_count(),
dst_el,
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
return Ok(Self::new(buffer, device, self.dtype));
}
for &dim_idx in sum_dims.iter() { for &dim_idx in sum_dims.iter() {
dims.push(src_dims[dim_idx]); dims.push(src_dims[dim_idx]);
stride.push(src_stride[dim_idx]); stride.push(src_stride[dim_idx]);
} }
// The reduction loop requires the shared array to be properly initialized and for
// this we want the number of threads to be a power of two.
let (name, check_empty, return_index) = match (op, self.dtype) { let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false), (ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false), (ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
@ -597,7 +542,7 @@ impl BackendStorage for MetalStorage {
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false), (ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true), (ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true), (ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
(k, dtype) => crate::bail!("Metal strided reduce op {k:?} {dtype:?} not implemented"), (k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
}; };
if check_empty && layout.shape().elem_count() == 0 { if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
@ -645,26 +590,14 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::U8) => "cast_u32_u8", (DType::U32, DType::U8) => "cast_u32_u8",
(DType::U32, DType::I64) => "cast_u32_i64", (DType::U32, DType::I64) => "cast_u32_i64",
(DType::U32, DType::BF16) => "cast_u32_bf16",
(DType::U8, DType::U32) => "cast_u8_u32", (DType::U8, DType::U32) => "cast_u8_u32",
(DType::U8, DType::F32) => "cast_u8_f32", (DType::U8, DType::F32) => "cast_u8_f32",
(DType::U8, DType::I64) => "cast_u8_i64", (DType::U8, DType::I64) => "cast_u8_i64",
(DType::U8, DType::BF16) => "cast_u8_bf16",
(DType::F32, DType::F16) => "cast_f32_f16", (DType::F32, DType::F16) => "cast_f32_f16",
(DType::F32, DType::BF16) => "cast_f32_bf16",
(DType::I64, DType::F32) => "cast_i64_f32",
(DType::F16, DType::BF16) => "cast_f16_bf16",
(DType::F16, DType::F32) => "cast_f16_f32", (DType::F16, DType::F32) => "cast_f16_f32",
(DType::I64, DType::F32) => "cast_i64_f32",
(DType::BF16, DType::U8) => "cast_bf16_u8", (DType::F32, DType::BF16) => "cast_f32_bf16",
(DType::BF16, DType::U32) => "cast_bf16_u32",
(DType::BF16, DType::F16) => "cast_bf16_f16",
(DType::BF16, DType::F32) => "cast_bf16_f32", (DType::BF16, DType::F32) => "cast_bf16_f32",
(left, right) => { (left, right) => {
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented") crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
} }
@ -742,7 +675,6 @@ impl BackendStorage for MetalStorage {
("uround", DType::F32) => contiguous::round::FLOAT, ("uround", DType::F32) => contiguous::round::FLOAT,
("urecip", DType::F32) => contiguous::recip::FLOAT, ("urecip", DType::F32) => contiguous::recip::FLOAT,
("utanh", DType::F32) => contiguous::tanh::FLOAT, ("utanh", DType::F32) => contiguous::tanh::FLOAT,
("urelu", DType::F32) => contiguous::relu::FLOAT,
("ucos", DType::F16) => contiguous::cos::HALF, ("ucos", DType::F16) => contiguous::cos::HALF,
("usin", DType::F16) => contiguous::sin::HALF, ("usin", DType::F16) => contiguous::sin::HALF,
("usqr", DType::F16) => contiguous::sqr::HALF, ("usqr", DType::F16) => contiguous::sqr::HALF,
@ -759,7 +691,6 @@ impl BackendStorage for MetalStorage {
("uround", DType::F16) => contiguous::round::HALF, ("uround", DType::F16) => contiguous::round::HALF,
("urecip", DType::F16) => contiguous::recip::HALF, ("urecip", DType::F16) => contiguous::recip::HALF,
("utanh", DType::F16) => contiguous::tanh::HALF, ("utanh", DType::F16) => contiguous::tanh::HALF,
("urelu", DType::F16) => contiguous::relu::HALF,
(name, dtype) => { (name, dtype) => {
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
} }
@ -790,7 +721,6 @@ impl BackendStorage for MetalStorage {
("uabs", DType::F32) => strided::abs::FLOAT, ("uabs", DType::F32) => strided::abs::FLOAT,
("uceil", DType::F32) => strided::ceil::FLOAT, ("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT, ("ufloor", DType::F32) => strided::floor::FLOAT,
("urelu", DType::F32) => strided::relu::FLOAT,
("uround", DType::F32) => strided::round::FLOAT, ("uround", DType::F32) => strided::round::FLOAT,
("ucos", DType::F16) => strided::cos::HALF, ("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF, ("usin", DType::F16) => strided::sin::HALF,
@ -805,7 +735,6 @@ impl BackendStorage for MetalStorage {
("uabs", DType::F16) => strided::abs::HALF, ("uabs", DType::F16) => strided::abs::HALF,
("uceil", DType::F16) => strided::ceil::HALF, ("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF, ("ufloor", DType::F16) => strided::floor::HALF,
("urelu", DType::F16) => strided::relu::HALF,
("uround", DType::F16) => strided::round::HALF, ("uround", DType::F16) => strided::round::HALF,
(name, dtype) => { (name, dtype) => {
crate::bail!("Metal strided unary {name} {dtype:?} not implemented") crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
@ -861,7 +790,6 @@ impl BackendStorage for MetalStorage {
} }
let name = match (self.dtype, t.dtype()) { let name = match (self.dtype, t.dtype()) {
(DType::U8, DType::F32) => "where_u8_f32", (DType::U8, DType::F32) => "where_u8_f32",
(DType::U8, DType::BF16) => "where_u8_bf16",
(DType::U8, DType::F16) => "where_u8_f16", (DType::U8, DType::F16) => "where_u8_f16",
(DType::U8, DType::I64) => "where_u8_i64", (DType::U8, DType::I64) => "where_u8_i64",
(DType::U8, DType::U32) => "where_u8_u32", (DType::U8, DType::U32) => "where_u8_u32",
@ -1199,12 +1127,8 @@ impl BackendStorage for MetalStorage {
let device = self.device(); let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?; let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) { let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::BF16) => "is_u8_bf16",
(DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16", (DType::U32, DType::F16) => "is_u32_f16",
(DType::U32, DType::BF16) => "is_u32_bf16",
(left, right) => { (left, right) => {
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented") crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
} }
@ -1317,7 +1241,7 @@ impl BackendStorage for MetalStorage {
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger; let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length); blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
blit.end_encoding(); blit.end_encoding();
} else { } else {
let src_shape = src_l.shape(); let src_shape = src_l.shape();
@ -1394,7 +1318,6 @@ impl MetalStorage {
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8), ("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8), ("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8), ("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
("add", DType::F16) => (contiguous::add::HALF, self.dtype), ("add", DType::F16) => (contiguous::add::HALF, self.dtype),
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype), ("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype), ("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
@ -1405,18 +1328,6 @@ impl MetalStorage {
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8), ("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8), ("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8), ("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype),
("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype),
("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype),
("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype),
("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8),
("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8),
("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8),
("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8),
("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8),
("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8),
("add", DType::I64) => (contiguous::add::I64, self.dtype), ("add", DType::I64) => (contiguous::add::I64, self.dtype),
("sub", DType::I64) => (contiguous::sub::I64, self.dtype), ("sub", DType::I64) => (contiguous::sub::I64, self.dtype),
("mul", DType::I64) => (contiguous::mul::I64, self.dtype), ("mul", DType::I64) => (contiguous::mul::I64, self.dtype),
@ -1427,7 +1338,6 @@ impl MetalStorage {
("lt", DType::I64) => (contiguous::lt::I64, DType::U8), ("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
("ge", DType::I64) => (contiguous::ge::I64, DType::U8), ("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
("gt", DType::I64) => (contiguous::gt::I64, DType::U8), ("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
("add", DType::U32) => (contiguous::add::U32, self.dtype), ("add", DType::U32) => (contiguous::add::U32, self.dtype),
("sub", DType::U32) => (contiguous::sub::U32, self.dtype), ("sub", DType::U32) => (contiguous::sub::U32, self.dtype),
("mul", DType::U32) => (contiguous::mul::U32, self.dtype), ("mul", DType::U32) => (contiguous::mul::U32, self.dtype),
@ -1438,7 +1348,6 @@ impl MetalStorage {
("lt", DType::U32) => (contiguous::lt::U32, DType::U8), ("lt", DType::U32) => (contiguous::lt::U32, DType::U8),
("ge", DType::U32) => (contiguous::ge::U32, DType::U8), ("ge", DType::U32) => (contiguous::ge::U32, DType::U8),
("gt", DType::U32) => (contiguous::gt::U32, DType::U8), ("gt", DType::U32) => (contiguous::gt::U32, DType::U8),
("add", DType::U8) => (contiguous::add::U8, self.dtype), ("add", DType::U8) => (contiguous::add::U8, self.dtype),
("sub", DType::U8) => (contiguous::sub::U8, self.dtype), ("sub", DType::U8) => (contiguous::sub::U8, self.dtype),
("mul", DType::U8) => (contiguous::mul::U8, self.dtype), ("mul", DType::U8) => (contiguous::mul::U8, self.dtype),
@ -1449,7 +1358,6 @@ impl MetalStorage {
("lt", DType::U8) => (contiguous::lt::U8, DType::U8), ("lt", DType::U8) => (contiguous::lt::U8, DType::U8),
("ge", DType::U8) => (contiguous::ge::U8, DType::U8), ("ge", DType::U8) => (contiguous::ge::U8, DType::U8),
("gt", DType::U8) => (contiguous::gt::U8, DType::U8), ("gt", DType::U8) => (contiguous::gt::U8, DType::U8),
(name, dtype) => { (name, dtype) => {
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented") crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
} }
@ -1483,7 +1391,6 @@ impl MetalStorage {
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8), ("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8), ("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8), ("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
("badd", DType::F16) => (strided::add::HALF, self.dtype), ("badd", DType::F16) => (strided::add::HALF, self.dtype),
("bsub", DType::F16) => (strided::sub::HALF, self.dtype), ("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
("bmul", DType::F16) => (strided::mul::HALF, self.dtype), ("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
@ -1496,20 +1403,6 @@ impl MetalStorage {
("lt", DType::F16) => (strided::lt::HALF, DType::U8), ("lt", DType::F16) => (strided::lt::HALF, DType::U8),
("ge", DType::F16) => (strided::ge::HALF, DType::U8), ("ge", DType::F16) => (strided::ge::HALF, DType::U8),
("gt", DType::F16) => (strided::gt::HALF, DType::U8), ("gt", DType::F16) => (strided::gt::HALF, DType::U8),
("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype),
("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype),
("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype),
("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype),
("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype),
("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype),
("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8),
("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8),
("le", DType::BF16) => (strided::le::BFLOAT, DType::U8),
("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8),
("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8),
("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8),
("badd", DType::I64) => (strided::add::I64, self.dtype), ("badd", DType::I64) => (strided::add::I64, self.dtype),
("bsub", DType::I64) => (strided::sub::I64, self.dtype), ("bsub", DType::I64) => (strided::sub::I64, self.dtype),
("bmul", DType::I64) => (strided::mul::I64, self.dtype), ("bmul", DType::I64) => (strided::mul::I64, self.dtype),
@ -1522,7 +1415,6 @@ impl MetalStorage {
("lt", DType::I64) => (strided::lt::I64, DType::U8), ("lt", DType::I64) => (strided::lt::I64, DType::U8),
("ge", DType::I64) => (strided::ge::I64, DType::U8), ("ge", DType::I64) => (strided::ge::I64, DType::U8),
("gt", DType::I64) => (strided::gt::I64, DType::U8), ("gt", DType::I64) => (strided::gt::I64, DType::U8),
("badd", DType::U32) => (strided::add::U32, self.dtype), ("badd", DType::U32) => (strided::add::U32, self.dtype),
("bsub", DType::U32) => (strided::sub::U32, self.dtype), ("bsub", DType::U32) => (strided::sub::U32, self.dtype),
("bmul", DType::U32) => (strided::mul::U32, self.dtype), ("bmul", DType::U32) => (strided::mul::U32, self.dtype),
@ -1535,7 +1427,6 @@ impl MetalStorage {
("lt", DType::U32) => (strided::lt::U32, DType::U8), ("lt", DType::U32) => (strided::lt::U32, DType::U8),
("ge", DType::U32) => (strided::ge::U32, DType::U8), ("ge", DType::U32) => (strided::ge::U32, DType::U8),
("gt", DType::U32) => (strided::gt::U32, DType::U8), ("gt", DType::U32) => (strided::gt::U32, DType::U8),
("badd", DType::U8) => (strided::add::U8, self.dtype), ("badd", DType::U8) => (strided::add::U8, self.dtype),
("bsub", DType::U8) => (strided::sub::U8, self.dtype), ("bsub", DType::U8) => (strided::sub::U8, self.dtype),
("bmul", DType::U8) => (strided::mul::U8, self.dtype), ("bmul", DType::U8) => (strided::mul::U8, self.dtype),
@ -1548,7 +1439,6 @@ impl MetalStorage {
("lt", DType::U8) => (strided::lt::U8, DType::U8), ("lt", DType::U8) => (strided::lt::U8, DType::U8),
("ge", DType::U8) => (strided::ge::U8, DType::U8), ("ge", DType::U8) => (strided::ge::U8, DType::U8),
("gt", DType::U8) => (strided::gt::U8, DType::U8), ("gt", DType::U8) => (strided::gt::U8, DType::U8),
(name, dtype) => { (name, dtype) => {
crate::bail!("Metal strided binary {name} {dtype:?} not implemented") crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
} }
@ -1574,28 +1464,6 @@ impl MetalStorage {
command_buffer.set_label("binary"); command_buffer.set_label("binary");
Ok(Self::new(buffer, device.clone(), dtype)) Ok(Self::new(buffer, device.clone(), dtype))
} }
pub(crate) fn to_cpu<T: Clone>(&self) -> Result<Vec<T>> {
let length = self.buffer.length() as usize;
let size = self.dtype.size_in_bytes();
if length % size != 0 {
crate::bail!(
"The Metal buffer length is not aligned with dtype {:?}",
self.dtype
);
}
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
{
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding();
}
self.device.wait_until_completed()?;
Ok(read_to_vec(&buffer, length / size))
}
} }
impl BackendDevice for MetalDevice { impl BackendDevice for MetalDevice {
@ -1608,29 +1476,29 @@ impl BackendDevice for MetalDevice {
command_buffer.enqueue(); command_buffer.enqueue();
let command_buffer = Arc::new(RwLock::new(command_buffer)); let command_buffer = Arc::new(RwLock::new(command_buffer));
let command_buffer_index = Arc::new(RwLock::new(0)); let command_buffer_index = Arc::new(RwLock::new(0));
let kernels = Arc::new(Kernels::new()); let fence = device.new_fence();
let kernels = Arc::new(Kernels::new(fence.clone()));
let buffers = Arc::new(RwLock::new(HashMap::new())); let buffers = Arc::new(RwLock::new(HashMap::new()));
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?, Ok(val) => val.parse()?,
_ => 10, _ => 20,
}; };
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
4,
MTLResourceOptions::StorageModeManaged,
)));
Ok(Self { Ok(Self {
device, device,
fence,
command_queue, command_queue,
command_buffer, command_buffer,
command_buffer_index, command_buffer_index,
compute_per_buffer, compute_per_buffer,
buffers, buffers,
kernels, kernels,
seed,
}) })
} }
fn set_seed(&self, _seed: u64) -> Result<()> {
crate::bail!("Metal set_seed not implemented")
}
fn location(&self) -> crate::DeviceLocation { fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Metal { crate::DeviceLocation::Metal {
gpu_id: self.registry_id() as usize, gpu_id: self.registry_id() as usize,
@ -1642,8 +1510,21 @@ impl BackendDevice for MetalDevice {
} }
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> { fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let size = shape.elem_count() * dtype.size_in_bytes(); let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?;
let buffer = self.allocate_zeros(size)?; let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder();
blit.wait_for_fence(&self.fence);
blit.fill_buffer(
&buffer,
metal::NSRange {
location: 0,
length: buffer.length(),
},
0,
);
blit.update_fence(&self.fence);
blit.end_encoding();
Ok(MetalStorage::new(buffer, self.clone(), dtype)) Ok(MetalStorage::new(buffer, self.clone(), dtype))
} }
@ -1670,31 +1551,12 @@ impl BackendDevice for MetalDevice {
&self, &self,
shape: &Shape, shape: &Shape,
dtype: DType, dtype: DType,
min: f64, mean: f64,
max: f64, stddev: f64,
) -> Result<Self::Storage> { ) -> Result<Self::Storage> {
let name = match dtype { // TODO is there a better way ?
DType::F32 => "rand_uniform_f32", let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
DType::F16 => "rand_uniform_f16", self.storage_from_cpu_storage(&cpu_storage)
DType::BF16 => "rand_uniform_bf16",
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_uniform")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_random_uniform(
&self.device,
&command_buffer,
&self.kernels,
name,
min as f32,
max as f32,
shape.elem_count(),
&*self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::Storage::new(buffer, self.clone(), dtype))
} }
fn rand_normal( fn rand_normal(
@ -1704,43 +1566,9 @@ impl BackendDevice for MetalDevice {
mean: f64, mean: f64,
stddev: f64, stddev: f64,
) -> Result<Self::Storage> { ) -> Result<Self::Storage> {
let name = match dtype { // TODO is there a better way ?
DType::F32 => "rand_normal_f32", let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
DType::F16 => "rand_normal_f16", self.storage_from_cpu_storage(&cpu_storage)
DType::BF16 => "rand_normal_bf16",
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_random_normal(
&self.device,
&command_buffer,
&self.kernels,
name,
mean as f32,
stddev as f32,
shape.elem_count(),
&*self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::Storage::new(buffer, self.clone(), dtype))
}
fn set_seed(&self, seed: u64) -> Result<()> {
let seed: u32 = seed.try_into().map_err(|_| {
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())
})?;
let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?;
let contents = seed_buffer.contents();
unsafe {
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4);
}
seed_buffer.did_modify_range(metal::NSRange::new(0, 4));
Ok(())
} }
} }

View File

@ -703,7 +703,6 @@ impl PthTensors {
} }
pub fn get(&self, name: &str) -> Result<Option<Tensor>> { pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
use std::io::Read;
let tensor_info = match self.tensor_infos.get(name) { let tensor_info = match self.tensor_infos.get(name) {
None => return Ok(None), None => return Ok(None),
Some(tensor_info) => tensor_info, Some(tensor_info) => tensor_info,
@ -713,21 +712,14 @@ impl PthTensors {
let mut zip = zip::ZipArchive::new(zip_reader)?; let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_name(&tensor_info.path)?; let mut reader = zip.by_name(&tensor_info.path)?;
// Reading the data is a bit tricky as it can be strided, for now only support the basic // Reading the data is a bit tricky as it can be strided, use an offset, etc.
// case. // For now only support the basic case.
if !tensor_info.layout.is_contiguous() { if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
crate::bail!( crate::bail!(
"cannot retrieve non-contiguous tensors {:?}", "cannot retrieve non-contiguous tensors {:?}",
tensor_info.layout tensor_info.layout
) )
} }
let start_offset = tensor_info.layout.start_offset();
if start_offset > 0 {
std::io::copy(
&mut reader.by_ref().take(start_offset as u64),
&mut std::io::sink(),
)?;
}
let tensor = Tensor::from_reader( let tensor = Tensor::from_reader(
tensor_info.layout.shape().clone(), tensor_info.layout.shape().clone(),
tensor_info.dtype, tensor_info.dtype,

View File

@ -1,9 +1,7 @@
//! Support for the GGML file format. //! Support for the GGML file format.
#[cfg(feature = "metal")] use super::{k_quants, GgmlDType};
use super::metal::load_quantized_metal; use crate::Result;
use super::{k_quants, GgmlDType, QStorage};
use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt}; use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap; use std::collections::HashMap;
@ -123,22 +121,11 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
raw_data: &[u8], raw_data: &[u8],
size_in_bytes: usize, size_in_bytes: usize,
dims: Vec<usize>, dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> { ) -> Result<super::QTensor> {
let raw_data_ptr = raw_data.as_ptr(); let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>(); let n_blocks = size_in_bytes / std::mem::size_of::<T>();
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
let data: QStorage = match device { super::QTensor::new(data.to_vec(), dims)
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
#[cfg(feature = "metal")]
Device::Metal(metal) => load_quantized_metal(metal, data)?,
#[cfg(not(feature = "metal"))]
Device::Metal(_metal) => {
crate::bail!("Metal backend requires `metal` feature")
}
device => unimplemented!("Implement quantized tensor for device {device:?}"),
};
super::QTensor::new(data, dims)
} }
/// Creates a [Tensor] from a raw GGML tensor. /// Creates a [Tensor] from a raw GGML tensor.
@ -146,50 +133,29 @@ pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType, ggml_dtype: GgmlDType,
raw_data: &[u8], raw_data: &[u8],
dims: Vec<usize>, dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> { ) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>(); let tensor_elems = dims.iter().product::<usize>();
let block_size = ggml_dtype.block_size(); let blck_size = ggml_dtype.blck_size();
if tensor_elems % block_size != 0 { if tensor_elems % blck_size != 0 {
crate::bail!( crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {block_size}" "the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
) )
} }
let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size(); let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
match ggml_dtype { match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device), GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device), GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_0 => { GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device) GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
} GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_1 => { GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device) GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
} GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_0 => { GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device) GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
} GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_1 => { GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q8_0 => {
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q2K => {
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q3K => {
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4K => {
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5K => {
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q6K => {
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
}
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
} }
} }
@ -197,7 +163,6 @@ pub fn qtensor_from_ggml(
fn read_one_tensor<R: std::io::Seek + std::io::Read>( fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R, reader: &mut R,
magic: VersionedMagic, magic: VersionedMagic,
device: &Device,
) -> Result<(String, super::QTensor)> { ) -> Result<(String, super::QTensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?; let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?; let name_len = reader.read_u32::<LittleEndian>()?;
@ -218,11 +183,11 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
} }
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>(); let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
let tensor_elems = dims.iter().product::<usize>(); let tensor_elems = dims.iter().product::<usize>();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size(); let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
// TODO: Mmap version to avoid copying the data around? // TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes]; let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?; reader.read_exact(&mut raw_data)?;
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) { match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
Ok(tensor) => Ok((name, tensor)), Ok(tensor) => Ok((name, tensor)),
Err(e) => crate::bail!("Error creating tensor {name}: {e}"), Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
} }
@ -236,10 +201,7 @@ pub struct Content {
} }
impl Content { impl Content {
pub fn read<R: std::io::Seek + std::io::Read>( pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
reader: &mut R,
device: &Device,
) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505 // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?; let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?; reader.seek(std::io::SeekFrom::Start(0))?;
@ -249,7 +211,7 @@ impl Content {
let mut tensors = HashMap::new(); let mut tensors = HashMap::new();
while reader.stream_position()? != last_position { while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic, device)?; let (name, tensor) = read_one_tensor(reader, magic)?;
tensors.insert(name, tensor); tensors.insert(name, tensor);
} }
Ok(Self { Ok(Self {

View File

@ -3,7 +3,7 @@
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md //! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor}; use super::{GgmlDType, QTensor};
use crate::{Device, Result}; use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap; use std::collections::HashMap;
@ -59,25 +59,19 @@ impl TensorInfo {
&self, &self,
reader: &mut R, reader: &mut R,
tensor_data_offset: u64, tensor_data_offset: u64,
device: &Device,
) -> Result<QTensor> { ) -> Result<QTensor> {
let tensor_elems = self.shape.elem_count(); let tensor_elems = self.shape.elem_count();
let block_size = self.ggml_dtype.block_size(); let blck_size = self.ggml_dtype.blck_size();
if tensor_elems % block_size != 0 { if tensor_elems % blck_size != 0 {
crate::bail!( crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {block_size}" "the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
) )
} }
let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size(); let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
let mut raw_data = vec![0u8; size_in_bytes]; let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?; reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?; reader.read_exact(&mut raw_data)?;
super::ggml_file::qtensor_from_ggml( super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
self.ggml_dtype,
&raw_data,
self.shape.dims().to_vec(),
device,
)
} }
} }
@ -466,13 +460,12 @@ impl Content {
&self, &self,
reader: &mut R, reader: &mut R,
name: &str, name: &str,
device: &Device,
) -> Result<QTensor> { ) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) { let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info, Some(tensor_info) => tensor_info,
None => crate::bail!("cannot find tensor info for {name}"), None => crate::bail!("cannot find tensor info for {name}"),
}; };
tensor_info.read(reader, self.tensor_data_offset, device) tensor_info.read(reader, self.tensor_data_offset)
} }
} }
@ -524,9 +517,10 @@ pub fn write<W: std::io::Seek + std::io::Write>(
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}" "internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
) )
} }
let data = tensor.data()?; let data_ptr = tensor.as_ptr();
let size_in_bytes = data.len(); let size_in_bytes = tensor.storage_size_in_bytes();
w.write_all(&data)?; let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
w.write_all(data)?;
let padding = 31 - (31 + size_in_bytes) % 32; let padding = 31 - (31 + size_in_bytes) % 32;
w.write_all(&vec![0u8; padding])?; w.write_all(&vec![0u8; padding])?;
} }

View File

@ -1545,13 +1545,13 @@ impl GgmlType for BlockQ5K {
let d2 = d * sc as f32; let d2 = d * sc as f32;
let m2 = min * m as f32; let m2 = min * m as f32;
for (ql, qh) in ql.iter().zip(qh) { for (ql, qh) in ql.iter().zip(qh) {
let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 }; let to_add = if qh & u1 != 0 { 16 } else { 1 };
y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1; y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1;
ys_index += 1; ys_index += 1;
} }
for (ql, qh) in ql.iter().zip(qh) { for (ql, qh) in ql.iter().zip(qh) {
let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 }; let to_add = if qh & u2 != 0 { 16 } else { 1 };
y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2; y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2;
ys_index += 1; ys_index += 1;
} }
is += 2; is += 2;

View File

@ -1,153 +0,0 @@
use super::{GgmlDType, QStorage};
use crate::{DType, MetalDevice, MetalStorage, Result};
use metal::Buffer;
use std::sync::Arc;
pub struct QMetalStorage {
dtype: GgmlDType,
device: MetalDevice,
buffer: Arc<Buffer>,
}
impl QMetalStorage {
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: GgmlDType) -> Self {
Self {
device,
buffer,
dtype,
}
}
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding();
self.device.wait_until_completed()?;
let mut out = vec![0.0; elem_count];
match self.dtype {
GgmlDType::F32 => {
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
f32::to_float(&vec, &mut out)?;
}
GgmlDType::F16 => {
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
half::f16::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_0 => {
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_1 => {
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_0 => {
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_1 => {
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_0 => {
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_1 => {
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q2K => {
let vec: Vec<crate::quantized::BlockQ2K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
}
GgmlDType::Q3K => {
let vec: Vec<crate::quantized::BlockQ3K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
}
GgmlDType::Q4K => {
let vec: Vec<crate::quantized::BlockQ4K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
}
GgmlDType::Q5K => {
let vec: Vec<crate::quantized::BlockQ5K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
}
GgmlDType::Q6K => {
let vec: Vec<crate::quantized::BlockQ6K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
}
GgmlDType::Q8K => {
let vec: Vec<crate::quantized::BlockQ8K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
}
}
let buffer = self.device.new_buffer_with_data(&out)?;
Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
}
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
// Quantization only happens on CPU for now.
let src = src.to_cpu::<f32>()?;
let elem_count = src.len();
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;
qcpu_storage.quantize(&src)?;
let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
self.buffer = buffer;
Ok(())
}
}
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
device: &MetalDevice,
data: &[T],
) -> Result<QStorage> {
let buffer = device.new_buffer_with_data(data)?;
let device = device.clone();
Ok(QStorage::Metal(QMetalStorage {
dtype: T::DTYPE,
device,
buffer,
}))
}
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
assert!(!ptr.is_null());
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
slice.to_vec()
}

View File

@ -1,125 +1,23 @@
#[cfg(feature = "metal")] use crate::{Device, Result, Shape, Tensor};
use crate::{backend::BackendStorage, DType};
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;
#[cfg(target_feature = "avx")] #[cfg(target_feature = "avx")]
pub mod avx; pub mod avx;
pub mod ggml_file; pub mod ggml_file;
pub mod gguf_file; pub mod gguf_file;
pub mod k_quants; pub mod k_quants;
#[cfg(feature = "metal")]
pub mod metal;
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
pub mod neon; pub mod neon;
#[cfg(target_feature = "simd128")] #[cfg(target_feature = "simd128")]
pub mod simd128; pub mod simd128;
pub mod utils; pub mod utils;
use half::f16;
pub use k_quants::GgmlType; pub use k_quants::GgmlType;
pub struct QTensor { pub struct QTensor {
storage: QStorage, data: Box<dyn QuantizedType>,
shape: Shape, shape: Shape,
} }
impl Device {
fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
match self {
Device::Cpu => {
let storage = dtype.cpu_zeros(elem_count);
Ok(QStorage::Cpu(storage))
}
#[cfg(feature = "metal")]
Device::Metal(metal) => {
let size = elem_count * dtype.type_size() / dtype.block_size();
let buffer = metal.allocate_zeros(size)?;
Ok(QStorage::Metal(metal::QMetalStorage::new(
buffer,
metal.clone(),
dtype,
)))
}
#[cfg(not(feature = "metal"))]
Device::Metal(_metal) => {
crate::bail!("Metal feature not activated");
}
Device::Cuda(_cuda) => {
crate::bail!("Cuda ggml quantization not supported");
}
}
}
}
pub enum QStorage {
Cpu(Box<dyn QuantizedType>),
#[cfg(feature = "metal")]
Metal(metal::QMetalStorage),
}
impl QStorage {
fn block_size(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.block_size(),
#[cfg(feature = "metal")]
QStorage::Metal(storage) => storage.dtype().block_size(),
}
}
fn dtype(&self) -> GgmlDType {
match self {
QStorage::Cpu(storage) => storage.dtype(),
#[cfg(feature = "metal")]
QStorage::Metal(storage) => storage.dtype(),
}
}
fn size_in_bytes(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
#[cfg(feature = "metal")]
QStorage::Metal(storage) => storage.buffer().length() as usize,
}
}
fn quantize(&mut self, src: &Storage) -> Result<()> {
match (self, src) {
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
storage.from_float(src.as_slice::<f32>()?)?;
}
#[cfg(feature = "metal")]
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
_ => crate::bail!("Invalid dequantize storage locations do not match"),
}
Ok(())
}
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
match self {
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
#[cfg(feature = "metal")]
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
}
}
fn data(&self) -> Result<Cow<[u8]>> {
match self {
QStorage::Cpu(storage) => {
let data_ptr = storage.as_ptr();
let size_in_bytes = storage.storage_size_in_bytes();
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
Ok(Cow::from(data))
}
#[cfg(feature = "metal")]
QStorage::Metal(_storage) => {
crate::bail!("not implemented");
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GgmlDType { pub enum GgmlDType {
F32, F32,
@ -179,25 +77,6 @@ impl GgmlDType {
} }
} }
/// The block dtype
pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
match self {
Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
}
}
/// The type size for blocks in bytes. /// The type size for blocks in bytes.
pub fn type_size(&self) -> usize { pub fn type_size(&self) -> usize {
use k_quants::*; use k_quants::*;
@ -221,7 +100,7 @@ impl GgmlDType {
} }
/// The block size, i.e. the number of elements stored in each block. /// The block size, i.e. the number of elements stored in each block.
pub fn block_size(&self) -> usize { pub fn blck_size(&self) -> usize {
match self { match self {
Self::F32 => 1, Self::F32 => 1,
Self::F16 => 1, Self::F16 => 1,
@ -240,13 +119,9 @@ impl GgmlDType {
pub trait QuantizedType: Send + Sync { pub trait QuantizedType: Send + Sync {
fn dtype(&self) -> GgmlDType; fn dtype(&self) -> GgmlDType;
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>; fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>; fn to_float(&self, ys: &mut [f32]) -> Result<()>;
fn storage_size_in_bytes(&self) -> usize; fn storage_size_in_bytes(&self) -> usize;
fn as_ptr(&self) -> *const u8; fn as_ptr(&self) -> *const u8;
fn block_size(&self) -> usize;
#[allow(clippy::wrong_self_convention)]
fn from_float(&mut self, xs: &[f32]) -> Result<()>;
fn size(&self) -> usize;
} }
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> { impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
@ -254,26 +129,12 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
k_quants::matmul(mkn, lhs, self.as_slice(), dst) k_quants::matmul(mkn, lhs, self.as_slice(), dst)
} }
fn size(&self) -> usize {
self.len() * core::mem::size_of::<T>()
}
fn from_float(&mut self, xs: &[f32]) -> Result<()> {
T::from_float(xs, self)
}
fn dtype(&self) -> GgmlDType { fn dtype(&self) -> GgmlDType {
T::DTYPE T::DTYPE
} }
fn block_size(&self) -> usize { fn to_float(&self, ys: &mut [f32]) -> Result<()> {
T::BLCK_SIZE T::to_float(self.as_slice(), ys)
}
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
let mut ys = vec![0.0f32; elem_count];
T::to_float(self.as_slice(), &mut ys)?;
Ok(CpuStorage::F32(ys))
} }
fn storage_size_in_bytes(&self) -> usize { fn storage_size_in_bytes(&self) -> usize {
@ -291,49 +152,56 @@ impl std::fmt::Debug for QTensor {
} }
} }
fn check_shape(shape: &Shape, block_size: usize) -> Result<()> { fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
let dims = shape.dims(); let dims = shape.dims();
if dims.is_empty() { if dims.is_empty() {
crate::bail!("scalar tensor cannot be quantized {shape:?}") crate::bail!("scalar tensor cannot be quantized {shape:?}")
} }
if dims[dims.len() - 1] % block_size != 0 { if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
crate::bail!( crate::bail!(
"quantized tensor must have their last dim divisible by block size {shape:?} {}", "quantized tensor must have their last dim divisible by block size {shape:?} {}",
block_size T::BLCK_SIZE
) )
} }
Ok(()) Ok(())
} }
impl QTensor { impl QTensor {
pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> { pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
data: Vec<T>,
shape: S,
) -> Result<Self> {
let shape = shape.into(); let shape = shape.into();
check_shape(&shape, storage.block_size())?; check_shape::<T>(&shape)?;
Ok(Self { storage, shape }) Ok(Self {
data: Box::new(data),
shape,
})
} }
pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> { pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
let shape = src.shape(); let shape = src.shape();
let block_size = dtype.block_size(); check_shape::<T>(shape)?;
check_shape(shape, block_size)?; let src = src
let src = src.to_dtype(crate::DType::F32)?.flatten_all()?; .to_dtype(crate::DType::F32)?
let elem_count = shape.elem_count(); .flatten_all()?
if elem_count % block_size != 0 { .to_vec1::<f32>()?;
if src.len() % T::BLCK_SIZE != 0 {
crate::bail!( crate::bail!(
"tensor size ({shape:?}) is not divisible by block size {}", "tensor size ({shape:?}) is not divisible by block size {}",
block_size T::BLCK_SIZE
) )
} }
let mut storage = src.device().qzeros(elem_count, dtype)?; let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
storage.quantize(&src.storage())?; T::from_float(&src, &mut data)?;
Ok(Self { Ok(Self {
storage, data: Box::new(data),
shape: shape.clone(), shape: shape.clone(),
}) })
} }
pub fn dtype(&self) -> GgmlDType { pub fn dtype(&self) -> GgmlDType {
self.storage.dtype() self.data.dtype()
} }
pub fn rank(&self) -> usize { pub fn rank(&self) -> usize {
@ -345,19 +213,21 @@ impl QTensor {
} }
pub fn dequantize(&self, device: &Device) -> Result<Tensor> { pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
let storage = self.storage.dequantize(self.shape.elem_count())?; let mut f32_data = vec![0f32; self.shape.elem_count()];
let none = crate::op::BackpropOp::none(); self.data.to_float(&mut f32_data)?;
let is_variable = false; Tensor::from_vec(f32_data, &self.shape, device)
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable) }
.to_device(device)
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
self.data.matmul_t(mkn, lhs, dst)
} }
pub fn storage_size_in_bytes(&self) -> usize { pub fn storage_size_in_bytes(&self) -> usize {
self.storage.size_in_bytes() self.data.storage_size_in_bytes()
} }
pub fn data(&self) -> Result<Cow<'_, [u8]>> { pub fn as_ptr(&self) -> *const u8 {
self.storage.data() self.data.as_ptr()
} }
} }
@ -424,93 +294,17 @@ impl crate::CustomOp1 for QTensor {
} }
dst_shape.push(n); dst_shape.push(n);
let dst_shape = Shape::from(dst_shape); let dst_shape = Shape::from(dst_shape);
#[allow(clippy::infallible_destructuring_match)] let storage = storage.as_slice::<f32>()?;
let self_storage = match &self.storage { let storage =
QStorage::Cpu(storage) => storage, &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
#[cfg(feature = "metal")]
_ => crate::bail!("Invalid storage"),
};
let slice = storage.as_slice::<f32>()?;
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let mut dst_storage = vec![0f32; dst_shape.elem_count()]; let mut dst_storage = vec![0f32; dst_shape.elem_count()];
self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?; self.matmul_t(
(dst_shape.elem_count() / n, k, n),
storage,
&mut dst_storage,
)?;
Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
} }
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &crate::MetalStorage,
layout: &crate::Layout,
) -> Result<(crate::MetalStorage, Shape)> {
use crate::MetalError;
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let (n, k) = self.shape.dims2()?;
let mut dst_shape = src_shape.dims().to_vec();
let (b, m) = match dst_shape.len() {
3 => (dst_shape[0], dst_shape[1]),
2 => (1, dst_shape[0]),
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
let device = storage.device().clone();
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
let (buffer, dtype) = match &self.storage {
QStorage::Metal(metal) => (metal.buffer(), metal.dtype()),
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
};
let command_buffer = device.command_buffer()?;
candle_metal_kernels::call_quantized_matmul_t(
device.device(),
&command_buffer,
device.kernels(),
dtype.into(),
(b, m, n, k),
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
buffer,
&dst,
)
.map_err(MetalError::from)?;
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
Ok((dst_storage, dst_shape))
}
}
#[cfg(feature = "metal")]
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
fn from(value: GgmlDType) -> Self {
match value {
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
}
}
} }
impl crate::Module for QMatMul { impl crate::Module for QMatMul {

View File

@ -2578,21 +2578,11 @@ impl Tensor {
} }
/// Returns log(sum(exp(tensor), dim)). /// Returns log(sum(exp(tensor), dim)).
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> { pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
let exp = self.exp()?; let exp = self.exp()?;
let sum = exp.sum(sum_dims)?; let sum = exp.sum(sum_dims)?;
sum.log() sum.log()
} }
/// Pointwise pow operation.
pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.mul(&self.log()?)?.exp()
}
/// Broadcasting version of `pow`.
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.broadcast_mul(&self.log()?)?.exp()
}
} }
macro_rules! bin_trait { macro_rules! bin_trait {

View File

@ -1,7 +1,6 @@
use candle_core::{ use candle_core::{
bail, bail,
quantized::{self, GgmlDType}, quantized::{self, GgmlDType},
test_device,
test_utils::to_vec2_round, test_utils::to_vec2_round,
Device, Module, Result, Tensor, Device, Module, Result, Tensor,
}; };
@ -15,48 +14,16 @@ const GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS: f32 = 0.0075;
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040; const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040;
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02; const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
fn test_matmul( #[test]
device: &Device, fn quantized_matmul() -> Result<()> {
(b, m, n, k): (usize, usize, usize, usize), let cpu = &Device::Cpu;
dtype: GgmlDType,
) -> Result<()> {
let lhs = (0..(m * k))
.map(|v| v as f32 / (m * k) as f32)
.collect::<Vec<_>>();
let rhs = (0..(k * n))
.map(|v| v as f32 / (n * k) as f32)
.collect::<Vec<_>>();
let lhs = Tensor::from_slice(&lhs, (m, k), device)?;
let rhs = Tensor::from_slice(&rhs, (k, n), device)?;
let mm = lhs.matmul(&rhs)?;
let qtensor = quantized::QTensor::quantize(&rhs.t()?, dtype)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&lhs)?;
let error: f32 = ((&mm - &res)?.abs()? / &mm.abs()?)?
.sum_all()?
.to_scalar()?;
let error = error / (b * m * n) as f32;
assert!(
error <= 0.02,
"Error {error} is too big. \nExpected:\n {mm} \nFound:\n {res}\n for {dtype:?}"
);
Ok(())
}
fn quantized_matmul(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let (m, k, n) = (3, 64, 4); let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>(); let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?; let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
let mut dst = vec![42.; 3 * 4]; let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>(); let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!( assert_eq!(
@ -66,7 +33,6 @@ fn quantized_matmul(device: &Device) -> Result<()> {
341876.0, 994283.0, 1655709.0, 2301518.0 341876.0, 994283.0, 1655709.0, 2301518.0
] ]
); );
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
let mm = tensor_lhs.matmul(&tensor_rhs)?; let mm = tensor_lhs.matmul(&tensor_rhs)?;
assert_eq!( assert_eq!(
mm.to_vec2::<f32>()?, mm.to_vec2::<f32>()?,
@ -77,49 +43,35 @@ fn quantized_matmul(device: &Device) -> Result<()> {
] ]
); );
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?; let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?; let res = matmul.forward(&tensor_lhs)?;
match device { assert_eq!(
Device::Metal(_) => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[84946.0, 214126.0, 344757.0, 473798.0],
[213458.0, 604350.0, 1000469.0, 1387990.0],
[341970.0, 994574.0, 1656181.0, 2302182.0]
]
),
_ => assert_eq!(
to_vec2_round(&res, 0)?, to_vec2_round(&res, 0)?,
&[ &[
[85120.0, 214562.0, 345455.0, 474748.0], [85120.0, 214562.0, 345455.0, 474748.0],
[213475.0, 604465.0, 1000686.0, 1388317.0], [213475.0, 604465.0, 1000686.0, 1388317.0],
[341876.0, 994283.0, 1655709.0, 2301518.0] [341876.0, 994283.0, 1655709.0, 2301518.0]
] ]
), );
}
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
Ok(()) Ok(())
} }
fn quantized_matmul_neg(device: &Device) -> Result<()> { #[test]
// TODO Enable this later when we enable cuda. fn quantized_matmul_neg() -> Result<()> {
if device.is_cuda() { let cpu = &Device::Cpu;
return Ok(());
}
let (m, k, n) = (3, 64, 4); let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k)) let lhs = (0..(m * k))
.map(|v| v as f32 - (m * k) as f32 / 2.0) .map(|v| v as f32 - (m * k) as f32 / 2.0)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?; let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
let mut dst = vec![42.; 3 * 4]; let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..k * n) let rhs = (0..k * n)
.map(|v| v as f32 - (k * n) as f32 / 3.0) .map(|v| v as f32 - (k * n) as f32 / 3.0)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?; let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!( assert_eq!(
@ -139,56 +91,32 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
] ]
); );
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?; let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?; let res = matmul.forward(&tensor_lhs)?;
match device { assert_eq!(
Device::Metal(_) => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243666.0, -19714.0, -285433.0, -550453.0],
[23782.0, 21654.0, 19400.0, 18369.0],
[-196102.0, 63022.0, 324233.0, 587191.0]
]
),
_ => assert_eq!(
to_vec2_round(&res, 0)?, to_vec2_round(&res, 0)?,
&[ &[
[243524.0, -19596.0, -285051.0, -549815.0], [243524.0, -19596.0, -285051.0, -549815.0],
[23777.0, 21651.0, 19398.0, 18367.0], [23777.0, 21651.0, 19398.0, 18367.0],
[-196472.0, 63012.0, 324585.0, 587902.0] [-196472.0, 63012.0, 324585.0, 587902.0]
] ]
), );
}
Ok(()) Ok(())
} }
test_device!( #[test]
quantized_matmul, fn quantize_q4_0() -> Result<()> {
quantized_matmul_cpu, use k_quants::BlockQ4_0;
quantized_matmul_cuda,
quantized_matmul_metal
);
test_device!(
quantized_matmul_neg,
quantized_matmul_neg_cpu,
quantized_matmul_neg_cuda,
quantized_matmul_neg_metal
);
fn quantize_q4_0(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let mut dst = vec![0f32; 32 * 4];
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let mut quant = vec![BlockQ4_0::zeros(); 4];
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?; BlockQ4_0::from_float(&src, &mut quant)?;
let dst = quant.dequantize(device)?; BlockQ4_0::to_float(&quant, dst.as_mut_slice())?;
assert_eq!( assert_eq!(
dst.to_vec1::<f32>()?, dst,
&[ &[
-0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625, -0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625,
11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25, 11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25,
@ -204,21 +132,21 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
127.0, 127.0 127.0, 127.0
] ]
); );
ggml_quantization_error_test(GgmlDType::Q4_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ4_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q4_1(device: &Device) -> Result<()> { #[test]
// TODO Enable this later when we enable cuda. fn quantize_q4_1() -> Result<()> {
if device.is_cuda() { use k_quants::BlockQ4_1;
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let mut dst = vec![0f32; 32 * 4];
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?; let mut quant = vec![BlockQ4_1::zeros(); 4];
let dst = quant.dequantize(device)?; BlockQ4_1::from_float(&src, &mut quant)?;
BlockQ4_1::to_float(&quant, dst.as_mut_slice())?;
assert_eq!( assert_eq!(
round_vector(&dst.to_vec1::<f32>()?), round_vector(&dst),
&[ &[
0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332, 0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332,
12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73, 12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73,
@ -234,21 +162,21 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996 118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996
] ]
); );
ggml_quantization_error_test(GgmlDType::Q4_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ4_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q5_0(device: &Device) -> Result<()> { #[test]
// TODO Enable this later when we enable cuda. fn quantize_q5_0() -> Result<()> {
if device.is_cuda() { use k_quants::BlockQ5_0;
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let mut dst = vec![0f32; 32 * 4];
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?; let mut quant = vec![BlockQ5_0::zeros(); 4];
let dst = quant.dequantize(device)?; BlockQ5_0::from_float(&src, &mut quant)?;
BlockQ5_0::to_float(&quant, dst.as_mut_slice())?;
assert_eq!( assert_eq!(
round_vector(&dst.to_vec1::<f32>()?), round_vector(&dst),
&[ &[
-0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625, -0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625,
11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313, 11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313,
@ -264,21 +192,21 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0 119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0
] ]
); );
ggml_quantization_error_test(GgmlDType::Q5_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ5_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q5_1(device: &Device) -> Result<()> { #[test]
// TODO Enable this later when we enable cuda. fn quantize_q5_1() -> Result<()> {
if device.is_cuda() { use k_quants::BlockQ5_1;
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let mut dst = vec![0f32; 32 * 4];
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?; let mut quant = vec![BlockQ5_1::zeros(); 4];
let dst = quant.dequantize(device)?; BlockQ5_1::from_float(&src, &mut quant)?;
BlockQ5_1::to_float(&quant, dst.as_mut_slice())?;
assert_eq!( assert_eq!(
round_vector(&dst.to_vec1::<f32>()?), dst,
&[ &[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
@ -292,11 +220,13 @@ fn quantize_q5_1(device: &Device) -> Result<()> {
124.0, 125.0, 126.0, 127.0 124.0, 125.0, 126.0, 127.0
] ]
); );
ggml_quantization_error_test(GgmlDType::Q5_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
ggml_quantization_error_test::<BlockQ5_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result<Tensor> { /// Generates a small test vector ranging from -`bound` to `bound` with `size` steps
fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
assert!( assert!(
size % crate::quantized::k_quants::QK_K == 0, size % crate::quantized::k_quants::QK_K == 0,
"size must be a multiple of {}", "size must be a multiple of {}",
@ -306,8 +236,10 @@ fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result<Tensor>
let src = (0..size) let src = (0..size)
.map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.)) .map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let dst = vec![0f32; size];
assert_eq!([src[0], src[size / 2]], [-bound, 0.0]); assert_eq!([src[0], src[size / 2]], [-bound, 0.0]);
Tensor::from_vec(src, (size,), device) (src, dst)
} }
/// Round a vector /// Round a vector
@ -356,12 +288,11 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
/// Similar to the GGML quantization unit test: /// Similar to the GGML quantization unit test:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50 /// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f32) -> Result<()> { fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
let src = create_ggml_like_vector(0.0); let src = create_ggml_like_vector(0.0);
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?; let mut dst = vec![0.0; GGML_TEST_SIZE];
let quant = quantized::QTensor::quantize(&src, dtype)?; let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
let dst = quant.dequantize(device)?; let error = calculate_rmse(src.as_slice(), dst.as_slice());
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
if error > max_error { if error > max_error {
bail!( bail!(
"Quantization error {} exceeds max error {}", "Quantization error {} exceeds max error {}",
@ -372,19 +303,19 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
Ok(()) Ok(())
} }
fn quantize_q2k(device: &Device) -> Result<()> { fn quantize_roundtrip<T: GgmlType>(src: &[f32], dst: &mut [f32]) -> Result<Vec<T>> {
// TODO Enable this later when we enable cuda. let mut quant = vec![T::zeros(); src.len() / T::BLCK_SIZE];
if device.is_cuda() { T::from_float(src, &mut quant)?;
return Ok(()); T::to_float(&quant, dst)?;
Ok(quant)
} }
let dtype = GgmlDType::Q2K;
let src = get_test_vector2(0.5, 1024, device)?; #[test]
let quant = quantized::QTensor::quantize(&src, dtype)?; fn quantize_q2k() -> Result<()> {
let dst = quant.dequantize(device)?; use k_quants::BlockQ2K;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ2K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.1); compare_with_error(dst.as_slice(), src.as_slice(), 0.1);
// Test some specific values // Test some specific values
@ -398,30 +329,20 @@ fn quantize_q2k(device: &Device) -> Result<()> {
[-0.499, -0.366, -0.249, 0.0, 0.295, 0.492] [-0.499, -0.366, -0.249, 0.0, 0.295, 0.492]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ2K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?; ggml_quantization_error_test::<BlockQ2K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
Ok(()) Ok(())
} }
fn quantize_q3k(device: &Device) -> Result<()> { #[test]
// TODO Enable this later when we enable cuda. fn quantize_q3k() -> Result<()> {
if device.is_cuda() { use k_quants::BlockQ3K;
return Ok(());
}
let dtype = GgmlDType::Q3K;
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ3K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.03); compare_with_error(dst.as_slice(), src.as_slice(), 0.03);
// Test some specific values // Test some specific values
@ -435,30 +356,20 @@ fn quantize_q3k(device: &Device) -> Result<()> {
[-0.493, -0.37, -0.243, -0.0, 0.292, 0.492] [-0.493, -0.37, -0.243, -0.0, 0.292, 0.492]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ3K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?; ggml_quantization_error_test::<BlockQ3K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
Ok(()) Ok(())
} }
fn quantize_q4k(device: &Device) -> Result<()> { #[test]
// TODO Enable this later when we enable cuda. fn quantize_q4k() -> Result<()> {
if device.is_cuda() { use k_quants::BlockQ4K;
return Ok(());
}
let dtype = GgmlDType::Q4K;
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ4K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.017); compare_with_error(dst.as_slice(), src.as_slice(), 0.017);
// Test some specific values // Test some specific values
@ -472,31 +383,21 @@ fn quantize_q4k(device: &Device) -> Result<()> {
[-0.5, -0.373, -0.25, 0.0, 0.288, 0.498] [-0.5, -0.373, -0.25, 0.0, 0.288, 0.498]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ4K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ4K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q5k(device: &Device) -> Result<()> { #[test]
// TODO Enable this later when we enable cuda. fn quantize_q5k() -> Result<()> {
if device.is_cuda() { use k_quants::BlockQ5K;
return Ok(());
}
let dtype = GgmlDType::Q5K;
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ5K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.009); compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
// Test some specific values // Test some specific values
assert_eq!( assert_eq!(
@ -506,33 +407,24 @@ fn quantize_q5k(device: &Device) -> Result<()> {
let dst = round_vector(&dst); let dst = round_vector(&dst);
assert_eq!( assert_eq!(
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
[-0.5, -0.373, -0.25, 0.0, 0.279, 0.499] [-0.499, -0.372, -0.249, 0.001, 0.279, 0.499]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ5K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ5K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q6k(device: &Device) -> Result<()> { #[test]
// TODO Enable this later when we enable cuda. fn quantize_q6k() -> Result<()> {
if device.is_cuda() { use k_quants::BlockQ6K;
return Ok(());
}
let dtype = GgmlDType::Q6K;
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ6K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.008); compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
// Test some specific values // Test some specific values
@ -546,31 +438,22 @@ fn quantize_q6k(device: &Device) -> Result<()> {
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5] [-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ6K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ6K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q8k(device: &Device) -> Result<()> { #[test]
// TODO Enable this later when we enable cuda. fn quantize_q8k() -> Result<()> {
if device.is_cuda() { use k_quants::BlockQ8K;
return Ok(());
}
let dtype = GgmlDType::Q8K;
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ8K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.008); compare_with_error(dst.as_slice(), src.as_slice(), 0.003);
// Test some specific values // Test some specific values
assert_eq!( assert_eq!(
@ -583,79 +466,15 @@ fn quantize_q8k(device: &Device) -> Result<()> {
[-0.5, -0.375, -0.25, -0.0, 0.281, 0.499] [-0.5, -0.375, -0.25, -0.0, 0.281, 0.499]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ8K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ8K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
test_device!(
quantize_q4_0,
quantize_q4_0_cpu,
quantize_q4_0_cuda,
quantize_q4_0_metal
);
test_device!(
quantize_q4_1,
quantize_q4_1_cpu,
quantize_q4_1_cuda,
quantize_q4_1_metal
);
test_device!(
quantize_q5_0,
quantize_q5_0_cpu,
quantize_q5_0_cuda,
quantize_q5_0_metal
);
test_device!(
quantize_q5_1,
quantize_q5_1_cpu,
quantize_q5_1_cuda,
quantize_q5_1_metal
);
test_device!(
quantize_q2k,
quantize_q2k_cpu,
quantize_q2k_cuda,
quantize_q2k_metal
);
test_device!(
quantize_q3k,
quantize_q3k_cpu,
quantize_q3k_cuda,
quantize_q3k_metal
);
test_device!(
quantize_q4k,
quantize_q4k_cpu,
quantize_q4k_cuda,
quantize_q4k_metal
);
test_device!(
quantize_q5k,
quantize_q5k_cpu,
quantize_q5k_cuda,
quantize_q5k_metal
);
test_device!(
quantize_q6k,
quantize_q6k_cpu,
quantize_q6k_cuda,
quantize_q6k_metal
);
test_device!(
quantize_q8k,
quantize_q8k_cpu,
quantize_q8k_cuda,
quantize_q8k_metal
);
/// Very simple dot product implementation /// Very simple dot product implementation
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 { fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(a, b)| a * b).sum() a.iter().zip(b).map(|(a, b)| a * b).sum()
@ -739,6 +558,26 @@ fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Res
Ok(()) Ok(())
} }
fn get_small_tensors(
m: usize,
k: usize,
n: usize,
device: &Device,
) -> Result<(Tensor, Tensor, Tensor)> {
let lhs = (0..m * k)
.map(|i| i as f32 / (m * k) as f32)
.collect::<Vec<_>>();
let rhs = (0..n * k)
.map(|i| i as f32 / (n * k) as f32)
.collect::<Vec<_>>();
let lhs = Tensor::from_vec(lhs, (m, k), device)?;
let rhs = Tensor::from_vec(rhs, (n, k), device)?;
let mm = lhs.matmul(&rhs.t()?)?;
Ok((lhs, rhs, mm))
}
#[test] #[test]
fn quantized_mm() -> Result<()> { fn quantized_mm() -> Result<()> {
ggml_matmul_error_test::<k_quants::BlockQ4_0>()?; ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
@ -772,112 +611,6 @@ fn get_random_tensors(
Ok((lhs, rhs, mm)) Ok((lhs, rhs, mm))
} }
#[macro_export]
macro_rules! quantized_matmul {
// TODO: Switch to generating the two last arguments automatically once concat_idents is
// stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
fn $fn_name(device: &Device) -> Result<()> {
if device.is_cuda() {
// TODO Enable Cuda GGML sometime maybe.
return Ok(());
}
test_matmul(device, (1, 3, 4, 256), $dtype)?;
Ok(())
}
test_device!($fn_name, $fn_name_cpu, $fn_name_cuda, $fn_name_metal);
};
}
quantized_matmul!(
quantized_matmul_q4_0_bis,
quantized_matmul_q4_0_cpu,
quantized_matmul_q4_0_cuda,
quantized_matmul_q4_0_metal,
GgmlDType::Q4_0
);
quantized_matmul!(
quantized_matmul_q4_1_bis,
quantized_matmul_q4_1_cpu,
quantized_matmul_q4_1_cuda,
quantized_matmul_q4_1_metal,
GgmlDType::Q4_1
);
quantized_matmul!(
quantized_matmul_q5_0_bis,
quantized_matmul_q5_0_cpu,
quantized_matmul_q5_0_cuda,
quantized_matmul_q5_0_metal,
GgmlDType::Q5_0
);
quantized_matmul!(
quantized_matmul_q5_1_bis,
quantized_matmul_q5_1_cpu,
quantized_matmul_q5_1_cuda,
quantized_matmul_q5_1_metal,
GgmlDType::Q5_1
);
quantized_matmul!(
quantized_matmul_q8_0_bis,
quantized_matmul_q8_0_cpu,
quantized_matmul_q8_0_cuda,
quantized_matmul_q8_0_metal,
GgmlDType::Q8_0
);
// Not implemented in Ggml
// quantized_matmul!(
// quantized_matmul_q8_1_bis,
// quantized_matmul_q8_1_cpu,
// quantized_matmul_q8_1_cuda,
// quantized_matmul_q8_1_metal,
// GgmlDType::Q8_1
// );
// TODO This is bugged (also bugged in GGML
quantized_matmul!(
quantized_matmul_q2k_bis,
quantized_matmul_q2k_cpu,
quantized_matmul_q2k_cuda,
quantized_matmul_q2k_metal,
GgmlDType::Q2K
);
quantized_matmul!(
quantized_matmul_q3k_bis,
quantized_matmul_q3k_cpu,
quantized_matmul_q3k_cuda,
quantized_matmul_q3k_metal,
GgmlDType::Q3K
);
quantized_matmul!(
quantized_matmul_q4k_bis,
quantized_matmul_q4k_cpu,
quantized_matmul_q4k_cuda,
quantized_matmul_q4k_metal,
GgmlDType::Q4K
);
quantized_matmul!(
quantized_matmul_q5k_bis,
quantized_matmul_q5k_cpu,
quantized_matmul_q5k_cuda,
quantized_matmul_q5k_metal,
GgmlDType::Q5K
);
quantized_matmul!(
quantized_matmul_q6k_bis,
quantized_matmul_q6k_cpu,
quantized_matmul_q6k_cuda,
quantized_matmul_q6k_metal,
GgmlDType::Q6K
);
// Not implemented on metal
// quantized_matmul!(
// quantized_matmul_q8k_bis,
// quantized_matmul_q8k_cpu,
// quantized_matmul_q8k_cuda,
// quantized_matmul_q8k_metal,
// GgmlDType::Q8K
// );
#[test] #[test]
fn quantized_matmul_q2k() -> Result<()> { fn quantized_matmul_q2k() -> Result<()> {
use k_quants::BlockQ2K; use k_quants::BlockQ2K;
@ -890,7 +623,7 @@ fn quantized_matmul_q2k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?; let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;
@ -910,20 +643,30 @@ fn quantized_matmul_q3k() -> Result<()> {
let cpu = &Device::Cpu; let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21); let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; let (lhs, rhs, mm) = get_small_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]); // assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?; // let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); // let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); // assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q3K)?; let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let qmm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]); let error: f32 = ((&mm - &qmm)?.abs()? / &mm.abs()?)?
let dst = mm.flatten_all()?.to_vec1::<f32>()?; .sum_all()?
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); .to_scalar()?;
assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]); let error = error / (m * n) as f32;
// assert_eq!(qmm.dims(), [m, n]);
// let dst = qmm.flatten_all()?.to_vec1::<f32>()?;
// let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
// assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]);
assert!(
error < 0.01,
"{error} is too big, shouldn't exceed a few percent. \nGot:{qmm}\nExpected:\n{mm} "
);
ggml_matmul_error_test::<BlockQ3K>()?; ggml_matmul_error_test::<BlockQ3K>()?;
@ -936,20 +679,30 @@ fn quantized_matmul_q4k() -> Result<()> {
let cpu = &Device::Cpu; let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21); let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; let (lhs, rhs, mm) = get_small_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]); // assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?; // let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); // let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); // assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q4K)?; let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let qmm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]); let error: f32 = ((&mm - &qmm)?.abs()? / &mm.abs()?)?
let dst = mm.flatten_all()?.to_vec1::<f32>()?; .sum_all()?
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); .to_scalar()?;
assert_eq!(dst, [1.125, 1.435, -0.201, 1.589]); let error = error / (m * n) as f32;
assert!(
error < 0.01,
"{error} is too big, shouldn't exceed a few percent. \nGot:{qmm}\nExpected:\n{mm} "
);
// assert_eq!(mm.dims(), [m, n]);
// let dst = mm.flatten_all()?.to_vec1::<f32>()?;
// let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
// assert_eq!(dst, [1.125, 1.435, -0.201, 1.589]);
ggml_matmul_error_test::<BlockQ4K>()?; ggml_matmul_error_test::<BlockQ4K>()?;
@ -968,7 +721,7 @@ fn quantized_matmul_q5k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q5K)?; let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;
@ -995,7 +748,7 @@ fn quantized_matmul_q6k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q6K)?; let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;
@ -1020,7 +773,7 @@ fn quantized_matmul_q8k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q8K)?; let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;

View File

@ -1245,23 +1245,11 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
} }
#[test] #[test]
fn log_sum_exp() -> Result<()> { fn logsumexp() -> Result<()> {
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let output = input.log_sum_exp(D::Minus1)?; let output = input.logsumexp(D::Minus1)?;
// The expectations obtained from pytorch. // The expectations obtained from pytorch.
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?; let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
assert_close(&output, &expected, 0.00001)?; assert_close(&output, &expected, 0.00001)?;
Ok(()) Ok(())
} }
#[test]
fn pow() -> Result<()> {
let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let rhs = (&lhs - 2.)?;
let res = lhs.pow(&rhs)?;
assert_eq!(
test_utils::to_vec2_round(&res, 4)?,
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
);
Ok(())
}

View File

@ -27,5 +27,11 @@ fn main() -> Result<()> {
bindings.write(kdir.rust_target).unwrap() bindings.write(kdir.rust_target).unwrap()
} }
} }
#[cfg(not(feature = "cuda"))]
{
for kdir in KERNEL_DIRS.iter() {
let _file = std::fs::File::create(kdir.rust_target)?;
}
}
Ok(()) Ok(())
} }

View File

@ -106,17 +106,17 @@ pub fn main() -> anyhow::Result<()> {
let config = blip::Config::image_captioning_large(); let config = blip::Config::image_captioning_large();
let device = candle_examples::device(args.cpu)?;
let (image_embeds, device, mut model) = if args.quantized { let (image_embeds, device, mut model) = if args.quantized {
let device = Device::Cpu; let device = Device::Cpu;
let image = load_image(args.image)?.to_device(&device)?; let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}"); println!("loaded image {image:?}");
let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?; let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?; let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?; let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
(image_embeds, device, Model::Q(model)) (image_embeds, device, Model::Q(model))
} else { } else {
let device = candle_examples::device(args.cpu)?;
let image = load_image(args.image)?.to_device(&device)?; let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}"); println!("loaded image {image:?}");

View File

@ -0,0 +1 @@
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));

View File

@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.extension() .extension()
.map_or(false, |v| v == "safetensors"); .map_or(false, |v| v == "safetensors");
let (model, config) = if is_gguf { let (model, config) = if is_gguf {
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?; let vb = qmodel::VarBuilder::from_gguf(config_path)?;
let (_vocab_size, dim) = vb let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")? .get_no_shape("model.embed_tokens.weight")?
.shape() .shape()
@ -279,13 +279,13 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
(config.seq_len, config.head_size() / 2), (config.seq_len, config.head_size() / 2),
"rot.freq_cis_real", "rot.freq_cis_real",
)? )?
.dequantize(&device)?; .dequantize(&candle::Device::Cpu)?;
let freq_cis_imag = vb let freq_cis_imag = vb
.get( .get(
(config.seq_len, config.head_size() / 2), (config.seq_len, config.head_size() / 2),
"rot.freq_cis_imag", "rot.freq_cis_imag",
)? )?
.dequantize(&device)?; .dequantize(&candle::Device::Cpu)?;
let fake_vb = candle_nn::VarBuilder::from_tensors( let fake_vb = candle_nn::VarBuilder::from_tensors(
[ [
@ -295,7 +295,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.into_iter() .into_iter()
.collect(), .collect(),
candle::DType::F32, candle::DType::F32,
&device, &candle::Device::Cpu,
); );
let cache = model::Cache::new(true, &config, fake_vb)?; let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?); let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);

View File

@ -244,14 +244,13 @@ fn main() -> Result<()> {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config = Config::config_7b_v0_1(args.use_flash_attn); let config = Config::config_7b_v0_1(args.use_flash_attn);
let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized { let (model, device) = if args.quantized {
let filename = &filenames[0]; let filename = &filenames[0];
let vb = let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QMistral::new(&config, vb)?; let model = QMistral::new(&config, vb)?;
(Model::Quantized(model), device) (Model::Quantized(model), Device::Cpu)
} else { } else {
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() { let dtype = if device.is_cuda() {
DType::BF16 DType::BF16
} else { } else {

View File

@ -1,22 +0,0 @@
# candle-mobileone
[MobileOne: An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040).
This candle implementation uses a pre-trained MobileOne network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example mobileone --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which s2
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 79.33%
bicycle-built-for-two, tandem bicycle, tandem: 15.32%
crash helmet : 2.58%
unicycle, monocycle : 1.70%
alp : 0.21%
```

View File

@ -1,96 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::mobileone;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
S0,
S1,
S2,
S3,
S4,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::S0 => "s0",
Self::S1 => "s1",
Self::S2 => "s2",
Self::S3 => "s3",
Self::S4 => "s4",
};
format!("timm/mobileone_{}.apple_in1k", name)
}
fn config(&self) -> mobileone::Config {
match self {
Self::S0 => mobileone::Config::s0(),
Self::S1 => mobileone::Config::s1(),
Self::S2 => mobileone::Config::s2(),
Self::S3 => mobileone::Config::s3(),
Self::S4 => mobileone::Config::s4(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::S0)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = mobileone::mobileone(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -8,7 +8,6 @@ use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer}; use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer; use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
@ -19,7 +18,6 @@ use tokenizers::Tokenizer;
enum Model { enum Model {
MixFormer(MixFormer), MixFormer(MixFormer),
Phi(Phi),
Quantized(QMixFormer), Quantized(QMixFormer),
} }
@ -86,7 +84,6 @@ impl TextGeneration {
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = match &mut self.model { let logits = match &mut self.model {
Model::MixFormer(m) => m.forward(&input)?, Model::MixFormer(m) => m.forward(&input)?,
Model::Phi(m) => m.forward(&input)?,
Model::Quantized(m) => m.forward(&input)?, Model::Quantized(m) => m.forward(&input)?,
}; };
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
@ -120,7 +117,7 @@ impl TextGeneration {
} }
} }
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)] #[derive(Clone, Copy, Debug, ValueEnum)]
enum WhichModel { enum WhichModel {
#[value(name = "1")] #[value(name = "1")]
V1, V1,
@ -128,8 +125,6 @@ enum WhichModel {
V1_5, V1_5,
#[value(name = "2")] #[value(name = "2")]
V2, V2,
#[value(name = "2-old")]
V2Old,
PuffinPhiV2, PuffinPhiV2,
PhiHermes, PhiHermes,
} }
@ -174,7 +169,7 @@ struct Args {
#[arg(long)] #[arg(long)]
model_id: Option<String>, model_id: Option<String>,
#[arg(long, default_value = "2")] #[arg(long, default_value = "1.5")]
model: WhichModel, model: WhichModel,
#[arg(long)] #[arg(long)]
@ -235,7 +230,7 @@ fn main() -> Result<()> {
match args.model { match args.model {
WhichModel::V1 => "microsoft/phi-1".to_string(), WhichModel::V1 => "microsoft/phi-1".to_string(),
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(), WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(), WhichModel::V2 => "microsoft/phi-2".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string() "lmz/candle-quantized-phi".to_string()
} }
@ -250,9 +245,8 @@ fn main() -> Result<()> {
"main".to_string() "main".to_string()
} else { } else {
match args.model { match args.model {
WhichModel::V1 => "refs/pr/8".to_string(), WhichModel::V1 => "refs/pr/2".to_string(),
WhichModel::V1_5 => "refs/pr/73".to_string(), WhichModel::V1_5 => "refs/pr/18".to_string(),
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"main".to_string() "main".to_string()
} }
@ -264,9 +258,7 @@ fn main() -> Result<()> {
let tokenizer_filename = match args.tokenizer { let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file), Some(file) => std::path::PathBuf::from(file),
None => match args.model { None => match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old => { WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?,
repo.get("tokenizer.json")?
}
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")? repo.get("tokenizer-puffin-phi-v2.json")?
} }
@ -279,14 +271,14 @@ fn main() -> Result<()> {
match args.model { match args.model {
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?], WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?], WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?], WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?],
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?], WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?], WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
} }
} else { } else {
match args.model { match args.model {
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?], WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
WhichModel::V2 | WhichModel::V2Old => candle_examples::hub_load_safetensors( WhichModel::V2 => candle_examples::hub_load_safetensors(
&repo, &repo,
"model.safetensors.index.json", "model.safetensors.index.json",
)?, )?,
@ -300,44 +292,28 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config = || match args.model { let config = match args.model {
WhichModel::V1 => Config::v1(), WhichModel::V1 => Config::v1(),
WhichModel::V1_5 => Config::v1_5(), WhichModel::V1_5 => Config::v1_5(),
WhichModel::V2 | WhichModel::V2Old => Config::v2(), WhichModel::V2 => Config::v2(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
}; };
let device = candle_examples::device(args.cpu)?; let (model, device) = if args.quantized {
let model = if args.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
let config = config();
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&filenames[0],
&device,
)?;
let model = match args.model { let model = match args.model {
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?, WhichModel::V2 => QMixFormer::new_v2(&config, vb)?,
_ => QMixFormer::new(&config, vb)?, _ => QMixFormer::new(&config, vb)?,
}; };
Model::Quantized(model) (Model::Quantized(model), Device::Cpu)
} else { } else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
match args.model { let model = match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => { WhichModel::V2 => MixFormer::new_v2(&config, vb)?,
let config_filename = repo.get("config.json")?; _ => MixFormer::new(&config, vb)?,
let config = std::fs::read_to_string(config_filename)?; };
let config: PhiConfig = serde_json::from_str(&config)?; (Model::MixFormer(model), device)
let phi = Phi::new(&config, vb)?;
Model::Phi(phi)
}
WhichModel::V2Old => {
let config = config();
Model::MixFormer(MixFormer::new_v2(&config, vb)?)
}
WhichModel::PhiHermes | WhichModel::PuffinPhiV2 => {
let config = config();
Model::MixFormer(MixFormer::new(&config, vb)?)
}
}
}; };
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());
@ -417,10 +393,6 @@ fn mmlu<P: AsRef<std::path::Path>>(
m.clear_kv_cache(); m.clear_kv_cache();
m.forward(&input)? m.forward(&input)?
} }
Model::Phi(m) => {
m.clear_kv_cache();
m.forward(&input)?
}
Model::Quantized(m) => { Model::Quantized(m) => {
m.clear_kv_cache(); m.clear_kv_cache();
m.forward(&input)? m.forward(&input)?

View File

@ -132,8 +132,7 @@ impl T5ModelBuilder {
} }
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> { pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
let device = Device::Cpu; let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &device)?;
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
} }

View File

@ -9,7 +9,7 @@ use std::io::Write;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file}; use candle::quantized::{ggml_file, gguf_file};
use candle::Tensor; use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor; use candle_transformers::generation::LogitsProcessor;
use candle_examples::token_output_stream::TokenOutputStream; use candle_examples::token_output_stream::TokenOutputStream;
@ -361,7 +361,6 @@ fn main() -> anyhow::Result<()> {
let model_path = args.model()?; let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?; let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let device = candle_examples::device(false)?;
let mut model = match model_path.extension().and_then(|v| v.to_str()) { let mut model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => { Some("gguf") => {
@ -370,7 +369,7 @@ fn main() -> anyhow::Result<()> {
for (_, tensor) in model.tensor_infos.iter() { for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count(); let elem_count = tensor.shape.elem_count();
total_size_in_bytes += total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
} }
println!( println!(
"loaded {:?} tensors ({}) in {:.2}s", "loaded {:?} tensors ({}) in {:.2}s",
@ -378,16 +377,15 @@ fn main() -> anyhow::Result<()> {
&format_size(total_size_in_bytes), &format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(), start.elapsed().as_secs_f32(),
); );
ModelWeights::from_gguf(model, &mut file, &device)? ModelWeights::from_gguf(model, &mut file)?
} }
Some("ggml" | "bin") | Some(_) | None => { Some("ggml" | "bin") | Some(_) | None => {
let model = ggml_file::Content::read(&mut file, &device) let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
.map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0; let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() { for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count(); let elem_count = tensor.shape().elem_count();
total_size_in_bytes += total_size_in_bytes +=
elem_count * tensor.dtype().type_size() / tensor.dtype().block_size(); elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
} }
println!( println!(
"loaded {:?} tensors ({}) in {:.2}s", "loaded {:?} tensors ({}) in {:.2}s",
@ -488,7 +486,7 @@ fn main() -> anyhow::Result<()> {
let start_prompt_processing = std::time::Instant::now(); let start_prompt_processing = std::time::Instant::now();
let mut next_token = { let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?; let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)? logits_processor.sample(&logits)?
@ -509,7 +507,7 @@ fn main() -> anyhow::Result<()> {
let start_post_prompt = std::time::Instant::now(); let start_post_prompt = std::time::Instant::now();
let mut sampled = 0; let mut sampled = 0;
for index in 0..to_sample { for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?; let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. { let logits = if args.repeat_penalty == 1. {

View File

@ -236,15 +236,16 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let config = Config::replit_code_v1_5_3b(); let config = Config::replit_code_v1_5_3b();
let model = if args.quantized { let (model, device) = if args.quantized {
let vb = let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?; let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
Model::Q(Q::new(&config, vb.pp("transformer"))?) (model, Device::Cpu)
} else { } else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
Model::M(M::new(&config, vb.pp("transformer"))?) let model = Model::M(M::new(&config, vb.pp("transformer"))?);
(model, device)
}; };
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());

View File

@ -1,22 +0,0 @@
# candle-repvgg
[RepVGG: Making VGG-style ConvNets Great Again](https://arxiv.org/abs/2101.03697).
This candle implementation uses a pre-trained RepVGG network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example repvgg --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 61.70%
bicycle-built-for-two, tandem bicycle, tandem: 33.14%
unicycle, monocycle : 4.88%
crash helmet : 0.15%
moped : 0.04%
```

View File

@ -1,111 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::repvgg;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
A0,
A1,
A2,
B0,
B1,
B2,
B3,
B1G4,
B2G4,
B3G4,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::A0 => "a0",
Self::A1 => "a1",
Self::A2 => "a2",
Self::B0 => "b0",
Self::B1 => "b1",
Self::B2 => "b2",
Self::B3 => "b3",
Self::B1G4 => "b1g4",
Self::B2G4 => "b2g4",
Self::B3G4 => "b3g4",
};
format!("timm/repvgg_{}.rvgg_in1k", name)
}
fn config(&self) -> repvgg::Config {
match self {
Self::A0 => repvgg::Config::a0(),
Self::A1 => repvgg::Config::a1(),
Self::A2 => repvgg::Config::a2(),
Self::B0 => repvgg::Config::b0(),
Self::B1 => repvgg::Config::b1(),
Self::B2 => repvgg::Config::b2(),
Self::B3 => repvgg::Config::b3(),
Self::B1G4 => repvgg::Config::b1g4(),
Self::B2G4 => repvgg::Config::b2g4(),
Self::B3G4 => repvgg::Config::b3g4(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::A0)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = repvgg::repvgg(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -234,14 +234,13 @@ fn main() -> Result<()> {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config = Config::stablelm_3b_4e1t(args.use_flash_attn); let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized { let (model, device) = if args.quantized {
let filename = &filenames[0]; let filename = &filenames[0];
let vb = let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QStableLM::new(&config, vb)?; let model = QStableLM::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu) (Model::Quantized(model), Device::Cpu)
} else { } else {
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() { let dtype = if device.is_cuda() {
DType::BF16 DType::BF16
} else { } else {

View File

@ -557,10 +557,8 @@ fn main() -> Result<()> {
println!("loaded mel: {:?}", mel.dims()); println!("loaded mel: {:?}", mel.dims());
let mut model = if args.quantized { let mut model = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( let vb =
&weights_filename, candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
&device,
)?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?) Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else { } else {
let vb = let vb =

View File

@ -9,7 +9,6 @@ keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"] categories = ["science"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
metal = { version = "0.27.0", features = ["mps"]} metal = { version = "0.27.0", features = ["mps"]}
once_cell = "1.18.0" once_cell = "1.18.0"
@ -17,9 +16,5 @@ thiserror = "1"
tracing = "0.1.37" tracing = "0.1.37"
[dev-dependencies] [dev-dependencies]
half = { version = "2.3.1", features = [ half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
"num-traits",
"use-intrinsics",
"rand_distr",
] }
rand = "0.8.5" rand = "0.8.5"

View File

@ -17,19 +17,19 @@ METAL_FUNC uint get_strided_index(
using namespace metal; using namespace metal;
#define AFFINE(FN_NAME, T) \ #define AFFINE(FN_NAME, TYPENAME) \
kernel void FN_NAME( \ kernel void FN_NAME( \
constant size_t &dim, \ constant size_t &dim, \
constant float &mul, \ constant float &mul, \
constant float &add, \ constant float &add, \
device const T *input, \ device const TYPENAME *input, \
device T *output, \ device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \ uint id [[ thread_position_in_grid ]] \
) { \ ) { \
if (id >= dim) { \ if (id >= dim) { \
return; \ return; \
} \ } \
output[id] = T(fma(float(input[id]), mul, add)); \ output[id] = TYPENAME(float(input[id]) * mul + add); \
} \ } \
kernel void FN_NAME##_strided( \ kernel void FN_NAME##_strided( \
constant size_t &dim, \ constant size_t &dim, \
@ -38,14 +38,14 @@ kernel void FN_NAME##_strided( \
constant size_t *strides, \ constant size_t *strides, \
constant float &mul, \ constant float &mul, \
constant float &add, \ constant float &add, \
device const T *input, \ device const TYPENAME *input, \
device T *output, \ device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \ uint id [[ thread_position_in_grid ]] \
) { \ ) { \
if (id >= dim) { \ if (id >= dim) { \
return; \ return; \
} \ } \
output[id] = T(fma(float(input[get_strided_index(id, num_dims, dims, strides)]), mul, add)); \ output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \
} }
#define POWF(FN_NAME, TYPENAME) \ #define POWF(FN_NAME, TYPENAME) \
@ -117,7 +117,7 @@ ELU(elu_f32, float)
ELU(elu_f16, half) ELU(elu_f16, half)
#if defined(__HAVE_BFLOAT__) #if __METAL_VERSION__ >= 310
AFFINE(affine_bf16, bfloat); AFFINE(affine_bf16, bfloat);
POWF(powf_bf16, bfloat); POWF(powf_bf16, bfloat);
ELU(elu_bf16, bfloat); ELU(elu_bf16, bfloat);

View File

@ -73,7 +73,7 @@ BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
#define INT64_BINARY_OP_OUT(NAME, FN) \ #define INT64_BINARY_OP_OUT(NAME, FN) \
BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided); BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided);
BINARY_OP(x + y, add) BINARY_OP(x + y, add)
BINARY_OP(x - y, sub) BINARY_OP(x - y, sub)
@ -105,7 +105,7 @@ INT64_BINARY_OP_OUT(ge, x >= y)
INT64_BINARY_OP_OUT(gt, x > y) INT64_BINARY_OP_OUT(gt, x > y)
#endif #endif
#if defined(__HAVE_BFLOAT__) #if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add) BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub) BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul) BFLOAT_BINARY_OP(x * y, mul)

View File

@ -28,7 +28,7 @@ kernel void FN_NAME( \
if (tid >= dim) { \ if (tid >= dim) { \
return; \ return; \
} \ } \
output[tid] = static_cast<RIGHT_TYPENAME>(input[tid]); \ output[tid] = RIGHT_TYPENAME(input[tid]); \
} \ } \
kernel void FN_NAME_STRIDED( \ kernel void FN_NAME_STRIDED( \
constant size_t &dim, \ constant size_t &dim, \
@ -42,34 +42,7 @@ kernel void FN_NAME_STRIDED( \
if (tid >= dim) { \ if (tid >= dim) { \
return; \ return; \
} \ } \
output[tid] = static_cast<RIGHT_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)]); \ output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
} \
#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[tid])); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
} \ } \
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
@ -85,14 +58,7 @@ CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
#endif #endif
#if defined(__HAVE_BFLOAT__) #if __METAL_VERSION__ >= 310
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
#endif #endif

View File

@ -173,10 +173,7 @@ SCATTER_ADD_OP(sa_u32_f32, uint, float)
SCATTER_ADD_OP(sa_u32_f16, uint, half) SCATTER_ADD_OP(sa_u32_f16, uint, half)
#if defined(__HAVE_BFLOAT__) #if __METAL_VERSION__ >= 310
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)

View File

@ -12,11 +12,9 @@ const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal"); const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal"); const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal"); const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
const REDUCE: &str = include_str!("reduce.metal"); const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal"); const CONV: &str = include_str!("conv.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
/// Most kernels apply similarly across the tensors /// Most kernels apply similarly across the tensors
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the /// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
@ -63,12 +61,8 @@ macro_rules! primitive {
} }
}; };
} }
primitive!(bool);
primitive!(usize); primitive!(usize);
primitive!(i32);
primitive!(i64);
primitive!(u32); primitive!(u32);
primitive!(u64);
primitive!(f32); primitive!(f32);
impl<T> EncoderParam for &[T] { impl<T> EncoderParam for &[T] {
@ -123,8 +117,6 @@ pub enum Source {
Reduce, Reduce,
Mfa, Mfa,
Conv, Conv,
Random,
Quantized,
} }
macro_rules! ops{ macro_rules! ops{
@ -182,8 +174,8 @@ macro_rules! ops{
pub mod unary { pub mod unary {
ops!( ops!(
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf, cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
tanh, recip recip
); );
} }
pub mod binary { pub mod binary {
@ -223,15 +215,17 @@ type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipeline
pub struct Kernels { pub struct Kernels {
libraries: RwLock<Libraries>, libraries: RwLock<Libraries>,
pipelines: RwLock<Pipelines>, pipelines: RwLock<Pipelines>,
fence: metal::Fence,
} }
impl Kernels { impl Kernels {
pub fn new() -> Self { pub fn new(fence: metal::Fence) -> Self {
let libraries = RwLock::new(Libraries::new()); let libraries = RwLock::new(Libraries::new());
let pipelines = RwLock::new(Pipelines::new()); let pipelines = RwLock::new(Pipelines::new());
Self { Self {
libraries, libraries,
pipelines, pipelines,
fence,
} }
} }
@ -245,8 +239,6 @@ impl Kernels {
Source::Cast => CAST, Source::Cast => CAST,
Source::Reduce => REDUCE, Source::Reduce => REDUCE,
Source::Conv => CONV, Source::Conv => CONV,
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
Source::Mfa => panic!("Invalid lib"), Source::Mfa => panic!("Invalid lib"),
} }
} }
@ -353,6 +345,7 @@ pub fn call_unary_contiguous(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, output)); set_params!(encoder, (length, input, output));
@ -361,6 +354,7 @@ pub fn call_unary_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -382,6 +376,7 @@ pub fn call_unary_strided(
let num_dims: usize = shape.len(); let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
@ -403,6 +398,7 @@ pub fn call_unary_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -421,6 +417,7 @@ pub fn call_binary_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, left, right, output)); set_params!(encoder, (length, left, right, output));
@ -431,6 +428,7 @@ pub fn call_binary_contiguous(
encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -455,6 +453,7 @@ pub fn call_binary_strided(
let num_dims: usize = shape.len(); let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
let width: usize = shape.iter().product(); let width: usize = shape.iter().product();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
@ -479,6 +478,7 @@ pub fn call_binary_strided(
encoder.use_resource(right_input, metal::MTLResourceUsage::Read); encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -497,6 +497,7 @@ pub fn call_cast_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, (input, input_offset), output)); set_params!(encoder, (length, (input, input_offset), output));
@ -505,6 +506,7 @@ pub fn call_cast_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -524,6 +526,7 @@ pub fn call_cast_strided(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
@ -545,6 +548,7 @@ pub fn call_cast_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -564,6 +568,7 @@ pub fn call_reduce_contiguous(
let elements_to_sum = length / out_length; let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -592,6 +597,7 @@ pub fn call_reduce_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -613,6 +619,7 @@ pub fn call_reduce_strided(
let elements_to_sum = length / out_length; let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -623,8 +630,7 @@ pub fn call_reduce_strided(
strides, strides,
elements_to_sum, elements_to_sum,
(input, input_offset), (input, input_offset),
output, output
out_length
) )
); );
@ -649,6 +655,7 @@ pub fn call_reduce_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -667,6 +674,7 @@ pub fn call_last_softmax(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -697,6 +705,7 @@ pub fn call_last_softmax(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -716,6 +725,7 @@ pub fn call_affine(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, add, input, output)); set_params!(encoder, (size, mul, add, input, output));
@ -724,6 +734,7 @@ pub fn call_affine(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -746,6 +757,7 @@ pub fn call_affine_strided(
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -766,6 +778,7 @@ pub fn call_affine_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -784,6 +797,7 @@ pub fn call_powf(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output)); set_params!(encoder, (size, mul, input, output));
@ -792,6 +806,7 @@ pub fn call_powf(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -813,6 +828,7 @@ pub fn call_powf_strided(
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -832,6 +848,7 @@ pub fn call_powf_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -850,6 +867,7 @@ pub fn call_elu(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output)); set_params!(encoder, (size, mul, input, output));
@ -858,6 +876,7 @@ pub fn call_elu(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -879,6 +898,7 @@ pub fn call_elu_strided(
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -898,6 +918,7 @@ pub fn call_elu_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -919,6 +940,7 @@ pub fn call_where_cond_strided(
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
@ -947,6 +969,7 @@ pub fn call_where_cond_strided(
encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -973,6 +996,7 @@ pub fn call_index_select(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -995,6 +1019,7 @@ pub fn call_index_select(
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1023,6 +1048,7 @@ pub fn call_gather(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1045,6 +1071,7 @@ pub fn call_gather(
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1073,6 +1100,7 @@ pub fn call_scatter_add(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1095,6 +1123,7 @@ pub fn call_scatter_add(
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1124,6 +1153,7 @@ pub fn call_index_add(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1147,6 +1177,7 @@ pub fn call_index_add(
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1350,6 +1381,7 @@ pub fn call_gemm(
let block_bytes = block_elements * bytes; let block_bytes = block_elements * bytes;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
@ -1389,10 +1421,12 @@ pub fn call_gemm(
height: 1, height: 1,
depth: 1, depth: 1,
}; };
// println!("grid size {grid_size:?} group size {group_size:?}");
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size); encoder.dispatch_thread_groups(grid_size, group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -1417,6 +1451,7 @@ pub fn call_im2col1d_strided(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -1436,6 +1471,7 @@ pub fn call_im2col1d_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -1465,6 +1501,7 @@ pub fn call_im2col_strided(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -1486,6 +1523,7 @@ pub fn call_im2col_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -1511,6 +1549,7 @@ pub fn call_upsample_nearest_2d(
let scale_h = shape[3] as f32 / out_h as f32; let scale_h = shape[3] as f32 / out_h as f32;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -1528,243 +1567,7 @@ pub fn call_upsample_nearest_2d(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.update_fence(&kernels.fence);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_random_uniform(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
min: f32,
max: f32,
length: usize,
seed: &Buffer,
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
if min >= max {
return Err(MetalKernelError::LoadLibraryError(
"min must be less than max".to_string(),
));
}
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, min, max, seed, buffer));
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_random_normal(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
mean: f32,
stddev: f32,
length: usize,
seed: &Buffer,
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, mean, stddev, seed, buffer));
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[derive(Debug, Clone, Copy)]
pub enum GgmlDType {
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2K,
Q3K,
Q4K,
Q5K,
Q6K,
Q8K,
F16,
F32,
}
pub fn call_quantized_matmul_t(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
dtype: GgmlDType,
(b, m, n, k): (usize, usize, usize, usize),
lhs: &Buffer,
lhs_offset: usize,
rhs: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
// Everything is in reverse
let ne00 = k as i64;
let ne01 = n as i64;
let ne02 = b as i64;
let ne03 = 1 as i64;
let nb00 = 0i64;
let nb01 = 0 as i64;
let nb02 = 0 as i64;
let ne10 = k as i64;
let ne11 = m as i64;
let ne12 = b as i64;
let ne13 = 1 as i64;
let nb10 = 0i64;
let nb11 = 0i64;
let nb12 = 0i64;
let ne0 = n as i64;
let ne1 = m as i64;
let r2: u32 = (ne12 / ne02) as u32;
let r3: u32 = (ne13 / ne03) as u32;
let (nth0, nth1, align) = match dtype {
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q8_1 => {
let nth0 = 8;
let nth1 = 8;
let align = 8;
(nth0, nth1, align)
}
GgmlDType::Q2K => {
// Fixing a bug in Metal for GGML
let nth0 = 4;
let nth1 = 8;
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q4K => {
let nth0 = 4;
let nth1 = 8;
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q3K | GgmlDType::Q5K => {
let nth0 = 2;
let nth1 = 32;
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q6K => {
let nth0 = 2;
let nth1 = 32;
let align = 2;
(nth0, nth1, align)
}
GgmlDType::F16 | GgmlDType::Q8K => {
// Original implem uses rows
let nth0 = 32;
let nth1 = 1;
let align = 8;
(nth0, nth1, align)
}
GgmlDType::F32 => {
let nth0 = 32;
let nth1 = 1;
let align = 8;
(nth0, nth1, align)
}
};
let thread_groups_count = MTLSize {
width: divide(ne01 as usize, align),
height: ne11 as u64,
depth: (ne12 * ne13) as u64,
};
let threads_per_threadgroup = MTLSize {
width: nth0,
height: nth1,
depth: 1,
};
let name = match dtype {
GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32",
GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32",
GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32",
GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32",
GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32",
GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32",
GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32",
GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32",
GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32",
GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32",
GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32",
GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32",
GgmlDType::F16 => "kernel_mul_mv_f16_f32",
GgmlDType::F32 => "kernel_mul_mv_f32_f32",
};
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
rhs,
(lhs, lhs_offset),
output,
ne00,
ne01,
ne02,
nb00,
nb01,
nb02,
ne10,
ne11,
ne12,
nb10,
nb11,
nb12,
ne0,
ne1,
r2,
r3
)
);
encoder.set_threadgroup_memory_length(0, 8192);
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())

File diff suppressed because it is too large Load Diff

View File

@ -1,206 +0,0 @@
#include <metal_stdlib>
#include <metal_integer>
#include <metal_atomic>
using namespace metal;
// Constants
// 2^32 and 1/2^32. Useful for converting between float and uint.
static constexpr constant ulong UNIF01_NORM32 = 4294967296;
static constexpr constant float UNIF01_INV32 = 2.328306436538696289e-10;
// 2 * pi
static constexpr constant float TWO_PI = 2.0 * M_PI_F;
static constexpr constant int3 S1 = {13, 19, 12};
static constexpr constant int3 S2 = {2, 25, 4};
static constexpr constant int3 S3 = {3, 11, 17};
// Used to prevent bad seeds.
static constexpr constant uint64_t PHI[16] = {
0x9E3779B97F4A7C15,
0xF39CC0605CEDC834,
0x1082276BF3A27251,
0xF86C6A11D0C18E95,
0x2767F0B153D27B7F,
0x0347045B5BF1827F,
0x01886F0928403002,
0xC1D64BA40F335E36,
0xF06AD7AE9717877E,
0x85839D6EFFBD7DC6,
0x64D325D1C5371682,
0xCADD0CCCFDFFBBE1,
0x626E33B8D04B4331,
0xBBF73C790D94F79D,
0x471C4AB3ED3D82A5,
0xFEC507705E4AE6E5,
};
// Combined Tausworthe and LCG Random Number Generator.
// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application
// https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf
struct HybridTaus {
float state;
HybridTaus() thread = default;
HybridTaus() threadgroup = default;
HybridTaus() device = default;
HybridTaus() constant = default;
// Generate seeds for each thread.
METAL_FUNC static uint4 seed_per_thread(const ulong4 seeds) {
return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL));
}
// Tausworthe generator.
METAL_FUNC static uint taus(const uint z, const int3 s, const uint M) {
uint b = (((z << s.x) ^ z) >> s.y);
return (((z & M) << s.z) ^ b);
}
// LCG generator.
METAL_FUNC static uint lcg(const uint z) {
return (1664525 * z + 1013904223UL);
}
// Initialize the RNG state.
METAL_FUNC static HybridTaus init(const ulong4 seeds) {
uint4 seed = seed_per_thread(seeds);
// Seed #1
uint z1 = taus(seed.x, S1, 4294967294UL);
uint z2 = taus(seed.y, S2, 4294967288UL);
uint z3 = taus(seed.z, S3, 4294967280UL);
uint z4 = lcg(seed.x);
// Seed #2
uint r1 = (z1^z2^z3^z4^seed.y);
z1 = taus(r1, S1, 429496729UL);
z2 = taus(r1, S2, 4294967288UL);
z3 = taus(r1, S3, 429496280UL);
z4 = lcg(r1);
// Seed #3
r1 = (z1^z2^z3^z4^seed.z);
z1 = taus(r1, S1, 429496729UL);
z2 = taus(r1, S2, 4294967288UL);
z3 = taus(r1, S3, 429496280UL);
z4 = lcg(r1);
// Seed #4
r1 = (z1^z2^z3^z4^seed.w);
z1 = taus(r1, S1, 429496729UL);
z2 = taus(r1, S2, 4294967288UL);
z3 = taus(r1, S3, 429496280UL);
z4 = lcg(r1);
HybridTaus rng;
rng.state = (z1^z2^z3^z4) * UNIF01_INV32;
return rng;
}
METAL_FUNC float rand() {
uint seed = this->state * UNIF01_NORM32;
uint z1 = taus(seed, S1, 429496729UL);
uint z2 = taus(seed, S2, 4294967288UL);
uint z3 = taus(seed, S3, 429496280UL);
uint z4 = lcg(seed);
thread float result = this->state;
this->state = (z1^z2^z3^z4) * UNIF01_INV32;
return result;
}
};
template<typename T> METAL_FUNC void rand_uniform(
constant size_t &size,
constant float &min,
constant float &max,
device atomic_uint *seed,
device T *out,
uint tid [[thread_position_in_grid]]
) {
if (tid >= size) {
return;
}
float diff = abs(min - max);
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
out[tid] = static_cast<T>(rng.rand() * diff + min);
if (tid == 0) {
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
// Return early if tid == 0, otherwise we will write to out[size].
return;
}
// Use symmetry to fill the other half of the array.
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
}
// Create Gaussian normal distribution using Box-Muller transform:
// https://en.wikipedia.org/wiki/BoxMuller_transform
template<typename T> METAL_FUNC void normal(
constant size_t &size,
constant float &mean,
constant float &stddev,
device atomic_uint *seed,
device T *out,
uint tid [[thread_position_in_grid]]
) {
if (tid >= size) {
return;
}
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
float u1 = rng.rand();
float u2 = rng.rand();
float cosval;
float sinval = sincos(TWO_PI * u2, cosval);
float mag = stddev * sqrt(-2.0 * log(u1));
float z0 = mag * cosval + mean;
float z1 = mag * sinval + mean;
out[tid] = static_cast<T>(z0);
if (tid == 0) {
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
// Return early if tid == 0, otherwise we will write to out[size].
return;
}
// Use symmetry to fill the other half of the array.
out[size - tid] = static_cast<T>(z1);
}
#define UNIFORM_OP(NAME, T) \
kernel void rand_uniform_##NAME( \
constant size_t &size, \
constant float &min, \
constant float &max, \
device atomic_uint *seed, \
device T *out, \
uint tid [[thread_position_in_grid]] \
) { \
rand_uniform<T>(size, min, max, seed, out, tid); \
} \
#define NORMAL_OP(NAME, T) \
kernel void rand_normal_##NAME( \
constant size_t &size, \
constant float &mean, \
constant float &stddev, \
device atomic_uint *seed, \
device T *out, \
uint tid [[thread_position_in_grid]] \
) { \
normal<T>(size, mean, stddev, seed, out, tid); \
} \
#define RANDOM_OPS(NAME, T) \
UNIFORM_OP(NAME, T) \
NORMAL_OP(NAME, T) \
RANDOM_OPS(f32, float)
RANDOM_OPS(f16, half)
#if __METAL_VERSION__ >= 310
RANDOM_OPS(bf16, bfloat)
#endif

View File

@ -1,18 +1,16 @@
#include <metal_stdlib> #include <metal_stdlib>
#include <metal_limits>
using namespace metal; using namespace metal;
// TODO: Load multiple values per thread to improve memory bandwidth utilization #define MAX(x, y) ((x) > (y) ? (x) : (y))
// static constant constexpr uint VALUES_PER_THREAD = 1; #define MIN(x, y) ((x) < (y) ? (x) : (y))
METAL_FUNC uint get_strided_index( METAL_FUNC uint get_strided_index(
uint idx, uint idx,
constant const size_t &num_dims, constant size_t &num_dims,
constant const size_t *dims, constant size_t *dims,
constant const size_t *strides constant size_t *strides
) { ) {
uint strided_i = 0; uint strided_i = 0;
#pragma clang loop unroll(full)
for (uint d = 0; d < num_dims; d++) { for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d; uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
@ -21,637 +19,288 @@ METAL_FUNC uint get_strided_index(
return strided_i; return strided_i;
} }
template <typename V> constant int THREADGROUP_SIZE = 2048;
struct Indexed {
uint i;
V val;
typedef V type;
constexpr Indexed<V>() thread = default;
constexpr Indexed<V>() threadgroup = default;
constexpr Indexed<V>() device = default;
constexpr Indexed<V>() constant = default;
constexpr Indexed<V>(uint _i, V _val) : i(_i), val(_val) {}
template <typename U, typename = typename enable_if<is_convertible_v<U, V>>::type>
constexpr Indexed<V>(uint _i, U _val) : i(_i), val(static_cast<U>(_val)) {}
template <typename U>
constexpr Indexed<V>(const thread Indexed<U> &iv): Indexed<V>(iv.i, iv.val) {}
template <typename U>
constexpr Indexed<V>(const threadgroup Indexed<V> &iv): Indexed<V>(iv.i, iv.val) {}
Indexed<V> operator=(const thread Indexed<V> &iv) thread {
this->i = iv.i;
this->val = iv.val;
return *this;
}
Indexed<V> operator=(const thread Indexed<V> &iv) threadgroup {
this->i = iv.i;
this->val = iv.val;
return *this;
}
};
template<typename V>
constexpr METAL_FUNC bool operator<(Indexed<V> lhs, Indexed<V> rhs) {
return lhs.val < rhs.val || (lhs.val == rhs.val && lhs.i < rhs.i);
}
template<typename V>
constexpr METAL_FUNC bool operator>(Indexed<V> lhs, Indexed<V> rhs) {
return lhs.val > rhs.val || (lhs.val == rhs.val && lhs.i > rhs.i);
}
template<typename T>
struct _numeric_limits_impl<Indexed<T>> {
static constexpr Indexed<T> lowest() {
return Indexed<T>(0, numeric_limits<T>::lowest());
}
static constexpr Indexed<T> max() {
return Indexed<T>(0, numeric_limits<T>::max());
}
};
#if defined(__HAVE_BFLOAT__)
// Metal does not have simd_shuffle_down for bfloat16
// TODO: Check if volatile threadgroup memory reduction is faster than simd_shuffle_down for bfloat
bfloat simd_shuffle_down(bfloat value, ushort delta) {
return static_cast<bfloat>(__metal_simd_shuffle_down(static_cast<float>(value), delta));
}
#endif
template <typename V>
Indexed<V> simd_shuffle_down(Indexed<V> iv, ushort delta) {
return Indexed<V>(
simd_shuffle_down(iv.i, delta),
simd_shuffle_down(iv.val, delta)
);
}
#define impl_reduction_op_helper(name, op, init_val, __result_type__) \
template<typename T, typename R = __result_type__> \
struct name { \
static constexpr T init() { \
return init_val; \
} \
METAL_FUNC R operator()(T a, T b) { \
return op; \
} \
METAL_FUNC R operator()(thread const T& a, thread const T& b) const { \
return op; \
} \
METAL_FUNC R operator()(threadgroup const T& a, threadgroup const T& b) const { \
return op; \
} \
} \
#define impl_reduction_op(name, op, init_val) \
impl_reduction_op_helper(name, op, init_val, T);
#define impl_arg_reduction_op(name, op, init_val) \
impl_reduction_op_helper(name, op, init_val, tuple<bool, Indexed<T>>);
impl_reduction_op(Sum, a + b, 0);
impl_reduction_op(Mul, a * b, 1);
impl_reduction_op(Min, a < b ? a : b, numeric_limits<T>::max());
impl_reduction_op(Max, a > b ? a : b, numeric_limits<T>::lowest());
#undef impl_reduction_op
// These are used when loading elements from global memory into shared memory.
// They let us use the same code for both indexed and non-indexed types.
template<typename Op, typename T, typename U>
METAL_FUNC T apply_operator(Op op, size_t _idx, T a, U b) {
return op(a, static_cast<T>(b));
}
template<typename Op, typename T, typename U>
METAL_FUNC Indexed<T> apply_operator(Op op, size_t idx, Indexed<T> a, U b) {
return op(a, Indexed<T>(idx, b));
}
// Load elements from global memory into shared memory.
// Handles both indexed and non-indexed types by using apply_operator.
template<
typename T,
typename R,
typename ReductionOp,
ushort BLOCKSIZE,
bool STRIDED = false
>
METAL_FUNC R load_from_global(
R value,
constant size_t &num_elements,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides,
constant size_t &el_to_sum_per_block,
const device T *src,
const ushort offset,
threadgroup R shared[BLOCKSIZE],
const ushort tid
) {
ReductionOp op;
size_t stop_idx = offset + el_to_sum_per_block;
size_t idx = offset + tid;
while (idx < stop_idx) {
if (STRIDED) {
idx = get_strided_index(idx, num_dims, dims, strides);
}
value = apply_operator(op, idx, value, src[idx]);
idx += BLOCKSIZE;
}
return value;
}
// Convenience function for when we don't need to sum over multiple dimensions. #define ARGMIN(NAME, T, MAXVALUE) \
template<
typename T,
typename R,
typename ReductionOp,
ushort BLOCKSIZE
>
METAL_FUNC R load_from_global(
R value,
constant size_t &num_elements,
constant size_t &el_to_sum_per_block,
const device T *src,
const size_t offset,
threadgroup R shared[BLOCKSIZE],
const ushort tid
) {
return load_from_global<T, R, ReductionOp, BLOCKSIZE, false>(
value,
num_elements,
// Dummy values for num_dims, dims, and strides
num_elements,
nullptr,
nullptr,
// end dummy values
el_to_sum_per_block,
src,
offset,
shared,
tid
);
}
// Since we are using simd_shuffle_down with a BLOCKSIZE guard we don't need any barriers.
template<typename ReductionOp, ushort BLOCKSIZE, typename T>
METAL_FUNC T simdgroup_reduce(T value) {
ReductionOp op;
if (BLOCKSIZE >= 32) value = op(value, simd_shuffle_down(value, 16));
if (BLOCKSIZE >= 16) value = op(value, simd_shuffle_down(value, 8));
if (BLOCKSIZE >= 8) value = op(value, simd_shuffle_down(value, 4));
if (BLOCKSIZE >= 4) value = op(value, simd_shuffle_down(value, 2));
if (BLOCKSIZE >= 2) value = op(value, simd_shuffle_down(value, 1));
return value;
}
template<
typename ReductionOp,
ushort BLOCKSIZE,
typename T
>
METAL_FUNC T threadgroup_reduce(
threadgroup T shared[BLOCKSIZE],
ushort tid [[ thread_index_in_threadgroup ]]
) {
ReductionOp op;
// Fully unrolled reduction loop from BLOCKSIZE down to 64.
#pragma clang loop unroll(full)
for (uint s = BLOCKSIZE / 2; s >= 64; s >>= 1) {
if (tid < s) {
shared[tid] = op(shared[tid], shared[tid + s]);
}
threadgroup_barrier(mem_flags::mem_none);
}
if (tid < 32) {
// Last shared memory reduce can be done without tid < s check.
if (BLOCKSIZE >= 64) {
shared[tid] = op(shared[tid], shared[tid + 32]);
simdgroup_barrier(mem_flags::mem_none);
}
// Remaining 32 threads can be reduced with simdgroup_reduce.
shared[tid] = simdgroup_reduce<ReductionOp, BLOCKSIZE>(shared[tid]);
}
return shared[tid];
}
// Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris
template<
typename T,
typename R,
typename ReductionOp,
ushort BLOCKSIZE,
bool STRIDED = false
>
METAL_FUNC void reduce(
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides,
constant size_t &el_to_sum_per_block,
device const T *src,
device R *dst,
constant size_t &num_elements,
threadgroup T shared[BLOCKSIZE],
ushort tid [[ thread_index_in_threadgroup ]],
ushort dst_id [[ threadgroup_position_in_grid ]]
) {
// Initialize shared memory for current thread to correct value for reduction operation
shared[tid] = ReductionOp::init();
// Calcluate offset for the threadgroup of current thread
ushort offset = dst_id * el_to_sum_per_block;
R initial = ReductionOp::init();
// Load with reduction from global memory into shared memory
shared[tid] = load_from_global<T, R, ReductionOp, BLOCKSIZE, STRIDED>(
initial,
num_elements,
num_dims,
dims,
strides,
el_to_sum_per_block,
src,
offset,
shared,
tid
);
// Threadgroup barrier is needed to ensure that all threads have written to shared memory
// Memory space is not shared between threadgroups so we can use the mem_none flag for all threadgroup barriers.
threadgroup_barrier(mem_flags::mem_none);
// Complete reduction
R value = threadgroup_reduce<ReductionOp, BLOCKSIZE>(shared, tid);
if (tid == 0) dst[dst_id] = value;
}
#define reduce_case(OP, T, R, N) \
case N: { \
threadgroup R shared[N]; \
reduce<T, R, OP<R>, N, STRIDED>( \
num_dims, \
dims, \
strides, \
el_to_sum_per_block, \
src, \
dst, \
num_elements, \
shared, \
tid, \
dst_id); \
break; \
}
#define impl_reduce(OP, NAME, T) \
kernel void NAME( \ kernel void NAME( \
constant size_t &num_dims, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
constant size_t &num_elements, \
ushort tid [[ thread_index_in_threadgroup ]], \
ushort dst_id [[ threadgroup_position_in_grid ]], \
ushort block_dim [[ threads_per_threadgroup ]] \
) { \
constant size_t *dims = {}; \
constant size_t *strides = {}; \
const bool STRIDED = false; \
switch (block_dim) { \
reduce_case(OP, T, T, 2048); \
reduce_case(OP, T, T, 1024); \
reduce_case(OP, T, T, 512); \
reduce_case(OP, T, T, 256); \
reduce_case(OP, T, T, 128); \
reduce_case(OP, T, T, 64); \
reduce_case(OP, T, T, 32); \
reduce_case(OP, T, T, 16); \
reduce_case(OP, T, T, 8); \
reduce_case(OP, T, T, 4); \
reduce_case(OP, T, T, 2); \
reduce_case(OP, T, T, 1); \
} \
} \
kernel void NAME##_strided( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
constant size_t &num_elements, \
ushort tid [[ thread_index_in_threadgroup ]], \
ushort dst_id [[ threadgroup_position_in_grid ]], \
ushort block_dim [[ threads_per_threadgroup ]] \
) { \
const bool STRIDED = true; \
switch (block_dim) { \
reduce_case(OP, T, T, 2048); \
reduce_case(OP, T, T, 1024); \
reduce_case(OP, T, T, 512); \
reduce_case(OP, T, T, 256); \
reduce_case(OP, T, T, 128); \
reduce_case(OP, T, T, 64); \
reduce_case(OP, T, T, 32); \
reduce_case(OP, T, T, 16); \
reduce_case(OP, T, T, 8); \
reduce_case(OP, T, T, 4); \
reduce_case(OP, T, T, 2); \
reduce_case(OP, T, T, 1); \
} \
}
template<
typename T,
typename ReductionOp,
ushort BLOCKSIZE,
bool STRIDED
>
METAL_FUNC void reduce(
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides,
constant size_t &el_to_sum_per_block,
device const T *src,
device uint *dst,
constant size_t &num_elements,
threadgroup Indexed<T> shared[BLOCKSIZE],
ushort tid [[ thread_index_in_threadgroup ]],
ushort dst_id [[ threadgroup_position_in_grid ]]
) {
// Initialize shared memory for current thread to correct value for reduction operation
shared[tid] = ReductionOp::init();
// Calcluate offset for the threadgroup of current thread
ushort offset = dst_id * el_to_sum_per_block;
Indexed<T> initial = ReductionOp::init();
// Load with reduction from global memory into shared memory
shared[tid] = load_from_global<T, Indexed<T>, ReductionOp, BLOCKSIZE, STRIDED>(
initial,
num_elements,
num_dims,
dims,
strides,
el_to_sum_per_block,
src,
offset,
shared,
tid
);
// Threadgroup barrier is needed to ensure that all threads have written to shared memory
// Memory space is not shared between threadgroups so we can use the mem_none flag for all threadgroup barriers.
threadgroup_barrier(mem_flags::mem_none);
// Complete reduction
Indexed<T> value = threadgroup_reduce<ReductionOp, BLOCKSIZE, Indexed<T>>(shared, tid);
// Return index of reduce result
if (tid == 0) dst[dst_id] = value.i;
}
#define arg_reduce_case(OP, T, N) \
case N: { \
threadgroup Indexed<T> shared[N]; \
reduce<T, OP<Indexed<T>>, N, STRIDED>( \
num_dims, \
dims, \
strides, \
el_to_sum_per_block, \
src, \
dst, \
num_elements, \
shared, \
tid, \
dst_id); \
break; \
}
#define impl_arg_reduce(OP, NAME, T) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
constant size_t &num_elements, \
ushort tid [[ thread_index_in_threadgroup ]], \
ushort dst_id [[ threadgroup_position_in_grid ]], \
ushort block_dim [[ threads_per_threadgroup ]] \
) { \
constant size_t *dims = {}; \
constant size_t *strides = {}; \
const bool STRIDED = false; \
switch (block_dim) { \
arg_reduce_case(OP, T, 2048); \
arg_reduce_case(OP, T, 1024); \
arg_reduce_case(OP, T, 512); \
arg_reduce_case(OP, T, 256); \
arg_reduce_case(OP, T, 128); \
arg_reduce_case(OP, T, 64); \
arg_reduce_case(OP, T, 32); \
arg_reduce_case(OP, T, 16); \
arg_reduce_case(OP, T, 8); \
arg_reduce_case(OP, T, 4); \
arg_reduce_case(OP, T, 2); \
arg_reduce_case(OP, T, 1); \
} \
} \
kernel void NAME##_strided( \
constant size_t &num_dims, \ constant size_t &num_dims, \
constant size_t *dims, \ constant size_t *dims, \
constant size_t *strides, \ constant size_t *strides, \
constant size_t &el_to_sum_per_block, \ constant size_t &el_to_sum_per_block, \
device const T *src, \ device const T *src, \
device uint *dst, \ device uint *dst, \
constant size_t &num_elements, \ uint id [[ thread_position_in_grid ]], \
ushort tid [[ thread_index_in_threadgroup ]], \ uint tid [[ thread_index_in_threadgroup ]], \
ushort dst_id [[ threadgroup_position_in_grid ]], \ uint dst_id [[ threadgroup_position_in_grid ]], \
ushort block_dim [[ threads_per_threadgroup ]] \ uint block_dim [[ threads_per_threadgroup ]] \
) { \ ) { \
const bool STRIDED = true; \ \
switch (block_dim) { \ threadgroup T shared_memory[THREADGROUP_SIZE]; \
arg_reduce_case(OP, T, 2048); \ threadgroup uint shared_indices[THREADGROUP_SIZE]; \
arg_reduce_case(OP, T, 1024); \ \
arg_reduce_case(OP, T, 512); \ shared_memory[tid] = MAXVALUE; \
arg_reduce_case(OP, T, 256); \ shared_indices[tid] = 0xFFFFFFFF; \
arg_reduce_case(OP, T, 128); \ bool notset = true; \
arg_reduce_case(OP, T, 64); \ /* \
arg_reduce_case(OP, T, 32); \ // Elements summed in this block range from dst_id * el_to_sum_per_block \
arg_reduce_case(OP, T, 16); \ // to (dst_id + 1) * el_to_sum_per_block. \
arg_reduce_case(OP, T, 8); \ */ \
arg_reduce_case(OP, T, 4); \ size_t start_idx = dst_id * el_to_sum_per_block; \
arg_reduce_case(OP, T, 2); \ size_t stop_idx = start_idx + el_to_sum_per_block; \
arg_reduce_case(OP, T, 1); \ size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || src[strided_i] < shared_memory[tid]) { \
shared_memory[tid] = src[strided_i]; \
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \ } \
}
template<
typename T,
typename ACC = float,
ushort BLOCKSIZE
>
METAL_FUNC void softmax(
constant size_t &src_numel,
constant size_t &el_to_sum_per_block,
const device T *src,
device T *dst,
threadgroup ACC shared[BLOCKSIZE],
ushort tid [[ thread_index_in_threadgroup ]], #define ARGMAX(NAME, T, MINVALUE) \
ushort dst_id [[ threadgroup_position_in_grid ]] kernel void NAME( \
) { constant size_t &num_dims, \
// Initialize shared memory for current thread to lowest value constant size_t *dims, \
shared[tid] = numeric_limits<ACC>::lowest(); constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MINVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
bool notset = true; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || shared_memory[tid] < src[strided_i]) { \
shared_memory[tid] = src[strided_i]; \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
// Calcluate offset for the threadgroup of current thread #define REDUCE(FN, NAME, T, START) \
size_t offset = dst_id * el_to_sum_per_block; kernel void NAME( \
ACC initial = numeric_limits<ACC>::lowest(); constant size_t &num_dims, \
// Load with reduction from global memory into shared memory constant size_t *dims, \
shared[tid] = load_from_global<T, ACC, Max<ACC>, BLOCKSIZE>( constant size_t *strides, \
initial, constant size_t &el_to_sum_per_block, \
src_numel, device const T *src, \
el_to_sum_per_block, device T *dst, \
src, uint id [[ thread_position_in_grid ]], \
offset, uint tid [[ thread_index_in_threadgroup ]], \
shared, uint dst_id [[ threadgroup_position_in_grid ]], \
tid uint block_dim [[ threads_per_threadgroup ]] \
); ) { \
// Threadgroup barrier is needed to ensure that all threads have written to shared memory \
// Memory space is not shared between threadgroups so we can use the mem_none flag for all threadgroup barriers. threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup_barrier(mem_flags::mem_none); \
shared_memory[tid] = START; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
T x = shared_memory[tid]; \
T y = src[strided_i]; \
shared_memory[tid] = FN; \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
T x = shared_memory[tid]; \
T y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
dst[dst_id] = shared_memory[0]; \
} \
// Reduce shared memory to find max value
threadgroup_reduce<Max<ACC>, BLOCKSIZE>(shared, tid);
ACC max_result = shared[0];
// Ensure all threads have max_result = shared[0] before we set shared[0] = 0. #define SOFTMAX(NAME, T) \
threadgroup_barrier(mem_flags::mem_none);
shared[tid] = 0;
// Calculate softmax values
size_t stop_idx = min(offset + el_to_sum_per_block, src_numel);
size_t idx = offset + tid;
while (idx < stop_idx) {
const ACC val = exp(ACC(src[idx]) - max_result);
dst[idx] = T(val);
shared[tid] += val;
idx += BLOCKSIZE;
}
threadgroup_barrier(mem_flags::mem_none);
threadgroup_reduce<Sum<ACC>, BLOCKSIZE>(shared, tid);
threadgroup_barrier(mem_flags::mem_none);
const T inv_acc = T(1.0/shared[0]);
idx = offset + tid;
while (idx < stop_idx) {
dst[idx] *= inv_acc;
idx += BLOCKSIZE;
}
}
#define softmax_case(T, ACC, N) \
case N: { \
threadgroup ACC shared[N]; \
softmax<T, ACC, N>( \
src_numel, \
el_to_sum_per_block, \
src, \
dst, \
shared, \
tid, \
dst_id); \
break; \
}
#define impl_softmax(NAME, T, ACC) \
kernel void NAME( \ kernel void NAME( \
constant size_t &src_numel, \ constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \ constant size_t &el_to_sum_per_block, \
device const T *src, \ device const T *src, \
device T *dst, \ device T *dst, \
\ \
ushort tid [[ thread_index_in_threadgroup ]], \ uint id [[ thread_position_in_grid ]], \
ushort dst_id [[ threadgroup_position_in_grid ]], \ uint tid [[ thread_index_in_threadgroup ]], \
ushort block_dim [[ threads_per_threadgroup ]] \ uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \ ) { \
switch (block_dim) { \ threadgroup float shared_memory[THREADGROUP_SIZE]; \
softmax_case(T, ACC, 2048); \ shared_memory[tid] = -INFINITY; \
softmax_case(T, ACC, 1024); \ size_t start_idx = dst_id * el_to_sum_per_block; \
softmax_case(T, ACC, 512); \ size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
softmax_case(T, ACC, 256); \ size_t idx = start_idx + tid; \
softmax_case(T, ACC, 128); \ \
softmax_case(T, ACC, 64); \ \
softmax_case(T, ACC, 32); \ float tmp = -INFINITY; \
softmax_case(T, ACC, 16); \ while (idx < stop_idx) { \
softmax_case(T, ACC, 8); \ tmp = MAX(tmp, float(src[idx])); \
softmax_case(T, ACC, 4); \ idx += block_dim; \
softmax_case(T, ACC, 2); \ } \
softmax_case(T, ACC, 1); \ shared_memory[tid] = tmp; \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
/* wait for shared_memory[0] to be filled */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
float _max = shared_memory[0]; \
\
/* prevent tid=0 from overwriting _max before other threads have written it */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
shared_memory[tid] = 0; \
\
idx = start_idx + tid; \
while (idx < stop_idx) { \
const float val = exp(float(src[idx]) - _max); \
dst[idx] = T(val); \
shared_memory[tid] += val; \
idx += block_dim; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] += shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
const T inv_acc = T(1.0/shared_memory[0]); \
idx = start_idx + tid; \
while (idx < stop_idx) { \
dst[idx] *= inv_acc; \
idx += block_dim; \
} \
} \ } \
}
impl_reduce(Sum, fast_sum_f32, float) REDUCE(x + y, fast_sum_f32_strided, float, 0)
impl_reduce(Sum, fast_sum_u32, uint) REDUCE(x + y, fast_sum_u32_strided, uint, 0)
impl_reduce(Sum, fast_sum_f16, half) REDUCE(x + y, fast_sum_f16_strided, half, 0)
impl_reduce(Sum, fast_sum_u8, uint8_t) REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
REDUCE(x * y, fast_mul_f32_strided, float, 1)
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
REDUCE(x * y, fast_mul_f16_strided, half, 1)
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
ARGMAX(fast_argmax_u32_strided, uint, 0)
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
impl_reduce(Mul, fast_mul_f32, float) SOFTMAX(softmax_f32, float)
impl_reduce(Mul, fast_mul_u32, uint) SOFTMAX(softmax_f16, half)
impl_reduce(Mul, fast_mul_f16, half)
impl_reduce(Mul, fast_mul_u8, uint8_t)
impl_reduce(Max, fast_max_f32, float)
impl_reduce(Max, fast_max_u32, uint)
impl_reduce(Max, fast_max_f16, half)
impl_reduce(Max, fast_max_u8, uint8_t)
impl_reduce(Min, fast_min_f32, float)
impl_reduce(Min, fast_min_u32, uint)
impl_reduce(Min, fast_min_f16, half)
impl_reduce(Min, fast_min_u8, uint8_t)
impl_arg_reduce(Min, fast_argmin_f32, float)
impl_arg_reduce(Min, fast_argmin_f16, half)
impl_arg_reduce(Min, fast_argmin_u32, uint)
impl_arg_reduce(Min, fast_argmin_u8, uint8_t)
impl_arg_reduce(Max, fast_argmax_f32, float)
impl_arg_reduce(Max, fast_argmax_f16, half)
impl_arg_reduce(Max, fast_argmax_u32, uint)
impl_arg_reduce(Max, fast_argmax_u8, uint8_t)
impl_softmax(softmax_f32, float, float)
impl_softmax(softmax_f16, half, float)
#if __METAL_VERSION__ >= 220 #if __METAL_VERSION__ >= 220
impl_reduce(Sum, fast_sum_i64, int64_t) REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
impl_reduce(Mul, fast_mul_i64, int64_t) REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
impl_reduce(Min, fast_min_i64, int64_t) REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
impl_reduce(Max, fast_max_i64, int64_t) ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
impl_arg_reduce(Min, fast_argmin_i64, int64_t)
impl_arg_reduce(Max, fast_argmax_i64, int64_t)
#endif #endif
#if defined(__HAVE_BFLOAT__) #if __METAL_VERSION__ >= 310
impl_reduce(Sum, fast_sum_bf16, bfloat) REDUCE(x + y, fast_sum_bf16, bfloat, 0)
impl_reduce(Mul, fast_mul_bf16, bfloat) REDUCE(x * y, fast_mul_bf16, bfloat, 1)
impl_reduce(Max, fast_max_bf16, bfloat) REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
impl_reduce(Min, fast_min_bf16, bfloat) REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
impl_arg_reduce(Min, fast_argmin_bf16, bfloat) ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
impl_arg_reduce(Max, fast_argmax_bf16, bfloat) SOFTMAX(softmax_bf16, bfloat)
impl_softmax(softmax_bf16, bfloat, float)
#endif #endif

View File

@ -1,346 +0,0 @@
#include <metal_stdlib>
using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
constant int THREADGROUP_SIZE = 2048;
#define ARGMIN(NAME, T, MAXVALUE) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MAXVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
bool notset = true; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || src[strided_i] < shared_memory[tid]) { \
shared_memory[tid] = src[strided_i]; \
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
#define ARGMAX(NAME, T, MINVALUE) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MINVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
bool notset = true; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || shared_memory[tid] < src[strided_i]) { \
shared_memory[tid] = src[strided_i]; \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
#define REDUCE(FN, NAME, T, START) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
\
shared_memory[tid] = START; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
T x = shared_memory[tid]; \
T y = src[strided_i]; \
shared_memory[tid] = FN; \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
T x = shared_memory[tid]; \
T y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
dst[dst_id] = shared_memory[0]; \
} \
#define SOFTMAX(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
\
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = -INFINITY; \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
\
\
float tmp = -INFINITY; \
while (idx < stop_idx) { \
tmp = MAX(tmp, float(src[idx])); \
idx += block_dim; \
} \
shared_memory[tid] = tmp; \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
/* wait for shared_memory[0] to be filled */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
float _max = shared_memory[0]; \
\
/* prevent tid=0 from overwriting _max before other threads have written it */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
shared_memory[tid] = 0; \
\
idx = start_idx + tid; \
while (idx < stop_idx) { \
const float val = exp(float(src[idx]) - _max); \
dst[idx] = T(val); \
shared_memory[tid] += val; \
idx += block_dim; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] += shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
const T inv_acc = T(1.0/shared_memory[0]); \
idx = start_idx + tid; \
while (idx < stop_idx) { \
dst[idx] *= inv_acc; \
idx += block_dim; \
} \
} \
REDUCE(x + y, fast_sum_f32_strided, float, 0)
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
REDUCE(x + y, fast_sum_f16_strided, half, 0)
REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
REDUCE(x * y, fast_mul_f32_strided, float, 1)
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
REDUCE(x * y, fast_mul_f16_strided, half, 1)
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
ARGMAX(fast_argmax_u32_strided, uint, 0)
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
REDUCE(x + y, fast_sum_f32, float, 0)
REDUCE(x + y, fast_sum_u32, uint, 0)
REDUCE(x + y, fast_sum_f16, half, 0)
REDUCE(x + y, fast_sum_u8, uint8_t, 0)
REDUCE(x * y, fast_mul_f32, float, 1)
REDUCE(x * y, fast_mul_u32, uint, 1)
REDUCE(x * y, fast_mul_f16, half, 1)
REDUCE(MAX(x, y), fast_max_f32, float, -HUGE_VALF)
REDUCE(MAX(x, y), fast_max_u32, uint, 0)
REDUCE(MAX(x, y), fast_max_f16, half, -HUGE_VALH)
REDUCE(MAX(x, y), fast_max_u8, uint8_t, 0)
REDUCE(MIN(x, y), fast_min_f32, float, HUGE_VALF)
REDUCE(MIN(x, y), fast_min_u32, uint, 0xFFFFFFFF)
REDUCE(MIN(x, y), fast_min_f16, half, HUGE_VALH)
REDUCE(MIN(x, y), fast_min_u8, uint8_t, 0xFF)
ARGMIN(fast_argmin_f32, float, HUGE_VALF)
ARGMIN(fast_argmin_f16, half, HUGE_VALH)
ARGMIN(fast_argmin_u32, uint, 0xFFFFFFFF)
ARGMIN(fast_argmin_u8, uint8_t, 0xFF)
ARGMAX(fast_argmax_f32, float, -HUGE_VALF)
ARGMAX(fast_argmax_f16, half, -HUGE_VALH)
ARGMAX(fast_argmax_u32, uint, 0)
ARGMAX(fast_argmax_u8, uint8_t, 0)
SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
#if __METAL_VERSION__ >= 220
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
REDUCE(x + y, fast_sum_i64, int64_t, 0)
REDUCE(MIN(x, y), fast_min_i64, int64_t, INT_MAX)
REDUCE(MAX(x, y), fast_max_i64, int64_t, INT_MIN)
ARGMIN(fast_argmin_i64, int64_t, INT_MAX)
ARGMAX(fast_argmax_i64, int64_t, INT_MIN)
#endif
#if defined(__HAVE_BFLOAT__)
REDUCE(x + y, fast_sum_bf16_strided, bfloat, 0)
REDUCE(x * y, fast_mul_bf16_strided, bfloat, 1)
REDUCE(MAX(x, y), fast_max_bf16_strided, bfloat, -HUGE_VALBF)
REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF)
ARGMIN(fast_argmin_bf16_strided, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16_strided, bfloat, -HUGE_VALBF)
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
#endif

View File

@ -17,30 +17,8 @@ METAL_FUNC uint get_strided_index(
return strided_i; return strided_i;
} }
template<typename T, typename ID>
METAL_FUNC void where_cond(
constant size_t &numel,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides,
constant size_t *strides_t,
constant size_t *strides_f,
device const ID *ids,
device const T *t,
device const T *f,
device T *out,
uint i [[ thread_position_in_grid ]]
) {
if (i >= numel){
return;
}
uint strided_i = get_strided_index(i, num_dims, dims, strides);
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t);
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f);
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f];
}
#define WHERE_OP(T, ID, FN_NAME) \ #define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
kernel void FN_NAME( \ kernel void FN_NAME( \
constant size_t &numel, \ constant size_t &numel, \
constant size_t &num_dims, \ constant size_t &num_dims, \
@ -48,13 +26,19 @@ kernel void FN_NAME(
constant size_t *strides, \ constant size_t *strides, \
constant size_t *strides_t, \ constant size_t *strides_t, \
constant size_t *strides_f, \ constant size_t *strides_f, \
device const ID *ids, \ device const ID_TYPENAME *ids, \
device const T *t, \ device const TYPENAME *t, \
device const T *f, \ device const TYPENAME *f, \
device T *out, \ device TYPENAME *out ,\
uint i [[ thread_position_in_grid ]] \ uint i [[ thread_position_in_grid ]] \
) { \ ) { \
where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \ if (i >= numel){ \
return; \
} \
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
} \ } \
// WHERE_OP(float, int64_t, where_i64_f32) // WHERE_OP(float, int64_t, where_i64_f32)
@ -70,14 +54,10 @@ kernel void FN_NAME(
// WHERE_OP(int64_t, uint32_t, where_u32_i64) // WHERE_OP(int64_t, uint32_t, where_u32_i64)
WHERE_OP(float, uint8_t, where_u8_f32) WHERE_OP(float, uint8_t, where_u8_f32)
WHERE_OP(half, uint8_t, where_u8_f16) // WHERE_OP(double, uint8_t, where_u8_f64)
WHERE_OP(uint8_t, uint8_t, where_u8_u8) WHERE_OP(uint8_t, uint8_t, where_u8_u8)
WHERE_OP(uint32_t, uint8_t, where_u8_u32) WHERE_OP(uint32_t, uint8_t, where_u8_u32)
#if __METAL_VERSION__ >= 220 #if __METAL_VERSION__ >= 220
WHERE_OP(int64_t, uint8_t, where_u8_i64) WHERE_OP(int64_t, uint8_t, where_u8_i64)
#endif #endif
#if defined(__HAVE_BFLOAT__)
WHERE_OP(bfloat, uint8_t, where_u8_bf16)
#endif

View File

@ -1,6 +1,6 @@
use super::*; use super::*;
use half::{bf16, f16}; use half::{bf16, f16};
use metal::{Buffer, Device, MTLResourceOptions}; use metal::{Device, MTLResourceOptions};
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T; let ptr = buffer.contents() as *const T;
@ -11,7 +11,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer { fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
let ptr = data.as_ptr() as *const c_void; let ptr = data.as_ptr() as *const core::ffi::c_void;
let size = (data.len() * std::mem::size_of::<T>()) as u64; let size = (data.len() * std::mem::size_of::<T>()) as u64;
device.new_buffer_with_data(ptr, size, options) device.new_buffer_with_data(ptr, size, options)
} }
@ -37,7 +37,8 @@ fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> { fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let device = device(); let device = device();
let kernels = Kernels::new(); let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v); let input = new_buffer(&device, v);
@ -59,7 +60,8 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> { fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
let device = device(); let device = device();
let kernels = Kernels::new(); let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
@ -94,7 +96,8 @@ fn run_strided<T: Clone>(
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v); let input = new_buffer(&device, v);
let output = new_buffer(&device, v); let output = new_buffer(&device, v);
let kernels = Kernels::new(); let fence = device.new_fence();
let kernels = Kernels::new(fence);
call_unary_strided( call_unary_strided(
&device, &device,
command_buffer, command_buffer,
@ -245,37 +248,10 @@ fn binary_add_f32() {
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]); assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
} }
#[test]
fn binary_ops_bf16() {
let lhs: Vec<bf16> = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect();
let rhs: Vec<bf16> = [4.2f32, 5.5f32, 6.91f32]
.into_iter()
.map(bf16::from_f32)
.collect();
macro_rules! binary_op {
($opname:ident, $opexpr:expr) => {{
let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT);
let expected: Vec<bf16> = lhs
.iter()
.zip(rhs.iter())
.map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y))
.collect();
assert_eq!(results, expected);
}};
}
binary_op!(add, |x, y| x + y);
binary_op!(sub, |x, y| x - y);
binary_op!(mul, |x, y| x * y);
binary_op!(div, |x, y| x / y);
binary_op!(min, |x: bf16, y| x.min(y));
binary_op!(max, |x: bf16, y| x.max(y));
}
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device(); let device = device();
let kernels = Kernels::new(); let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v); let input = new_buffer(&device, v);
@ -320,92 +296,10 @@ fn cast_u32_f32() {
assert_eq!(results, vec![1.0f32; 10_000]); assert_eq!(results, vec![1.0f32; 10_000]);
} }
#[test]
fn it_cast_bf16_u32() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
let output: Vec<u32> = cast(&input, "cast_bf16_u32");
let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_bf16_f32() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
let output: Vec<f32> = cast(&input, "cast_bf16_f32");
let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_u8_bf16() {
let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
let expected: Vec<bf16> = input
.iter()
.map(|v| bf16::from_f32(*v as f32))
.collect::<Vec<_>>();
assert_eq!(output, expected);
}
#[test]
fn it_cast_u32_bf16() {
let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_f32_bf16() {
let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_bf16_u8() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
let output: Vec<u8> = cast(&input, "cast_bf16_u8");
let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_bf16_f16() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
let output: Vec<f16> = cast(&input, "cast_bf16_f16");
let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_f16_bf16() {
let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
assert_eq!(output, expected);
}
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device(); let device = device();
let kernels = Kernels::new(); let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
@ -440,7 +334,8 @@ fn run_affine_strided<T: Clone>(
add: f64, add: f64,
) -> Vec<T> { ) -> Vec<T> {
let device = device(); let device = device();
let kernels = Kernels::new(); let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
@ -501,14 +396,14 @@ fn index_select() {
let shape = [5, 2]; let shape = [5, 2];
let ids = [0u32, 4, 2]; let ids = [0u32, 4, 2];
let dim = 0; let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]); assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [2, 5]; let shape = [2, 5];
let ids = [0u32, 1, 0]; let ids = [0u32, 1, 0];
let dim = 0; let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!( assert_eq!(
result, result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
@ -524,46 +419,20 @@ fn index_select_f16() {
let shape = [5, 2]; let shape = [5, 2];
let ids = [0u32, 4, 2]; let ids = [0u32, 4, 2];
let dim = 0; let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16"); let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!( assert_eq!(
approx_f16(result, 4), approx_f16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
); );
} }
#[test]
fn index_select_is_u32_bf16() {
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16");
assert_eq!(
approx_bf16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
);
}
#[test]
fn index_select_is_u8_bf16() {
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
let shape = [5, 2];
let ids = [0u8, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16");
assert_eq!(
approx_bf16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
);
}
#[test] #[test]
fn index_select_dim1() { fn index_select_dim1() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2]; let shape = [5, 2];
let ids = [0u32, 1, 0]; let ids = [0u32, 1, 0];
let dim = 1; let dim = 1;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!( assert_eq!(
result, result,
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0] vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
@ -575,7 +444,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
shape: &[usize], shape: &[usize],
ids: &[I], ids: &[I],
dim: usize, dim: usize,
name: &'static str,
) -> Vec<T> { ) -> Vec<T> {
let device = Device::system_default().expect("no device found"); let device = Device::system_default().expect("no device found");
@ -589,7 +457,14 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let dst_el = ids.len() * left_size * right_size; let dst_el = ids.len() * left_size * right_size;
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
let kernels = Kernels::new(); let name = match core::mem::size_of::<T>() {
4 => "is_u32_f32",
2 => "is_u32_f16",
_ => unimplemented!(),
};
let fence = device.new_fence();
let kernels = Kernels::new(fence);
call_index_select( call_index_select(
&device, &device,
&command_buffer, &command_buffer,
@ -622,18 +497,19 @@ fn cos_f16() {
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]); assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
} }
fn run_reduce<T, U: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<U> { fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
let device = device(); let device = device();
let kernels = Kernels::new(); let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v); let input = new_buffer(&device, v);
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((out_length * core::mem::size_of::<U>()) as u64, options); let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
let dims = vec![v.len()]; let dims = vec![v.len()];
let strides = vec![1]; let strides = vec![1];
match call_reduce_strided( call_reduce_strided(
&device, &device,
command_buffer, command_buffer,
&kernels, &kernels,
@ -644,13 +520,8 @@ fn run_reduce<T, U: Clone>(v: &[T], out_length: usize, name: &'static str) -> Ve
&input, &input,
0, 0,
&output, &output,
) { )
Ok(_) => {} .unwrap();
Err(e) => {
println!("Error: {}", e);
panic!();
}
}
command_buffer.commit(); command_buffer.commit();
command_buffer.wait_until_completed(); command_buffer.wait_until_completed();
@ -659,7 +530,8 @@ fn run_reduce<T, U: Clone>(v: &[T], out_length: usize, name: &'static str) -> Ve
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> { fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
let device = device(); let device = device();
let kernels = Kernels::new(); let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v); let input = new_buffer(&device, v);
@ -682,114 +554,22 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
read_to_vec(&output, v.len()) read_to_vec(&output, v.len())
} }
const fn create_array<const N: usize>() -> [f32; N] {
let mut array: [f32; N] = [0.0; N];
let mut i = 1;
while i <= N {
array[i - 1] = i as f32;
i += 1;
}
array
}
const fn correct_sum<const N: usize, const D: usize>() -> [f32; D] {
let mut sum = 0;
let mut results: [f32; D] = [0.0; D];
let mut i = 1;
let mut j = 1;
while i <= N {
sum += i;
i += 1;
if i > j * N / D {
results[j - 1] = sum as f32;
j += 1;
sum = 0;
}
}
results
}
fn correct_argmax<const N: usize, const D: usize>(arr: [f32; N]) -> [u32; D] {
let mut max = 0.0;
let mut max_index: u32 = 0;
let mut results: [u32; D] = [0; D];
let mut i = 0;
let mut j = 1;
while i <= N {
if i >= (j * N / D) {
results[j - 1] = max_index;
max = 0.0;
max_index = 0;
j += 1;
}
if i == N {
break;
}
if arr[i] > max {
max = arr[i];
max_index = i as u32;
}
i += 1;
}
results
}
fn reduce_sum_case<const N: usize, const D: usize>() {
let v = create_array::<N>();
let results = run_reduce(&v, D, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), correct_sum::<N, D>());
}
fn reduce_argmax_case<const N: usize, const D: usize>() {
let v = create_array::<N>();
let results: Vec<u32> = run_reduce(&v, D, "fast_argmax_f32_strided");
assert_eq!(results, correct_argmax::<N, D>(v));
}
#[test] #[test]
fn reduce_sum() { fn reduce_sum() {
reduce_sum_case::<6, 1>(); let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
reduce_sum_case::<10, 1>(); let out_length = 1;
reduce_sum_case::<64, 1>();
reduce_sum_case::<128, 1>();
reduce_sum_case::<256, 1>();
reduce_sum_case::<512, 1>();
reduce_sum_case::<1024, 1>();
reduce_sum_case::<2048, 1>();
reduce_sum_case::<4096, 1>();
reduce_sum_case::<6, 2>(); let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
reduce_sum_case::<10, 2>(); assert_eq!(approx(results, 4), vec![21.0]);
reduce_sum_case::<64, 2>();
reduce_sum_case::<128, 2>();
reduce_sum_case::<256, 2>();
reduce_sum_case::<512, 2>();
reduce_sum_case::<1024, 2>();
reduce_sum_case::<2048, 2>();
reduce_sum_case::<4096, 2>();
} }
#[test] #[test]
fn reduce_argmax() { fn reduce_sum2() {
reduce_argmax_case::<6, 1>(); let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
reduce_argmax_case::<10, 1>(); let out_length = 2;
reduce_argmax_case::<64, 1>();
reduce_argmax_case::<128, 1>();
reduce_argmax_case::<256, 1>();
reduce_argmax_case::<512, 1>();
reduce_argmax_case::<1024, 1>();
reduce_argmax_case::<2048, 1>();
reduce_argmax_case::<4096, 1>();
reduce_argmax_case::<6, 2>(); let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
reduce_argmax_case::<10, 2>(); assert_eq!(approx(results, 4), vec![6.0, 15.0]);
reduce_argmax_case::<64, 2>();
reduce_argmax_case::<128, 2>();
reduce_argmax_case::<256, 2>();
reduce_argmax_case::<512, 2>();
reduce_argmax_case::<1024, 2>();
reduce_argmax_case::<2048, 2>();
reduce_argmax_case::<4096, 2>();
} }
#[test] #[test]
@ -810,6 +590,7 @@ fn softmax() {
} }
let results = run_softmax(&v, last_dim, "softmax_f32"); let results = run_softmax(&v, last_dim, "softmax_f32");
let results = approx(results, 4); let results = approx(results, 4);
println!("{results:?}");
assert_eq!( assert_eq!(
results.iter().map(|&s| s.round() as usize).sum::<usize>(), results.iter().map(|&s| s.round() as usize).sum::<usize>(),
n n
@ -869,7 +650,8 @@ fn run_where_cond<I: Clone, T: Clone>(
name: &'static str, name: &'static str,
) -> Vec<T> { ) -> Vec<T> {
let device = device(); let device = device();
let kernels = Kernels::new(); let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
@ -945,7 +727,8 @@ fn run_gemm<T: Clone>(
rhs_offset: usize, rhs_offset: usize,
) -> Vec<T> { ) -> Vec<T> {
let device = device(); let device = device();
let kernels = Kernels::new(); let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
@ -1023,124 +806,3 @@ fn gemm() {
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0] vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
); );
} }
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options);
let seed = device.new_buffer_with_data(
&seed as *const u32 as *const core::ffi::c_void,
std::mem::size_of::<u32>() as NSUInteger,
options,
);
if name.starts_with("rand_uniform") {
call_random_uniform(
&device,
command_buffer,
&kernels,
name,
a,
b,
length,
&seed,
&output,
)
.unwrap();
} else {
call_random_normal(
&device,
command_buffer,
&kernels,
name,
a,
b,
length,
&seed,
&output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, length)
}
#[test]
fn random() {
fn calc_mean(data: &[f32]) -> f32 {
let sum = data.iter().sum::<f32>() as f32;
let count = data.len();
assert!(count > 0);
sum / count as f32
}
fn calc_stddev(data: &[f32]) -> f32 {
let mean = calc_mean(data);
let count = data.len();
assert!(count > 0);
let variance = data
.iter()
.map(|value| {
let diff = mean - (*value as f32);
diff * diff
})
.sum::<f32>()
/ count as f32;
variance.sqrt()
}
let shape = vec![1024, 10];
let length = shape.iter().product::<usize>();
let seed = 299792458;
let min = -30.0;
let max = 30.0;
let mean = 100.0;
let stddev = 50.0;
macro_rules! validate_random {
($type:ty) => {
let results: Vec<f32> = run_random::<$type>(
concat!("rand_uniform_", stringify!($type)),
seed,
length,
min,
max,
)
.into_iter()
.map(f32::from)
.collect();
results.iter().for_each(|v| {
assert!(*v >= min && *v <= max);
});
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
let results: Vec<f32> = run_random::<$type>(
concat!("rand_normal_", stringify!($type)),
seed,
length,
mean,
stddev,
)
.into_iter()
.map(f32::from)
.collect();
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
};
}
validate_random!(f32);
validate_random!(f16);
validate_random!(bf16);
}

View File

@ -58,12 +58,6 @@ template <typename T> METAL_FUNC T gelu(T x) {
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta))); return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
} }
template <typename T> METAL_FUNC T relu(T in){
if (in < 0) {
return 0;
}
return in;
}
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \ kernel void FN_NAME( \
@ -116,7 +110,7 @@ UNARY_OP(gelu_erf)
UNARY_OP(erf) UNARY_OP(erf)
UNARY_OP(tanh) UNARY_OP(tanh)
UNARY_OP(recip) UNARY_OP(recip)
UNARY_OP(relu)
UNARY(id, float, copy_f32, copy_f32_strided) UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided) UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided)
@ -126,7 +120,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided)
UNARY(id, int64_t, copy_i64, copy_i64_strided) UNARY(id, int64_t, copy_i64, copy_i64_strided)
#endif #endif
#if defined(__HAVE_BFLOAT__) #if __METAL_VERSION__ >= 310
BFLOAT_UNARY_OP(cos) BFLOAT_UNARY_OP(cos)
BFLOAT_UNARY_OP(sin) BFLOAT_UNARY_OP(sin)
BFLOAT_UNARY_OP(sqr) BFLOAT_UNARY_OP(sqr)
@ -135,7 +129,6 @@ BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp) BFLOAT_UNARY_OP(exp)
BFLOAT_UNARY_OP(log) BFLOAT_UNARY_OP(log)
BFLOAT_UNARY_OP(gelu) BFLOAT_UNARY_OP(gelu)
BFLOAT_UNARY_OP(abs)
BFLOAT_UNARY_OP(ceil) BFLOAT_UNARY_OP(ceil)
BFLOAT_UNARY_OP(floor) BFLOAT_UNARY_OP(floor)
BFLOAT_UNARY_OP(round) BFLOAT_UNARY_OP(round)
@ -143,7 +136,6 @@ BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf) BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh) BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip) BFLOAT_UNARY_OP(recip)
BFLOAT_UNARY_OP(relu)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided) UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
#endif #endif

View File

@ -222,10 +222,7 @@ impl Benchmark for QMatMul {
type RunResult = Tensor; type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> { fn preprocess() -> Result<Self::PreProcessData> {
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32]; let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
let mm = candle::quantized::QTensor::new( let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
candle::quantized::QStorage::Cpu(Box::new(zeros)),
(4096, 11008),
)?;
let mm = candle::quantized::QMatMul::from_qtensor(mm)?; let mm = candle::quantized::QMatMul::from_qtensor(mm)?;
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?; let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
Ok((mm, arg)) Ok((mm, arg))

View File

@ -6,7 +6,6 @@ use serde::Deserialize;
pub enum Activation { pub enum Activation {
#[default] #[default]
Gelu, Gelu,
#[serde(alias = "gelu_new")]
NewGelu, NewGelu,
Relu, Relu,
Relu2, Relu2,

View File

@ -254,12 +254,6 @@ pub fn simple_eval(
let output = input0.broadcast_div(input1)?; let output = input0.broadcast_div(input1)?;
values.insert(node.output[0].clone(), output); values.insert(node.output[0].clone(), output);
} }
"Pow" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_pow(input1)?;
values.insert(node.output[0].clone(), output);
}
"Equal" => { "Equal" => {
let input0 = get(&node.input[0])?; let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?; let input1 = get(&node.input[1])?;

View File

@ -33,9 +33,7 @@ def has_mkl() -> bool:
pass pass
@staticmethod @staticmethod
def load_ggml( def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
path: Union[str, PathLike], device: Optional[Device] = None
) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
""" """
Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
@ -43,9 +41,7 @@ def load_ggml(
pass pass
@staticmethod @staticmethod
def load_gguf( def load_gguf(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
path: Union[str, PathLike], device: Optional[Device] = None
) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
""" """
Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
and the second maps metadata keys to metadata values. and the second maps metadata keys to metadata values.

View File

@ -1074,20 +1074,20 @@ impl PyTensor {
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> { fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
use ::candle::quantized; use ::candle::quantized;
let res = match quantized_dtype.to_lowercase().as_str() { let res = match quantized_dtype.to_lowercase().as_str() {
"q2k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q2K), "q2k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ2K>(self),
"q3k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q3K), "q3k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ3K>(self),
"q4_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_0), "q4_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_0>(self),
"q4_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_1), "q4_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_1>(self),
"q4k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4K), "q4k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4K>(self),
"q5_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_0), "q5_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_0>(self),
"q5_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_1), "q5_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_1>(self),
"q5k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5K), "q5k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5K>(self),
"q6k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q6K), "q6k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ6K>(self),
"q8_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_0), "q8_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_0>(self),
"q8_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_1), "q8_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_1>(self),
"q8k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8K), "q8k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8K>(self),
"f16" => quantized::QTensor::quantize(self, quantized::GgmlDType::F16), "f16" => quantized::QTensor::quantize::<f16>(self),
"f32" => quantized::QTensor::quantize(self, quantized::GgmlDType::F32), "f32" => quantized::QTensor::quantize::<f32>(self),
dt => { dt => {
return Err(PyErr::new::<PyValueError, _>(format!( return Err(PyErr::new::<PyValueError, _>(format!(
"unknown quantized-dtype {dt}" "unknown quantized-dtype {dt}"
@ -1278,19 +1278,13 @@ fn save_safetensors(
} }
#[pyfunction] #[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")] #[pyo3(text_signature = "(path:Union[str,PathLike])")]
/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, /// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. /// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]] /// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]]
fn load_ggml( fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
path: &str,
device: Option<PyDevice>,
py: Python<'_>,
) -> PyResult<(PyObject, PyObject, PyObject)> {
let mut file = std::fs::File::open(path)?; let mut file = std::fs::File::open(path)?;
let device = device.unwrap_or(PyDevice::Cpu).as_device()?; let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
let ggml =
::candle::quantized::ggml_file::Content::read(&mut file, &device).map_err(wrap_err)?;
let tensors = ggml let tensors = ggml
.tensors .tensors
.into_iter() .into_iter()
@ -1319,16 +1313,11 @@ fn load_ggml(
} }
#[pyfunction] #[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")] #[pyo3(text_signature = "(path:Union[str,PathLike])")]
/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, /// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
/// and the second maps metadata keys to metadata values. /// and the second maps metadata keys to metadata values.
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]] /// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]]
fn load_gguf( fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
path: &str,
device: Option<PyDevice>,
py: Python<'_>,
) -> PyResult<(PyObject, PyObject)> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
use ::candle::quantized::gguf_file; use ::candle::quantized::gguf_file;
fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> { fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
let v: PyObject = match v { let v: PyObject = match v {
@ -1360,7 +1349,7 @@ fn load_gguf(
.tensor_infos .tensor_infos
.keys() .keys()
.map(|key| { .map(|key| {
let qtensor = gguf.tensor(&mut file, key, &device)?; let qtensor = gguf.tensor(&mut file, key)?;
Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))) Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py)))
}) })
.collect::<::candle::Result<Vec<_>>>() .collect::<::candle::Result<Vec<_>>>()

View File

@ -1,6 +1,6 @@
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear}; use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
use candle::{DType, Device, Result, Tensor}; use candle::{DType, Device, Result, Tensor};
use candle_nn::{embedding, Embedding, Module, VarBuilder}; use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize; use serde::Deserialize;
pub const DTYPE: DType = DType::F32; pub const DTYPE: DType = DType::F32;
@ -112,6 +112,11 @@ impl Config {
} }
} }
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
struct Dropout { struct Dropout {
#[allow(dead_code)] #[allow(dead_code)]
pr: f64, pr: f64,

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder}; use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> { fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?; let weight = vb.get((size2, size1), "weight")?;
@ -11,6 +11,11 @@ fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Line
Ok(Linear::new(weight, bias)) Ok(Linear::new(weight, bias))
} }
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?; let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?; let bias = vb.get(size, "bias")?;

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, Result, Tensor, D}; use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder}; use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
const MAX_SEQ_LEN: usize = 5000; const MAX_SEQ_LEN: usize = 5000;
@ -27,6 +27,11 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
Ok(LayerNorm::new(weight, bias, eps)) Ok(LayerNorm::new(weight, bias, eps))
} }
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py // https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
#[derive(Debug)] #[derive(Debug)]
pub struct Config { pub struct Config {

View File

@ -1,6 +1,6 @@
use super::with_tracing::{linear_no_bias as linear, Linear}; use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, Module, VarBuilder}; use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize; use serde::Deserialize;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
@ -136,6 +136,11 @@ impl Cache {
} }
} }
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
Ok(Embedding::new(embeddings, cfg.hidden_size))
}
struct RmsNorm { struct RmsNorm {
inner: candle_nn::RmsNorm, inner: candle_nn::RmsNorm,
span: tracing::Span, span: tracing::Span,
@ -404,7 +409,7 @@ impl Llama {
} }
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.num_hidden_layers) let blocks: Vec<_> = (0..cfg.num_hidden_layers)

View File

@ -1,333 +0,0 @@
//! MobileOne inference implementation based on timm and candle-repvgg
//!
//! See "MobileOne: An Improved One millisecond Mobile Backbone"
//! https://arxiv.org/abs/2206.04040
use candle::{DType, Result, Tensor, D};
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, BatchNorm, Conv2d, Conv2dConfig,
Func, VarBuilder,
};
struct StageConfig {
blocks: usize,
channels: usize,
}
// The architecture in the paper has 6 stages. The timm implementation uses an equivalent form
// by concatenating the 5th stage (starts with stride 1) to the previous one.
const STAGES: [StageConfig; 5] = [
StageConfig {
blocks: 1,
channels: 64,
},
StageConfig {
blocks: 2,
channels: 64,
},
StageConfig {
blocks: 8,
channels: 128,
},
StageConfig {
blocks: 10,
channels: 256,
},
StageConfig {
blocks: 1,
channels: 512,
},
];
#[derive(Clone)]
pub struct Config {
/// overparameterization factor
k: usize,
/// per-stage channel number multipliers
alphas: [f32; 5],
}
impl Config {
pub fn s0() -> Self {
Self {
k: 4,
alphas: [0.75, 0.75, 1.0, 1.0, 2.0],
}
}
pub fn s1() -> Self {
Self {
k: 1,
alphas: [1.5, 1.5, 1.5, 2.0, 2.5],
}
}
pub fn s2() -> Self {
Self {
k: 1,
alphas: [1.5, 1.5, 2.0, 2.5, 4.0],
}
}
pub fn s3() -> Self {
Self {
k: 1,
alphas: [2.0, 2.0, 2.5, 3.0, 4.0],
}
}
pub fn s4() -> Self {
Self {
k: 1,
alphas: [3.0, 3.0, 3.5, 3.5, 4.0],
}
}
}
// SE blocks are used in the last stages of the s4 variant.
fn squeeze_and_excitation(
in_channels: usize,
squeeze_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
..Default::default()
};
let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
Ok(Func::new(move |xs| {
let residual = xs;
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
residual.broadcast_mul(&xs)
}))
}
// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
// based on the _fuse_bn_tensor method in timm
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
let (gamma, beta) = bn.weight_and_bias().unwrap();
let mu = bn.running_mean();
let sigma = (bn.running_var() + bn.eps())?.sqrt();
let gps = (gamma / sigma)?;
let bias = (beta - mu * &gps)?;
let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
Ok((weights, bias))
}
// A mobileone block has a different training time and inference time architecture.
// The latter is a simple and efficient equivalent transformation of the former
// realized by a structural reparameterization technique, where convolutions
// along with identity branches and batchnorm layers are fused into a single convolution.
#[allow(clippy::too_many_arguments)]
fn mobileone_block(
has_identity: bool,
k: usize,
dim: usize,
stride: usize,
padding: usize,
groups: usize,
kernel: usize,
in_channels: usize,
out_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
stride,
padding,
groups,
..Default::default()
};
let mut w = Tensor::zeros(
(out_channels, in_channels / groups, kernel, kernel),
DType::F32,
vb.device(),
)?;
let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;
// k is the training-time overparameterization factor, larger than 1 only in the s0 variant
for i in 0..k {
let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp(format!("conv_kxk.{i}.bn")))?;
let conv_kxk = conv2d_no_bias(
in_channels,
out_channels,
kernel,
conv2d_cfg,
vb.pp(format!("conv_kxk.{i}.conv")),
)?;
let (wk, bk) = fuse_conv_bn(conv_kxk.weight(), conv_kxk_bn)?;
w = (w + wk)?;
b = (b + bk)?;
}
if kernel > 1 {
let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"))?;
let conv_scale = conv2d_no_bias(
in_channels,
out_channels,
1,
conv2d_cfg,
vb.pp("conv_scale.conv"),
)?;
let (mut ws, bs) = fuse_conv_bn(conv_scale.weight(), conv_scale_bn)?;
// resize to 3x3
ws = ws.pad_with_zeros(D::Minus1, 1, 1)?;
ws = ws.pad_with_zeros(D::Minus2, 1, 1)?;
w = (w + ws)?;
b = (b + bs)?;
}
// Use SE blocks if present (last layers of the s4 variant)
let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("attn"));
// read and reparameterize the identity bn into wi and bi
if has_identity {
let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
let id = in_channels / groups;
// See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809
for i in 0..in_channels {
if kernel > 1 {
weights[i * kernel * kernel + 4] = 1.0;
} else {
weights[i * (id + 1)] = 1.0;
}
}
let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
w = (w + wi)?;
b = (b + bi)?;
}
let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
Ok(Func::new(move |xs| {
let mut xs = xs.apply(&reparam_conv)?;
if let Ok(f) = &se {
xs = xs.apply(f)?;
}
xs = xs.relu()?;
Ok(xs)
}))
}
// Get the number of output channels per stage taking into account the multipliers
fn output_channels_per_stage(cfg: &Config, stage: usize) -> usize {
let channels = STAGES[stage].channels as f32;
let alpha = cfg.alphas[stage];
match stage {
0 => std::cmp::min(64, (channels * alpha) as usize),
_ => (channels * alpha) as usize,
}
}
// Each stage is made of blocks. The first layer always downsamples with stride 2.
// All but the first block have a residual connection.
fn mobileone_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
let nblocks = STAGES[idx].blocks;
let mut blocks = Vec::with_capacity(nblocks);
let mut in_channels = output_channels_per_stage(cfg, idx - 1);
for block_idx in 0..nblocks {
let out_channels = output_channels_per_stage(cfg, idx);
let (has_identity, stride) = if block_idx == 0 {
(false, 2)
} else {
(true, 1)
};
// depthwise convolution layer
blocks.push(mobileone_block(
has_identity,
cfg.k,
in_channels,
stride,
1,
in_channels,
3,
in_channels,
in_channels,
vb.pp(block_idx * 2),
)?);
// pointwise convolution layer
blocks.push(mobileone_block(
has_identity,
cfg.k,
out_channels,
1, // stride
0, // padding
1, // groups
1, // kernel
in_channels,
out_channels,
vb.pp(block_idx * 2 + 1),
)?);
in_channels = out_channels;
}
Ok(Func::new(move |xs| {
let mut xs = xs.clone();
for block in blocks.iter() {
xs = xs.apply(block)?
}
Ok(xs)
}))
}
// Build a mobileone model for a given configuration.
fn mobileone_model(
config: &Config,
nclasses: Option<usize>,
vb: VarBuilder,
) -> Result<Func<'static>> {
let cls = match nclasses {
None => None,
Some(nclasses) => {
let outputs = output_channels_per_stage(config, 4);
let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
Some(linear)
}
};
let stem_dim = output_channels_per_stage(config, 0);
let stem = mobileone_block(false, 1, stem_dim, 2, 1, 1, 3, 3, stem_dim, vb.pp("stem"))?;
let vb = vb.pp("stages");
let stage1 = mobileone_stage(config, 1, vb.pp(0))?;
let stage2 = mobileone_stage(config, 2, vb.pp(1))?;
let stage3 = mobileone_stage(config, 3, vb.pp(2))?;
let stage4 = mobileone_stage(config, 4, vb.pp(3))?;
Ok(Func::new(move |xs| {
let xs = xs
.apply(&stem)?
.apply(&stage1)?
.apply(&stage2)?
.apply(&stage3)?
.apply(&stage4)?
.mean(D::Minus2)?
.mean(D::Minus1)?;
match &cls {
None => Ok(xs),
Some(cls) => xs.apply(cls),
}
}))
}
pub fn mobileone(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
mobileone_model(cfg, Some(nclasses), vb)
}
pub fn mobileone_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
mobileone_model(cfg, None, vb)
}

View File

@ -15,10 +15,8 @@ pub mod marian;
pub mod mistral; pub mod mistral;
pub mod mixformer; pub mod mixformer;
pub mod mixtral; pub mod mixtral;
pub mod mobileone;
pub mod mpt; pub mod mpt;
pub mod persimmon; pub mod persimmon;
pub mod phi;
pub mod quantized_blip; pub mod quantized_blip;
pub mod quantized_blip_text; pub mod quantized_blip_text;
pub mod quantized_llama; pub mod quantized_llama;
@ -28,7 +26,6 @@ pub mod quantized_mixformer;
pub mod quantized_mpt; pub mod quantized_mpt;
pub mod quantized_stable_lm; pub mod quantized_stable_lm;
pub mod quantized_t5; pub mod quantized_t5;
pub mod repvgg;
pub mod resnet; pub mod resnet;
pub mod segment_anything; pub mod segment_anything;
pub mod stable_diffusion; pub mod stable_diffusion;

View File

@ -1,363 +0,0 @@
use crate::models::with_tracing::{layer_norm, linear, Embedding, LayerNorm, Linear};
/// Phi model.
/// https://huggingface.co/microsoft/phi-2
/// There is an alternative implementation of the phi model in mixformers.rs.
/// This corresponds to the model update made with the following commit:
/// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use serde::Deserialize;
// https://huggingface.co/microsoft/phi-2/blob/main/configuration_phi.py
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub(crate) vocab_size: usize,
pub(crate) hidden_size: usize,
pub(crate) intermediate_size: usize,
pub(crate) num_hidden_layers: usize,
pub(crate) num_attention_heads: usize,
pub(crate) num_key_value_heads: Option<usize>,
pub(crate) hidden_act: Activation,
pub(crate) max_position_embeddings: usize,
pub(crate) layer_norm_eps: f64,
pub(crate) tie_word_embeddings: bool,
pub(crate) rope_theta: f32,
pub(crate) partial_rotary_factor: f64,
pub(crate) qk_layernorm: bool,
}
impl Config {
fn num_key_value_heads(&self) -> usize {
self.num_key_value_heads.unwrap_or(self.num_attention_heads)
}
fn head_dim(&self) -> usize {
self.hidden_size / self.num_attention_heads
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
dim: usize,
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
fn new(cfg: &Config, dev: &Device) -> Result<Self> {
let dim = (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
.to_dtype(DType::F32)?
.reshape((cfg.max_position_embeddings, 1))?;
let freqs = t.matmul(&inv_freq)?;
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
dim,
sin: emb.sin()?,
cos: emb.cos()?,
})
}
fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
let xs_rot = xs.i((.., .., .., ..self.dim))?;
let xs_pass = xs.i((.., .., .., self.dim..))?;
let xs12 = xs_rot.chunk(2, D::Minus1)?;
let (xs1, xs2) = (&xs12[0], &xs12[1]);
let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
let rotate_half = Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)?;
let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
}
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
fc1: Linear,
fc2: Linear,
act: Activation,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?;
let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
Ok(Self {
fc1,
fc2,
// This does not match the mixformers implementation where Gelu is used rather than
// GeluNew.
act: cfg.hidden_act,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
}
}
#[derive(Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
dense: Linear,
kv_cache: Option<(Tensor, Tensor)>,
q_layernorm: Option<LayerNorm>,
k_layernorm: Option<LayerNorm>,
rotary_emb: RotaryEmbedding,
softmax_scale: f64,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
span: tracing::Span,
}
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Tensor::from_slice(&mask, (size, size), device)
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
impl Attention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads();
let head_dim = cfg.head_dim();
let q_proj = linear(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let dense = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("dense"))?;
// Alternative rope scalings are not supported.
let rotary_emb = RotaryEmbedding::new(cfg, vb.device())?;
let (q_layernorm, k_layernorm) = if cfg.qk_layernorm {
let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?;
let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?;
(Some(q_layernorm), Some(k_layernorm))
} else {
(None, None)
};
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
Ok(Self {
q_proj,
k_proj,
v_proj,
dense,
kv_cache: None,
q_layernorm,
k_layernorm,
rotary_emb,
softmax_scale,
num_heads,
num_kv_heads,
head_dim,
span: tracing::span!(tracing::Level::TRACE, "attention"),
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_heads / self.num_kv_heads;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_size, seq_len, _n_embd) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = match &self.q_layernorm {
None => query_states,
Some(ln) => query_states.apply(ln)?,
};
let key_states = match &self.k_layernorm {
None => key_states,
Some(ln) => key_states.apply(ln)?,
};
let query_states = query_states
.reshape((b_size, seq_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
// Rotary embeddings.
let seqlen_offset = match &self.kv_cache {
None => 0,
Some((prev_k, _)) => prev_k.dim(2)?,
};
let query_states = self
.rotary_emb
.apply_rotary_emb(&query_states, seqlen_offset)?;
let key_states = self
.rotary_emb
.apply_rotary_emb(&key_states, seqlen_offset)?;
// KV cache.
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let k = Tensor::cat(&[prev_k, &key_states], 2)?;
let v = Tensor::cat(&[prev_v, &value_states], 2)?;
(k, v)
}
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
// Repeat kv.
let key_states = self.repeat_kv(key_states)?.contiguous()?;
let value_states = self.repeat_kv(value_states)?.contiguous()?;
let attn_weights = (query_states
.to_dtype(DType::F32)?
.contiguous()?
.matmul(&key_states.to_dtype(DType::F32)?.t()?)?
* self.softmax_scale)?;
let attn_weights = match mask {
None => attn_weights,
Some(mask) => masked_fill(
&attn_weights,
&mask.broadcast_left((b_size, self.num_heads))?,
f32::NEG_INFINITY,
)?,
};
let attn_weights =
candle_nn::ops::softmax_last_dim(&attn_weights)?.to_dtype(value_states.dtype())?;
let attn_output = attn_weights.matmul(&value_states)?;
let attn_output = attn_output
.transpose(1, 2)?
.reshape((b_size, seq_len, ()))?;
attn_output.apply(&self.dense)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
input_layernorm: LayerNorm,
span: tracing::Span,
}
impl DecoderLayer {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm = layer_norm(
cfg.hidden_size,
cfg.layer_norm_eps,
vb.pp("input_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
span: tracing::span!(tracing::Level::TRACE, "block"),
})
}
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let residual = xs;
let xs = xs.apply(&self.input_layernorm)?;
let attn_outputs = self.self_attn.forward(&xs, mask)?;
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
attn_outputs + feed_forward_hidden_states + residual
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache()
}
}
#[derive(Clone)]
pub struct Model {
embed_tokens: Embedding,
layers: Vec<DecoderLayer>,
final_layernorm: LayerNorm,
lm_head: Linear,
span: tracing::Span,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let final_layernorm = layer_norm(
cfg.hidden_size,
cfg.layer_norm_eps,
vb_m.pp("final_layernorm"),
)?;
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_m = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(cfg, vb_m.pp(layer_idx))?;
layers.push(layer)
}
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self {
embed_tokens,
layers,
final_layernorm,
lm_head,
span: tracing::span!(tracing::Level::TRACE, "model"),
})
}
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (_b_size, seq_len) = xs.dims2()?;
let mut xs = xs.apply(&self.embed_tokens)?;
let mask = if seq_len <= 1 {
None
} else {
Some(get_mask(seq_len, xs.device())?)
};
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, mask.as_ref())?;
}
xs.apply(&self.final_layernorm)?
.narrow(1, seq_len - 1, 1)?
.apply(&self.lm_head)?
.squeeze(1)
}
pub fn clear_kv_cache(&mut self) {
self.layers.iter_mut().for_each(|b| b.clear_kv_cache())
}
}

View File

@ -356,7 +356,6 @@ impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>( pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content, ct: gguf_file::Content,
reader: &mut R, reader: &mut R,
device: &Device,
) -> Result<Self> { ) -> Result<Self> {
let cpu = &Device::Cpu; let cpu = &Device::Cpu;
let md_get = |s: &str| match ct.metadata.get(s) { let md_get = |s: &str| match ct.metadata.get(s) {
@ -384,28 +383,21 @@ impl ModelWeights {
.unwrap_or(10000f32); .unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?; let tok_embeddings = tok_embeddings.dequantize(cpu)?;
let norm = RmsNorm::new( let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
ct.tensor(reader, "output_norm.weight", device)?, let output = ct.tensor(reader, "output.weight")?;
rms_norm_eps,
)?;
let output = ct.tensor(reader, "output.weight", device)?;
let mut layers = Vec::with_capacity(block_count); let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count { for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}"); let prefix = format!("blk.{layer_idx}");
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
let attention_wo = let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let mlp_or_moe = if n_expert <= 1 { let mlp_or_moe = if n_expert <= 1 {
let feed_forward_w1 = let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
let feed_forward_w2 = let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let feed_forward_w3 =
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
MlpOrMoe::Mlp(Mlp { MlpOrMoe::Mlp(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
@ -413,15 +405,15 @@ impl ModelWeights {
}) })
} else { } else {
let feed_forward_gate_inp = let feed_forward_gate_inp =
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?; ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?;
let mut experts = Vec::with_capacity(n_expert); let mut experts = Vec::with_capacity(n_expert);
for i in 0..n_expert { for i in 0..n_expert {
let feed_forward_w1 = let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?; ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?;
let feed_forward_w2 = let feed_forward_w2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?; ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?;
let feed_forward_w3 = let feed_forward_w3 =
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?; ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?;
experts.push(Mlp { experts.push(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
@ -434,9 +426,8 @@ impl ModelWeights {
experts, experts,
} }
}; };
let attention_norm = let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?; let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");

View File

@ -311,7 +311,7 @@ impl MixFormerSequentialForCausalLM {
let mut blocks = Vec::new(); let mut blocks = Vec::new();
for i in 0..cfg.n_layer { for i in 0..cfg.n_layer {
let block = ParallelBlock::new(cfg, vb.pp(i + 1))?; let block = ParallelBlock::new(cfg, vb.pp(i + 1))?;
blocks.push(block); blocks.push(block)
} }
let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?; let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?;
Ok(Self { Ok(Self {
@ -332,7 +332,7 @@ impl MixFormerSequentialForCausalLM {
Some(get_mask(seq_len, xs.device())?) Some(get_mask(seq_len, xs.device())?)
}; };
for block in self.blocks.iter_mut() { for block in self.blocks.iter_mut() {
xs = block.forward(&xs, mask.as_ref())?; xs = block.forward(&xs, mask.as_ref())?
} }
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1) xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
} }

View File

@ -1,306 +0,0 @@
//! RepVGG inference implementation
//!
//! See "RepVGG: Making VGG-style ConvNets Great Again" Ding et al. 2021
//! https://arxiv.org/abs/2101.03697
use candle::{Result, Tensor, D};
use candle_nn::{
batch_norm, conv2d_no_bias, linear, BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
};
const CHANNELS_PER_STAGE: [usize; 5] = [64, 64, 128, 256, 512];
#[derive(Clone)]
pub struct Config {
a: f32,
b: f32,
groups: usize,
stages: [usize; 4],
}
impl Config {
pub fn a0() -> Self {
Self {
a: 0.75,
b: 2.5,
groups: 1,
stages: [2, 4, 14, 1],
}
}
pub fn a1() -> Self {
Self {
a: 1.0,
b: 2.5,
groups: 1,
stages: [2, 4, 14, 1],
}
}
pub fn a2() -> Self {
Self {
a: 1.5,
b: 2.75,
groups: 1,
stages: [2, 4, 14, 1],
}
}
pub fn b0() -> Self {
Self {
a: 1.0,
b: 2.5,
groups: 1,
stages: [4, 6, 16, 1],
}
}
pub fn b1() -> Self {
Self {
a: 2.0,
b: 4.0,
groups: 1,
stages: [4, 6, 16, 1],
}
}
pub fn b2() -> Self {
Self {
a: 2.5,
b: 5.0,
groups: 1,
stages: [4, 6, 16, 1],
}
}
pub fn b3() -> Self {
Self {
a: 3.0,
b: 5.0,
groups: 1,
stages: [4, 6, 16, 1],
}
}
pub fn b1g4() -> Self {
Self {
a: 2.0,
b: 4.0,
groups: 4,
stages: [4, 6, 16, 1],
}
}
pub fn b2g4() -> Self {
Self {
a: 2.5,
b: 5.0,
groups: 4,
stages: [4, 6, 16, 1],
}
}
pub fn b3g4() -> Self {
Self {
a: 3.0,
b: 5.0,
groups: 4,
stages: [4, 6, 16, 1],
}
}
}
// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
// based on the _fuse_bn_tensor method in timm
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
let (gamma, beta) = bn.weight_and_bias().unwrap();
let mu = bn.running_mean();
let sigma = (bn.running_var() + bn.eps())?.sqrt();
let gps = (gamma / sigma)?;
let bias = (beta - mu * &gps)?;
let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
Ok((weights, bias))
}
// A RepVGG layer has a different training time and inference time architecture.
// The latter is a simple and efficient equivalent transformation of the former
// realized by a structural reparameterization technique, where 3x3 and 1x1 convolutions
// along with identity branches and batchnorm layers are fused into a single 3x3 convolution.
fn repvgg_layer(
has_identity: bool,
dim: usize,
stride: usize,
in_channels: usize,
out_channels: usize,
groups: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
stride,
groups,
padding: 1,
..Default::default()
};
// read and reparameterize the 1x1 conv and bn into w1 and b1
// based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L543
let conv1x1_bn = batch_norm(dim, 1e-5, vb.pp("conv_1x1.bn"))?;
let conv1x1 = conv2d_no_bias(
in_channels,
out_channels,
1,
conv2d_cfg,
vb.pp("conv_1x1.conv"),
)?;
let (mut w1, b1) = fuse_conv_bn(conv1x1.weight(), conv1x1_bn)?;
// resize to 3x3
w1 = w1.pad_with_zeros(D::Minus1, 1, 1)?;
w1 = w1.pad_with_zeros(D::Minus2, 1, 1)?;
// read and reparameterize the 3x3 conv and bn into w3 and b3
let convkxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.bn"))?;
let conv3x3 = conv2d_no_bias(
in_channels,
out_channels,
3,
conv2d_cfg,
vb.pp("conv_kxk.conv"),
)?;
let (w3, b3) = fuse_conv_bn(conv3x3.weight(), convkxk_bn)?;
let mut w = (w1 + w3)?;
let mut b = (b1 + b3)?;
// read and reparameterize the identity bn into wi and bi
if has_identity {
let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
// create a 3x3 convolution equivalent to the identity branch
let mut weights: Vec<f32> = vec![0.0; conv3x3.weight().elem_count()];
// https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L620
let in_dim = in_channels / groups;
for i in 0..in_channels {
weights[i * in_dim * 3 * 3 + (i % in_dim) * 3 * 3 + 4] = 1.0;
}
let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
w = (w + wi)?;
b = (b + bi)?;
}
// create the 3x3 conv equivalent to the sum of 3x3, 1x1 and identity branches
let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
Ok(Func::new(move |xs| {
let xs = xs.apply(&reparam_conv)?.relu()?;
Ok(xs)
}))
}
// Get the number of output channels per stage taking into account the multipliers
fn output_channels_per_stage(a: f32, b: f32, stage: usize) -> usize {
let channels = CHANNELS_PER_STAGE[stage] as f32;
match stage {
0 => std::cmp::min(64, (channels * a) as usize),
4 => (channels * b) as usize,
_ => (channels * a) as usize,
}
}
// Each stage is made of layers. The first layer always downsamples with stride 2.
// All but the first layer have a residual connection.
// The G4 variants have a groupwise convolution instead of a dense one on odd layers
// counted across stage boundaries, so we keep track of which layer we are in the
// full model.
fn repvgg_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
let nlayers = cfg.stages[idx - 1];
let mut layers = Vec::with_capacity(nlayers);
let prev_layers: usize = cfg.stages[..idx - 1].iter().sum();
let out_channels_prev = output_channels_per_stage(cfg.a, cfg.b, idx - 1);
let out_channels = output_channels_per_stage(cfg.a, cfg.b, idx);
for layer_idx in 0..nlayers {
let (has_identity, stride, in_channels) = if layer_idx == 0 {
(false, 2, out_channels_prev)
} else {
(true, 1, out_channels)
};
let groups = if (prev_layers + layer_idx) % 2 == 1 {
cfg.groups
} else {
1
};
layers.push(repvgg_layer(
has_identity,
out_channels,
stride,
in_channels,
out_channels,
groups,
vb.pp(layer_idx),
)?)
}
Ok(Func::new(move |xs| {
let mut xs = xs.clone();
for layer in layers.iter() {
xs = xs.apply(layer)?
}
Ok(xs)
}))
}
// Build a RepVGG model for a given configuration.
fn repvgg_model(config: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
let cls = match nclasses {
None => None,
Some(nclasses) => {
let outputs = output_channels_per_stage(config.a, config.b, 4);
let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
Some(linear)
}
};
let stem_dim = output_channels_per_stage(config.a, config.b, 0);
let stem = repvgg_layer(false, stem_dim, 2, 3, stem_dim, 1, vb.pp("stem"))?;
let vb = vb.pp("stages");
let stage1 = repvgg_stage(config, 1, vb.pp(0))?;
let stage2 = repvgg_stage(config, 2, vb.pp(1))?;
let stage3 = repvgg_stage(config, 3, vb.pp(2))?;
let stage4 = repvgg_stage(config, 4, vb.pp(3))?;
Ok(Func::new(move |xs| {
let xs = xs
.apply(&stem)?
.apply(&stage1)?
.apply(&stage2)?
.apply(&stage3)?
.apply(&stage4)?
.mean(D::Minus1)?
.mean(D::Minus1)?;
match &cls {
None => Ok(xs),
Some(cls) => xs.apply(cls),
}
}))
}
pub fn repvgg(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
repvgg_model(cfg, Some(nclasses), vb)
}
pub fn repvgg_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
repvgg_model(cfg, None, vb)
}

View File

@ -1,7 +1,12 @@
use super::Config; use super::Config;
use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use crate::models::with_tracing::{linear, linear_no_bias, Linear};
use candle::{Device, IndexOp, Result, Tensor, D}; use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
fn conv1d( fn conv1d(
in_channels: usize, in_channels: usize,

View File

@ -10,33 +10,33 @@ pub struct VarBuilder {
} }
impl VarBuilder { impl VarBuilder {
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P, device: &Device) -> Result<Self> { pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
let mut file = std::fs::File::open(p)?; let mut file = std::fs::File::open(p)?;
let content = candle::quantized::gguf_file::Content::read(&mut file)?; let content = candle::quantized::gguf_file::Content::read(&mut file)?;
let mut data = std::collections::HashMap::new(); let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() { for tensor_name in content.tensor_infos.keys() {
let tensor = content.tensor(&mut file, tensor_name, device)?; let tensor = content.tensor(&mut file, tensor_name)?;
data.insert(tensor_name.to_string(), Arc::new(tensor)); data.insert(tensor_name.to_string(), Arc::new(tensor));
} }
Ok(Self { Ok(Self {
data: Arc::new(data), data: Arc::new(data),
path: Vec::new(), path: Vec::new(),
device: device.clone(), device: Device::Cpu,
}) })
} }
pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result<Self> { pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
let mut cursor = std::io::Cursor::new(buffer); let mut cursor = std::io::Cursor::new(buffer);
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?; let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
let mut data = std::collections::HashMap::new(); let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() { for tensor_name in content.tensor_infos.keys() {
let tensor = content.tensor(&mut cursor, tensor_name, device)?; let tensor = content.tensor(&mut cursor, tensor_name)?;
data.insert(tensor_name.to_string(), Arc::new(tensor)); data.insert(tensor_name.to_string(), Arc::new(tensor));
} }
Ok(Self { Ok(Self {
data: Arc::new(data), data: Arc::new(data),
path: Vec::new(), path: Vec::new(),
device: device.clone(), device: Device::Cpu,
}) })
} }

View File

@ -27,7 +27,7 @@ safetensors = { workspace = true }
# Wasm specific crates. # Wasm specific crates.
console_error_panic_hook = "0.1.7" console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] } getrandom = { version = "0.2", features = ["js"] }
gloo = "0.11" gloo = "0.8"
js-sys = "0.3.64" js-sys = "0.3.64"
wasm-bindgen = "0.2.87" wasm-bindgen = "0.2.87"
serde-wasm-bindgen = "0.6.0" serde-wasm-bindgen = "0.6.0"

View File

@ -61,7 +61,7 @@ impl Model {
let start = Date::now(); let start = Date::now();
let model: SelectedModel = if quantized { let model: SelectedModel = if quantized {
let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights, &device)?; let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?; let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
SelectedModel::Q(model) SelectedModel::Q(model)
} else { } else {

View File

@ -26,7 +26,7 @@ serde_json = { workspace = true }
# Wasm specific crates. # Wasm specific crates.
console_error_panic_hook = "0.1.7" console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] } getrandom = { version = "0.2", features = ["js"] }
gloo = "0.11" gloo = "0.8"
js-sys = "0.3.64" js-sys = "0.3.64"
wasm-bindgen = "0.2.87" wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37" wasm-bindgen-futures = "0.4.37"

View File

@ -41,7 +41,6 @@ impl Model {
) -> Result<Model, JsError> { ) -> Result<Model, JsError> {
console_error_panic_hook::set_once(); console_error_panic_hook::set_once();
console_log!("loading model"); console_log!("loading model");
let device = Device::Cpu;
let name: ModelName = serde_json::from_slice(&config)?; let name: ModelName = serde_json::from_slice(&config)?;
let config: Config = serde_json::from_slice(&config)?; let config: Config = serde_json::from_slice(&config)?;
@ -51,9 +50,8 @@ impl Model {
let start = Date::now(); let start = Date::now();
console_log!("weights len: {:?}", weights.len()); console_log!("weights len: {:?}", weights.len());
let model = if quantized { let model = if quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer( let vb =
&weights, &device, candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?;
)?;
console_log!("weights loaded"); console_log!("weights loaded");
if name._name_or_path == "microsoft/phi-2" { if name._name_or_path == "microsoft/phi-2" {
let model = QMixFormer::new_v2(&config, vb)?; let model = QMixFormer::new_v2(&config, vb)?;

View File

@ -27,7 +27,7 @@ safetensors = { workspace = true }
# Wasm specific crates. # Wasm specific crates.
console_error_panic_hook = "0.1.7" console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] } getrandom = { version = "0.2", features = ["js"] }
gloo = "0.11" gloo = "0.8"
js-sys = "0.3.64" js-sys = "0.3.64"
wasm-bindgen = "0.2.87" wasm-bindgen = "0.2.87"
serde-wasm-bindgen = "0.6.0" serde-wasm-bindgen = "0.6.0"

View File

@ -7,7 +7,6 @@ pub use candle_transformers::models::quantized_t5::{
use candle_wasm_example_t5::console_log; use candle_wasm_example_t5::console_log;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*; use wasm_bindgen::prelude::*;
const DEVICE: Device = Device::Cpu;
#[wasm_bindgen] #[wasm_bindgen]
pub struct ModelEncoder { pub struct ModelEncoder {
@ -32,7 +31,7 @@ impl ModelConditionalGeneration {
) -> Result<ModelConditionalGeneration, JsError> { ) -> Result<ModelConditionalGeneration, JsError> {
console_error_panic_hook::set_once(); console_error_panic_hook::set_once();
console_log!("loading model"); console_log!("loading model");
let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?; let vb = VarBuilder::from_gguf_buffer(&weights)?;
let mut config: Config = serde_json::from_slice(&config)?; let mut config: Config = serde_json::from_slice(&config)?;
let tokenizer = let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
@ -47,7 +46,7 @@ impl ModelConditionalGeneration {
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> { pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
let input: ConditionalGenerationParams = let input: ConditionalGenerationParams =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
let device = &DEVICE; let device = &Device::Cpu;
self.model.clear_kv_cache(); self.model.clear_kv_cache();
let mut output_token_ids = [self.config.pad_token_id as u32].to_vec(); let mut output_token_ids = [self.config.pad_token_id as u32].to_vec();
let prompt = input.prompt; let prompt = input.prompt;
@ -129,7 +128,7 @@ impl ModelEncoder {
) -> Result<ModelEncoder, JsError> { ) -> Result<ModelEncoder, JsError> {
console_error_panic_hook::set_once(); console_error_panic_hook::set_once();
console_log!("loading model"); console_log!("loading model");
let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?; let vb = VarBuilder::from_gguf_buffer(&weights)?;
let mut config: Config = serde_json::from_slice(&config)?; let mut config: Config = serde_json::from_slice(&config)?;
config.use_cache = false; config.use_cache = false;
let tokenizer = let tokenizer =
@ -139,7 +138,7 @@ impl ModelEncoder {
} }
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> { pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
let device = &DEVICE; let device = &Device::Cpu;
let input: DecoderParams = let input: DecoderParams =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;

View File

@ -26,7 +26,7 @@ safetensors = { workspace = true }
# Wasm specific crates. # Wasm specific crates.
getrandom = { version = "0.2", features = ["js"] } getrandom = { version = "0.2", features = ["js"] }
gloo = "0.11" gloo = "0.8"
js-sys = "0.3.64" js-sys = "0.3.64"
wasm-bindgen = "0.2.87" wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37" wasm-bindgen-futures = "0.4.37"

View File

@ -315,7 +315,6 @@ impl Decoder {
let model = if md.quantized { let model = if md.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer( let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
&md.weights, &md.weights,
&device,
)?; )?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?) Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else { } else {

View File

@ -26,7 +26,7 @@ safetensors = { workspace = true }
# Wasm specific crates. # Wasm specific crates.
console_error_panic_hook = "0.1.7" console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] } getrandom = { version = "0.2", features = ["js"] }
gloo = "0.11" gloo = "0.8"
js-sys = "0.3.64" js-sys = "0.3.64"
wasm-bindgen = "0.2.87" wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37" wasm-bindgen-futures = "0.4.37"

View File

@ -40,7 +40,7 @@ fn quantized_matmul_neg() -> Result<()> {
] ]
); );
let qtensor = quantized::QTensor::new(quantized::QStorage::Cpu(Box::new(rhs_t)), (4, 64))?; let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?; let res = matmul.forward(&tensor_lhs)?;
assert_eq!( assert_eq!(