mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Merge branch 'main' into ivarflakstad/metal-reduce-2
This commit is contained in:
7
.github/dependabot.yml
vendored
Normal file
7
.github/dependabot.yml
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "cargo"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 5
|
20
Cargo.toml
20
Cargo.toml
@ -31,10 +31,18 @@ license = "MIT OR Apache-2.0"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core" }
|
||||
candle-datasets = { path = "./candle-datasets" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn" }
|
||||
candle-kernels = { path = "./candle-kernels" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels" }
|
||||
candle-nn = { path = "./candle-nn" }
|
||||
candle-onnx = { path = "./candle-onnx" }
|
||||
candle-transformers = { path = "./candle-transformers" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.9.14", features = ["f16"] }
|
||||
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
|
||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.3.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||
@ -42,20 +50,20 @@ imageproc = { version = "0.23.0", default-features = false }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
|
||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
parquet = { version = "45.0.0" }
|
||||
parquet = { version = "50.0.0" }
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
rayon = "1.7.0"
|
||||
rusttype = { version = "0.9", default-features = false }
|
||||
safetensors = "0.3.1"
|
||||
safetensors = "0.4.1"
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_plain = "1.0.2"
|
||||
serde_json = "1.0.99"
|
||||
thiserror = "1"
|
||||
tokenizers = { version = "0.13.4", default-features = false }
|
||||
tokenizers = { version = "0.15.0", default-features = false }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
|
@ -66,7 +66,7 @@ We also provide a some command line based examples using state of the art models
|
||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets.
|
||||
- [Minimal Mamba](./candle-examples/examples/minimal-mamba/): a minimal
|
||||
- [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal
|
||||
implementation of the Mamba state space model.
|
||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||
better performance than all publicly available 13b models as of 2023-09-28.
|
||||
@ -109,6 +109,9 @@ We also provide a some command line based examples using state of the art models
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
using self-supervision (can be used for imagenet classification, depth
|
||||
evaluation, segmentation).
|
||||
- [VGG](./candle-examples/examples/vgg/),
|
||||
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
generate captions for an image.
|
||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||
@ -204,7 +207,7 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Image to text.
|
||||
- BLIP.
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG.
|
||||
- yolo-v3, yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.3" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
|
||||
candle = { workspace = true }
|
||||
candle-datasets = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
@ -12,8 +12,8 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.3", optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.3", optional = true }
|
||||
candle-kernels = { workspace = true, optional = true }
|
||||
candle-metal-kernels = { workspace = true, optional = true }
|
||||
metal = { workspace = true, optional = true}
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
@ -48,4 +48,3 @@ metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||
[[bench]]
|
||||
name = "bench_main"
|
||||
harness = false
|
||||
|
||||
|
@ -1,4 +1,11 @@
|
||||
mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
criterion_main!(benchmarks::reduce::benches);
|
||||
|
||||
criterion_main!(
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::reduce::benches,
|
||||
benchmarks::where_cond::benches
|
||||
);
|
||||
|
43
candle-core/benches/benchmarks/affine.rs
Normal file
43
candle-core/benches/benchmarks/affine.rs
Normal file
@ -0,0 +1,43 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor) {
|
||||
a.affine(12.34, 56.78).unwrap();
|
||||
}
|
||||
|
||||
fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
||||
|
||||
let flops = b * m * k * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
|
||||
run_affine_benchmark(c, &device, DType::F16, "affine_f16");
|
||||
run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,5 +1,5 @@
|
||||
use crate::benchmarks::{bench_name, device, BenchDevice};
|
||||
use candle_core::{DType, Tensor};
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
@ -7,20 +7,19 @@ fn run(a: &Tensor, b: &Tensor) {
|
||||
a.matmul(&b.t().unwrap()).unwrap();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
fn run_bench(c: &mut Criterion, device: &Device) {
|
||||
let b = 1;
|
||||
let m = 1;
|
||||
let n = 2048;
|
||||
let k = 2048;
|
||||
|
||||
let device = device().unwrap();
|
||||
let dtype = DType::F32;
|
||||
let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
||||
let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap();
|
||||
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
||||
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
|
||||
|
||||
let flops = b * m * n * k;
|
||||
|
||||
let mut group = c.benchmark_group(bench_name("matmul"));
|
||||
let mut group = c.benchmark_group(device.bench_name("matmul"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
@ -35,4 +34,11 @@ fn criterion_benchmark(c: &mut Criterion) {
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_bench(c, &device);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
||||
|
@ -1,10 +1,15 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod reduce;
|
||||
pub(crate) mod where_cond;
|
||||
|
||||
use candle_core::{Device, Result};
|
||||
|
||||
pub(crate) trait BenchDevice {
|
||||
fn sync(&self) -> Result<()>;
|
||||
|
||||
fn bench_name<S: Into<String>>(&self, name: S) -> String;
|
||||
}
|
||||
|
||||
impl BenchDevice for Device {
|
||||
@ -25,32 +30,38 @@ impl BenchDevice for Device {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn device() -> Result<Device> {
|
||||
if cfg!(feature = "metal") {
|
||||
Device::new_metal(0)
|
||||
} else if cfg!(feature = "cuda") {
|
||||
Device::new_cuda(0)
|
||||
} else {
|
||||
Ok(Device::Cpu)
|
||||
fn bench_name<S: Into<String>>(&self, name: S) -> String {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let cpu_type = if cfg!(feature = "accelerate") {
|
||||
"accelerate"
|
||||
} else if cfg!(feature = "mkl") {
|
||||
"mkl"
|
||||
} else {
|
||||
"cpu"
|
||||
};
|
||||
format!("{}_{}", cpu_type, name.into())
|
||||
}
|
||||
Device::Cuda(_) => format!("cuda_{}", name.into()),
|
||||
Device::Metal(_) => format!("metal_{}", name.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn bench_name<S: Into<String>>(name: S) -> String {
|
||||
format!("{}_{}", device_variant(), name.into())
|
||||
struct BenchDeviceHandler {
|
||||
devices: Vec<Device>,
|
||||
}
|
||||
|
||||
const fn device_variant() -> &'static str {
|
||||
if cfg!(feature = "metal") {
|
||||
"metal"
|
||||
} else if cfg!(feature = "cuda") {
|
||||
"cuda"
|
||||
} else if cfg!(feature = "accelerate") {
|
||||
"accelerate"
|
||||
} else if cfg!(feature = "mkl") {
|
||||
"mkl"
|
||||
} else {
|
||||
"cpu"
|
||||
impl BenchDeviceHandler {
|
||||
pub fn new() -> Result<Self> {
|
||||
let mut devices = Vec::new();
|
||||
if cfg!(feature = "metal") {
|
||||
devices.push(Device::new_metal(0)?);
|
||||
} else if cfg!(feature = "cuda") {
|
||||
devices.push(Device::new_cuda(0)?);
|
||||
}
|
||||
devices.push(Device::Cpu);
|
||||
Ok(Self { devices })
|
||||
}
|
||||
}
|
||||
|
63
candle-core/benches/benchmarks/random.rs
Normal file
63
candle-core/benches/benchmarks/random.rs
Normal file
@ -0,0 +1,63 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn rand_uniform(a: &Tensor) {
|
||||
a.rand_like(-1.0, 123.0).unwrap();
|
||||
}
|
||||
|
||||
fn rand_normal(a: &Tensor) {
|
||||
a.randn_like(100.0, 15.0).unwrap();
|
||||
}
|
||||
|
||||
fn run_random_bench(c: &mut Criterion, device: &Device) {
|
||||
let b = 1;
|
||||
|
||||
let rows = 2048;
|
||||
let cols = 2048;
|
||||
|
||||
let dtype = DType::F32;
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||
|
||||
let flops = b * rows * cols * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |benches| {
|
||||
benches.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
rand_uniform(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("random_normal"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |benches| {
|
||||
benches.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
rand_normal(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_random_bench(c, &device);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,4 +1,4 @@
|
||||
use crate::benchmarks::{bench_name, device, BenchDevice};
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Storage, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use half::{bf16, f16};
|
||||
@ -12,6 +12,7 @@ fn run_arg_min(a: &Tensor) {
|
||||
a.argmin(2).unwrap();
|
||||
}
|
||||
|
||||
// TODO: Remove before merging. Softmax impls live in candle-nn, so this is a temporary workaround.
|
||||
fn softmax(a: &Tensor) -> candle_core::Result<()> {
|
||||
use candle_core::{backend::BackendStorage, DType};
|
||||
let (storage, layout) = a.storage_and_layout();
|
||||
@ -53,20 +54,21 @@ fn softmax(a: &Tensor) -> candle_core::Result<()> {
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let device = device().unwrap();
|
||||
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
let (lo, up) = (-1000.0f32, 1000.0f32);
|
||||
run_softmax(c, &device, (lo, up));
|
||||
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||
for device in handler.devices {
|
||||
run_softmax(c, &device, (lo, up));
|
||||
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||
|
||||
run_reduce(c, &device, (lo, up));
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||
run_reduce(c, &device, (lo, up));
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up));
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||
run_arg_reduce(c, &device, (lo, up));
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||
}
|
||||
}
|
||||
|
||||
fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
|
||||
@ -75,8 +77,8 @@ fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (
|
||||
}
|
||||
|
||||
let b = 1;
|
||||
let m = 2048;
|
||||
let k = 2048;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
@ -88,7 +90,7 @@ fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (
|
||||
_ => "softmax",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(bench_name(name));
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
@ -105,8 +107,8 @@ fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (
|
||||
|
||||
fn run_reduce<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
|
||||
let b = 1;
|
||||
let m = 2048;
|
||||
let k = 2048;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
|
||||
|
||||
@ -119,7 +121,7 @@ fn run_reduce<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (l
|
||||
_ => "reduce",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(bench_name(name));
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
@ -140,8 +142,8 @@ fn run_arg_reduce<T: candle_core::FloatDType>(
|
||||
(lo, up): (T, T),
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 2048;
|
||||
let k = 2048;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
|
||||
|
||||
@ -154,7 +156,7 @@ fn run_arg_reduce<T: candle_core::FloatDType>(
|
||||
_ => "reduce",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(bench_name(name));
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
|
64
candle-core/benches/benchmarks/where_cond.rs
Normal file
64
candle-core/benches/benchmarks/where_cond.rs
Normal file
@ -0,0 +1,64 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
|
||||
a.where_cond(b, c).unwrap();
|
||||
}
|
||||
|
||||
const fn create_cond_arr<const N: usize>() -> [u8; N] {
|
||||
let mut arr = [0u8; N];
|
||||
let mut i = 0;
|
||||
while i < N {
|
||||
arr[i] = (i % 2) as u8;
|
||||
i += 1;
|
||||
}
|
||||
arr
|
||||
}
|
||||
|
||||
const B: usize = 1;
|
||||
const M: usize = 1024;
|
||||
const K: usize = 1024;
|
||||
const SIZE: usize = B * M * K;
|
||||
|
||||
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
||||
|
||||
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
|
||||
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
|
||||
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
|
||||
|
||||
let elements = B * M * K;
|
||||
// E.g. 2 f32 tensors + 1 u8 tensor
|
||||
let flops = (2 * elements * dtype.size_in_bytes()) + elements;
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(
|
||||
black_box(&tensor),
|
||||
black_box(&on_true),
|
||||
black_box(&on_false),
|
||||
);
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let device = BenchDeviceHandler::new().unwrap();
|
||||
for d in device.devices {
|
||||
run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
|
||||
run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
|
||||
run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,5 +1,5 @@
|
||||
use candle_core::quantized::{gguf_file, k_quants, QTensor};
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
|
||||
use candle_core::{Device, Result};
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
use rayon::prelude::*;
|
||||
|
||||
@ -11,12 +11,7 @@ enum QuantizationMode {
|
||||
}
|
||||
|
||||
impl QuantizationMode {
|
||||
fn quantize(
|
||||
&self,
|
||||
name: &str,
|
||||
tensor: QTensor,
|
||||
default: fn(&Tensor) -> Result<QTensor>,
|
||||
) -> Result<QTensor> {
|
||||
fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result<QTensor> {
|
||||
match self {
|
||||
Self::Llama => {
|
||||
// Same behavior as the llama.cpp quantization.
|
||||
@ -24,9 +19,9 @@ impl QuantizationMode {
|
||||
if should_quantize {
|
||||
let tensor = tensor.dequantize(&Device::Cpu)?;
|
||||
if name == "output.weight" {
|
||||
QTensor::quantize::<k_quants::BlockQ6K>(&tensor)
|
||||
QTensor::quantize(&tensor, GgmlDType::Q6K)
|
||||
} else {
|
||||
default(&tensor)
|
||||
QTensor::quantize(&tensor, dtype)
|
||||
}
|
||||
} else {
|
||||
Ok(tensor)
|
||||
@ -60,6 +55,27 @@ enum Quantization {
|
||||
F32,
|
||||
}
|
||||
|
||||
impl Quantization {
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
match self {
|
||||
Quantization::Q4_0 => GgmlDType::Q4_0,
|
||||
Quantization::Q4_1 => GgmlDType::Q4_1,
|
||||
Quantization::Q5_0 => GgmlDType::Q5_0,
|
||||
Quantization::Q5_1 => GgmlDType::Q5_1,
|
||||
Quantization::Q8_0 => GgmlDType::Q8_0,
|
||||
Quantization::Q8_1 => GgmlDType::Q8_1,
|
||||
Quantization::Q2k => GgmlDType::Q2K,
|
||||
Quantization::Q3k => GgmlDType::Q3K,
|
||||
Quantization::Q4k => GgmlDType::Q4K,
|
||||
Quantization::Q5k => GgmlDType::Q5K,
|
||||
Quantization::Q6k => GgmlDType::Q6K,
|
||||
Quantization::Q8k => GgmlDType::Q8K,
|
||||
Quantization::F16 => GgmlDType::F16,
|
||||
Quantization::F32 => GgmlDType::F32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
enum Format {
|
||||
Safetensors,
|
||||
@ -102,7 +118,7 @@ enum Command {
|
||||
},
|
||||
|
||||
Quantize {
|
||||
/// The input file, in gguf format.
|
||||
/// The input file(s), in safetensors format.
|
||||
in_file: Vec<std::path::PathBuf>,
|
||||
|
||||
/// The output file, in gguf format.
|
||||
@ -117,6 +133,15 @@ enum Command {
|
||||
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
|
||||
mode: QuantizationMode,
|
||||
},
|
||||
|
||||
Dequantize {
|
||||
/// The input file, in gguf format.
|
||||
in_file: std::path::PathBuf,
|
||||
|
||||
/// The output file, in safetensors format.
|
||||
#[arg(long)]
|
||||
out_file: std::path::PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
@ -125,7 +150,12 @@ struct Args {
|
||||
command: Command,
|
||||
}
|
||||
|
||||
fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
|
||||
fn run_ls(
|
||||
file: &std::path::PathBuf,
|
||||
format: Option<Format>,
|
||||
verbose: bool,
|
||||
device: &Device,
|
||||
) -> Result<()> {
|
||||
let format = match format {
|
||||
Some(format) => format,
|
||||
None => match Format::infer(file) {
|
||||
@ -191,7 +221,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
||||
}
|
||||
Format::Ggml => {
|
||||
let mut file = std::fs::File::open(file)?;
|
||||
let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
|
||||
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
|
||||
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, qtensor) in tensors.iter() {
|
||||
@ -232,37 +262,8 @@ fn run_quantize_safetensors(
|
||||
}
|
||||
println!("tensors: {}", tensors.len());
|
||||
|
||||
let quantize_fn = match q {
|
||||
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||
Quantization::F32 => QTensor::quantize::<f32>,
|
||||
};
|
||||
let block_size = match q {
|
||||
Quantization::Q4_0 => k_quants::QK4_0,
|
||||
Quantization::Q4_1 => k_quants::QK4_1,
|
||||
Quantization::Q5_0 => k_quants::QK5_0,
|
||||
Quantization::Q5_1 => k_quants::QK5_1,
|
||||
Quantization::Q8_0 => k_quants::QK8_0,
|
||||
Quantization::Q8_1 => k_quants::QK8_1,
|
||||
Quantization::Q2k
|
||||
| Quantization::Q3k
|
||||
| Quantization::Q4k
|
||||
| Quantization::Q5k
|
||||
| Quantization::Q6k
|
||||
| Quantization::Q8k => k_quants::QK_K,
|
||||
Quantization::F16 | Quantization::F32 => 1,
|
||||
};
|
||||
let dtype = q.dtype();
|
||||
let block_size = dtype.block_size();
|
||||
|
||||
let qtensors = tensors
|
||||
.into_par_iter()
|
||||
@ -270,9 +271,9 @@ fn run_quantize_safetensors(
|
||||
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
|
||||
println!(" quantizing {name} {tensor:?} {should_quantize}");
|
||||
let tensor = if should_quantize {
|
||||
quantize_fn(&tensor)?
|
||||
QTensor::quantize(&tensor, dtype)?
|
||||
} else {
|
||||
QTensor::quantize::<f32>(&tensor)?
|
||||
QTensor::quantize(&tensor, GgmlDType::F32)?
|
||||
};
|
||||
Ok((name, tensor))
|
||||
})
|
||||
@ -285,11 +286,29 @@ fn run_quantize_safetensors(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_dequantize(
|
||||
in_file: std::path::PathBuf,
|
||||
out_file: std::path::PathBuf,
|
||||
device: &Device,
|
||||
) -> Result<()> {
|
||||
let mut in_file = std::fs::File::open(in_file)?;
|
||||
let content = gguf_file::Content::read(&mut in_file)?;
|
||||
let mut tensors = std::collections::HashMap::new();
|
||||
for (tensor_name, _) in content.tensor_infos.iter() {
|
||||
let tensor = content.tensor(&mut in_file, tensor_name, device)?;
|
||||
let tensor = tensor.dequantize(device)?;
|
||||
tensors.insert(tensor_name.to_string(), tensor);
|
||||
}
|
||||
candle_core::safetensors::save(&tensors, out_file)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize(
|
||||
in_files: &[std::path::PathBuf],
|
||||
out_file: std::path::PathBuf,
|
||||
q: Quantization,
|
||||
qmode: QuantizationMode,
|
||||
device: &Device,
|
||||
) -> Result<()> {
|
||||
if in_files.is_empty() {
|
||||
candle_core::bail!("no specified input files")
|
||||
@ -315,31 +334,15 @@ fn run_quantize(
|
||||
let content = gguf_file::Content::read(&mut in_)?;
|
||||
println!("tensors: {}", content.tensor_infos.len());
|
||||
|
||||
let quantize_fn = match q {
|
||||
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
|
||||
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
|
||||
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
|
||||
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
|
||||
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
|
||||
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
|
||||
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
|
||||
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
|
||||
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
|
||||
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
|
||||
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
|
||||
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
|
||||
Quantization::F16 => QTensor::quantize::<half::f16>,
|
||||
Quantization::F32 => QTensor::quantize::<f32>,
|
||||
};
|
||||
|
||||
let dtype = q.dtype();
|
||||
let qtensors = content
|
||||
.tensor_infos
|
||||
.par_iter()
|
||||
.map(|(name, _)| {
|
||||
println!(" quantizing {name}");
|
||||
let mut in_file = std::fs::File::open(&in_files[0])?;
|
||||
let tensor = content.tensor(&mut in_file, name)?;
|
||||
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
|
||||
let tensor = content.tensor(&mut in_file, name, device)?;
|
||||
let tensor = qmode.quantize(name, tensor, dtype)?;
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
@ -359,6 +362,7 @@ fn run_quantize(
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = Device::Cpu;
|
||||
match args.command {
|
||||
Command::Ls {
|
||||
files,
|
||||
@ -370,7 +374,7 @@ fn main() -> anyhow::Result<()> {
|
||||
if multiple_files {
|
||||
println!("--- {file:?} ---");
|
||||
}
|
||||
run_ls(file, format.clone(), verbose)?
|
||||
run_ls(file, format.clone(), verbose, &device)?
|
||||
}
|
||||
}
|
||||
Command::Quantize {
|
||||
@ -378,7 +382,8 @@ fn main() -> anyhow::Result<()> {
|
||||
out_file,
|
||||
quantization,
|
||||
mode,
|
||||
} => run_quantize(&in_file, out_file, quantization, mode)?,
|
||||
} => run_quantize(&in_file, out_file, quantization, mode, &device)?,
|
||||
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -72,7 +72,7 @@ pub mod utils;
|
||||
mod variable;
|
||||
|
||||
pub use cpu_backend::CpuStorage;
|
||||
pub use device::{Device, DeviceLocation};
|
||||
pub use device::{Device, DeviceLocation, NdArray};
|
||||
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use indexer::IndexOp;
|
||||
|
@ -7,8 +7,9 @@ use candle_metal_kernels::Kernels;
|
||||
use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, RwLock, TryLockError};
|
||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||
|
||||
/// Simple way to catch lock error without
|
||||
/// depending on T
|
||||
@ -84,13 +85,8 @@ pub struct MetalDevice {
|
||||
command_buffer_index: Arc<RwLock<usize>>,
|
||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||
compute_per_buffer: usize,
|
||||
/// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the
|
||||
/// execution order to be linear.
|
||||
/// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
|
||||
/// compute graph.
|
||||
fence: metal::Fence,
|
||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||
/// Heavily used by [`candle_metal_kernels`], both fences need to match
|
||||
/// Heavily used by [`candle_metal_kernels`]
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
/// Simple allocator struct.
|
||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||
@ -106,6 +102,8 @@ pub struct MetalDevice {
|
||||
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
|
||||
/// (strong_count = 1).
|
||||
buffers: AllocatedBuffers,
|
||||
/// Seed for random number generation.
|
||||
seed: Arc<Mutex<Buffer>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
@ -221,10 +219,8 @@ impl MetalDevice {
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("with_data");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.wait_for_fence(&self.fence);
|
||||
blit.set_label("with_data_blit");
|
||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||
blit.update_fence(&self.fence);
|
||||
blit.end_encoding();
|
||||
|
||||
// This is necessary, for mmaped safetensors
|
||||
@ -232,12 +228,33 @@ impl MetalDevice {
|
||||
// The slice might not live long enough for metal
|
||||
// To actually fill the GPU buffer.
|
||||
// Putting this wait forces the GPU buffer to be filled
|
||||
// with the actual data allowing the CPU storage todo
|
||||
// with the actual data allowing the CPU storage to do
|
||||
// deallocate properly.
|
||||
self.wait_until_completed()?;
|
||||
Ok(real)
|
||||
}
|
||||
|
||||
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
|
||||
let buffer = self.allocate_buffer(
|
||||
size_in_bytes as NSUInteger,
|
||||
MTLResourceOptions::StorageModePrivate,
|
||||
"allocate_zeros",
|
||||
)?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("zeros");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.fill_buffer(
|
||||
&buffer,
|
||||
metal::NSRange {
|
||||
location: 0,
|
||||
length: buffer.length(),
|
||||
},
|
||||
0,
|
||||
);
|
||||
blit.end_encoding();
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
/// The critical allocator algorithm
|
||||
fn allocate_buffer(
|
||||
&self,
|
||||
@ -308,35 +325,14 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
let length = self.buffer.length() as usize;
|
||||
let size = self.dtype.size_in_bytes();
|
||||
if length % size != 0 {
|
||||
crate::bail!(
|
||||
"The Metal buffer length is not aligned with dtype {:?}",
|
||||
self.dtype
|
||||
);
|
||||
}
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||
{
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
blit.wait_for_fence(&self.device.fence);
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
blit.update_fence(&self.device.fence);
|
||||
blit.end_encoding();
|
||||
}
|
||||
self.device.wait_until_completed()?;
|
||||
|
||||
match self.dtype {
|
||||
DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))),
|
||||
DType::U32 => Ok(CpuStorage::U32(read_to_vec(&buffer, length / size))),
|
||||
DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))),
|
||||
DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))),
|
||||
DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))),
|
||||
DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))),
|
||||
DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))),
|
||||
DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)),
|
||||
DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)),
|
||||
DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)),
|
||||
DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)),
|
||||
DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)),
|
||||
DType::F32 => Ok(CpuStorage::F32(self.to_cpu()?)),
|
||||
DType::F64 => Ok(CpuStorage::F64(self.to_cpu()?)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -353,6 +349,7 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_f32",
|
||||
DType::F16 => "affine_f16",
|
||||
DType::BF16 => "affine_bf16",
|
||||
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_affine(
|
||||
@ -371,6 +368,7 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_f32_strided",
|
||||
DType::F16 => "affine_f16_strided",
|
||||
DType::BF16 => "affine_bf16_strided",
|
||||
dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_affine_strided(
|
||||
@ -647,14 +645,26 @@ impl BackendStorage for MetalStorage {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
(DType::U8, DType::F32) => "cast_u8_f32",
|
||||
(DType::U8, DType::I64) => "cast_u8_i64",
|
||||
(DType::U8, DType::BF16) => "cast_u8_bf16",
|
||||
|
||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||
(DType::I64, DType::F32) => "cast_i64_f32",
|
||||
(DType::F32, DType::BF16) => "cast_f32_bf16",
|
||||
|
||||
(DType::I64, DType::F32) => "cast_i64_f32",
|
||||
|
||||
(DType::F16, DType::BF16) => "cast_f16_bf16",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||
|
||||
(DType::BF16, DType::U8) => "cast_bf16_u8",
|
||||
(DType::BF16, DType::U32) => "cast_bf16_u32",
|
||||
(DType::BF16, DType::F16) => "cast_bf16_f16",
|
||||
(DType::BF16, DType::F32) => "cast_bf16_f32",
|
||||
|
||||
(left, right) => {
|
||||
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
|
||||
}
|
||||
@ -732,6 +742,7 @@ impl BackendStorage for MetalStorage {
|
||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||
("urecip", DType::F32) => contiguous::recip::FLOAT,
|
||||
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
||||
("urelu", DType::F32) => contiguous::relu::FLOAT,
|
||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||
("usin", DType::F16) => contiguous::sin::HALF,
|
||||
("usqr", DType::F16) => contiguous::sqr::HALF,
|
||||
@ -748,6 +759,7 @@ impl BackendStorage for MetalStorage {
|
||||
("uround", DType::F16) => contiguous::round::HALF,
|
||||
("urecip", DType::F16) => contiguous::recip::HALF,
|
||||
("utanh", DType::F16) => contiguous::tanh::HALF,
|
||||
("urelu", DType::F16) => contiguous::relu::HALF,
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
@ -778,6 +790,7 @@ impl BackendStorage for MetalStorage {
|
||||
("uabs", DType::F32) => strided::abs::FLOAT,
|
||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
("urelu", DType::F32) => strided::relu::FLOAT,
|
||||
("uround", DType::F32) => strided::round::FLOAT,
|
||||
("ucos", DType::F16) => strided::cos::HALF,
|
||||
("usin", DType::F16) => strided::sin::HALF,
|
||||
@ -792,6 +805,7 @@ impl BackendStorage for MetalStorage {
|
||||
("uabs", DType::F16) => strided::abs::HALF,
|
||||
("uceil", DType::F16) => strided::ceil::HALF,
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
("urelu", DType::F16) => strided::relu::HALF,
|
||||
("uround", DType::F16) => strided::round::HALF,
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||
@ -847,6 +861,7 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
let name = match (self.dtype, t.dtype()) {
|
||||
(DType::U8, DType::F32) => "where_u8_f32",
|
||||
(DType::U8, DType::BF16) => "where_u8_bf16",
|
||||
(DType::U8, DType::F16) => "where_u8_f16",
|
||||
(DType::U8, DType::I64) => "where_u8_i64",
|
||||
(DType::U8, DType::U32) => "where_u8_u32",
|
||||
@ -1184,8 +1199,12 @@ impl BackendStorage for MetalStorage {
|
||||
let device = self.device();
|
||||
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U8, DType::BF16) => "is_u8_bf16",
|
||||
|
||||
(DType::U32, DType::F32) => "is_u32_f32",
|
||||
(DType::U32, DType::F16) => "is_u32_f16",
|
||||
(DType::U32, DType::BF16) => "is_u32_bf16",
|
||||
|
||||
(left, right) => {
|
||||
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
|
||||
}
|
||||
@ -1298,7 +1317,7 @@ impl BackendStorage for MetalStorage {
|
||||
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length);
|
||||
blit.end_encoding();
|
||||
} else {
|
||||
let src_shape = src_l.shape();
|
||||
@ -1375,6 +1394,7 @@ impl MetalStorage {
|
||||
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
|
||||
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
|
||||
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
|
||||
|
||||
("add", DType::F16) => (contiguous::add::HALF, self.dtype),
|
||||
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
|
||||
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
|
||||
@ -1385,6 +1405,18 @@ impl MetalStorage {
|
||||
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
|
||||
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
|
||||
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
|
||||
|
||||
("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype),
|
||||
("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype),
|
||||
("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype),
|
||||
("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype),
|
||||
("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8),
|
||||
("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8),
|
||||
("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8),
|
||||
("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8),
|
||||
("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8),
|
||||
("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8),
|
||||
|
||||
("add", DType::I64) => (contiguous::add::I64, self.dtype),
|
||||
("sub", DType::I64) => (contiguous::sub::I64, self.dtype),
|
||||
("mul", DType::I64) => (contiguous::mul::I64, self.dtype),
|
||||
@ -1395,6 +1427,7 @@ impl MetalStorage {
|
||||
("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
|
||||
("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
|
||||
("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
|
||||
|
||||
("add", DType::U32) => (contiguous::add::U32, self.dtype),
|
||||
("sub", DType::U32) => (contiguous::sub::U32, self.dtype),
|
||||
("mul", DType::U32) => (contiguous::mul::U32, self.dtype),
|
||||
@ -1405,6 +1438,7 @@ impl MetalStorage {
|
||||
("lt", DType::U32) => (contiguous::lt::U32, DType::U8),
|
||||
("ge", DType::U32) => (contiguous::ge::U32, DType::U8),
|
||||
("gt", DType::U32) => (contiguous::gt::U32, DType::U8),
|
||||
|
||||
("add", DType::U8) => (contiguous::add::U8, self.dtype),
|
||||
("sub", DType::U8) => (contiguous::sub::U8, self.dtype),
|
||||
("mul", DType::U8) => (contiguous::mul::U8, self.dtype),
|
||||
@ -1415,6 +1449,7 @@ impl MetalStorage {
|
||||
("lt", DType::U8) => (contiguous::lt::U8, DType::U8),
|
||||
("ge", DType::U8) => (contiguous::ge::U8, DType::U8),
|
||||
("gt", DType::U8) => (contiguous::gt::U8, DType::U8),
|
||||
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
|
||||
}
|
||||
@ -1448,6 +1483,7 @@ impl MetalStorage {
|
||||
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
|
||||
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
|
||||
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
|
||||
|
||||
("badd", DType::F16) => (strided::add::HALF, self.dtype),
|
||||
("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
|
||||
("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
|
||||
@ -1460,6 +1496,20 @@ impl MetalStorage {
|
||||
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
|
||||
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
|
||||
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
|
||||
|
||||
("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype),
|
||||
("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype),
|
||||
("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype),
|
||||
("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype),
|
||||
("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype),
|
||||
("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype),
|
||||
("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8),
|
||||
("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8),
|
||||
("le", DType::BF16) => (strided::le::BFLOAT, DType::U8),
|
||||
("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8),
|
||||
("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8),
|
||||
("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8),
|
||||
|
||||
("badd", DType::I64) => (strided::add::I64, self.dtype),
|
||||
("bsub", DType::I64) => (strided::sub::I64, self.dtype),
|
||||
("bmul", DType::I64) => (strided::mul::I64, self.dtype),
|
||||
@ -1472,6 +1522,7 @@ impl MetalStorage {
|
||||
("lt", DType::I64) => (strided::lt::I64, DType::U8),
|
||||
("ge", DType::I64) => (strided::ge::I64, DType::U8),
|
||||
("gt", DType::I64) => (strided::gt::I64, DType::U8),
|
||||
|
||||
("badd", DType::U32) => (strided::add::U32, self.dtype),
|
||||
("bsub", DType::U32) => (strided::sub::U32, self.dtype),
|
||||
("bmul", DType::U32) => (strided::mul::U32, self.dtype),
|
||||
@ -1484,6 +1535,7 @@ impl MetalStorage {
|
||||
("lt", DType::U32) => (strided::lt::U32, DType::U8),
|
||||
("ge", DType::U32) => (strided::ge::U32, DType::U8),
|
||||
("gt", DType::U32) => (strided::gt::U32, DType::U8),
|
||||
|
||||
("badd", DType::U8) => (strided::add::U8, self.dtype),
|
||||
("bsub", DType::U8) => (strided::sub::U8, self.dtype),
|
||||
("bmul", DType::U8) => (strided::mul::U8, self.dtype),
|
||||
@ -1496,6 +1548,7 @@ impl MetalStorage {
|
||||
("lt", DType::U8) => (strided::lt::U8, DType::U8),
|
||||
("ge", DType::U8) => (strided::ge::U8, DType::U8),
|
||||
("gt", DType::U8) => (strided::gt::U8, DType::U8),
|
||||
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
|
||||
}
|
||||
@ -1521,6 +1574,28 @@ impl MetalStorage {
|
||||
command_buffer.set_label("binary");
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
pub(crate) fn to_cpu<T: Clone>(&self) -> Result<Vec<T>> {
|
||||
let length = self.buffer.length() as usize;
|
||||
let size = self.dtype.size_in_bytes();
|
||||
if length % size != 0 {
|
||||
crate::bail!(
|
||||
"The Metal buffer length is not aligned with dtype {:?}",
|
||||
self.dtype
|
||||
);
|
||||
}
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||
{
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
blit.end_encoding();
|
||||
}
|
||||
self.device.wait_until_completed()?;
|
||||
Ok(read_to_vec(&buffer, length / size))
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for MetalDevice {
|
||||
@ -1533,29 +1608,29 @@ impl BackendDevice for MetalDevice {
|
||||
command_buffer.enqueue();
|
||||
let command_buffer = Arc::new(RwLock::new(command_buffer));
|
||||
let command_buffer_index = Arc::new(RwLock::new(0));
|
||||
let fence = device.new_fence();
|
||||
let kernels = Arc::new(Kernels::new(fence.clone()));
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
||||
Ok(val) => val.parse()?,
|
||||
_ => 20,
|
||||
_ => 10,
|
||||
};
|
||||
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
||||
[299792458].as_ptr() as *const c_void,
|
||||
4,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
)));
|
||||
Ok(Self {
|
||||
device,
|
||||
fence,
|
||||
command_queue,
|
||||
command_buffer,
|
||||
command_buffer_index,
|
||||
compute_per_buffer,
|
||||
buffers,
|
||||
kernels,
|
||||
seed,
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||
crate::bail!("Metal set_seed not implemented")
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Metal {
|
||||
gpu_id: self.registry_id() as usize,
|
||||
@ -1567,21 +1642,8 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("zeros");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.wait_for_fence(&self.fence);
|
||||
blit.fill_buffer(
|
||||
&buffer,
|
||||
metal::NSRange {
|
||||
location: 0,
|
||||
length: buffer.length(),
|
||||
},
|
||||
0,
|
||||
);
|
||||
blit.update_fence(&self.fence);
|
||||
blit.end_encoding();
|
||||
let size = shape.elem_count() * dtype.size_in_bytes();
|
||||
let buffer = self.allocate_zeros(size)?;
|
||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
@ -1608,12 +1670,31 @@ impl BackendDevice for MetalDevice {
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
min: f64,
|
||||
max: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
let name = match dtype {
|
||||
DType::F32 => "rand_uniform_f32",
|
||||
DType::F16 => "rand_uniform_f16",
|
||||
DType::BF16 => "rand_uniform_bf16",
|
||||
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
|
||||
};
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_uniform")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
candle_metal_kernels::call_random_uniform(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
min as f32,
|
||||
max as f32,
|
||||
shape.elem_count(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
fn rand_normal(
|
||||
@ -1623,9 +1704,43 @@ impl BackendDevice for MetalDevice {
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
let name = match dtype {
|
||||
DType::F32 => "rand_normal_f32",
|
||||
DType::F16 => "rand_normal_f16",
|
||||
DType::BF16 => "rand_normal_bf16",
|
||||
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
|
||||
};
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
candle_metal_kernels::call_random_normal(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
mean as f32,
|
||||
stddev as f32,
|
||||
shape.elem_count(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
let seed: u32 = seed.try_into().map_err(|_| {
|
||||
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())
|
||||
})?;
|
||||
|
||||
let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?;
|
||||
let contents = seed_buffer.contents();
|
||||
unsafe {
|
||||
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4);
|
||||
}
|
||||
seed_buffer.did_modify_range(metal::NSRange::new(0, 4));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -703,6 +703,7 @@ impl PthTensors {
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||
use std::io::Read;
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
None => return Ok(None),
|
||||
Some(tensor_info) => tensor_info,
|
||||
@ -712,14 +713,21 @@ impl PthTensors {
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||
|
||||
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
|
||||
// For now only support the basic case.
|
||||
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
|
||||
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
||||
// case.
|
||||
if !tensor_info.layout.is_contiguous() {
|
||||
crate::bail!(
|
||||
"cannot retrieve non-contiguous tensors {:?}",
|
||||
tensor_info.layout
|
||||
)
|
||||
}
|
||||
let start_offset = tensor_info.layout.start_offset();
|
||||
if start_offset > 0 {
|
||||
std::io::copy(
|
||||
&mut reader.by_ref().take(start_offset as u64),
|
||||
&mut std::io::sink(),
|
||||
)?;
|
||||
}
|
||||
let tensor = Tensor::from_reader(
|
||||
tensor_info.layout.shape().clone(),
|
||||
tensor_info.dtype,
|
||||
|
@ -1,7 +1,9 @@
|
||||
//! Support for the GGML file format.
|
||||
|
||||
use super::{k_quants, GgmlDType};
|
||||
use crate::Result;
|
||||
#[cfg(feature = "metal")]
|
||||
use super::metal::load_quantized_metal;
|
||||
use super::{k_quants, GgmlDType, QStorage};
|
||||
use crate::{Device, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -121,11 +123,22 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
||||
raw_data: &[u8],
|
||||
size_in_bytes: usize,
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<super::QTensor> {
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
super::QTensor::new(data.to_vec(), dims)
|
||||
let data: QStorage = match device {
|
||||
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
|
||||
#[cfg(feature = "metal")]
|
||||
Device::Metal(metal) => load_quantized_metal(metal, data)?,
|
||||
#[cfg(not(feature = "metal"))]
|
||||
Device::Metal(_metal) => {
|
||||
crate::bail!("Metal backend requires `metal` feature")
|
||||
}
|
||||
device => unimplemented!("Implement quantized tensor for device {device:?}"),
|
||||
};
|
||||
super::QTensor::new(data, dims)
|
||||
}
|
||||
|
||||
/// Creates a [Tensor] from a raw GGML tensor.
|
||||
@ -133,29 +146,50 @@ pub fn qtensor_from_ggml(
|
||||
ggml_dtype: GgmlDType,
|
||||
raw_data: &[u8],
|
||||
dims: Vec<usize>,
|
||||
device: &Device,
|
||||
) -> Result<super::QTensor> {
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let blck_size = ggml_dtype.blck_size();
|
||||
if tensor_elems % blck_size != 0 {
|
||||
let block_size = ggml_dtype.block_size();
|
||||
if tensor_elems % block_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
|
||||
let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size();
|
||||
|
||||
match ggml_dtype {
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
|
||||
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
|
||||
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
|
||||
GgmlDType::Q4_0 => {
|
||||
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
|
||||
}
|
||||
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
|
||||
}
|
||||
}
|
||||
@ -163,6 +197,7 @@ pub fn qtensor_from_ggml(
|
||||
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
magic: VersionedMagic,
|
||||
device: &Device,
|
||||
) -> Result<(String, super::QTensor)> {
|
||||
let n_dims = reader.read_u32::<LittleEndian>()?;
|
||||
let name_len = reader.read_u32::<LittleEndian>()?;
|
||||
@ -183,11 +218,11 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
|
||||
}
|
||||
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
|
||||
let tensor_elems = dims.iter().product::<usize>();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
|
||||
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size();
|
||||
// TODO: Mmap version to avoid copying the data around?
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
|
||||
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
|
||||
Ok(tensor) => Ok((name, tensor)),
|
||||
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
|
||||
}
|
||||
@ -201,7 +236,10 @@ pub struct Content {
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
|
||||
pub fn read<R: std::io::Seek + std::io::Read>(
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Content> {
|
||||
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
|
||||
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
|
||||
reader.seek(std::io::SeekFrom::Start(0))?;
|
||||
@ -211,7 +249,7 @@ impl Content {
|
||||
let mut tensors = HashMap::new();
|
||||
|
||||
while reader.stream_position()? != last_position {
|
||||
let (name, tensor) = read_one_tensor(reader, magic)?;
|
||||
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||
tensors.insert(name, tensor);
|
||||
}
|
||||
Ok(Self {
|
||||
|
@ -3,7 +3,7 @@
|
||||
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
||||
|
||||
use super::{GgmlDType, QTensor};
|
||||
use crate::Result;
|
||||
use crate::{Device, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -59,19 +59,25 @@ impl TensorInfo {
|
||||
&self,
|
||||
reader: &mut R,
|
||||
tensor_data_offset: u64,
|
||||
device: &Device,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_elems = self.shape.elem_count();
|
||||
let blck_size = self.ggml_dtype.blck_size();
|
||||
if tensor_elems % blck_size != 0 {
|
||||
let block_size = self.ggml_dtype.block_size();
|
||||
if tensor_elems % block_size != 0 {
|
||||
crate::bail!(
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
|
||||
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
|
||||
)
|
||||
}
|
||||
let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
|
||||
let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size();
|
||||
let mut raw_data = vec![0u8; size_in_bytes];
|
||||
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
|
||||
reader.read_exact(&mut raw_data)?;
|
||||
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
|
||||
super::ggml_file::qtensor_from_ggml(
|
||||
self.ggml_dtype,
|
||||
&raw_data,
|
||||
self.shape.dims().to_vec(),
|
||||
device,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -460,12 +466,13 @@ impl Content {
|
||||
&self,
|
||||
reader: &mut R,
|
||||
name: &str,
|
||||
device: &Device,
|
||||
) -> Result<QTensor> {
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
Some(tensor_info) => tensor_info,
|
||||
None => crate::bail!("cannot find tensor info for {name}"),
|
||||
};
|
||||
tensor_info.read(reader, self.tensor_data_offset)
|
||||
tensor_info.read(reader, self.tensor_data_offset, device)
|
||||
}
|
||||
}
|
||||
|
||||
@ -517,10 +524,9 @@ pub fn write<W: std::io::Seek + std::io::Write>(
|
||||
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
|
||||
)
|
||||
}
|
||||
let data_ptr = tensor.as_ptr();
|
||||
let size_in_bytes = tensor.storage_size_in_bytes();
|
||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||
w.write_all(data)?;
|
||||
let data = tensor.data()?;
|
||||
let size_in_bytes = data.len();
|
||||
w.write_all(&data)?;
|
||||
let padding = 31 - (31 + size_in_bytes) % 32;
|
||||
w.write_all(&vec![0u8; padding])?;
|
||||
}
|
||||
|
@ -1545,13 +1545,13 @@ impl GgmlType for BlockQ5K {
|
||||
let d2 = d * sc as f32;
|
||||
let m2 = min * m as f32;
|
||||
for (ql, qh) in ql.iter().zip(qh) {
|
||||
let to_add = if qh & u1 != 0 { 16 } else { 1 };
|
||||
y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1;
|
||||
let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 };
|
||||
y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1;
|
||||
ys_index += 1;
|
||||
}
|
||||
for (ql, qh) in ql.iter().zip(qh) {
|
||||
let to_add = if qh & u2 != 0 { 16 } else { 1 };
|
||||
y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2;
|
||||
let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 };
|
||||
y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2;
|
||||
ys_index += 1;
|
||||
}
|
||||
is += 2;
|
||||
|
153
candle-core/src/quantized/metal.rs
Normal file
153
candle-core/src/quantized/metal.rs
Normal file
@ -0,0 +1,153 @@
|
||||
use super::{GgmlDType, QStorage};
|
||||
use crate::{DType, MetalDevice, MetalStorage, Result};
|
||||
use metal::Buffer;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct QMetalStorage {
|
||||
dtype: GgmlDType,
|
||||
device: MetalDevice,
|
||||
buffer: Arc<Buffer>,
|
||||
}
|
||||
|
||||
impl QMetalStorage {
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn buffer(&self) -> &Buffer {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: GgmlDType) -> Self {
|
||||
Self {
|
||||
device,
|
||||
buffer,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
blit.end_encoding();
|
||||
self.device.wait_until_completed()?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
match self.dtype {
|
||||
GgmlDType::F32 => {
|
||||
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
f32::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::F16 => {
|
||||
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
half::f16::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
let vec: Vec<crate::quantized::BlockQ2K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
let vec: Vec<crate::quantized::BlockQ3K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let vec: Vec<crate::quantized::BlockQ4K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
let vec: Vec<crate::quantized::BlockQ5K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let vec: Vec<crate::quantized::BlockQ6K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8K => {
|
||||
let vec: Vec<crate::quantized::BlockQ8K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
}
|
||||
|
||||
let buffer = self.device.new_buffer_with_data(&out)?;
|
||||
Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
||||
// Quantization only happens on CPU for now.
|
||||
let src = src.to_cpu::<f32>()?;
|
||||
let elem_count = src.len();
|
||||
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
||||
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;
|
||||
qcpu_storage.quantize(&src)?;
|
||||
let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
|
||||
self.buffer = buffer;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
||||
device: &MetalDevice,
|
||||
data: &[T],
|
||||
) -> Result<QStorage> {
|
||||
let buffer = device.new_buffer_with_data(data)?;
|
||||
let device = device.clone();
|
||||
Ok(QStorage::Metal(QMetalStorage {
|
||||
dtype: T::DTYPE,
|
||||
device,
|
||||
buffer,
|
||||
}))
|
||||
}
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let ptr = buffer.contents() as *const T;
|
||||
assert!(!ptr.is_null());
|
||||
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
||||
slice.to_vec()
|
||||
}
|
@ -1,23 +1,125 @@
|
||||
use crate::{Device, Result, Shape, Tensor};
|
||||
#[cfg(feature = "metal")]
|
||||
use crate::{backend::BackendStorage, DType};
|
||||
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
||||
use k_quants::*;
|
||||
use std::borrow::Cow;
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
pub mod ggml_file;
|
||||
pub mod gguf_file;
|
||||
pub mod k_quants;
|
||||
#[cfg(feature = "metal")]
|
||||
pub mod metal;
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(target_feature = "simd128")]
|
||||
pub mod simd128;
|
||||
pub mod utils;
|
||||
use half::f16;
|
||||
|
||||
pub use k_quants::GgmlType;
|
||||
|
||||
pub struct QTensor {
|
||||
data: Box<dyn QuantizedType>,
|
||||
storage: QStorage,
|
||||
shape: Shape,
|
||||
}
|
||||
|
||||
impl Device {
|
||||
fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = dtype.cpu_zeros(elem_count);
|
||||
Ok(QStorage::Cpu(storage))
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
Device::Metal(metal) => {
|
||||
let size = elem_count * dtype.type_size() / dtype.block_size();
|
||||
let buffer = metal.allocate_zeros(size)?;
|
||||
Ok(QStorage::Metal(metal::QMetalStorage::new(
|
||||
buffer,
|
||||
metal.clone(),
|
||||
dtype,
|
||||
)))
|
||||
}
|
||||
#[cfg(not(feature = "metal"))]
|
||||
Device::Metal(_metal) => {
|
||||
crate::bail!("Metal feature not activated");
|
||||
}
|
||||
Device::Cuda(_cuda) => {
|
||||
crate::bail!("Cuda ggml quantization not supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum QStorage {
|
||||
Cpu(Box<dyn QuantizedType>),
|
||||
#[cfg(feature = "metal")]
|
||||
Metal(metal::QMetalStorage),
|
||||
}
|
||||
|
||||
impl QStorage {
|
||||
fn block_size(&self) -> usize {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.block_size(),
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(storage) => storage.dtype().block_size(),
|
||||
}
|
||||
}
|
||||
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.dtype(),
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(storage) => storage.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(storage) => storage.buffer().length() as usize,
|
||||
}
|
||||
}
|
||||
|
||||
fn quantize(&mut self, src: &Storage) -> Result<()> {
|
||||
match (self, src) {
|
||||
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
||||
storage.from_float(src.as_slice::<f32>()?)?;
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
||||
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
||||
}
|
||||
}
|
||||
|
||||
fn data(&self) -> Result<Cow<[u8]>> {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => {
|
||||
let data_ptr = storage.as_ptr();
|
||||
let size_in_bytes = storage.storage_size_in_bytes();
|
||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||
Ok(Cow::from(data))
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
QStorage::Metal(_storage) => {
|
||||
crate::bail!("not implemented");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum GgmlDType {
|
||||
F32,
|
||||
@ -77,6 +179,25 @@ impl GgmlDType {
|
||||
}
|
||||
}
|
||||
|
||||
/// The block dtype
|
||||
pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
|
||||
match self {
|
||||
Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
|
||||
Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
|
||||
Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
|
||||
Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
|
||||
Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
|
||||
Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
|
||||
Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
|
||||
Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
|
||||
Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
|
||||
Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
|
||||
Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
|
||||
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
|
||||
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
|
||||
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
|
||||
}
|
||||
}
|
||||
/// The type size for blocks in bytes.
|
||||
pub fn type_size(&self) -> usize {
|
||||
use k_quants::*;
|
||||
@ -100,7 +221,7 @@ impl GgmlDType {
|
||||
}
|
||||
|
||||
/// The block size, i.e. the number of elements stored in each block.
|
||||
pub fn blck_size(&self) -> usize {
|
||||
pub fn block_size(&self) -> usize {
|
||||
match self {
|
||||
Self::F32 => 1,
|
||||
Self::F16 => 1,
|
||||
@ -119,9 +240,13 @@ impl GgmlDType {
|
||||
pub trait QuantizedType: Send + Sync {
|
||||
fn dtype(&self) -> GgmlDType;
|
||||
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()>;
|
||||
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
|
||||
fn storage_size_in_bytes(&self) -> usize;
|
||||
fn as_ptr(&self) -> *const u8;
|
||||
fn block_size(&self) -> usize;
|
||||
#[allow(clippy::wrong_self_convention)]
|
||||
fn from_float(&mut self, xs: &[f32]) -> Result<()>;
|
||||
fn size(&self) -> usize;
|
||||
}
|
||||
|
||||
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
@ -129,12 +254,26 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
|
||||
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.len() * core::mem::size_of::<T>()
|
||||
}
|
||||
|
||||
fn from_float(&mut self, xs: &[f32]) -> Result<()> {
|
||||
T::from_float(xs, self)
|
||||
}
|
||||
|
||||
fn dtype(&self) -> GgmlDType {
|
||||
T::DTYPE
|
||||
}
|
||||
|
||||
fn to_float(&self, ys: &mut [f32]) -> Result<()> {
|
||||
T::to_float(self.as_slice(), ys)
|
||||
fn block_size(&self) -> usize {
|
||||
T::BLCK_SIZE
|
||||
}
|
||||
|
||||
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
|
||||
let mut ys = vec![0.0f32; elem_count];
|
||||
T::to_float(self.as_slice(), &mut ys)?;
|
||||
Ok(CpuStorage::F32(ys))
|
||||
}
|
||||
|
||||
fn storage_size_in_bytes(&self) -> usize {
|
||||
@ -152,56 +291,49 @@ impl std::fmt::Debug for QTensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
|
||||
fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
|
||||
let dims = shape.dims();
|
||||
if dims.is_empty() {
|
||||
crate::bail!("scalar tensor cannot be quantized {shape:?}")
|
||||
}
|
||||
if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
|
||||
if dims[dims.len() - 1] % block_size != 0 {
|
||||
crate::bail!(
|
||||
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
|
||||
T::BLCK_SIZE
|
||||
block_size
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl QTensor {
|
||||
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||
data: Vec<T>,
|
||||
shape: S,
|
||||
) -> Result<Self> {
|
||||
pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
check_shape::<T>(&shape)?;
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
shape,
|
||||
})
|
||||
check_shape(&shape, storage.block_size())?;
|
||||
Ok(Self { storage, shape })
|
||||
}
|
||||
|
||||
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
|
||||
pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
|
||||
let shape = src.shape();
|
||||
check_shape::<T>(shape)?;
|
||||
let src = src
|
||||
.to_dtype(crate::DType::F32)?
|
||||
.flatten_all()?
|
||||
.to_vec1::<f32>()?;
|
||||
if src.len() % T::BLCK_SIZE != 0 {
|
||||
let block_size = dtype.block_size();
|
||||
check_shape(shape, block_size)?;
|
||||
let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
|
||||
let elem_count = shape.elem_count();
|
||||
if elem_count % block_size != 0 {
|
||||
crate::bail!(
|
||||
"tensor size ({shape:?}) is not divisible by block size {}",
|
||||
T::BLCK_SIZE
|
||||
block_size
|
||||
)
|
||||
}
|
||||
let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
||||
T::from_float(&src, &mut data)?;
|
||||
let mut storage = src.device().qzeros(elem_count, dtype)?;
|
||||
storage.quantize(&src.storage())?;
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.data.dtype()
|
||||
self.storage.dtype()
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
@ -213,21 +345,19 @@ impl QTensor {
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||
let mut f32_data = vec![0f32; self.shape.elem_count()];
|
||||
self.data.to_float(&mut f32_data)?;
|
||||
Tensor::from_vec(f32_data, &self.shape, device)
|
||||
}
|
||||
|
||||
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
|
||||
self.data.matmul_t(mkn, lhs, dst)
|
||||
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
||||
let none = crate::op::BackpropOp::none();
|
||||
let is_variable = false;
|
||||
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
|
||||
.to_device(device)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.data.storage_size_in_bytes()
|
||||
self.storage.size_in_bytes()
|
||||
}
|
||||
|
||||
pub fn as_ptr(&self) -> *const u8 {
|
||||
self.data.as_ptr()
|
||||
pub fn data(&self) -> Result<Cow<'_, [u8]>> {
|
||||
self.storage.data()
|
||||
}
|
||||
}
|
||||
|
||||
@ -294,17 +424,93 @@ impl crate::CustomOp1 for QTensor {
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
let storage = storage.as_slice::<f32>()?;
|
||||
let storage =
|
||||
&storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
#[allow(clippy::infallible_destructuring_match)]
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Cpu(storage) => storage,
|
||||
#[cfg(feature = "metal")]
|
||||
_ => crate::bail!("Invalid storage"),
|
||||
};
|
||||
let slice = storage.as_slice::<f32>()?;
|
||||
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||
self.matmul_t(
|
||||
(dst_shape.elem_count() / n, k, n),
|
||||
storage,
|
||||
&mut dst_storage,
|
||||
)?;
|
||||
self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?;
|
||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &crate::MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::MetalStorage, Shape)> {
|
||||
use crate::MetalError;
|
||||
|
||||
if !layout.is_contiguous() {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
// self is transposed so n is first then k.
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let (n, k) = self.shape.dims2()?;
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
|
||||
let (b, m) = match dst_shape.len() {
|
||||
3 => (dst_shape[0], dst_shape[1]),
|
||||
2 => (1, dst_shape[0]),
|
||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||
};
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
let device = storage.device().clone();
|
||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||
let (buffer, dtype) = match &self.storage {
|
||||
QStorage::Metal(metal) => (metal.buffer(), metal.dtype()),
|
||||
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
||||
};
|
||||
let command_buffer = device.command_buffer()?;
|
||||
candle_metal_kernels::call_quantized_matmul_t(
|
||||
device.device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
dtype.into(),
|
||||
(b, m, n, k),
|
||||
storage.buffer(),
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
buffer,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
||||
Ok((dst_storage, dst_shape))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
||||
fn from(value: GgmlDType) -> Self {
|
||||
match value {
|
||||
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
||||
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
||||
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
||||
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
||||
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
||||
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
||||
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
||||
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
||||
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
||||
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
||||
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
||||
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
||||
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
||||
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for QMatMul {
|
||||
|
@ -12,6 +12,14 @@ use core::arch::arm::*;
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
use core::arch::aarch64::*;
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
|
||||
// TODO: dotprod
|
||||
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
|
||||
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
||||
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
@ -43,15 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
||||
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
||||
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
|
||||
// TODO: Support dotprod when it's available outside of nightly.
|
||||
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
|
||||
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
||||
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
|
||||
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
||||
|
||||
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||
|
||||
let pl0 = vdotq_s32(v0_0ls, v1_0l);
|
||||
let ph0 = vdotq_s32(v0_0hs, v1_0h);
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
||||
@ -82,14 +83,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
||||
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
|
||||
// TODO dotprod once this is the intrinsics are.
|
||||
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
||||
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
||||
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
||||
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
||||
|
||||
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
||||
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
||||
let p0 = vdotq_s32(x0_0, y0_0);
|
||||
let p1 = vdotq_s32(x0_1, y0_1);
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
@ -118,10 +113,7 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res
|
||||
for i in (0..QK_K).step_by(16) {
|
||||
let xs = vld1q_s8(xs.add(i));
|
||||
let ys = vld1q_s8(ys.add(i));
|
||||
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
|
||||
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
|
||||
|
||||
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
|
||||
let xy = vdotq_s32(xs, ys);
|
||||
sum_i = vaddq_s32(sum_i, xy)
|
||||
}
|
||||
sumf += vaddvq_s32(sum_i) as f32 * scale
|
||||
@ -191,30 +183,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
|
||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
|
||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let q8bytes = vld1q_s8_x4(q8);
|
||||
@ -234,29 +212,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
|
||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
|
||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
|
||||
|
||||
// TODO: dotprod case.
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
||||
scale = scale.add(2);
|
||||
}
|
||||
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
|
||||
@ -333,28 +298,14 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
||||
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
|
||||
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
|
||||
let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
|
||||
sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
|
||||
scales = scales.add(1);
|
||||
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
|
||||
let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
|
||||
sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
|
||||
scales = scales.add(1);
|
||||
}
|
||||
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
|
||||
@ -417,22 +368,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
||||
for j in 0..QK_K / 64 {
|
||||
let q4bits = vld1q_u8_x2(q4);
|
||||
q4 = q4.add(32);
|
||||
// TODO: dotprod
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
let q4bytes = int8x16x2_t(
|
||||
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
|
||||
);
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
|
||||
let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
|
||||
sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
@ -440,15 +384,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
|
||||
let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
|
||||
let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
|
||||
sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
|
||||
}
|
||||
sumf += d * (sumi1 + sumi2) as f32;
|
||||
}
|
||||
@ -526,27 +464,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
||||
vreinterpretq_s8_u8(q3h_3),
|
||||
);
|
||||
|
||||
// TODO: dotprod
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
|
||||
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
|
||||
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
|
||||
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
|
||||
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
|
||||
);
|
||||
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||
let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
|
||||
let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
|
||||
let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
|
||||
let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
|
||||
isum += vaddvq_s32(p0) * *scale as i32
|
||||
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
||||
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
||||
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
||||
scale = scale.add(4);
|
||||
|
||||
let q3h_0 = vbicq_u8(m2, qhbits.0);
|
||||
@ -571,27 +496,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
||||
vreinterpretq_s8_u8(q3h_3),
|
||||
);
|
||||
|
||||
// TODO: dotprod
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
|
||||
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
|
||||
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
|
||||
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
|
||||
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
|
||||
);
|
||||
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||
let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
|
||||
let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
|
||||
let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
|
||||
let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
|
||||
isum += vaddvq_s32(p0) * *scale as i32
|
||||
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
||||
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
||||
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
||||
scale = scale.add(4);
|
||||
|
||||
if j == 0 {
|
||||
@ -649,7 +561,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
|
||||
let mut is = 0usize;
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
for _j in 0..QK_K / 128 {
|
||||
let q2bits = vld1q_u8_x2(q2);
|
||||
q2 = q2.add(32);
|
||||
@ -696,14 +607,7 @@ unsafe fn multiply_accum_with_scale(
|
||||
q2bytes: int8x16x2_t,
|
||||
q8bytes: int8x16x2_t,
|
||||
) -> i32 {
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
vaddvq_s16(p1) as i32 * aux[is + index] as i32
|
||||
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
|
||||
let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
|
||||
let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
|
||||
vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
|
||||
}
|
||||
|
@ -2578,11 +2578,21 @@ impl Tensor {
|
||||
}
|
||||
|
||||
/// Returns log(sum(exp(tensor), dim)).
|
||||
pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||
let exp = self.exp()?;
|
||||
let sum = exp.sum(sum_dims)?;
|
||||
sum.log()
|
||||
}
|
||||
|
||||
/// Pointwise pow operation.
|
||||
pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||
rhs.mul(&self.log()?)?.exp()
|
||||
}
|
||||
|
||||
/// Broadcasting version of `pow`.
|
||||
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||
rhs.broadcast_mul(&self.log()?)?.exp()
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
@ -1,5 +1,7 @@
|
||||
use candle_core::{
|
||||
bail,
|
||||
quantized::{self, GgmlDType},
|
||||
test_device,
|
||||
test_utils::to_vec2_round,
|
||||
Device, Module, Result, Tensor,
|
||||
};
|
||||
@ -13,16 +15,48 @@ const GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS: f32 = 0.0075;
|
||||
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040;
|
||||
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
fn test_matmul(
|
||||
device: &Device,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
dtype: GgmlDType,
|
||||
) -> Result<()> {
|
||||
let lhs = (0..(m * k))
|
||||
.map(|v| v as f32 / (m * k) as f32)
|
||||
.collect::<Vec<_>>();
|
||||
let rhs = (0..(k * n))
|
||||
.map(|v| v as f32 / (n * k) as f32)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||
let rhs = Tensor::from_slice(&rhs, (k, n), device)?;
|
||||
let mm = lhs.matmul(&rhs)?;
|
||||
let qtensor = quantized::QTensor::quantize(&rhs.t()?, dtype)?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&lhs)?;
|
||||
|
||||
let error: f32 = ((&mm - &res)?.abs()? / &mm.abs()?)?
|
||||
.sum_all()?
|
||||
.to_scalar()?;
|
||||
let error = error / (b * m * n) as f32;
|
||||
assert!(
|
||||
error <= 0.02,
|
||||
"Error {error} is too big. \nExpected:\n {mm} \nFound:\n {res}\n for {dtype:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||
let mut dst = vec![42.; 3 * 4];
|
||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||
assert_eq!(
|
||||
@ -32,6 +66,7 @@ fn quantized_matmul() -> Result<()> {
|
||||
341876.0, 994283.0, 1655709.0, 2301518.0
|
||||
]
|
||||
);
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||
assert_eq!(
|
||||
mm.to_vec2::<f32>()?,
|
||||
@ -42,35 +77,49 @@ fn quantized_matmul() -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[85120.0, 214562.0, 345455.0, 474748.0],
|
||||
[213475.0, 604465.0, 1000686.0, 1388317.0],
|
||||
[341876.0, 994283.0, 1655709.0, 2301518.0]
|
||||
]
|
||||
);
|
||||
match device {
|
||||
Device::Metal(_) => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[84946.0, 214126.0, 344757.0, 473798.0],
|
||||
[213458.0, 604350.0, 1000469.0, 1387990.0],
|
||||
[341970.0, 994574.0, 1656181.0, 2302182.0]
|
||||
]
|
||||
),
|
||||
_ => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[85120.0, 214562.0, 345455.0, 474748.0],
|
||||
[213475.0, 604465.0, 1000686.0, 1388317.0],
|
||||
[341876.0, 994283.0, 1655709.0, 2301518.0]
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_neg() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k))
|
||||
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
||||
.collect::<Vec<_>>();
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||
let mut dst = vec![42.; 3 * 4];
|
||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||
let rhs = (0..k * n)
|
||||
.map(|v| v as f32 - (k * n) as f32 / 3.0)
|
||||
.collect::<Vec<_>>();
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||
assert_eq!(
|
||||
@ -90,32 +139,56 @@ fn quantized_matmul_neg() -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||
[23777.0, 21651.0, 19398.0, 18367.0],
|
||||
[-196472.0, 63012.0, 324585.0, 587902.0]
|
||||
]
|
||||
);
|
||||
match device {
|
||||
Device::Metal(_) => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[243666.0, -19714.0, -285433.0, -550453.0],
|
||||
[23782.0, 21654.0, 19400.0, 18369.0],
|
||||
[-196102.0, 63022.0, 324233.0, 587191.0]
|
||||
]
|
||||
),
|
||||
_ => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||
[23777.0, 21651.0, 19398.0, 18367.0],
|
||||
[-196472.0, 63012.0, 324585.0, 587902.0]
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q4_0() -> Result<()> {
|
||||
use k_quants::BlockQ4_0;
|
||||
test_device!(
|
||||
quantized_matmul,
|
||||
quantized_matmul_cpu,
|
||||
quantized_matmul_cuda,
|
||||
quantized_matmul_metal
|
||||
);
|
||||
test_device!(
|
||||
quantized_matmul_neg,
|
||||
quantized_matmul_neg_cpu,
|
||||
quantized_matmul_neg_cuda,
|
||||
quantized_matmul_neg_metal
|
||||
);
|
||||
|
||||
fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ4_0::zeros(); 4];
|
||||
BlockQ4_0::from_float(&src, &mut quant)?;
|
||||
BlockQ4_0::to_float(&quant, dst.as_mut_slice())?;
|
||||
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
assert_eq!(
|
||||
dst,
|
||||
dst.to_vec1::<f32>()?,
|
||||
&[
|
||||
-0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625,
|
||||
11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25,
|
||||
@ -131,21 +204,21 @@ fn quantize_q4_0() -> Result<()> {
|
||||
127.0, 127.0
|
||||
]
|
||||
);
|
||||
ggml_quantization_error_test::<BlockQ4_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test(GgmlDType::Q4_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q4_1() -> Result<()> {
|
||||
use k_quants::BlockQ4_1;
|
||||
|
||||
fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ4_1::zeros(); 4];
|
||||
BlockQ4_1::from_float(&src, &mut quant)?;
|
||||
BlockQ4_1::to_float(&quant, dst.as_mut_slice())?;
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
assert_eq!(
|
||||
round_vector(&dst),
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
&[
|
||||
0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332,
|
||||
12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73,
|
||||
@ -161,21 +234,21 @@ fn quantize_q4_1() -> Result<()> {
|
||||
118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996
|
||||
]
|
||||
);
|
||||
ggml_quantization_error_test::<BlockQ4_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test(GgmlDType::Q4_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q5_0() -> Result<()> {
|
||||
use k_quants::BlockQ5_0;
|
||||
|
||||
fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ5_0::zeros(); 4];
|
||||
BlockQ5_0::from_float(&src, &mut quant)?;
|
||||
BlockQ5_0::to_float(&quant, dst.as_mut_slice())?;
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
assert_eq!(
|
||||
round_vector(&dst),
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
&[
|
||||
-0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625,
|
||||
11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313,
|
||||
@ -191,21 +264,21 @@ fn quantize_q5_0() -> Result<()> {
|
||||
119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0
|
||||
]
|
||||
);
|
||||
ggml_quantization_error_test::<BlockQ5_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test(GgmlDType::Q5_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q5_1() -> Result<()> {
|
||||
use k_quants::BlockQ5_1;
|
||||
|
||||
fn quantize_q5_1(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let mut dst = vec![0f32; 32 * 4];
|
||||
let mut quant = vec![BlockQ5_1::zeros(); 4];
|
||||
BlockQ5_1::from_float(&src, &mut quant)?;
|
||||
BlockQ5_1::to_float(&quant, dst.as_mut_slice())?;
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
assert_eq!(
|
||||
dst,
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
&[
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
|
||||
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
|
||||
@ -219,13 +292,11 @@ fn quantize_q5_1() -> Result<()> {
|
||||
124.0, 125.0, 126.0, 127.0
|
||||
]
|
||||
);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ5_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test(GgmlDType::Q5_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generates a small test vector ranging from -`bound` to `bound` with `size` steps
|
||||
fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
|
||||
fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result<Tensor> {
|
||||
assert!(
|
||||
size % crate::quantized::k_quants::QK_K == 0,
|
||||
"size must be a multiple of {}",
|
||||
@ -235,10 +306,8 @@ fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
|
||||
let src = (0..size)
|
||||
.map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let dst = vec![0f32; size];
|
||||
assert_eq!([src[0], src[size / 2]], [-bound, 0.0]);
|
||||
(src, dst)
|
||||
Tensor::from_vec(src, (size,), device)
|
||||
}
|
||||
|
||||
/// Round a vector
|
||||
@ -265,7 +334,8 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||
/// Creates a vector similar to the ones used in GGML unit tests:
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
|
||||
(0..GGML_TEST_SIZE)
|
||||
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
|
||||
@ -284,14 +354,16 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
|
||||
sum / a.len() as f32
|
||||
}
|
||||
|
||||
/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
||||
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
|
||||
/// Similar to the GGML quantization unit test:
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
||||
fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f32) -> Result<()> {
|
||||
let src = create_ggml_like_vector(0.0);
|
||||
let mut dst = vec![0.0; GGML_TEST_SIZE];
|
||||
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
|
||||
let error = calculate_rmse(src.as_slice(), dst.as_slice());
|
||||
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
|
||||
if error > max_error {
|
||||
candle_core::bail!(
|
||||
bail!(
|
||||
"Quantization error {} exceeds max error {}",
|
||||
error,
|
||||
max_error
|
||||
@ -300,19 +372,19 @@ fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantize_roundtrip<T: GgmlType>(src: &[f32], dst: &mut [f32]) -> Result<Vec<T>> {
|
||||
let mut quant = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
||||
T::from_float(src, &mut quant)?;
|
||||
T::to_float(&quant, dst)?;
|
||||
Ok(quant)
|
||||
}
|
||||
fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q2K;
|
||||
|
||||
#[test]
|
||||
fn quantize_q2k() -> Result<()> {
|
||||
use k_quants::BlockQ2K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ2K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.1);
|
||||
|
||||
// Test some specific values
|
||||
@ -326,20 +398,30 @@ fn quantize_q2k() -> Result<()> {
|
||||
[-0.499, -0.366, -0.249, 0.0, 0.295, 0.492]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ2K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ2K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q3k() -> Result<()> {
|
||||
use k_quants::BlockQ3K;
|
||||
fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q3K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ3K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.03);
|
||||
|
||||
// Test some specific values
|
||||
@ -353,20 +435,30 @@ fn quantize_q3k() -> Result<()> {
|
||||
[-0.493, -0.37, -0.243, -0.0, 0.292, 0.492]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ3K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ3K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q4k() -> Result<()> {
|
||||
use k_quants::BlockQ4K;
|
||||
fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q4K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ4K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.017);
|
||||
|
||||
// Test some specific values
|
||||
@ -380,21 +472,31 @@ fn quantize_q4k() -> Result<()> {
|
||||
[-0.5, -0.373, -0.25, 0.0, 0.288, 0.498]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ4K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ4K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q5k() -> Result<()> {
|
||||
use k_quants::BlockQ5K;
|
||||
fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q5K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ5K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.009);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
@ -404,24 +506,33 @@ fn quantize_q5k() -> Result<()> {
|
||||
let dst = round_vector(&dst);
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.499, -0.372, -0.249, 0.001, 0.279, 0.499]
|
||||
[-0.5, -0.373, -0.25, 0.0, 0.279, 0.499]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ5K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ5K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q6k() -> Result<()> {
|
||||
use k_quants::BlockQ6K;
|
||||
fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q6K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ6K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||
|
||||
// Test some specific values
|
||||
@ -435,22 +546,31 @@ fn quantize_q6k() -> Result<()> {
|
||||
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ6K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ6K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantize_q8k() -> Result<()> {
|
||||
use k_quants::BlockQ8K;
|
||||
fn quantize_q8k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q8K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
|
||||
let (src, mut dst) = get_test_vector(0.5, 1024);
|
||||
let _quant = quantize_roundtrip::<BlockQ8K>(src.as_slice(), dst.as_mut_slice())?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.003);
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
|
||||
|
||||
// Test some specific values
|
||||
assert_eq!(
|
||||
@ -463,15 +583,79 @@ fn quantize_q8k() -> Result<()> {
|
||||
[-0.5, -0.375, -0.25, -0.0, 0.281, 0.499]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
let _quant_big = quantize_roundtrip::<BlockQ8K>(src_big.as_slice(), dst_big.as_mut_slice())?;
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6);
|
||||
|
||||
ggml_quantization_error_test::<BlockQ8K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
|
||||
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(
|
||||
quantize_q4_0,
|
||||
quantize_q4_0_cpu,
|
||||
quantize_q4_0_cuda,
|
||||
quantize_q4_0_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q4_1,
|
||||
quantize_q4_1_cpu,
|
||||
quantize_q4_1_cuda,
|
||||
quantize_q4_1_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q5_0,
|
||||
quantize_q5_0_cpu,
|
||||
quantize_q5_0_cuda,
|
||||
quantize_q5_0_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q5_1,
|
||||
quantize_q5_1_cpu,
|
||||
quantize_q5_1_cuda,
|
||||
quantize_q5_1_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q2k,
|
||||
quantize_q2k_cpu,
|
||||
quantize_q2k_cuda,
|
||||
quantize_q2k_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q3k,
|
||||
quantize_q3k_cpu,
|
||||
quantize_q3k_cuda,
|
||||
quantize_q3k_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q4k,
|
||||
quantize_q4k_cpu,
|
||||
quantize_q4k_cuda,
|
||||
quantize_q4k_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q5k,
|
||||
quantize_q5k_cpu,
|
||||
quantize_q5k_cuda,
|
||||
quantize_q5k_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q6k,
|
||||
quantize_q6k_cpu,
|
||||
quantize_q6k_cuda,
|
||||
quantize_q6k_metal
|
||||
);
|
||||
test_device!(
|
||||
quantize_q8k,
|
||||
quantize_q8k_cpu,
|
||||
quantize_q8k_cuda,
|
||||
quantize_q8k_metal
|
||||
);
|
||||
|
||||
/// Very simple dot product implementation
|
||||
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b).map(|(a, b)| a * b).sum()
|
||||
@ -487,54 +671,66 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
|
||||
GgmlDType::Q5K => 0.000740,
|
||||
GgmlDType::Q6K => 0.000952,
|
||||
GgmlDType::Q4_0 => 0.001143,
|
||||
GgmlDType::Q4_1 => 0.007784,
|
||||
GgmlDType::Q4_1 => 0.008,
|
||||
GgmlDType::Q5_0 => 0.001353,
|
||||
GgmlDType::Q5_1 => 0.001363,
|
||||
GgmlDType::Q5_1 => 0.00149,
|
||||
GgmlDType::Q8_0 => 0.000092,
|
||||
|
||||
// Not from the ggml repo.
|
||||
GgmlDType::Q8K => 0.00065,
|
||||
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
||||
_ => bail!("No GGML results for quantization type {dtype:?}",),
|
||||
};
|
||||
Ok(err)
|
||||
}
|
||||
|
||||
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
||||
/// Similar to the GGML matmul unit test:
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
||||
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||
let a = create_ggml_like_vector(0.0);
|
||||
let b = create_ggml_like_vector(1.0);
|
||||
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 1.0)?;
|
||||
// Another example that is more likely to trigger the overflow reported in #1526
|
||||
let a = (0..GGML_TEST_SIZE)
|
||||
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
|
||||
.collect::<Vec<_>>();
|
||||
let b = (0..GGML_TEST_SIZE)
|
||||
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
|
||||
.collect::<Vec<_>>();
|
||||
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 2.0)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Result<()> {
|
||||
let length = a.len();
|
||||
|
||||
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
|
||||
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
|
||||
T::from_float(&a, &mut a_quant)?;
|
||||
T::VecDotType::from_float(&b, &mut b_quant)?;
|
||||
T::from_float(a, &mut a_quant)?;
|
||||
T::VecDotType::from_float(b, &mut b_quant)?;
|
||||
|
||||
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
||||
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
|
||||
let reference_result = vec_dot_reference(&a, &b);
|
||||
let reference_result = vec_dot_reference(a, b);
|
||||
|
||||
if (result - result_unopt).abs() / length as f32 > 1e-6 {
|
||||
candle_core::bail!(
|
||||
bail!(
|
||||
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
|
||||
)
|
||||
}
|
||||
|
||||
let error = (result - reference_result).abs() / length as f32;
|
||||
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m;
|
||||
|
||||
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
candle_core::bail!(
|
||||
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
|
||||
);
|
||||
bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",);
|
||||
}
|
||||
|
||||
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
|
||||
// => we use a slightly higher error threshold
|
||||
const ERROR_LENIENCY: f32 = 0.00001;
|
||||
if error - ERROR_LENIENCY > ggml_error {
|
||||
candle_core::bail!(
|
||||
bail!(
|
||||
"Dot product error {} exceeds ggml reference error {}",
|
||||
error,
|
||||
ggml_error
|
||||
@ -543,6 +739,16 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_mm() -> Result<()> {
|
||||
ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ4_1>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ5_0>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ5_1>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ8_0>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
|
||||
fn get_random_tensors(
|
||||
m: usize,
|
||||
@ -566,6 +772,112 @@ fn get_random_tensors(
|
||||
Ok((lhs, rhs, mm))
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! quantized_matmul {
|
||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
|
||||
fn $fn_name(device: &Device) -> Result<()> {
|
||||
if device.is_cuda() {
|
||||
// TODO Enable Cuda GGML sometime maybe.
|
||||
return Ok(());
|
||||
}
|
||||
test_matmul(device, (1, 3, 4, 256), $dtype)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!($fn_name, $fn_name_cpu, $fn_name_cuda, $fn_name_metal);
|
||||
};
|
||||
}
|
||||
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q4_0_bis,
|
||||
quantized_matmul_q4_0_cpu,
|
||||
quantized_matmul_q4_0_cuda,
|
||||
quantized_matmul_q4_0_metal,
|
||||
GgmlDType::Q4_0
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q4_1_bis,
|
||||
quantized_matmul_q4_1_cpu,
|
||||
quantized_matmul_q4_1_cuda,
|
||||
quantized_matmul_q4_1_metal,
|
||||
GgmlDType::Q4_1
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q5_0_bis,
|
||||
quantized_matmul_q5_0_cpu,
|
||||
quantized_matmul_q5_0_cuda,
|
||||
quantized_matmul_q5_0_metal,
|
||||
GgmlDType::Q5_0
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q5_1_bis,
|
||||
quantized_matmul_q5_1_cpu,
|
||||
quantized_matmul_q5_1_cuda,
|
||||
quantized_matmul_q5_1_metal,
|
||||
GgmlDType::Q5_1
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q8_0_bis,
|
||||
quantized_matmul_q8_0_cpu,
|
||||
quantized_matmul_q8_0_cuda,
|
||||
quantized_matmul_q8_0_metal,
|
||||
GgmlDType::Q8_0
|
||||
);
|
||||
// Not implemented in Ggml
|
||||
// quantized_matmul!(
|
||||
// quantized_matmul_q8_1_bis,
|
||||
// quantized_matmul_q8_1_cpu,
|
||||
// quantized_matmul_q8_1_cuda,
|
||||
// quantized_matmul_q8_1_metal,
|
||||
// GgmlDType::Q8_1
|
||||
// );
|
||||
// TODO This is bugged (also bugged in GGML
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q2k_bis,
|
||||
quantized_matmul_q2k_cpu,
|
||||
quantized_matmul_q2k_cuda,
|
||||
quantized_matmul_q2k_metal,
|
||||
GgmlDType::Q2K
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q3k_bis,
|
||||
quantized_matmul_q3k_cpu,
|
||||
quantized_matmul_q3k_cuda,
|
||||
quantized_matmul_q3k_metal,
|
||||
GgmlDType::Q3K
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q4k_bis,
|
||||
quantized_matmul_q4k_cpu,
|
||||
quantized_matmul_q4k_cuda,
|
||||
quantized_matmul_q4k_metal,
|
||||
GgmlDType::Q4K
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q5k_bis,
|
||||
quantized_matmul_q5k_cpu,
|
||||
quantized_matmul_q5k_cuda,
|
||||
quantized_matmul_q5k_metal,
|
||||
GgmlDType::Q5K
|
||||
);
|
||||
quantized_matmul!(
|
||||
quantized_matmul_q6k_bis,
|
||||
quantized_matmul_q6k_cpu,
|
||||
quantized_matmul_q6k_cuda,
|
||||
quantized_matmul_q6k_metal,
|
||||
GgmlDType::Q6K
|
||||
);
|
||||
// Not implemented on metal
|
||||
// quantized_matmul!(
|
||||
// quantized_matmul_q8k_bis,
|
||||
// quantized_matmul_q8k_cpu,
|
||||
// quantized_matmul_q8k_cuda,
|
||||
// quantized_matmul_q8k_metal,
|
||||
// GgmlDType::Q8K
|
||||
// );
|
||||
|
||||
#[test]
|
||||
fn quantized_matmul_q2k() -> Result<()> {
|
||||
use k_quants::BlockQ2K;
|
||||
@ -578,7 +890,7 @@ fn quantized_matmul_q2k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
@ -604,7 +916,7 @@ fn quantized_matmul_q3k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q3K)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
@ -630,7 +942,7 @@ fn quantized_matmul_q4k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q4K)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
@ -656,7 +968,7 @@ fn quantized_matmul_q5k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q5K)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
@ -683,7 +995,7 @@ fn quantized_matmul_q6k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q6K)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
@ -708,7 +1020,7 @@ fn quantized_matmul_q8k() -> Result<()> {
|
||||
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
|
||||
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
|
||||
|
||||
let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q8K)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
|
||||
|
@ -1245,11 +1245,23 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn logsumexp() -> Result<()> {
|
||||
fn log_sum_exp() -> Result<()> {
|
||||
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
let output = input.logsumexp(D::Minus1)?;
|
||||
let output = input.log_sum_exp(D::Minus1)?;
|
||||
// The expectations obtained from pytorch.
|
||||
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
||||
assert_close(&output, &expected, 0.00001)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pow() -> Result<()> {
|
||||
let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
let rhs = (&lhs - 2.)?;
|
||||
let res = lhs.pow(&rhs)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res, 4)?,
|
||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -11,8 +11,8 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
|
@ -11,12 +11,12 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.3" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
|
||||
candle-onnx = { path = "../candle-onnx", version = "0.3.3", optional = true }
|
||||
candle = { workspace = true }
|
||||
candle-datasets = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
candle-onnx = { workspace = true, optional = true }
|
||||
|
||||
csv = "1.3.0"
|
||||
cudarc = { workspace = true, optional = true }
|
||||
@ -49,11 +49,12 @@ tokio = "1.29.1"
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
bindgen_cuda = { version = "0.1.1", optional = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
|
@ -4,251 +4,28 @@ use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
struct KernelDirectories {
|
||||
kernel_dir: &'static str,
|
||||
kernel_glob: &'static str,
|
||||
rust_target: &'static str,
|
||||
include_dirs: &'static [&'static str],
|
||||
}
|
||||
|
||||
const DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
||||
kernel_dir: "examples/custom-ops/kernels/",
|
||||
const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
||||
kernel_glob: "examples/custom-ops/kernels/*.cu",
|
||||
rust_target: "examples/custom-ops/cuda_kernels.rs",
|
||||
include_dirs: &[],
|
||||
}];
|
||||
|
||||
impl KernelDirectories {
|
||||
fn maybe_build_ptx(
|
||||
&self,
|
||||
cu_file: &std::path::Path,
|
||||
ptx_file: &std::path::Path,
|
||||
compute_cap: usize,
|
||||
) -> Result<()> {
|
||||
let should_compile = if ptx_file.exists() {
|
||||
let ptx_modified = ptx_file.metadata()?.modified()?;
|
||||
let cu_modified = cu_file.metadata()?.modified()?;
|
||||
cu_modified.duration_since(ptx_modified).is_ok()
|
||||
} else {
|
||||
true
|
||||
};
|
||||
if should_compile {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
|
||||
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
|
||||
let include_dirs: Vec<String> =
|
||||
self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
|
||||
command
|
||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("--ptx")
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.args(["--output-directory", out_dir.to_str().unwrap()])
|
||||
.arg(format!("-I/{}", self.kernel_dir))
|
||||
.args(include_dirs)
|
||||
.arg(cu_file);
|
||||
if let Ok(ccbin_path) = &ccbin_env {
|
||||
command
|
||||
.arg("-allow-unsupported-compiler")
|
||||
.args(["-ccbin", ccbin_path]);
|
||||
}
|
||||
let output = command
|
||||
.spawn()
|
||||
.context("failed spawning nvcc")?
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.open(ptx_file)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> {
|
||||
println!("cargo:rerun-if-changed={}", self.kernel_dir);
|
||||
let kernel_dir = PathBuf::from(self.kernel_dir);
|
||||
let out_dir = out_dir.join(self.kernel_dir);
|
||||
if !out_dir.exists() {
|
||||
std::fs::create_dir_all(&out_dir)?;
|
||||
}
|
||||
let mut cu_files = vec![];
|
||||
let mut cuh_files = vec![];
|
||||
for file in std::fs::read_dir(kernel_dir)?.flatten() {
|
||||
let file = file.path();
|
||||
match file.extension().and_then(|v| v.to_str()) {
|
||||
Some("cu") => cu_files.push(file),
|
||||
Some("cuh") => cuh_files.push(file),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let mut ptx_paths = vec![];
|
||||
for cu_file in cu_files.iter() {
|
||||
let file_stem = cu_file
|
||||
.file_stem()
|
||||
.with_context(|| format!("no stem {cu_file:?}"))?;
|
||||
let file_stem = file_stem.to_string_lossy().into_owned();
|
||||
let ptx_file = out_dir.join(&format!("{file_stem}.ptx"));
|
||||
self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?;
|
||||
ptx_paths.push(ptx_file);
|
||||
}
|
||||
|
||||
let regenerate_rs_file = true;
|
||||
if regenerate_rs_file {
|
||||
let mut file = std::fs::File::create(self.rust_target)?;
|
||||
for ptx_path in ptx_paths {
|
||||
let name = ptx_path
|
||||
.file_stem()
|
||||
.context("empty stem")?
|
||||
.to_string_lossy();
|
||||
file.write_all(b"#[rustfmt::skip]\n")?;
|
||||
let const_definition = format!(
|
||||
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
|
||||
name.to_uppercase().replace('.', "_"),
|
||||
self.kernel_dir,
|
||||
);
|
||||
file.write_all(const_definition.as_bytes())?;
|
||||
file.write_all(b"\n")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
|
||||
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
|
||||
let out_dir = PathBuf::from(out_dir);
|
||||
#[cfg(feature = "cuda")]
|
||||
set_cuda_include_dir()?;
|
||||
#[cfg(feature = "cuda")]
|
||||
let compute_cap = compute_cap()?;
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
let compute_cap = 0;
|
||||
for d in DIRS {
|
||||
d.process(&out_dir, compute_cap)?
|
||||
{
|
||||
for kdir in KERNEL_DIRS.iter() {
|
||||
let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob);
|
||||
println!("cargo:info={builder:?}");
|
||||
let bindings = builder.build_ptx().unwrap();
|
||||
bindings.write(kdir.rust_target).unwrap()
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_cuda_include_dir() -> Result<()> {
|
||||
// NOTE: copied from cudarc build.rs.
|
||||
let env_vars = [
|
||||
"CUDA_PATH",
|
||||
"CUDA_ROOT",
|
||||
"CUDA_TOOLKIT_ROOT_DIR",
|
||||
"CUDNN_LIB",
|
||||
];
|
||||
let env_vars = env_vars
|
||||
.into_iter()
|
||||
.map(std::env::var)
|
||||
.filter_map(Result::ok)
|
||||
.map(Into::<PathBuf>::into);
|
||||
|
||||
let roots = [
|
||||
"/usr",
|
||||
"/usr/local/cuda",
|
||||
"/opt/cuda",
|
||||
"/usr/lib/cuda",
|
||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||
"C:/CUDA",
|
||||
];
|
||||
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||
let root = env_vars
|
||||
.chain(roots)
|
||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||
.context("cannot find include/cuda.h")?;
|
||||
println!(
|
||||
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||
root.join("include").display()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn compute_cap() -> Result<usize> {
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
|
||||
// Try to parse compute cap from env
|
||||
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
|
||||
compute_cap_str
|
||||
.parse::<usize>()
|
||||
.context("Could not parse code")?
|
||||
} else {
|
||||
// Grab compute cap from nvidia-smi
|
||||
let out = std::process::Command::new("nvidia-smi")
|
||||
.arg("--query-gpu=compute_cap")
|
||||
.arg("--format=csv")
|
||||
.output()
|
||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||
let mut lines = out.lines();
|
||||
assert_eq!(
|
||||
lines.next().context("missing line in stdout")?,
|
||||
"compute_cap"
|
||||
);
|
||||
let cap = lines
|
||||
.next()
|
||||
.context("missing line in stdout")?
|
||||
.replace('.', "");
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
|
||||
cap.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as int {cap}"))?
|
||||
};
|
||||
|
||||
// Grab available GPU codes from nvcc and select the highest one
|
||||
let max_nvcc_code = {
|
||||
let out = std::process::Command::new("nvcc")
|
||||
.arg("--list-gpu-code")
|
||||
.output()
|
||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||
|
||||
let out = out.lines().collect::<Vec<&str>>();
|
||||
let mut codes = Vec::with_capacity(out.len());
|
||||
for code in out {
|
||||
let code = code.split('_').collect::<Vec<&str>>();
|
||||
if !code.is_empty() && code.contains(&"sm") {
|
||||
if let Ok(num) = code[1].parse::<usize>() {
|
||||
codes.push(num);
|
||||
}
|
||||
}
|
||||
}
|
||||
codes.sort();
|
||||
if !codes.contains(&compute_cap) {
|
||||
anyhow::bail!(
|
||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
|
||||
);
|
||||
}
|
||||
*codes.last().unwrap()
|
||||
};
|
||||
|
||||
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
|
||||
// then choose the highest gpu code in nvcc
|
||||
if compute_cap > max_nvcc_code {
|
||||
println!(
|
||||
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
|
||||
);
|
||||
compute_cap = max_nvcc_code;
|
||||
}
|
||||
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
|
||||
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
compute_cap = compute_cap_str
|
||||
.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
|
||||
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
|
||||
}
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
||||
Ok(compute_cap)
|
||||
}
|
||||
|
@ -106,17 +106,17 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let config = blip::Config::image_captioning_large();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (image_embeds, device, mut model) = if args.quantized {
|
||||
let device = Device::Cpu;
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
|
||||
let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?;
|
||||
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||
(image_embeds, device, Model::Q(model))
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
|
@ -1,2 +0,0 @@
|
||||
#[rustfmt::skip]
|
||||
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));
|
||||
|
@ -6,7 +6,8 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[allow(unused)]
|
||||
#[rustfmt::skip]
|
||||
#[cfg(feature = "cuda")]
|
||||
mod cuda_kernels;
|
||||
|
||||
use clap::Parser;
|
||||
|
@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
.extension()
|
||||
.map_or(false, |v| v == "safetensors");
|
||||
let (model, config) = if is_gguf {
|
||||
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
|
||||
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
|
||||
let (_vocab_size, dim) = vb
|
||||
.get_no_shape("model.embed_tokens.weight")?
|
||||
.shape()
|
||||
@ -279,13 +279,13 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
(config.seq_len, config.head_size() / 2),
|
||||
"rot.freq_cis_real",
|
||||
)?
|
||||
.dequantize(&candle::Device::Cpu)?;
|
||||
.dequantize(&device)?;
|
||||
let freq_cis_imag = vb
|
||||
.get(
|
||||
(config.seq_len, config.head_size() / 2),
|
||||
"rot.freq_cis_imag",
|
||||
)?
|
||||
.dequantize(&candle::Device::Cpu)?;
|
||||
.dequantize(&device)?;
|
||||
|
||||
let fake_vb = candle_nn::VarBuilder::from_tensors(
|
||||
[
|
||||
@ -295,7 +295,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
.into_iter()
|
||||
.collect(),
|
||||
candle::DType::F32,
|
||||
&candle::Device::Cpu,
|
||||
&device,
|
||||
);
|
||||
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||
|
@ -244,13 +244,14 @@ fn main() -> Result<()> {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||
let model = QMistral::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
(Model::Quantized(model), device)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
|
22
candle-examples/examples/mobileone/README.md
Normal file
22
candle-examples/examples/mobileone/README.md
Normal file
@ -0,0 +1,22 @@
|
||||
# candle-mobileone
|
||||
|
||||
[MobileOne: An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040).
|
||||
|
||||
This candle implementation uses a pre-trained MobileOne network for inference. The
|
||||
classification head has been trained on the ImageNet dataset and returns the
|
||||
probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example mobileone --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which s2
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
mountain bike, all-terrain bike, off-roader: 79.33%
|
||||
bicycle-built-for-two, tandem bicycle, tandem: 15.32%
|
||||
crash helmet : 2.58%
|
||||
unicycle, monocycle : 1.70%
|
||||
alp : 0.21%
|
||||
|
||||
```
|
96
candle-examples/examples/mobileone/main.rs
Normal file
96
candle-examples/examples/mobileone/main.rs
Normal file
@ -0,0 +1,96 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::mobileone;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
S0,
|
||||
S1,
|
||||
S2,
|
||||
S3,
|
||||
S4,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::S0 => "s0",
|
||||
Self::S1 => "s1",
|
||||
Self::S2 => "s2",
|
||||
Self::S3 => "s3",
|
||||
Self::S4 => "s4",
|
||||
};
|
||||
format!("timm/mobileone_{}.apple_in1k", name)
|
||||
}
|
||||
|
||||
fn config(&self) -> mobileone::Config {
|
||||
match self {
|
||||
Self::S0 => mobileone::Config::s0(),
|
||||
Self::S1 => mobileone::Config::s1(),
|
||||
Self::S2 => mobileone::Config::s2(),
|
||||
Self::S3 => mobileone::Config::s3(),
|
||||
Self::S4 => mobileone::Config::s4(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::S0)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = mobileone::mobileone(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -8,6 +8,7 @@ use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||
use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
|
||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
@ -18,6 +19,7 @@ use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
MixFormer(MixFormer),
|
||||
Phi(Phi),
|
||||
Quantized(QMixFormer),
|
||||
}
|
||||
|
||||
@ -84,6 +86,7 @@ impl TextGeneration {
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::MixFormer(m) => m.forward(&input)?,
|
||||
Model::Phi(m) => m.forward(&input)?,
|
||||
Model::Quantized(m) => m.forward(&input)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
@ -117,7 +120,7 @@ impl TextGeneration {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
|
||||
enum WhichModel {
|
||||
#[value(name = "1")]
|
||||
V1,
|
||||
@ -125,6 +128,8 @@ enum WhichModel {
|
||||
V1_5,
|
||||
#[value(name = "2")]
|
||||
V2,
|
||||
#[value(name = "2-old")]
|
||||
V2Old,
|
||||
PuffinPhiV2,
|
||||
PhiHermes,
|
||||
}
|
||||
@ -169,7 +174,7 @@ struct Args {
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "1.5")]
|
||||
#[arg(long, default_value = "2")]
|
||||
model: WhichModel,
|
||||
|
||||
#[arg(long)]
|
||||
@ -230,7 +235,7 @@ fn main() -> Result<()> {
|
||||
match args.model {
|
||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||
WhichModel::V2 => "microsoft/phi-2".to_string(),
|
||||
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"lmz/candle-quantized-phi".to_string()
|
||||
}
|
||||
@ -245,8 +250,9 @@ fn main() -> Result<()> {
|
||||
"main".to_string()
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 => "refs/pr/2".to_string(),
|
||||
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
||||
WhichModel::V1 => "refs/pr/8".to_string(),
|
||||
WhichModel::V1_5 => "refs/pr/73".to_string(),
|
||||
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
|
||||
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"main".to_string()
|
||||
}
|
||||
@ -258,7 +264,9 @@ fn main() -> Result<()> {
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?,
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old => {
|
||||
repo.get("tokenizer.json")?
|
||||
}
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||
}
|
||||
@ -271,14 +279,14 @@ fn main() -> Result<()> {
|
||||
match args.model {
|
||||
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
|
||||
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
|
||||
WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?],
|
||||
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
|
||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
||||
}
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
||||
WhichModel::V2 => candle_examples::hub_load_safetensors(
|
||||
WhichModel::V2 | WhichModel::V2Old => candle_examples::hub_load_safetensors(
|
||||
&repo,
|
||||
"model.safetensors.index.json",
|
||||
)?,
|
||||
@ -292,28 +300,44 @@ fn main() -> Result<()> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = match args.model {
|
||||
let config = || match args.model {
|
||||
WhichModel::V1 => Config::v1(),
|
||||
WhichModel::V1_5 => Config::v1_5(),
|
||||
WhichModel::V2 => Config::v2(),
|
||||
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
|
||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||
};
|
||||
let (model, device) = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = if args.quantized {
|
||||
let config = config();
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&filenames[0],
|
||||
&device,
|
||||
)?;
|
||||
let model = match args.model {
|
||||
WhichModel::V2 => QMixFormer::new_v2(&config, vb)?,
|
||||
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
|
||||
_ => QMixFormer::new(&config, vb)?,
|
||||
};
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
Model::Quantized(model)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = match args.model {
|
||||
WhichModel::V2 => MixFormer::new_v2(&config, vb)?,
|
||||
_ => MixFormer::new(&config, vb)?,
|
||||
};
|
||||
(Model::MixFormer(model), device)
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: PhiConfig = serde_json::from_str(&config)?;
|
||||
let phi = Phi::new(&config, vb)?;
|
||||
Model::Phi(phi)
|
||||
}
|
||||
WhichModel::V2Old => {
|
||||
let config = config();
|
||||
Model::MixFormer(MixFormer::new_v2(&config, vb)?)
|
||||
}
|
||||
WhichModel::PhiHermes | WhichModel::PuffinPhiV2 => {
|
||||
let config = config();
|
||||
Model::MixFormer(MixFormer::new(&config, vb)?)
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
@ -393,6 +417,10 @@ fn mmlu<P: AsRef<std::path::Path>>(
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
}
|
||||
Model::Phi(m) => {
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
}
|
||||
Model::Quantized(m) => {
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
|
@ -132,7 +132,8 @@ impl T5ModelBuilder {
|
||||
}
|
||||
|
||||
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
|
||||
let device = Device::Cpu;
|
||||
let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &device)?;
|
||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||
}
|
||||
|
||||
|
@ -9,7 +9,7 @@ use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{Device, Tensor};
|
||||
use candle::Tensor;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -361,6 +361,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(false)?;
|
||||
|
||||
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
|
||||
Some("gguf") => {
|
||||
@ -369,7 +370,7 @@ fn main() -> anyhow::Result<()> {
|
||||
for (_, tensor) in model.tensor_infos.iter() {
|
||||
let elem_count = tensor.shape.elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
@ -377,15 +378,16 @@ fn main() -> anyhow::Result<()> {
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
ModelWeights::from_gguf(model, &mut file)?
|
||||
ModelWeights::from_gguf(model, &mut file, &device)?
|
||||
}
|
||||
Some("ggml" | "bin") | Some(_) | None => {
|
||||
let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
|
||||
let model = ggml_file::Content::read(&mut file, &device)
|
||||
.map_err(|e| e.with_path(model_path))?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensors.iter() {
|
||||
let elem_count = tensor.shape().elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
|
||||
elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
@ -486,7 +488,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
@ -507,7 +509,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
|
@ -236,16 +236,15 @@ fn main() -> Result<()> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let config = Config::replit_code_v1_5_3b();
|
||||
let (model, device) = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
||||
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
|
||||
(model, Device::Cpu)
|
||||
let model = if args.quantized {
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
|
||||
Model::Q(Q::new(&config, vb.pp("transformer"))?)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||
let model = Model::M(M::new(&config, vb.pp("transformer"))?);
|
||||
(model, device)
|
||||
Model::M(M::new(&config, vb.pp("transformer"))?)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
22
candle-examples/examples/repvgg/README.md
Normal file
22
candle-examples/examples/repvgg/README.md
Normal file
@ -0,0 +1,22 @@
|
||||
# candle-repvgg
|
||||
|
||||
[RepVGG: Making VGG-style ConvNets Great Again](https://arxiv.org/abs/2101.03697).
|
||||
|
||||
This candle implementation uses a pre-trained RepVGG network for inference. The
|
||||
classification head has been trained on the ImageNet dataset and returns the
|
||||
probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example repvgg --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
mountain bike, all-terrain bike, off-roader: 61.70%
|
||||
bicycle-built-for-two, tandem bicycle, tandem: 33.14%
|
||||
unicycle, monocycle : 4.88%
|
||||
crash helmet : 0.15%
|
||||
moped : 0.04%
|
||||
|
||||
```
|
111
candle-examples/examples/repvgg/main.rs
Normal file
111
candle-examples/examples/repvgg/main.rs
Normal file
@ -0,0 +1,111 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::repvgg;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
A0,
|
||||
A1,
|
||||
A2,
|
||||
B0,
|
||||
B1,
|
||||
B2,
|
||||
B3,
|
||||
B1G4,
|
||||
B2G4,
|
||||
B3G4,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::A0 => "a0",
|
||||
Self::A1 => "a1",
|
||||
Self::A2 => "a2",
|
||||
Self::B0 => "b0",
|
||||
Self::B1 => "b1",
|
||||
Self::B2 => "b2",
|
||||
Self::B3 => "b3",
|
||||
Self::B1G4 => "b1g4",
|
||||
Self::B2G4 => "b2g4",
|
||||
Self::B3G4 => "b3g4",
|
||||
};
|
||||
format!("timm/repvgg_{}.rvgg_in1k", name)
|
||||
}
|
||||
|
||||
fn config(&self) -> repvgg::Config {
|
||||
match self {
|
||||
Self::A0 => repvgg::Config::a0(),
|
||||
Self::A1 => repvgg::Config::a1(),
|
||||
Self::A2 => repvgg::Config::a2(),
|
||||
Self::B0 => repvgg::Config::b0(),
|
||||
Self::B1 => repvgg::Config::b1(),
|
||||
Self::B2 => repvgg::Config::b2(),
|
||||
Self::B3 => repvgg::Config::b3(),
|
||||
Self::B1G4 => repvgg::Config::b1g4(),
|
||||
Self::B2G4 => repvgg::Config::b2g4(),
|
||||
Self::B3G4 => repvgg::Config::b3g4(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::A0)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = repvgg::repvgg(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -234,13 +234,14 @@ fn main() -> Result<()> {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||
let model = QStableLM::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
|
@ -557,8 +557,10 @@ fn main() -> Result<()> {
|
||||
println!("loaded mel: {:?}", mel.dims());
|
||||
|
||||
let mut model = if args.quantized {
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&weights_filename,
|
||||
&device,
|
||||
)?;
|
||||
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
|
||||
} else {
|
||||
let vb =
|
||||
|
@ -11,14 +11,14 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.3", package = "candle-core" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
bindgen_cuda = "0.1.1"
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
num_cpus = "1.15.0"
|
||||
rayon = "1.7.0"
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3", features = ["cuda"] }
|
||||
candle-nn = { path = "../candle-nn", features = ["cuda"] }
|
||||
|
@ -2,44 +2,32 @@
|
||||
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
|
||||
// variable in order to cache the compiled artifacts and avoid recompiling too often.
|
||||
use anyhow::{Context, Result};
|
||||
use rayon::prelude::*;
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
|
||||
const KERNEL_FILES: [&str; 17] = [
|
||||
"flash_api.cu",
|
||||
"flash_fwd_hdim128_fp16_sm80.cu",
|
||||
"flash_fwd_hdim160_fp16_sm80.cu",
|
||||
"flash_fwd_hdim192_fp16_sm80.cu",
|
||||
"flash_fwd_hdim224_fp16_sm80.cu",
|
||||
"flash_fwd_hdim256_fp16_sm80.cu",
|
||||
"flash_fwd_hdim32_fp16_sm80.cu",
|
||||
"flash_fwd_hdim64_fp16_sm80.cu",
|
||||
"flash_fwd_hdim96_fp16_sm80.cu",
|
||||
"flash_fwd_hdim128_bf16_sm80.cu",
|
||||
"flash_fwd_hdim160_bf16_sm80.cu",
|
||||
"flash_fwd_hdim192_bf16_sm80.cu",
|
||||
"flash_fwd_hdim224_bf16_sm80.cu",
|
||||
"flash_fwd_hdim256_bf16_sm80.cu",
|
||||
"flash_fwd_hdim32_bf16_sm80.cu",
|
||||
"flash_fwd_hdim64_bf16_sm80.cu",
|
||||
"flash_fwd_hdim96_bf16_sm80.cu",
|
||||
"kernels/flash_api.cu",
|
||||
"kernels/flash_fwd_hdim128_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim160_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim192_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim224_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim256_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim32_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim64_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim96_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim128_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim160_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim192_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim224_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim256_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim32_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim64_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim96_bf16_sm80.cu",
|
||||
];
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
|
||||
|_| num_cpus::get_physical(),
|
||||
|s| usize::from_str(&s).unwrap(),
|
||||
);
|
||||
|
||||
rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(num_cpus)
|
||||
.build_global()
|
||||
.unwrap();
|
||||
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
for kernel_file in KERNEL_FILES.iter() {
|
||||
println!("cargo:rerun-if-changed=kernels/{kernel_file}");
|
||||
println!("cargo:rerun-if-changed={kernel_file}");
|
||||
}
|
||||
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
|
||||
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
|
||||
@ -66,223 +54,30 @@ fn main() -> Result<()> {
|
||||
))
|
||||
}
|
||||
};
|
||||
set_cuda_include_dir()?;
|
||||
|
||||
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
|
||||
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
|
||||
|
||||
let compute_cap = compute_cap()?;
|
||||
let kernels = KERNEL_FILES.iter().collect();
|
||||
let builder = bindgen_cuda::Builder::default()
|
||||
.kernel_paths(kernels)
|
||||
.out_dir(build_dir.clone())
|
||||
.arg("-std=c++17")
|
||||
.arg("-O3")
|
||||
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
||||
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
||||
.arg("-Icutlass/include")
|
||||
.arg("--expt-relaxed-constexpr")
|
||||
.arg("--expt-extended-lambda")
|
||||
.arg("--use_fast_math")
|
||||
.arg("--verbose");
|
||||
|
||||
let out_file = build_dir.join("libflashattention.a");
|
||||
builder.build_lib(out_file);
|
||||
|
||||
let kernel_dir = PathBuf::from("kernels");
|
||||
let cu_files: Vec<_> = KERNEL_FILES
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let mut obj_file = out_dir.join(f);
|
||||
obj_file.set_extension("o");
|
||||
(kernel_dir.join(f), obj_file)
|
||||
})
|
||||
.collect();
|
||||
let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
|
||||
let should_compile = if out_file.exists() {
|
||||
kernel_dir
|
||||
.read_dir()
|
||||
.expect("kernels folder should exist")
|
||||
.any(|entry| {
|
||||
if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
|
||||
let in_modified = entry.metadata().unwrap().modified().unwrap();
|
||||
in_modified.duration_since(*out_modified).is_ok()
|
||||
} else {
|
||||
true
|
||||
}
|
||||
})
|
||||
} else {
|
||||
true
|
||||
};
|
||||
if should_compile {
|
||||
cu_files
|
||||
.par_iter()
|
||||
.map(|(cu_file, obj_file)| {
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
command
|
||||
.arg("-std=c++17")
|
||||
.arg("-O3")
|
||||
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
||||
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("-c")
|
||||
.args(["-o", obj_file.to_str().unwrap()])
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.arg("-Icutlass/include")
|
||||
.arg("--expt-relaxed-constexpr")
|
||||
.arg("--expt-extended-lambda")
|
||||
.arg("--use_fast_math")
|
||||
.arg("--verbose");
|
||||
if let Ok(ccbin_path) = &ccbin_env {
|
||||
command
|
||||
.arg("-allow-unsupported-compiler")
|
||||
.args(["-ccbin", ccbin_path]);
|
||||
}
|
||||
command.arg(cu_file);
|
||||
let output = command
|
||||
.spawn()
|
||||
.context("failed spawning nvcc")?
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
&command,
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<()>>()?;
|
||||
let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
command
|
||||
.arg("--lib")
|
||||
.args(["-o", out_file.to_str().unwrap()])
|
||||
.args(obj_files);
|
||||
let output = command
|
||||
.spawn()
|
||||
.context("failed spawning nvcc")?
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
&command,
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
}
|
||||
}
|
||||
println!("cargo:rustc-link-search={}", build_dir.display());
|
||||
println!("cargo:rustc-link-lib=flashattention");
|
||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||
|
||||
/* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
|
||||
finishing to run for some reason. Calling nvcc manually worked fine.
|
||||
cc::Build::new()
|
||||
.cuda(true)
|
||||
.include("cutlass/include")
|
||||
.flag("--expt-relaxed-constexpr")
|
||||
.flag("--default-stream")
|
||||
.flag("per-thread")
|
||||
.flag(&format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.file("kernels/flash_fwd_hdim32_fp16_sm80.cu")
|
||||
.compile("flashattn");
|
||||
*/
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_cuda_include_dir() -> Result<()> {
|
||||
// NOTE: copied from cudarc build.rs.
|
||||
let env_vars = [
|
||||
"CUDA_PATH",
|
||||
"CUDA_ROOT",
|
||||
"CUDA_TOOLKIT_ROOT_DIR",
|
||||
"CUDNN_LIB",
|
||||
];
|
||||
let env_vars = env_vars
|
||||
.into_iter()
|
||||
.map(std::env::var)
|
||||
.filter_map(Result::ok)
|
||||
.map(Into::<PathBuf>::into);
|
||||
|
||||
let roots = [
|
||||
"/usr",
|
||||
"/usr/local/cuda",
|
||||
"/opt/cuda",
|
||||
"/usr/lib/cuda",
|
||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||
"C:/CUDA",
|
||||
];
|
||||
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||
let root = env_vars
|
||||
.chain(roots)
|
||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||
.context("cannot find include/cuda.h")?;
|
||||
println!(
|
||||
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||
root.join("include").display()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn compute_cap() -> Result<usize> {
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
|
||||
// Try to parse compute caps from env
|
||||
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
|
||||
compute_cap_str
|
||||
.parse::<usize>()
|
||||
.context("Could not parse compute cap")?
|
||||
} else {
|
||||
// Use nvidia-smi to get the current compute cap
|
||||
let out = std::process::Command::new("nvidia-smi")
|
||||
.arg("--query-gpu=compute_cap")
|
||||
.arg("--format=csv")
|
||||
.output()
|
||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||
let mut lines = out.lines();
|
||||
assert_eq!(
|
||||
lines.next().context("missing line in stdout")?,
|
||||
"compute_cap"
|
||||
);
|
||||
let cap = lines
|
||||
.next()
|
||||
.context("missing line in stdout")?
|
||||
.replace('.', "");
|
||||
let cap = cap
|
||||
.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as int {cap}"))?;
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
|
||||
cap
|
||||
};
|
||||
|
||||
// Grab available GPU codes from nvcc and select the highest one
|
||||
let (supported_nvcc_codes, max_nvcc_code) = {
|
||||
let out = std::process::Command::new("nvcc")
|
||||
.arg("--list-gpu-code")
|
||||
.output()
|
||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||
|
||||
let out = out.lines().collect::<Vec<&str>>();
|
||||
let mut codes = Vec::with_capacity(out.len());
|
||||
for code in out {
|
||||
let code = code.split('_').collect::<Vec<&str>>();
|
||||
if !code.is_empty() && code.contains(&"sm") {
|
||||
if let Ok(num) = code[1].parse::<usize>() {
|
||||
codes.push(num);
|
||||
}
|
||||
}
|
||||
}
|
||||
codes.sort();
|
||||
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
|
||||
(codes, max_nvcc_code)
|
||||
};
|
||||
|
||||
// Check that nvcc supports the asked compute caps
|
||||
if !supported_nvcc_codes.contains(&compute_cap) {
|
||||
anyhow::bail!(
|
||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
|
||||
);
|
||||
}
|
||||
if compute_cap > max_nvcc_code {
|
||||
anyhow::bail!(
|
||||
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(compute_cap)
|
||||
}
|
||||
|
@ -12,6 +12,4 @@ license = "MIT OR Apache-2.0"
|
||||
[dependencies]
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
glob = "0.3.1"
|
||||
rayon = "1.7.0"
|
||||
bindgen_cuda = "0.1.1"
|
||||
|
@ -1,243 +1,8 @@
|
||||
use std::io::Write;
|
||||
|
||||
fn main() {
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
|
||||
cuda::set_include_dir();
|
||||
let (write, kernel_paths) = cuda::build_ptx();
|
||||
if write {
|
||||
let mut file = std::fs::File::create("src/lib.rs").unwrap();
|
||||
for kernel_path in kernel_paths {
|
||||
let name = kernel_path.file_stem().unwrap().to_str().unwrap();
|
||||
file.write_all(
|
||||
format!(
|
||||
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#,
|
||||
name.to_uppercase().replace('.', "_"),
|
||||
name
|
||||
)
|
||||
.as_bytes(),
|
||||
)
|
||||
.unwrap();
|
||||
file.write_all(&[b'\n']).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod cuda {
|
||||
use anyhow::{Context, Result};
|
||||
|
||||
pub fn set_include_dir() {
|
||||
use std::path::PathBuf;
|
||||
// NOTE: copied from cudarc build.rs.
|
||||
// We can't actually set a env!() value from another crate,
|
||||
// so we have to do that here.
|
||||
|
||||
// use PathBuf;
|
||||
|
||||
let env_vars = [
|
||||
"CUDA_PATH",
|
||||
"CUDA_ROOT",
|
||||
"CUDA_TOOLKIT_ROOT_DIR",
|
||||
"CUDNN_LIB",
|
||||
];
|
||||
#[allow(unused)]
|
||||
let env_vars = env_vars
|
||||
.into_iter()
|
||||
.map(std::env::var)
|
||||
.filter_map(Result::ok)
|
||||
.map(Into::<PathBuf>::into);
|
||||
|
||||
let roots = [
|
||||
"/usr",
|
||||
"/usr/local/cuda",
|
||||
"/opt/cuda",
|
||||
"/usr/lib/cuda",
|
||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||
"C:/CUDA",
|
||||
];
|
||||
#[allow(unused)]
|
||||
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||
|
||||
#[cfg(feature = "ci-check")]
|
||||
let root: PathBuf = "ci".into();
|
||||
|
||||
#[cfg(not(feature = "ci-check"))]
|
||||
let root = env_vars
|
||||
.chain(roots)
|
||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||
.unwrap();
|
||||
|
||||
println!(
|
||||
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||
root.join("include").display()
|
||||
);
|
||||
}
|
||||
|
||||
pub fn build_ptx() -> (bool, Vec<std::path::PathBuf>) {
|
||||
use rayon::prelude::*;
|
||||
use std::path::PathBuf;
|
||||
let out_dir = std::env::var("OUT_DIR").unwrap();
|
||||
let kernel_paths: Vec<PathBuf> = glob::glob("src/*.cu")
|
||||
.unwrap()
|
||||
.map(|p| p.unwrap())
|
||||
.collect();
|
||||
let mut include_directories: Vec<PathBuf> = glob::glob("src/**/*.cuh")
|
||||
.unwrap()
|
||||
.map(|p| p.unwrap())
|
||||
.collect();
|
||||
|
||||
println!("cargo:rerun-if-changed=src/");
|
||||
// for path in &kernel_paths {
|
||||
// println!("cargo:rerun-if-changed={}", path.display());
|
||||
// }
|
||||
|
||||
for path in &mut include_directories {
|
||||
// println!("cargo:rerun-if-changed={}", path.display());
|
||||
let destination =
|
||||
std::format!("{out_dir}/{}", path.file_name().unwrap().to_str().unwrap());
|
||||
std::fs::copy(path.clone(), destination).unwrap();
|
||||
// remove the filename from the path so it's just the directory
|
||||
path.pop();
|
||||
}
|
||||
|
||||
include_directories.sort();
|
||||
include_directories.dedup();
|
||||
|
||||
let compute_cap = compute_cap().expect("Could not get Cuda compute cap");
|
||||
|
||||
#[allow(unused)]
|
||||
let include_options: Vec<String> = include_directories
|
||||
.into_iter()
|
||||
.map(|s| "-I".to_string() + &s.into_os_string().into_string().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
|
||||
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
|
||||
let children = kernel_paths
|
||||
.par_iter()
|
||||
.flat_map(|p| {
|
||||
let mut output = p.clone();
|
||||
output.set_extension("ptx");
|
||||
let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap());
|
||||
|
||||
let ignore = if output_filename.exists() {
|
||||
let out_modified = output_filename.metadata().unwrap().modified().unwrap();
|
||||
let in_modified = p.metadata().unwrap().modified().unwrap();
|
||||
out_modified.duration_since(in_modified).is_ok()
|
||||
} else {
|
||||
false
|
||||
};
|
||||
if ignore {
|
||||
None
|
||||
} else {
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("--ptx")
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.args(["--output-directory", &out_dir])
|
||||
// Flash attention only
|
||||
// .arg("--expt-relaxed-constexpr")
|
||||
.args(&include_options);
|
||||
if let Ok(ccbin_path) = &ccbin_env {
|
||||
command
|
||||
.arg("-allow-unsupported-compiler")
|
||||
.args(["-ccbin", ccbin_path]);
|
||||
}
|
||||
command.arg(p);
|
||||
Some((p, command.spawn()
|
||||
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let ptx_paths: Vec<PathBuf> = glob::glob(&format!("{out_dir}/**/*.ptx"))
|
||||
.unwrap()
|
||||
.map(|p| p.unwrap())
|
||||
.collect();
|
||||
// We should rewrite `src/lib.rs` only if there are some newly compiled kernels, or removed
|
||||
// some old ones
|
||||
let write = !children.is_empty() || kernel_paths.len() < ptx_paths.len();
|
||||
for (kernel_path, child) in children {
|
||||
let output = child.expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
}
|
||||
(write, kernel_paths)
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn compute_cap() -> Result<usize> {
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
|
||||
// Try to parse compute caps from env
|
||||
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
|
||||
compute_cap_str
|
||||
.parse::<usize>()
|
||||
.context("Could not parse code")?
|
||||
} else {
|
||||
// Use nvidia-smi to get the current compute cap
|
||||
let out = std::process::Command::new("nvidia-smi")
|
||||
.arg("--query-gpu=compute_cap")
|
||||
.arg("--format=csv")
|
||||
.output()
|
||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||
let mut lines = out.lines();
|
||||
assert_eq!(
|
||||
lines.next().context("missing line in stdout")?,
|
||||
"compute_cap"
|
||||
);
|
||||
let cap = lines
|
||||
.next()
|
||||
.context("missing line in stdout")?
|
||||
.replace('.', "");
|
||||
let cap = cap
|
||||
.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as int {cap}"))?;
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
|
||||
cap
|
||||
};
|
||||
|
||||
// Grab available GPU codes from nvcc and select the highest one
|
||||
let (supported_nvcc_codes, max_nvcc_code) = {
|
||||
let out = std::process::Command::new("nvcc")
|
||||
.arg("--list-gpu-code")
|
||||
.output()
|
||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||
|
||||
let out = out.lines().collect::<Vec<&str>>();
|
||||
let mut codes = Vec::with_capacity(out.len());
|
||||
for code in out {
|
||||
let code = code.split('_').collect::<Vec<&str>>();
|
||||
if !code.is_empty() && code.contains(&"sm") {
|
||||
if let Ok(num) = code[1].parse::<usize>() {
|
||||
codes.push(num);
|
||||
}
|
||||
}
|
||||
}
|
||||
codes.sort();
|
||||
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
|
||||
(codes, max_nvcc_code)
|
||||
};
|
||||
|
||||
// Check that nvcc supports the asked compute caps
|
||||
if !supported_nvcc_codes.contains(&compute_cap) {
|
||||
anyhow::bail!(
|
||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
|
||||
);
|
||||
}
|
||||
if compute_cap > max_nvcc_code {
|
||||
anyhow::bail!(
|
||||
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(compute_cap)
|
||||
}
|
||||
let builder = bindgen_cuda::Builder::default();
|
||||
println!("cargo:info={builder:?}");
|
||||
let bindings = builder.build_ptx().unwrap();
|
||||
bindings.write("src/lib.rs").unwrap();
|
||||
}
|
||||
|
@ -9,12 +9,17 @@ keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
|
||||
[dependencies]
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
metal = { version = "0.27.0", features = ["mps"] }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
||||
[dev-dependencies]
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
half = { version = "2.3.1", features = [
|
||||
"num-traits",
|
||||
"use-intrinsics",
|
||||
"rand_distr",
|
||||
] }
|
||||
rand = "0.8.5"
|
||||
|
@ -17,19 +17,19 @@ METAL_FUNC uint get_strided_index(
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define AFFINE(FN_NAME, TYPENAME) \
|
||||
#define AFFINE(FN_NAME, T) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
constant float &mul, \
|
||||
constant float &add, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
device const T *input, \
|
||||
device T *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[id] = TYPENAME(float(input[id]) * mul + add); \
|
||||
output[id] = T(fma(float(input[id]), mul, add)); \
|
||||
} \
|
||||
kernel void FN_NAME##_strided( \
|
||||
constant size_t &dim, \
|
||||
@ -38,14 +38,14 @@ kernel void FN_NAME##_strided( \
|
||||
constant size_t *strides, \
|
||||
constant float &mul, \
|
||||
constant float &add, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
device const T *input, \
|
||||
device T *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \
|
||||
output[id] = T(fma(float(input[get_strided_index(id, num_dims, dims, strides)]), mul, add)); \
|
||||
}
|
||||
|
||||
#define POWF(FN_NAME, TYPENAME) \
|
||||
@ -117,7 +117,7 @@ ELU(elu_f32, float)
|
||||
ELU(elu_f16, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
AFFINE(affine_bf16, bfloat);
|
||||
POWF(powf_bf16, bfloat);
|
||||
ELU(elu_bf16, bfloat);
|
||||
|
@ -73,7 +73,7 @@ BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
|
||||
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
||||
|
||||
#define INT64_BINARY_OP_OUT(NAME, FN) \
|
||||
BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided);
|
||||
BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided);
|
||||
|
||||
BINARY_OP(x + y, add)
|
||||
BINARY_OP(x - y, sub)
|
||||
@ -105,7 +105,7 @@ INT64_BINARY_OP_OUT(ge, x >= y)
|
||||
INT64_BINARY_OP_OUT(gt, x > y)
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
BFLOAT_BINARY_OP(x + y, add)
|
||||
BFLOAT_BINARY_OP(x - y, sub)
|
||||
BFLOAT_BINARY_OP(x * y, mul)
|
||||
|
@ -28,7 +28,7 @@ kernel void FN_NAME( \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = RIGHT_TYPENAME(input[tid]); \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(input[tid]); \
|
||||
} \
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
@ -42,7 +42,34 @@ kernel void FN_NAME_STRIDED( \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)]); \
|
||||
} \
|
||||
|
||||
#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[tid])); \
|
||||
} \
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
|
||||
} \
|
||||
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||
@ -58,7 +85,14 @@ CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
|
||||
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
|
||||
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
|
||||
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
|
||||
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
|
||||
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
#endif
|
||||
|
||||
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
|
||||
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
|
||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||
#endif
|
@ -173,7 +173,10 @@ SCATTER_ADD_OP(sa_u32_f32, uint, float)
|
||||
SCATTER_ADD_OP(sa_u32_f16, uint, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
|
||||
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
||||
|
||||
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
||||
|
@ -12,9 +12,11 @@ const UNARY: &str = include_str!("unary.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||
@ -61,8 +63,12 @@ macro_rules! primitive {
|
||||
}
|
||||
};
|
||||
}
|
||||
primitive!(bool);
|
||||
primitive!(usize);
|
||||
primitive!(i32);
|
||||
primitive!(i64);
|
||||
primitive!(u32);
|
||||
primitive!(u64);
|
||||
primitive!(f32);
|
||||
|
||||
impl<T> EncoderParam for &[T] {
|
||||
@ -117,6 +123,8 @@ pub enum Source {
|
||||
Reduce,
|
||||
Mfa,
|
||||
Conv,
|
||||
Random,
|
||||
Quantized,
|
||||
}
|
||||
|
||||
macro_rules! ops{
|
||||
@ -174,8 +182,8 @@ macro_rules! ops{
|
||||
|
||||
pub mod unary {
|
||||
ops!(
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
|
||||
recip
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||
tanh, recip
|
||||
);
|
||||
}
|
||||
pub mod binary {
|
||||
@ -215,17 +223,15 @@ type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipeline
|
||||
pub struct Kernels {
|
||||
libraries: RwLock<Libraries>,
|
||||
pipelines: RwLock<Pipelines>,
|
||||
fence: metal::Fence,
|
||||
}
|
||||
|
||||
impl Kernels {
|
||||
pub fn new(fence: metal::Fence) -> Self {
|
||||
pub fn new() -> Self {
|
||||
let libraries = RwLock::new(Libraries::new());
|
||||
let pipelines = RwLock::new(Pipelines::new());
|
||||
Self {
|
||||
libraries,
|
||||
pipelines,
|
||||
fence,
|
||||
}
|
||||
}
|
||||
|
||||
@ -239,6 +245,8 @@ impl Kernels {
|
||||
Source::Cast => CAST,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Conv => CONV,
|
||||
Source::Random => RANDOM,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
@ -345,7 +353,6 @@ pub fn call_unary_contiguous(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, output));
|
||||
@ -354,7 +361,6 @@ pub fn call_unary_contiguous(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -376,7 +382,6 @@ pub fn call_unary_strided(
|
||||
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
@ -398,7 +403,6 @@ pub fn call_unary_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -417,7 +421,6 @@ pub fn call_binary_contiguous(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, left, right, output));
|
||||
@ -428,7 +431,6 @@ pub fn call_binary_contiguous(
|
||||
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -453,7 +455,6 @@ pub fn call_binary_strided(
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let width: usize = shape.iter().product();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
@ -478,7 +479,6 @@ pub fn call_binary_strided(
|
||||
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -497,7 +497,6 @@ pub fn call_cast_contiguous(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, (input, input_offset), output));
|
||||
@ -506,7 +505,6 @@ pub fn call_cast_contiguous(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -526,7 +524,6 @@ pub fn call_cast_strided(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
@ -548,7 +545,6 @@ pub fn call_cast_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -671,7 +667,6 @@ pub fn call_last_softmax(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -702,7 +697,6 @@ pub fn call_last_softmax(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -722,7 +716,6 @@ pub fn call_affine(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, add, input, output));
|
||||
@ -731,7 +724,6 @@ pub fn call_affine(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -754,7 +746,6 @@ pub fn call_affine_strided(
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -775,7 +766,6 @@ pub fn call_affine_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -794,7 +784,6 @@ pub fn call_powf(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, input, output));
|
||||
@ -803,7 +792,6 @@ pub fn call_powf(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -825,7 +813,6 @@ pub fn call_powf_strided(
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -845,7 +832,6 @@ pub fn call_powf_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -864,7 +850,6 @@ pub fn call_elu(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, input, output));
|
||||
@ -873,7 +858,6 @@ pub fn call_elu(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -895,7 +879,6 @@ pub fn call_elu_strided(
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -915,7 +898,6 @@ pub fn call_elu_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -937,7 +919,6 @@ pub fn call_where_cond_strided(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let size: usize = shape.iter().product();
|
||||
@ -966,7 +947,6 @@ pub fn call_where_cond_strided(
|
||||
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -993,7 +973,6 @@ pub fn call_index_select(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -1016,7 +995,6 @@ pub fn call_index_select(
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -1045,7 +1023,6 @@ pub fn call_gather(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -1068,7 +1045,6 @@ pub fn call_gather(
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -1097,7 +1073,6 @@ pub fn call_scatter_add(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -1120,7 +1095,6 @@ pub fn call_scatter_add(
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -1150,7 +1124,6 @@ pub fn call_index_add(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -1174,7 +1147,6 @@ pub fn call_index_add(
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -1378,7 +1350,6 @@ pub fn call_gemm(
|
||||
let block_bytes = block_elements * bytes;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||
@ -1418,12 +1389,10 @@ pub fn call_gemm(
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
// println!("grid size {grid_size:?} group size {group_size:?}");
|
||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
@ -1448,7 +1417,6 @@ pub fn call_im2col1d_strided(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
@ -1468,7 +1436,6 @@ pub fn call_im2col1d_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
@ -1498,7 +1465,6 @@ pub fn call_im2col_strided(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
@ -1520,7 +1486,6 @@ pub fn call_im2col_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
@ -1546,7 +1511,6 @@ pub fn call_upsample_nearest_2d(
|
||||
let scale_h = shape[3] as f32 / out_h as f32;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
@ -1564,7 +1528,243 @@ pub fn call_upsample_nearest_2d(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_random_uniform(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
min: f32,
|
||||
max: f32,
|
||||
length: usize,
|
||||
seed: &Buffer,
|
||||
buffer: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
if min >= max {
|
||||
return Err(MetalKernelError::LoadLibraryError(
|
||||
"min must be less than max".to_string(),
|
||||
));
|
||||
}
|
||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let odd = (length % 2 != 0) as usize;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, min, max, seed, buffer));
|
||||
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_random_normal(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
mean: f32,
|
||||
stddev: f32,
|
||||
length: usize,
|
||||
seed: &Buffer,
|
||||
buffer: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let odd = (length % 2 != 0) as usize;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, mean, stddev, seed, buffer));
|
||||
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum GgmlDType {
|
||||
Q4_0,
|
||||
Q4_1,
|
||||
Q5_0,
|
||||
Q5_1,
|
||||
Q8_0,
|
||||
Q8_1,
|
||||
Q2K,
|
||||
Q3K,
|
||||
Q4K,
|
||||
Q5K,
|
||||
Q6K,
|
||||
Q8K,
|
||||
F16,
|
||||
F32,
|
||||
}
|
||||
|
||||
pub fn call_quantized_matmul_t(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
dtype: GgmlDType,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs: &Buffer,
|
||||
lhs_offset: usize,
|
||||
rhs: &Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// Everything is in reverse
|
||||
let ne00 = k as i64;
|
||||
let ne01 = n as i64;
|
||||
let ne02 = b as i64;
|
||||
let ne03 = 1 as i64;
|
||||
|
||||
let nb00 = 0i64;
|
||||
let nb01 = 0 as i64;
|
||||
let nb02 = 0 as i64;
|
||||
|
||||
let ne10 = k as i64;
|
||||
let ne11 = m as i64;
|
||||
let ne12 = b as i64;
|
||||
let ne13 = 1 as i64;
|
||||
|
||||
let nb10 = 0i64;
|
||||
let nb11 = 0i64;
|
||||
let nb12 = 0i64;
|
||||
|
||||
let ne0 = n as i64;
|
||||
let ne1 = m as i64;
|
||||
let r2: u32 = (ne12 / ne02) as u32;
|
||||
let r3: u32 = (ne13 / ne03) as u32;
|
||||
|
||||
let (nth0, nth1, align) = match dtype {
|
||||
GgmlDType::Q4_0
|
||||
| GgmlDType::Q4_1
|
||||
| GgmlDType::Q5_0
|
||||
| GgmlDType::Q5_1
|
||||
| GgmlDType::Q8_0
|
||||
| GgmlDType::Q8_1 => {
|
||||
let nth0 = 8;
|
||||
let nth1 = 8;
|
||||
let align = 8;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
// Fixing a bug in Metal for GGML
|
||||
let nth0 = 4;
|
||||
let nth1 = 8;
|
||||
let align = 4;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let nth0 = 4;
|
||||
let nth1 = 8;
|
||||
let align = 4;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::Q3K | GgmlDType::Q5K => {
|
||||
let nth0 = 2;
|
||||
let nth1 = 32;
|
||||
let align = 4;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let nth0 = 2;
|
||||
let nth1 = 32;
|
||||
let align = 2;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::F16 | GgmlDType::Q8K => {
|
||||
// Original implem uses rows
|
||||
let nth0 = 32;
|
||||
let nth1 = 1;
|
||||
let align = 8;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::F32 => {
|
||||
let nth0 = 32;
|
||||
let nth1 = 1;
|
||||
let align = 8;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
};
|
||||
let thread_groups_count = MTLSize {
|
||||
width: divide(ne01 as usize, align),
|
||||
height: ne11 as u64,
|
||||
depth: (ne12 * ne13) as u64,
|
||||
};
|
||||
let threads_per_threadgroup = MTLSize {
|
||||
width: nth0,
|
||||
height: nth1,
|
||||
depth: 1,
|
||||
};
|
||||
let name = match dtype {
|
||||
GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32",
|
||||
GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32",
|
||||
GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32",
|
||||
GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32",
|
||||
GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32",
|
||||
GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32",
|
||||
GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32",
|
||||
GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32",
|
||||
GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32",
|
||||
GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32",
|
||||
GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32",
|
||||
GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32",
|
||||
GgmlDType::F16 => "kernel_mul_mv_f16_f32",
|
||||
GgmlDType::F32 => "kernel_mul_mv_f32_f32",
|
||||
};
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
rhs,
|
||||
(lhs, lhs_offset),
|
||||
output,
|
||||
ne00,
|
||||
ne01,
|
||||
ne02,
|
||||
nb00,
|
||||
nb01,
|
||||
nb02,
|
||||
ne10,
|
||||
ne11,
|
||||
ne12,
|
||||
nb10,
|
||||
nb11,
|
||||
nb12,
|
||||
ne0,
|
||||
ne1,
|
||||
r2,
|
||||
r3
|
||||
)
|
||||
);
|
||||
encoder.set_threadgroup_memory_length(0, 8192);
|
||||
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
|
5107
candle-metal-kernels/src/quantized.metal
Normal file
5107
candle-metal-kernels/src/quantized.metal
Normal file
File diff suppressed because it is too large
Load Diff
206
candle-metal-kernels/src/random.metal
Normal file
206
candle-metal-kernels/src/random.metal
Normal file
@ -0,0 +1,206 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_integer>
|
||||
#include <metal_atomic>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// Constants
|
||||
// 2^32 and 1/2^32. Useful for converting between float and uint.
|
||||
static constexpr constant ulong UNIF01_NORM32 = 4294967296;
|
||||
static constexpr constant float UNIF01_INV32 = 2.328306436538696289e-10;
|
||||
// 2 * pi
|
||||
static constexpr constant float TWO_PI = 2.0 * M_PI_F;
|
||||
static constexpr constant int3 S1 = {13, 19, 12};
|
||||
static constexpr constant int3 S2 = {2, 25, 4};
|
||||
static constexpr constant int3 S3 = {3, 11, 17};
|
||||
|
||||
// Used to prevent bad seeds.
|
||||
static constexpr constant uint64_t PHI[16] = {
|
||||
0x9E3779B97F4A7C15,
|
||||
0xF39CC0605CEDC834,
|
||||
0x1082276BF3A27251,
|
||||
0xF86C6A11D0C18E95,
|
||||
0x2767F0B153D27B7F,
|
||||
0x0347045B5BF1827F,
|
||||
0x01886F0928403002,
|
||||
0xC1D64BA40F335E36,
|
||||
0xF06AD7AE9717877E,
|
||||
0x85839D6EFFBD7DC6,
|
||||
0x64D325D1C5371682,
|
||||
0xCADD0CCCFDFFBBE1,
|
||||
0x626E33B8D04B4331,
|
||||
0xBBF73C790D94F79D,
|
||||
0x471C4AB3ED3D82A5,
|
||||
0xFEC507705E4AE6E5,
|
||||
};
|
||||
|
||||
// Combined Tausworthe and LCG Random Number Generator.
|
||||
// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application
|
||||
// https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf
|
||||
struct HybridTaus {
|
||||
|
||||
float state;
|
||||
|
||||
HybridTaus() thread = default;
|
||||
HybridTaus() threadgroup = default;
|
||||
HybridTaus() device = default;
|
||||
HybridTaus() constant = default;
|
||||
|
||||
// Generate seeds for each thread.
|
||||
METAL_FUNC static uint4 seed_per_thread(const ulong4 seeds) {
|
||||
return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL));
|
||||
}
|
||||
|
||||
// Tausworthe generator.
|
||||
METAL_FUNC static uint taus(const uint z, const int3 s, const uint M) {
|
||||
uint b = (((z << s.x) ^ z) >> s.y);
|
||||
return (((z & M) << s.z) ^ b);
|
||||
}
|
||||
|
||||
// LCG generator.
|
||||
METAL_FUNC static uint lcg(const uint z) {
|
||||
return (1664525 * z + 1013904223UL);
|
||||
}
|
||||
|
||||
// Initialize the RNG state.
|
||||
METAL_FUNC static HybridTaus init(const ulong4 seeds) {
|
||||
uint4 seed = seed_per_thread(seeds);
|
||||
|
||||
// Seed #1
|
||||
uint z1 = taus(seed.x, S1, 4294967294UL);
|
||||
uint z2 = taus(seed.y, S2, 4294967288UL);
|
||||
uint z3 = taus(seed.z, S3, 4294967280UL);
|
||||
uint z4 = lcg(seed.x);
|
||||
|
||||
// Seed #2
|
||||
uint r1 = (z1^z2^z3^z4^seed.y);
|
||||
z1 = taus(r1, S1, 429496729UL);
|
||||
z2 = taus(r1, S2, 4294967288UL);
|
||||
z3 = taus(r1, S3, 429496280UL);
|
||||
z4 = lcg(r1);
|
||||
|
||||
// Seed #3
|
||||
r1 = (z1^z2^z3^z4^seed.z);
|
||||
z1 = taus(r1, S1, 429496729UL);
|
||||
z2 = taus(r1, S2, 4294967288UL);
|
||||
z3 = taus(r1, S3, 429496280UL);
|
||||
z4 = lcg(r1);
|
||||
|
||||
// Seed #4
|
||||
r1 = (z1^z2^z3^z4^seed.w);
|
||||
z1 = taus(r1, S1, 429496729UL);
|
||||
z2 = taus(r1, S2, 4294967288UL);
|
||||
z3 = taus(r1, S3, 429496280UL);
|
||||
z4 = lcg(r1);
|
||||
|
||||
HybridTaus rng;
|
||||
rng.state = (z1^z2^z3^z4) * UNIF01_INV32;
|
||||
return rng;
|
||||
}
|
||||
|
||||
METAL_FUNC float rand() {
|
||||
uint seed = this->state * UNIF01_NORM32;
|
||||
uint z1 = taus(seed, S1, 429496729UL);
|
||||
uint z2 = taus(seed, S2, 4294967288UL);
|
||||
uint z3 = taus(seed, S3, 429496280UL);
|
||||
uint z4 = lcg(seed);
|
||||
|
||||
thread float result = this->state;
|
||||
this->state = (z1^z2^z3^z4) * UNIF01_INV32;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T> METAL_FUNC void rand_uniform(
|
||||
constant size_t &size,
|
||||
constant float &min,
|
||||
constant float &max,
|
||||
device atomic_uint *seed,
|
||||
device T *out,
|
||||
uint tid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (tid >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
float diff = abs(min - max);
|
||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||
out[tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
if (tid == 0) {
|
||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||
// Return early if tid == 0, otherwise we will write to out[size].
|
||||
return;
|
||||
}
|
||||
// Use symmetry to fill the other half of the array.
|
||||
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
}
|
||||
|
||||
// Create Gaussian normal distribution using Box-Muller transform:
|
||||
// https://en.wikipedia.org/wiki/Box–Muller_transform
|
||||
template<typename T> METAL_FUNC void normal(
|
||||
constant size_t &size,
|
||||
constant float &mean,
|
||||
constant float &stddev,
|
||||
device atomic_uint *seed,
|
||||
device T *out,
|
||||
uint tid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (tid >= size) {
|
||||
return;
|
||||
}
|
||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||
float u1 = rng.rand();
|
||||
float u2 = rng.rand();
|
||||
|
||||
float cosval;
|
||||
float sinval = sincos(TWO_PI * u2, cosval);
|
||||
float mag = stddev * sqrt(-2.0 * log(u1));
|
||||
float z0 = mag * cosval + mean;
|
||||
float z1 = mag * sinval + mean;
|
||||
|
||||
out[tid] = static_cast<T>(z0);
|
||||
|
||||
if (tid == 0) {
|
||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||
// Return early if tid == 0, otherwise we will write to out[size].
|
||||
return;
|
||||
}
|
||||
// Use symmetry to fill the other half of the array.
|
||||
out[size - tid] = static_cast<T>(z1);
|
||||
}
|
||||
|
||||
#define UNIFORM_OP(NAME, T) \
|
||||
kernel void rand_uniform_##NAME( \
|
||||
constant size_t &size, \
|
||||
constant float &min, \
|
||||
constant float &max, \
|
||||
device atomic_uint *seed, \
|
||||
device T *out, \
|
||||
uint tid [[thread_position_in_grid]] \
|
||||
) { \
|
||||
rand_uniform<T>(size, min, max, seed, out, tid); \
|
||||
} \
|
||||
|
||||
#define NORMAL_OP(NAME, T) \
|
||||
kernel void rand_normal_##NAME( \
|
||||
constant size_t &size, \
|
||||
constant float &mean, \
|
||||
constant float &stddev, \
|
||||
device atomic_uint *seed, \
|
||||
device T *out, \
|
||||
uint tid [[thread_position_in_grid]] \
|
||||
) { \
|
||||
normal<T>(size, mean, stddev, seed, out, tid); \
|
||||
} \
|
||||
|
||||
|
||||
#define RANDOM_OPS(NAME, T) \
|
||||
UNIFORM_OP(NAME, T) \
|
||||
NORMAL_OP(NAME, T) \
|
||||
|
||||
RANDOM_OPS(f32, float)
|
||||
RANDOM_OPS(f16, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
RANDOM_OPS(bf16, bfloat)
|
||||
#endif
|
@ -649,10 +649,9 @@ REDUCE(Max, fast_max_i64, int64_t)
|
||||
|
||||
ARG_REDUCE(ArgMin, fast_argmin_i64, int64_t)
|
||||
ARG_REDUCE(ArgMax, fast_argmax_i64, int64_t)
|
||||
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
REDUCE(Sum, fast_sum_bf16, bfloat)
|
||||
REDUCE(Mul, fast_mul_bf16, bfloat)
|
||||
REDUCE(Max, fast_max_bf16, bfloat)
|
||||
|
@ -17,29 +17,45 @@ METAL_FUNC uint get_strided_index(
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
template<typename T, typename ID>
|
||||
METAL_FUNC void where_cond(
|
||||
constant size_t &numel,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t *strides_t,
|
||||
constant size_t *strides_f,
|
||||
device const ID *ids,
|
||||
device const T *t,
|
||||
device const T *f,
|
||||
device T *out,
|
||||
uint i [[ thread_position_in_grid ]]
|
||||
) {
|
||||
if (i >= numel){
|
||||
return;
|
||||
}
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t);
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f);
|
||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f];
|
||||
}
|
||||
|
||||
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &numel, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t *strides_t, \
|
||||
constant size_t *strides_f, \
|
||||
device const ID_TYPENAME *ids, \
|
||||
device const TYPENAME *t, \
|
||||
device const TYPENAME *f, \
|
||||
device TYPENAME *out ,\
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (i >= numel){ \
|
||||
return; \
|
||||
} \
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
|
||||
} \
|
||||
#define WHERE_OP(T, ID, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &numel, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t *strides_t, \
|
||||
constant size_t *strides_f, \
|
||||
device const ID *ids, \
|
||||
device const T *t, \
|
||||
device const T *f, \
|
||||
device T *out, \
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \
|
||||
} \
|
||||
|
||||
// WHERE_OP(float, int64_t, where_i64_f32)
|
||||
// WHERE_OP(double, int64_t, where_i64_f64)
|
||||
@ -54,10 +70,14 @@ kernel void FN_NAME( \
|
||||
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
||||
|
||||
WHERE_OP(float, uint8_t, where_u8_f32)
|
||||
// WHERE_OP(double, uint8_t, where_u8_f64)
|
||||
WHERE_OP(half, uint8_t, where_u8_f16)
|
||||
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
||||
WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
WHERE_OP(bfloat, uint8_t, where_u8_bf16)
|
||||
#endif
|
@ -1,6 +1,6 @@
|
||||
use super::*;
|
||||
use half::{bf16, f16};
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
use metal::{Buffer, Device, MTLResourceOptions};
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let ptr = buffer.contents() as *const T;
|
||||
@ -11,7 +11,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
|
||||
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let ptr = data.as_ptr() as *const core::ffi::c_void;
|
||||
let ptr = data.as_ptr() as *const c_void;
|
||||
let size = (data.len() * std::mem::size_of::<T>()) as u64;
|
||||
device.new_buffer_with_data(ptr, size, options)
|
||||
}
|
||||
@ -37,8 +37,7 @@ fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
|
||||
|
||||
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
@ -60,8 +59,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
|
||||
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
@ -96,8 +94,7 @@ fn run_strided<T: Clone>(
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
call_unary_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -248,10 +245,37 @@ fn binary_add_f32() {
|
||||
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_ops_bf16() {
|
||||
let lhs: Vec<bf16> = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect();
|
||||
let rhs: Vec<bf16> = [4.2f32, 5.5f32, 6.91f32]
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect();
|
||||
|
||||
macro_rules! binary_op {
|
||||
($opname:ident, $opexpr:expr) => {{
|
||||
let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT);
|
||||
let expected: Vec<bf16> = lhs
|
||||
.iter()
|
||||
.zip(rhs.iter())
|
||||
.map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y))
|
||||
.collect();
|
||||
assert_eq!(results, expected);
|
||||
}};
|
||||
}
|
||||
|
||||
binary_op!(add, |x, y| x + y);
|
||||
binary_op!(sub, |x, y| x - y);
|
||||
binary_op!(mul, |x, y| x * y);
|
||||
binary_op!(div, |x, y| x / y);
|
||||
binary_op!(min, |x: bf16, y| x.min(y));
|
||||
binary_op!(max, |x: bf16, y| x.max(y));
|
||||
}
|
||||
|
||||
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
@ -296,10 +320,92 @@ fn cast_u32_f32() {
|
||||
assert_eq!(results, vec![1.0f32; 10_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_u32() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<u32> = cast(&input, "cast_bf16_u32");
|
||||
let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_f32() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<f32> = cast(&input, "cast_bf16_f32");
|
||||
let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_u8_bf16() {
|
||||
let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
|
||||
let expected: Vec<bf16> = input
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v as f32))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_u32_bf16() {
|
||||
let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_f32_bf16() {
|
||||
let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_u8() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<u8> = cast(&input, "cast_bf16_u8");
|
||||
let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_f16() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<f16> = cast(&input, "cast_bf16_f16");
|
||||
let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_f16_bf16() {
|
||||
let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
@ -334,8 +440,7 @@ fn run_affine_strided<T: Clone>(
|
||||
add: f64,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
@ -396,14 +501,14 @@ fn index_select() {
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
|
||||
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
|
||||
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [2, 5];
|
||||
let ids = [0u32, 1, 0];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||
@ -419,20 +524,46 @@ fn index_select_f16() {
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16");
|
||||
assert_eq!(
|
||||
approx_f16(result, 4),
|
||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_is_u32_bf16() {
|
||||
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16");
|
||||
assert_eq!(
|
||||
approx_bf16(result, 4),
|
||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_is_u8_bf16() {
|
||||
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
|
||||
let shape = [5, 2];
|
||||
let ids = [0u8, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16");
|
||||
assert_eq!(
|
||||
approx_bf16(result, 4),
|
||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_dim1() {
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 1, 0];
|
||||
let dim = 1;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
||||
@ -444,6 +575,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
shape: &[usize],
|
||||
ids: &[I],
|
||||
dim: usize,
|
||||
name: &'static str,
|
||||
) -> Vec<T> {
|
||||
let device = Device::system_default().expect("no device found");
|
||||
|
||||
@ -457,14 +589,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
let dst_el = ids.len() * left_size * right_size;
|
||||
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
|
||||
let name = match core::mem::size_of::<T>() {
|
||||
4 => "is_u32_f32",
|
||||
2 => "is_u32_f16",
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
call_index_select(
|
||||
&device,
|
||||
&command_buffer,
|
||||
@ -499,8 +624,7 @@ fn cos_f16() {
|
||||
|
||||
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
@ -520,25 +644,16 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
||||
&input,
|
||||
0,
|
||||
&output,
|
||||
);
|
||||
match result {
|
||||
Ok(_) => {
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Error: {}", e);
|
||||
panic!("damn!");
|
||||
}
|
||||
}
|
||||
).unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
read_to_vec(&output, out_length)
|
||||
}
|
||||
|
||||
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
@ -656,8 +771,7 @@ fn run_where_cond<I: Clone, T: Clone>(
|
||||
name: &'static str,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
@ -733,8 +847,7 @@ fn run_gemm<T: Clone>(
|
||||
rhs_offset: usize,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
@ -812,3 +925,124 @@ fn gemm() {
|
||||
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
|
||||
);
|
||||
}
|
||||
|
||||
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options);
|
||||
|
||||
let seed = device.new_buffer_with_data(
|
||||
&seed as *const u32 as *const core::ffi::c_void,
|
||||
std::mem::size_of::<u32>() as NSUInteger,
|
||||
options,
|
||||
);
|
||||
|
||||
if name.starts_with("rand_uniform") {
|
||||
call_random_uniform(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
a,
|
||||
b,
|
||||
length,
|
||||
&seed,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
} else {
|
||||
call_random_normal(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
a,
|
||||
b,
|
||||
length,
|
||||
&seed,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
read_to_vec(&output, length)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn random() {
|
||||
fn calc_mean(data: &[f32]) -> f32 {
|
||||
let sum = data.iter().sum::<f32>() as f32;
|
||||
let count = data.len();
|
||||
assert!(count > 0);
|
||||
sum / count as f32
|
||||
}
|
||||
|
||||
fn calc_stddev(data: &[f32]) -> f32 {
|
||||
let mean = calc_mean(data);
|
||||
let count = data.len();
|
||||
assert!(count > 0);
|
||||
|
||||
let variance = data
|
||||
.iter()
|
||||
.map(|value| {
|
||||
let diff = mean - (*value as f32);
|
||||
diff * diff
|
||||
})
|
||||
.sum::<f32>()
|
||||
/ count as f32;
|
||||
|
||||
variance.sqrt()
|
||||
}
|
||||
|
||||
let shape = vec![1024, 10];
|
||||
|
||||
let length = shape.iter().product::<usize>();
|
||||
let seed = 299792458;
|
||||
|
||||
let min = -30.0;
|
||||
let max = 30.0;
|
||||
let mean = 100.0;
|
||||
let stddev = 50.0;
|
||||
|
||||
macro_rules! validate_random {
|
||||
($type:ty) => {
|
||||
let results: Vec<f32> = run_random::<$type>(
|
||||
concat!("rand_uniform_", stringify!($type)),
|
||||
seed,
|
||||
length,
|
||||
min,
|
||||
max,
|
||||
)
|
||||
.into_iter()
|
||||
.map(f32::from)
|
||||
.collect();
|
||||
results.iter().for_each(|v| {
|
||||
assert!(*v >= min && *v <= max);
|
||||
});
|
||||
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
|
||||
|
||||
let results: Vec<f32> = run_random::<$type>(
|
||||
concat!("rand_normal_", stringify!($type)),
|
||||
seed,
|
||||
length,
|
||||
mean,
|
||||
stddev,
|
||||
)
|
||||
.into_iter()
|
||||
.map(f32::from)
|
||||
.collect();
|
||||
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
|
||||
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
|
||||
};
|
||||
}
|
||||
|
||||
validate_random!(f32);
|
||||
validate_random!(f16);
|
||||
validate_random!(bf16);
|
||||
}
|
||||
|
@ -58,6 +58,12 @@ template <typename T> METAL_FUNC T gelu(T x) {
|
||||
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
||||
}
|
||||
template <typename T> METAL_FUNC T relu(T in){
|
||||
if (in < 0) {
|
||||
return 0;
|
||||
}
|
||||
return in;
|
||||
}
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
@ -110,7 +116,7 @@ UNARY_OP(gelu_erf)
|
||||
UNARY_OP(erf)
|
||||
UNARY_OP(tanh)
|
||||
UNARY_OP(recip)
|
||||
|
||||
UNARY_OP(relu)
|
||||
UNARY(id, float, copy_f32, copy_f32_strided)
|
||||
UNARY(id, half, copy_f16, copy_f16_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
@ -120,7 +126,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
BFLOAT_UNARY_OP(cos)
|
||||
BFLOAT_UNARY_OP(sin)
|
||||
BFLOAT_UNARY_OP(sqr)
|
||||
@ -129,6 +135,7 @@ BFLOAT_UNARY_OP(neg)
|
||||
BFLOAT_UNARY_OP(exp)
|
||||
BFLOAT_UNARY_OP(log)
|
||||
BFLOAT_UNARY_OP(gelu)
|
||||
BFLOAT_UNARY_OP(abs)
|
||||
BFLOAT_UNARY_OP(ceil)
|
||||
BFLOAT_UNARY_OP(floor)
|
||||
BFLOAT_UNARY_OP(round)
|
||||
@ -136,6 +143,7 @@ BFLOAT_UNARY_OP(gelu_erf)
|
||||
BFLOAT_UNARY_OP(erf)
|
||||
BFLOAT_UNARY_OP(tanh)
|
||||
BFLOAT_UNARY_OP(recip)
|
||||
BFLOAT_UNARY_OP(relu)
|
||||
|
||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
#endif
|
||||
|
@ -11,7 +11,7 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle = { workspace = true }
|
||||
half = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
@ -20,7 +20,7 @@ rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
metal = { workspace = true, optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
candle-metal-kernels = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
@ -222,7 +222,10 @@ impl Benchmark for QMatMul {
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
|
||||
let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
|
||||
let mm = candle::quantized::QTensor::new(
|
||||
candle::quantized::QStorage::Cpu(Box::new(zeros)),
|
||||
(4096, 11008),
|
||||
)?;
|
||||
let mm = candle::quantized::QMatMul::from_qtensor(mm)?;
|
||||
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
|
||||
Ok((mm, arg))
|
||||
|
@ -6,6 +6,7 @@ use serde::Deserialize;
|
||||
pub enum Activation {
|
||||
#[default]
|
||||
Gelu,
|
||||
#[serde(alias = "gelu_new")]
|
||||
NewGelu,
|
||||
Relu,
|
||||
Relu2,
|
||||
|
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
candle = { path = "../candle-core", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
@ -20,4 +20,3 @@ prost-build = "0.12.1"
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
|
||||
|
@ -254,6 +254,12 @@ pub fn simple_eval(
|
||||
let output = input0.broadcast_div(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Pow" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
let output = input0.broadcast_pow(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Equal" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
|
@ -15,9 +15,9 @@ crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
candle-onnx = {path= "../candle-onnx", version = "0.3.3", optional = true}
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-onnx = { workspace = true, optional = true }
|
||||
half = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
||||
|
@ -33,7 +33,9 @@ def has_mkl() -> bool:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
|
||||
def load_ggml(
|
||||
path: Union[str, PathLike], device: Optional[Device] = None
|
||||
) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
|
||||
"""
|
||||
Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
|
||||
a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
|
||||
@ -41,7 +43,9 @@ def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str,
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def load_gguf(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
|
||||
def load_gguf(
|
||||
path: Union[str, PathLike], device: Optional[Device] = None
|
||||
) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
|
||||
"""
|
||||
Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
|
||||
and the second maps metadata keys to metadata values.
|
||||
|
@ -1074,20 +1074,20 @@ impl PyTensor {
|
||||
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
|
||||
use ::candle::quantized;
|
||||
let res = match quantized_dtype.to_lowercase().as_str() {
|
||||
"q2k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ2K>(self),
|
||||
"q3k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ3K>(self),
|
||||
"q4_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_0>(self),
|
||||
"q4_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_1>(self),
|
||||
"q4k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4K>(self),
|
||||
"q5_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_0>(self),
|
||||
"q5_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_1>(self),
|
||||
"q5k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5K>(self),
|
||||
"q6k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ6K>(self),
|
||||
"q8_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_0>(self),
|
||||
"q8_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_1>(self),
|
||||
"q8k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8K>(self),
|
||||
"f16" => quantized::QTensor::quantize::<f16>(self),
|
||||
"f32" => quantized::QTensor::quantize::<f32>(self),
|
||||
"q2k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q2K),
|
||||
"q3k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q3K),
|
||||
"q4_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_0),
|
||||
"q4_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_1),
|
||||
"q4k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4K),
|
||||
"q5_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_0),
|
||||
"q5_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_1),
|
||||
"q5k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5K),
|
||||
"q6k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q6K),
|
||||
"q8_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_0),
|
||||
"q8_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_1),
|
||||
"q8k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8K),
|
||||
"f16" => quantized::QTensor::quantize(self, quantized::GgmlDType::F16),
|
||||
"f32" => quantized::QTensor::quantize(self, quantized::GgmlDType::F32),
|
||||
dt => {
|
||||
return Err(PyErr::new::<PyValueError, _>(format!(
|
||||
"unknown quantized-dtype {dt}"
|
||||
@ -1278,13 +1278,19 @@ fn save_safetensors(
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
|
||||
#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")]
|
||||
/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
|
||||
/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
|
||||
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]]
|
||||
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
|
||||
fn load_ggml(
|
||||
path: &str,
|
||||
device: Option<PyDevice>,
|
||||
py: Python<'_>,
|
||||
) -> PyResult<(PyObject, PyObject, PyObject)> {
|
||||
let mut file = std::fs::File::open(path)?;
|
||||
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let ggml =
|
||||
::candle::quantized::ggml_file::Content::read(&mut file, &device).map_err(wrap_err)?;
|
||||
let tensors = ggml
|
||||
.tensors
|
||||
.into_iter()
|
||||
@ -1313,11 +1319,16 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
|
||||
#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")]
|
||||
/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
|
||||
/// and the second maps metadata keys to metadata values.
|
||||
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]]
|
||||
fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
|
||||
fn load_gguf(
|
||||
path: &str,
|
||||
device: Option<PyDevice>,
|
||||
py: Python<'_>,
|
||||
) -> PyResult<(PyObject, PyObject)> {
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
use ::candle::quantized::gguf_file;
|
||||
fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
|
||||
let v: PyObject = match v {
|
||||
@ -1349,7 +1360,7 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
|
||||
.tensor_infos
|
||||
.keys()
|
||||
.map(|key| {
|
||||
let qtensor = gguf.tensor(&mut file, key)?;
|
||||
let qtensor = gguf.tensor(&mut file, key, &device)?;
|
||||
Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py)))
|
||||
})
|
||||
.collect::<::candle::Result<Vec<_>>>()
|
||||
|
@ -12,9 +12,9 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
candle-nn = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
|
@ -1,6 +1,6 @@
|
||||
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const DTYPE: DType = DType::F32;
|
||||
@ -112,11 +112,6 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
struct Dropout {
|
||||
#[allow(dead_code)]
|
||||
pr: f64,
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
@ -11,11 +11,6 @@ fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Line
|
||||
Ok(Linear::new(weight, bias))
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let weight = vb.get(size, "weight")?;
|
||||
let bias = vb.get(size, "bias")?;
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
|
||||
@ -27,11 +27,6 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
Ok(LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
|
||||
#[derive(Debug)]
|
||||
pub struct Config {
|
||||
|
@ -1,6 +1,6 @@
|
||||
use super::with_tracing::{linear_no_bias as linear, Linear};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
@ -136,11 +136,6 @@ impl Cache {
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, cfg.hidden_size))
|
||||
}
|
||||
|
||||
struct RmsNorm {
|
||||
inner: candle_nn::RmsNorm,
|
||||
span: tracing::Span,
|
||||
@ -409,7 +404,7 @@ impl Llama {
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
||||
|
333
candle-transformers/src/models/mobileone.rs
Normal file
333
candle-transformers/src/models/mobileone.rs
Normal file
@ -0,0 +1,333 @@
|
||||
//! MobileOne inference implementation based on timm and candle-repvgg
|
||||
//!
|
||||
//! See "MobileOne: An Improved One millisecond Mobile Backbone"
|
||||
//! https://arxiv.org/abs/2206.04040
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, BatchNorm, Conv2d, Conv2dConfig,
|
||||
Func, VarBuilder,
|
||||
};
|
||||
|
||||
struct StageConfig {
|
||||
blocks: usize,
|
||||
channels: usize,
|
||||
}
|
||||
|
||||
// The architecture in the paper has 6 stages. The timm implementation uses an equivalent form
|
||||
// by concatenating the 5th stage (starts with stride 1) to the previous one.
|
||||
const STAGES: [StageConfig; 5] = [
|
||||
StageConfig {
|
||||
blocks: 1,
|
||||
channels: 64,
|
||||
},
|
||||
StageConfig {
|
||||
blocks: 2,
|
||||
channels: 64,
|
||||
},
|
||||
StageConfig {
|
||||
blocks: 8,
|
||||
channels: 128,
|
||||
},
|
||||
StageConfig {
|
||||
blocks: 10,
|
||||
channels: 256,
|
||||
},
|
||||
StageConfig {
|
||||
blocks: 1,
|
||||
channels: 512,
|
||||
},
|
||||
];
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
/// overparameterization factor
|
||||
k: usize,
|
||||
/// per-stage channel number multipliers
|
||||
alphas: [f32; 5],
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn s0() -> Self {
|
||||
Self {
|
||||
k: 4,
|
||||
alphas: [0.75, 0.75, 1.0, 1.0, 2.0],
|
||||
}
|
||||
}
|
||||
pub fn s1() -> Self {
|
||||
Self {
|
||||
k: 1,
|
||||
alphas: [1.5, 1.5, 1.5, 2.0, 2.5],
|
||||
}
|
||||
}
|
||||
pub fn s2() -> Self {
|
||||
Self {
|
||||
k: 1,
|
||||
alphas: [1.5, 1.5, 2.0, 2.5, 4.0],
|
||||
}
|
||||
}
|
||||
pub fn s3() -> Self {
|
||||
Self {
|
||||
k: 1,
|
||||
alphas: [2.0, 2.0, 2.5, 3.0, 4.0],
|
||||
}
|
||||
}
|
||||
pub fn s4() -> Self {
|
||||
Self {
|
||||
k: 1,
|
||||
alphas: [3.0, 3.0, 3.5, 3.5, 4.0],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SE blocks are used in the last stages of the s4 variant.
|
||||
fn squeeze_and_excitation(
|
||||
in_channels: usize,
|
||||
squeeze_channels: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
..Default::default()
|
||||
};
|
||||
let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
|
||||
let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs;
|
||||
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
|
||||
let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
|
||||
|
||||
residual.broadcast_mul(&xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
|
||||
// based on the _fuse_bn_tensor method in timm
|
||||
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
|
||||
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
|
||||
let (gamma, beta) = bn.weight_and_bias().unwrap();
|
||||
let mu = bn.running_mean();
|
||||
let sigma = (bn.running_var() + bn.eps())?.sqrt();
|
||||
let gps = (gamma / sigma)?;
|
||||
let bias = (beta - mu * &gps)?;
|
||||
let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
|
||||
|
||||
Ok((weights, bias))
|
||||
}
|
||||
|
||||
// A mobileone block has a different training time and inference time architecture.
|
||||
// The latter is a simple and efficient equivalent transformation of the former
|
||||
// realized by a structural reparameterization technique, where convolutions
|
||||
// along with identity branches and batchnorm layers are fused into a single convolution.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn mobileone_block(
|
||||
has_identity: bool,
|
||||
k: usize,
|
||||
dim: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
groups: usize,
|
||||
kernel: usize,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride,
|
||||
padding,
|
||||
groups,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut w = Tensor::zeros(
|
||||
(out_channels, in_channels / groups, kernel, kernel),
|
||||
DType::F32,
|
||||
vb.device(),
|
||||
)?;
|
||||
let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;
|
||||
|
||||
// k is the training-time overparameterization factor, larger than 1 only in the s0 variant
|
||||
for i in 0..k {
|
||||
let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp(format!("conv_kxk.{i}.bn")))?;
|
||||
let conv_kxk = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel,
|
||||
conv2d_cfg,
|
||||
vb.pp(format!("conv_kxk.{i}.conv")),
|
||||
)?;
|
||||
let (wk, bk) = fuse_conv_bn(conv_kxk.weight(), conv_kxk_bn)?;
|
||||
w = (w + wk)?;
|
||||
b = (b + bk)?;
|
||||
}
|
||||
|
||||
if kernel > 1 {
|
||||
let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"))?;
|
||||
let conv_scale = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv2d_cfg,
|
||||
vb.pp("conv_scale.conv"),
|
||||
)?;
|
||||
|
||||
let (mut ws, bs) = fuse_conv_bn(conv_scale.weight(), conv_scale_bn)?;
|
||||
// resize to 3x3
|
||||
ws = ws.pad_with_zeros(D::Minus1, 1, 1)?;
|
||||
ws = ws.pad_with_zeros(D::Minus2, 1, 1)?;
|
||||
|
||||
w = (w + ws)?;
|
||||
b = (b + bs)?;
|
||||
}
|
||||
|
||||
// Use SE blocks if present (last layers of the s4 variant)
|
||||
let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("attn"));
|
||||
|
||||
// read and reparameterize the identity bn into wi and bi
|
||||
if has_identity {
|
||||
let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
|
||||
|
||||
let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
|
||||
|
||||
let id = in_channels / groups;
|
||||
// See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809
|
||||
for i in 0..in_channels {
|
||||
if kernel > 1 {
|
||||
weights[i * kernel * kernel + 4] = 1.0;
|
||||
} else {
|
||||
weights[i * (id + 1)] = 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
|
||||
let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
|
||||
|
||||
w = (w + wi)?;
|
||||
b = (b + bi)?;
|
||||
}
|
||||
|
||||
let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.apply(&reparam_conv)?;
|
||||
if let Ok(f) = &se {
|
||||
xs = xs.apply(f)?;
|
||||
}
|
||||
xs = xs.relu()?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Get the number of output channels per stage taking into account the multipliers
|
||||
fn output_channels_per_stage(cfg: &Config, stage: usize) -> usize {
|
||||
let channels = STAGES[stage].channels as f32;
|
||||
let alpha = cfg.alphas[stage];
|
||||
|
||||
match stage {
|
||||
0 => std::cmp::min(64, (channels * alpha) as usize),
|
||||
_ => (channels * alpha) as usize,
|
||||
}
|
||||
}
|
||||
|
||||
// Each stage is made of blocks. The first layer always downsamples with stride 2.
|
||||
// All but the first block have a residual connection.
|
||||
fn mobileone_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let nblocks = STAGES[idx].blocks;
|
||||
let mut blocks = Vec::with_capacity(nblocks);
|
||||
|
||||
let mut in_channels = output_channels_per_stage(cfg, idx - 1);
|
||||
|
||||
for block_idx in 0..nblocks {
|
||||
let out_channels = output_channels_per_stage(cfg, idx);
|
||||
let (has_identity, stride) = if block_idx == 0 {
|
||||
(false, 2)
|
||||
} else {
|
||||
(true, 1)
|
||||
};
|
||||
|
||||
// depthwise convolution layer
|
||||
blocks.push(mobileone_block(
|
||||
has_identity,
|
||||
cfg.k,
|
||||
in_channels,
|
||||
stride,
|
||||
1,
|
||||
in_channels,
|
||||
3,
|
||||
in_channels,
|
||||
in_channels,
|
||||
vb.pp(block_idx * 2),
|
||||
)?);
|
||||
|
||||
// pointwise convolution layer
|
||||
blocks.push(mobileone_block(
|
||||
has_identity,
|
||||
cfg.k,
|
||||
out_channels,
|
||||
1, // stride
|
||||
0, // padding
|
||||
1, // groups
|
||||
1, // kernel
|
||||
in_channels,
|
||||
out_channels,
|
||||
vb.pp(block_idx * 2 + 1),
|
||||
)?);
|
||||
|
||||
in_channels = out_channels;
|
||||
}
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.clone();
|
||||
for block in blocks.iter() {
|
||||
xs = xs.apply(block)?
|
||||
}
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Build a mobileone model for a given configuration.
|
||||
fn mobileone_model(
|
||||
config: &Config,
|
||||
nclasses: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let cls = match nclasses {
|
||||
None => None,
|
||||
Some(nclasses) => {
|
||||
let outputs = output_channels_per_stage(config, 4);
|
||||
let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
|
||||
Some(linear)
|
||||
}
|
||||
};
|
||||
|
||||
let stem_dim = output_channels_per_stage(config, 0);
|
||||
let stem = mobileone_block(false, 1, stem_dim, 2, 1, 1, 3, 3, stem_dim, vb.pp("stem"))?;
|
||||
let vb = vb.pp("stages");
|
||||
let stage1 = mobileone_stage(config, 1, vb.pp(0))?;
|
||||
let stage2 = mobileone_stage(config, 2, vb.pp(1))?;
|
||||
let stage3 = mobileone_stage(config, 3, vb.pp(2))?;
|
||||
let stage4 = mobileone_stage(config, 4, vb.pp(3))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&stem)?
|
||||
.apply(&stage1)?
|
||||
.apply(&stage2)?
|
||||
.apply(&stage3)?
|
||||
.apply(&stage4)?
|
||||
.mean(D::Minus2)?
|
||||
.mean(D::Minus1)?;
|
||||
match &cls {
|
||||
None => Ok(xs),
|
||||
Some(cls) => xs.apply(cls),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn mobileone(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
mobileone_model(cfg, Some(nclasses), vb)
|
||||
}
|
||||
|
||||
pub fn mobileone_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
mobileone_model(cfg, None, vb)
|
||||
}
|
@ -15,8 +15,10 @@ pub mod marian;
|
||||
pub mod mistral;
|
||||
pub mod mixformer;
|
||||
pub mod mixtral;
|
||||
pub mod mobileone;
|
||||
pub mod mpt;
|
||||
pub mod persimmon;
|
||||
pub mod phi;
|
||||
pub mod quantized_blip;
|
||||
pub mod quantized_blip_text;
|
||||
pub mod quantized_llama;
|
||||
@ -26,6 +28,7 @@ pub mod quantized_mixformer;
|
||||
pub mod quantized_mpt;
|
||||
pub mod quantized_stable_lm;
|
||||
pub mod quantized_t5;
|
||||
pub mod repvgg;
|
||||
pub mod resnet;
|
||||
pub mod segment_anything;
|
||||
pub mod stable_diffusion;
|
||||
|
363
candle-transformers/src/models/phi.rs
Normal file
363
candle-transformers/src/models/phi.rs
Normal file
@ -0,0 +1,363 @@
|
||||
use crate::models::with_tracing::{layer_norm, linear, Embedding, LayerNorm, Linear};
|
||||
/// Phi model.
|
||||
/// https://huggingface.co/microsoft/phi-2
|
||||
/// There is an alternative implementation of the phi model in mixformers.rs.
|
||||
/// This corresponds to the model update made with the following commit:
|
||||
/// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
// https://huggingface.co/microsoft/phi-2/blob/main/configuration_phi.py
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub(crate) vocab_size: usize,
|
||||
pub(crate) hidden_size: usize,
|
||||
pub(crate) intermediate_size: usize,
|
||||
pub(crate) num_hidden_layers: usize,
|
||||
pub(crate) num_attention_heads: usize,
|
||||
pub(crate) num_key_value_heads: Option<usize>,
|
||||
pub(crate) hidden_act: Activation,
|
||||
pub(crate) max_position_embeddings: usize,
|
||||
pub(crate) layer_norm_eps: f64,
|
||||
pub(crate) tie_word_embeddings: bool,
|
||||
pub(crate) rope_theta: f32,
|
||||
pub(crate) partial_rotary_factor: f64,
|
||||
pub(crate) qk_layernorm: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn num_key_value_heads(&self) -> usize {
|
||||
self.num_key_value_heads.unwrap_or(self.num_attention_heads)
|
||||
}
|
||||
|
||||
fn head_dim(&self) -> usize {
|
||||
self.hidden_size / self.num_attention_heads
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
dim: usize,
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
||||
let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((cfg.max_position_embeddings, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
dim,
|
||||
sin: emb.sin()?,
|
||||
cos: emb.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
|
||||
let xs_rot = xs.i((.., .., .., ..self.dim))?;
|
||||
let xs_pass = xs.i((.., .., .., self.dim..))?;
|
||||
let xs12 = xs_rot.chunk(2, D::Minus1)?;
|
||||
let (xs1, xs2) = (&xs12[0], &xs12[1]);
|
||||
let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let rotate_half = Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)?;
|
||||
let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
|
||||
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
act: Activation,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?;
|
||||
let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
|
||||
Ok(Self {
|
||||
fc1,
|
||||
fc2,
|
||||
// This does not match the mixformers implementation where Gelu is used rather than
|
||||
// GeluNew.
|
||||
act: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
dense: Linear,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
q_layernorm: Option<LayerNorm>,
|
||||
k_layernorm: Option<LayerNorm>,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
softmax_scale: f64,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size, size), device)
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads();
|
||||
let head_dim = cfg.head_dim();
|
||||
let q_proj = linear(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let dense = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("dense"))?;
|
||||
// Alternative rope scalings are not supported.
|
||||
let rotary_emb = RotaryEmbedding::new(cfg, vb.device())?;
|
||||
let (q_layernorm, k_layernorm) = if cfg.qk_layernorm {
|
||||
let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?;
|
||||
let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?;
|
||||
(Some(q_layernorm), Some(k_layernorm))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
dense,
|
||||
kv_cache: None,
|
||||
q_layernorm,
|
||||
k_layernorm,
|
||||
rotary_emb,
|
||||
softmax_scale,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
span: tracing::span!(tracing::Level::TRACE, "attention"),
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_heads / self.num_kv_heads;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_size, seq_len, _n_embd) = xs.dims3()?;
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
let key_states = self.k_proj.forward(xs)?;
|
||||
let value_states = self.v_proj.forward(xs)?;
|
||||
|
||||
let query_states = match &self.q_layernorm {
|
||||
None => query_states,
|
||||
Some(ln) => query_states.apply(ln)?,
|
||||
};
|
||||
let key_states = match &self.k_layernorm {
|
||||
None => key_states,
|
||||
Some(ln) => key_states.apply(ln)?,
|
||||
};
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_size, seq_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let key_states = key_states
|
||||
.reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let value_states = value_states
|
||||
.reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
// Rotary embeddings.
|
||||
let seqlen_offset = match &self.kv_cache {
|
||||
None => 0,
|
||||
Some((prev_k, _)) => prev_k.dim(2)?,
|
||||
};
|
||||
let query_states = self
|
||||
.rotary_emb
|
||||
.apply_rotary_emb(&query_states, seqlen_offset)?;
|
||||
let key_states = self
|
||||
.rotary_emb
|
||||
.apply_rotary_emb(&key_states, seqlen_offset)?;
|
||||
|
||||
// KV cache.
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let k = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let v = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(k, v)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
// Repeat kv.
|
||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
||||
|
||||
let attn_weights = (query_states
|
||||
.to_dtype(DType::F32)?
|
||||
.contiguous()?
|
||||
.matmul(&key_states.to_dtype(DType::F32)?.t()?)?
|
||||
* self.softmax_scale)?;
|
||||
let attn_weights = match mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => masked_fill(
|
||||
&attn_weights,
|
||||
&mask.broadcast_left((b_size, self.num_heads))?,
|
||||
f32::NEG_INFINITY,
|
||||
)?,
|
||||
};
|
||||
let attn_weights =
|
||||
candle_nn::ops::softmax_last_dim(&attn_weights)?.to_dtype(value_states.dtype())?;
|
||||
let attn_output = attn_weights.matmul(&value_states)?;
|
||||
let attn_output = attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_size, seq_len, ()))?;
|
||||
attn_output.apply(&self.dense)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
mlp: MLP,
|
||||
input_layernorm: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let input_layernorm = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_eps,
|
||||
vb.pp("input_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
input_layernorm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "block"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.input_layernorm)?;
|
||||
let attn_outputs = self.self_attn.forward(&xs, mask)?;
|
||||
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
|
||||
attn_outputs + feed_forward_hidden_states + residual
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
final_layernorm: LayerNorm,
|
||||
lm_head: Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let final_layernorm = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_eps,
|
||||
vb_m.pp("final_layernorm"),
|
||||
)?;
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_m = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(cfg, vb_m.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
final_layernorm,
|
||||
lm_head,
|
||||
span: tracing::span!(tracing::Level::TRACE, "model"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_b_size, seq_len) = xs.dims2()?;
|
||||
let mut xs = xs.apply(&self.embed_tokens)?;
|
||||
let mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(get_mask(seq_len, xs.device())?)
|
||||
};
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, mask.as_ref())?;
|
||||
}
|
||||
xs.apply(&self.final_layernorm)?
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.lm_head)?
|
||||
.squeeze(1)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.layers.iter_mut().for_each(|b| b.clear_kv_cache())
|
||||
}
|
||||
}
|
@ -356,6 +356,7 @@ impl ModelWeights {
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let cpu = &Device::Cpu;
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
@ -383,21 +384,28 @@ impl ModelWeights {
|
||||
.unwrap_or(10000f32);
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
|
||||
let output = ct.tensor(reader, "output.weight")?;
|
||||
let norm = RmsNorm::new(
|
||||
ct.tensor(reader, "output_norm.weight", device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
let output = ct.tensor(reader, "output.weight", device)?;
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
|
||||
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
|
||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
||||
let attention_wo =
|
||||
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
||||
let mlp_or_moe = if n_expert <= 1 {
|
||||
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
|
||||
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
|
||||
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
|
||||
let feed_forward_w1 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||
let feed_forward_w2 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||
let feed_forward_w3 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||
MlpOrMoe::Mlp(Mlp {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
@ -405,15 +413,15 @@ impl ModelWeights {
|
||||
})
|
||||
} else {
|
||||
let feed_forward_gate_inp =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?;
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?;
|
||||
let mut experts = Vec::with_capacity(n_expert);
|
||||
for i in 0..n_expert {
|
||||
let feed_forward_w1 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?;
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?;
|
||||
let feed_forward_w2 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?;
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?;
|
||||
let feed_forward_w3 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?;
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?;
|
||||
experts.push(Mlp {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
@ -426,8 +434,9 @@ impl ModelWeights {
|
||||
experts,
|
||||
}
|
||||
};
|
||||
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
|
||||
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
|
||||
let attention_norm =
|
||||
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
|
||||
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
|
@ -311,7 +311,7 @@ impl MixFormerSequentialForCausalLM {
|
||||
let mut blocks = Vec::new();
|
||||
for i in 0..cfg.n_layer {
|
||||
let block = ParallelBlock::new(cfg, vb.pp(i + 1))?;
|
||||
blocks.push(block)
|
||||
blocks.push(block);
|
||||
}
|
||||
let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?;
|
||||
Ok(Self {
|
||||
@ -332,7 +332,7 @@ impl MixFormerSequentialForCausalLM {
|
||||
Some(get_mask(seq_len, xs.device())?)
|
||||
};
|
||||
for block in self.blocks.iter_mut() {
|
||||
xs = block.forward(&xs, mask.as_ref())?
|
||||
xs = block.forward(&xs, mask.as_ref())?;
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
|
||||
}
|
||||
|
306
candle-transformers/src/models/repvgg.rs
Normal file
306
candle-transformers/src/models/repvgg.rs
Normal file
@ -0,0 +1,306 @@
|
||||
//! RepVGG inference implementation
|
||||
//!
|
||||
//! See "RepVGG: Making VGG-style ConvNets Great Again" Ding et al. 2021
|
||||
//! https://arxiv.org/abs/2101.03697
|
||||
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d_no_bias, linear, BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
|
||||
};
|
||||
|
||||
const CHANNELS_PER_STAGE: [usize; 5] = [64, 64, 128, 256, 512];
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
a: f32,
|
||||
b: f32,
|
||||
groups: usize,
|
||||
stages: [usize; 4],
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn a0() -> Self {
|
||||
Self {
|
||||
a: 0.75,
|
||||
b: 2.5,
|
||||
groups: 1,
|
||||
stages: [2, 4, 14, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn a1() -> Self {
|
||||
Self {
|
||||
a: 1.0,
|
||||
b: 2.5,
|
||||
groups: 1,
|
||||
stages: [2, 4, 14, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn a2() -> Self {
|
||||
Self {
|
||||
a: 1.5,
|
||||
b: 2.75,
|
||||
groups: 1,
|
||||
stages: [2, 4, 14, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b0() -> Self {
|
||||
Self {
|
||||
a: 1.0,
|
||||
b: 2.5,
|
||||
groups: 1,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b1() -> Self {
|
||||
Self {
|
||||
a: 2.0,
|
||||
b: 4.0,
|
||||
groups: 1,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b2() -> Self {
|
||||
Self {
|
||||
a: 2.5,
|
||||
b: 5.0,
|
||||
groups: 1,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b3() -> Self {
|
||||
Self {
|
||||
a: 3.0,
|
||||
b: 5.0,
|
||||
groups: 1,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b1g4() -> Self {
|
||||
Self {
|
||||
a: 2.0,
|
||||
b: 4.0,
|
||||
groups: 4,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b2g4() -> Self {
|
||||
Self {
|
||||
a: 2.5,
|
||||
b: 5.0,
|
||||
groups: 4,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b3g4() -> Self {
|
||||
Self {
|
||||
a: 3.0,
|
||||
b: 5.0,
|
||||
groups: 4,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
|
||||
// based on the _fuse_bn_tensor method in timm
|
||||
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
|
||||
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
|
||||
let (gamma, beta) = bn.weight_and_bias().unwrap();
|
||||
let mu = bn.running_mean();
|
||||
let sigma = (bn.running_var() + bn.eps())?.sqrt();
|
||||
let gps = (gamma / sigma)?;
|
||||
let bias = (beta - mu * &gps)?;
|
||||
let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
|
||||
|
||||
Ok((weights, bias))
|
||||
}
|
||||
|
||||
// A RepVGG layer has a different training time and inference time architecture.
|
||||
// The latter is a simple and efficient equivalent transformation of the former
|
||||
// realized by a structural reparameterization technique, where 3x3 and 1x1 convolutions
|
||||
// along with identity branches and batchnorm layers are fused into a single 3x3 convolution.
|
||||
fn repvgg_layer(
|
||||
has_identity: bool,
|
||||
dim: usize,
|
||||
stride: usize,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
groups: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride,
|
||||
groups,
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// read and reparameterize the 1x1 conv and bn into w1 and b1
|
||||
// based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L543
|
||||
|
||||
let conv1x1_bn = batch_norm(dim, 1e-5, vb.pp("conv_1x1.bn"))?;
|
||||
let conv1x1 = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv2d_cfg,
|
||||
vb.pp("conv_1x1.conv"),
|
||||
)?;
|
||||
|
||||
let (mut w1, b1) = fuse_conv_bn(conv1x1.weight(), conv1x1_bn)?;
|
||||
|
||||
// resize to 3x3
|
||||
w1 = w1.pad_with_zeros(D::Minus1, 1, 1)?;
|
||||
w1 = w1.pad_with_zeros(D::Minus2, 1, 1)?;
|
||||
|
||||
// read and reparameterize the 3x3 conv and bn into w3 and b3
|
||||
let convkxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.bn"))?;
|
||||
let conv3x3 = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
conv2d_cfg,
|
||||
vb.pp("conv_kxk.conv"),
|
||||
)?;
|
||||
|
||||
let (w3, b3) = fuse_conv_bn(conv3x3.weight(), convkxk_bn)?;
|
||||
|
||||
let mut w = (w1 + w3)?;
|
||||
let mut b = (b1 + b3)?;
|
||||
|
||||
// read and reparameterize the identity bn into wi and bi
|
||||
if has_identity {
|
||||
let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
|
||||
|
||||
// create a 3x3 convolution equivalent to the identity branch
|
||||
let mut weights: Vec<f32> = vec![0.0; conv3x3.weight().elem_count()];
|
||||
|
||||
// https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L620
|
||||
let in_dim = in_channels / groups;
|
||||
for i in 0..in_channels {
|
||||
weights[i * in_dim * 3 * 3 + (i % in_dim) * 3 * 3 + 4] = 1.0;
|
||||
}
|
||||
|
||||
let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
|
||||
let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
|
||||
|
||||
w = (w + wi)?;
|
||||
b = (b + bi)?;
|
||||
}
|
||||
|
||||
// create the 3x3 conv equivalent to the sum of 3x3, 1x1 and identity branches
|
||||
let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.apply(&reparam_conv)?.relu()?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Get the number of output channels per stage taking into account the multipliers
|
||||
fn output_channels_per_stage(a: f32, b: f32, stage: usize) -> usize {
|
||||
let channels = CHANNELS_PER_STAGE[stage] as f32;
|
||||
|
||||
match stage {
|
||||
0 => std::cmp::min(64, (channels * a) as usize),
|
||||
4 => (channels * b) as usize,
|
||||
_ => (channels * a) as usize,
|
||||
}
|
||||
}
|
||||
|
||||
// Each stage is made of layers. The first layer always downsamples with stride 2.
|
||||
// All but the first layer have a residual connection.
|
||||
// The G4 variants have a groupwise convolution instead of a dense one on odd layers
|
||||
// counted across stage boundaries, so we keep track of which layer we are in the
|
||||
// full model.
|
||||
fn repvgg_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let nlayers = cfg.stages[idx - 1];
|
||||
let mut layers = Vec::with_capacity(nlayers);
|
||||
let prev_layers: usize = cfg.stages[..idx - 1].iter().sum();
|
||||
let out_channels_prev = output_channels_per_stage(cfg.a, cfg.b, idx - 1);
|
||||
let out_channels = output_channels_per_stage(cfg.a, cfg.b, idx);
|
||||
|
||||
for layer_idx in 0..nlayers {
|
||||
let (has_identity, stride, in_channels) = if layer_idx == 0 {
|
||||
(false, 2, out_channels_prev)
|
||||
} else {
|
||||
(true, 1, out_channels)
|
||||
};
|
||||
|
||||
let groups = if (prev_layers + layer_idx) % 2 == 1 {
|
||||
cfg.groups
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
layers.push(repvgg_layer(
|
||||
has_identity,
|
||||
out_channels,
|
||||
stride,
|
||||
in_channels,
|
||||
out_channels,
|
||||
groups,
|
||||
vb.pp(layer_idx),
|
||||
)?)
|
||||
}
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.clone();
|
||||
for layer in layers.iter() {
|
||||
xs = xs.apply(layer)?
|
||||
}
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Build a RepVGG model for a given configuration.
|
||||
fn repvgg_model(config: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let cls = match nclasses {
|
||||
None => None,
|
||||
Some(nclasses) => {
|
||||
let outputs = output_channels_per_stage(config.a, config.b, 4);
|
||||
let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
|
||||
Some(linear)
|
||||
}
|
||||
};
|
||||
|
||||
let stem_dim = output_channels_per_stage(config.a, config.b, 0);
|
||||
let stem = repvgg_layer(false, stem_dim, 2, 3, stem_dim, 1, vb.pp("stem"))?;
|
||||
let vb = vb.pp("stages");
|
||||
let stage1 = repvgg_stage(config, 1, vb.pp(0))?;
|
||||
let stage2 = repvgg_stage(config, 2, vb.pp(1))?;
|
||||
let stage3 = repvgg_stage(config, 3, vb.pp(2))?;
|
||||
let stage4 = repvgg_stage(config, 4, vb.pp(3))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&stem)?
|
||||
.apply(&stage1)?
|
||||
.apply(&stage2)?
|
||||
.apply(&stage3)?
|
||||
.apply(&stage4)?
|
||||
.mean(D::Minus1)?
|
||||
.mean(D::Minus1)?;
|
||||
match &cls {
|
||||
None => Ok(xs),
|
||||
Some(cls) => xs.apply(cls),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn repvgg(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
repvgg_model(cfg, Some(nclasses), vb)
|
||||
}
|
||||
|
||||
pub fn repvgg_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
repvgg_model(cfg, None, vb)
|
||||
}
|
@ -1,12 +1,7 @@
|
||||
use super::Config;
|
||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
use candle_nn::{embedding, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||
|
||||
fn conv1d(
|
||||
in_channels: usize,
|
||||
|
@ -10,33 +10,33 @@ pub struct VarBuilder {
|
||||
}
|
||||
|
||||
impl VarBuilder {
|
||||
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P, device: &Device) -> Result<Self> {
|
||||
let mut file = std::fs::File::open(p)?;
|
||||
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
|
||||
let mut data = std::collections::HashMap::new();
|
||||
for tensor_name in content.tensor_infos.keys() {
|
||||
let tensor = content.tensor(&mut file, tensor_name)?;
|
||||
let tensor = content.tensor(&mut file, tensor_name, device)?;
|
||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||
}
|
||||
Ok(Self {
|
||||
data: Arc::new(data),
|
||||
path: Vec::new(),
|
||||
device: Device::Cpu,
|
||||
device: device.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
|
||||
pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result<Self> {
|
||||
let mut cursor = std::io::Cursor::new(buffer);
|
||||
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
|
||||
let mut data = std::collections::HashMap::new();
|
||||
for tensor_name in content.tensor_infos.keys() {
|
||||
let tensor = content.tensor(&mut cursor, tensor_name)?;
|
||||
let tensor = content.tensor(&mut cursor, tensor_name, device)?;
|
||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||
}
|
||||
Ok(Self {
|
||||
data: Arc::new(data),
|
||||
path: Vec::new(),
|
||||
device: Device::Cpu,
|
||||
device: device.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
@ -27,7 +27,7 @@ safetensors = { workspace = true }
|
||||
# Wasm specific crates.
|
||||
console_error_panic_hook = "0.1.7"
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
gloo = "0.8"
|
||||
gloo = "0.11"
|
||||
js-sys = "0.3.64"
|
||||
wasm-bindgen = "0.2.87"
|
||||
serde-wasm-bindgen = "0.6.0"
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
|
@ -61,7 +61,7 @@ impl Model {
|
||||
|
||||
let start = Date::now();
|
||||
let model: SelectedModel = if quantized {
|
||||
let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights)?;
|
||||
let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights, &device)?;
|
||||
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
SelectedModel::Q(model)
|
||||
} else {
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
@ -26,7 +26,7 @@ serde_json = { workspace = true }
|
||||
# Wasm specific crates.
|
||||
console_error_panic_hook = "0.1.7"
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
gloo = "0.8"
|
||||
gloo = "0.11"
|
||||
js-sys = "0.3.64"
|
||||
wasm-bindgen = "0.2.87"
|
||||
wasm-bindgen-futures = "0.4.37"
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
|
@ -41,6 +41,7 @@ impl Model {
|
||||
) -> Result<Model, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let device = Device::Cpu;
|
||||
let name: ModelName = serde_json::from_slice(&config)?;
|
||||
let config: Config = serde_json::from_slice(&config)?;
|
||||
|
||||
@ -50,8 +51,9 @@ impl Model {
|
||||
let start = Date::now();
|
||||
console_log!("weights len: {:?}", weights.len());
|
||||
let model = if quantized {
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?;
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
|
||||
&weights, &device,
|
||||
)?;
|
||||
console_log!("weights loaded");
|
||||
if name._name_or_path == "microsoft/phi-2" {
|
||||
let model = QMixFormer::new_v2(&config, vb)?;
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
# App crates.
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
@ -27,7 +27,7 @@ safetensors = { workspace = true }
|
||||
# Wasm specific crates.
|
||||
console_error_panic_hook = "0.1.7"
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
gloo = "0.8"
|
||||
gloo = "0.11"
|
||||
js-sys = "0.3.64"
|
||||
wasm-bindgen = "0.2.87"
|
||||
serde-wasm-bindgen = "0.6.0"
|
||||
|
@ -7,6 +7,7 @@ pub use candle_transformers::models::quantized_t5::{
|
||||
use candle_wasm_example_t5::console_log;
|
||||
use tokenizers::Tokenizer;
|
||||
use wasm_bindgen::prelude::*;
|
||||
const DEVICE: Device = Device::Cpu;
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct ModelEncoder {
|
||||
@ -31,7 +32,7 @@ impl ModelConditionalGeneration {
|
||||
) -> Result<ModelConditionalGeneration, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let vb = VarBuilder::from_gguf_buffer(&weights)?;
|
||||
let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?;
|
||||
let mut config: Config = serde_json::from_slice(&config)?;
|
||||
let tokenizer =
|
||||
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
@ -46,7 +47,7 @@ impl ModelConditionalGeneration {
|
||||
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
|
||||
let input: ConditionalGenerationParams =
|
||||
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
let device = &Device::Cpu;
|
||||
let device = &DEVICE;
|
||||
self.model.clear_kv_cache();
|
||||
let mut output_token_ids = [self.config.pad_token_id as u32].to_vec();
|
||||
let prompt = input.prompt;
|
||||
@ -128,7 +129,7 @@ impl ModelEncoder {
|
||||
) -> Result<ModelEncoder, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let vb = VarBuilder::from_gguf_buffer(&weights)?;
|
||||
let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?;
|
||||
let mut config: Config = serde_json::from_slice(&config)?;
|
||||
config.use_cache = false;
|
||||
let tokenizer =
|
||||
@ -138,7 +139,7 @@ impl ModelEncoder {
|
||||
}
|
||||
|
||||
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
|
||||
let device = &Device::Cpu;
|
||||
let device = &DEVICE;
|
||||
let input: DecoderParams =
|
||||
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
@ -26,7 +26,7 @@ safetensors = { workspace = true }
|
||||
|
||||
# Wasm specific crates.
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
gloo = "0.8"
|
||||
gloo = "0.11"
|
||||
js-sys = "0.3.64"
|
||||
wasm-bindgen = "0.2.87"
|
||||
wasm-bindgen-futures = "0.4.37"
|
||||
|
@ -315,6 +315,7 @@ impl Decoder {
|
||||
let model = if md.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
|
||||
&md.weights,
|
||||
&device,
|
||||
)?;
|
||||
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
|
||||
} else {
|
||||
|
@ -9,8 +9,8 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
@ -26,7 +26,7 @@ safetensors = { workspace = true }
|
||||
# Wasm specific crates.
|
||||
console_error_panic_hook = "0.1.7"
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
gloo = "0.8"
|
||||
gloo = "0.11"
|
||||
js-sys = "0.3.64"
|
||||
wasm-bindgen = "0.2.87"
|
||||
wasm-bindgen-futures = "0.4.37"
|
||||
|
@ -7,7 +7,7 @@ keywords.workspace = true
|
||||
categories.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
|
||||
|
@ -40,7 +40,7 @@ fn quantized_matmul_neg() -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let qtensor = quantized::QTensor::new(quantized::QStorage::Cpu(Box::new(rhs_t)), (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
|
Reference in New Issue
Block a user