mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Compare commits
6 Commits
ivarflakst
...
tmp_no_rot
Author | SHA1 | Date | |
---|---|---|---|
9cd0cc1f65 | |||
12fc4af8f2 | |||
9130b6c4b6 | |||
7b4389099a | |||
6f8584091e | |||
f97fcd4712 |
7
.github/dependabot.yml
vendored
7
.github/dependabot.yml
vendored
@ -1,7 +0,0 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "cargo"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 5
|
20
Cargo.toml
20
Cargo.toml
@ -31,18 +31,10 @@ license = "MIT OR Apache-2.0"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core" }
|
||||
candle-datasets = { path = "./candle-datasets" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn" }
|
||||
candle-kernels = { path = "./candle-kernels" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels" }
|
||||
candle-nn = { path = "./candle-nn" }
|
||||
candle-onnx = { path = "./candle-onnx" }
|
||||
candle-transformers = { path = "./candle-transformers" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
cudarc = { version = "0.9.14", features = ["f16"] }
|
||||
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.3.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||
@ -50,20 +42,20 @@ imageproc = { version = "0.23.0", default-features = false }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
parquet = { version = "50.0.0" }
|
||||
parquet = { version = "45.0.0" }
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
rayon = "1.7.0"
|
||||
rusttype = { version = "0.9", default-features = false }
|
||||
safetensors = "0.4.1"
|
||||
safetensors = "0.3.1"
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_plain = "1.0.2"
|
||||
serde_json = "1.0.99"
|
||||
thiserror = "1"
|
||||
tokenizers = { version = "0.15.0", default-features = false }
|
||||
tokenizers = { version = "0.13.4", default-features = false }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
|
@ -66,7 +66,7 @@ We also provide a some command line based examples using state of the art models
|
||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets.
|
||||
- [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal
|
||||
- [Minimal Mamba](./candle-examples/examples/minimal-mamba/): a minimal
|
||||
implementation of the Mamba state space model.
|
||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||
better performance than all publicly available 13b models as of 2023-09-28.
|
||||
@ -109,9 +109,6 @@ We also provide a some command line based examples using state of the art models
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
using self-supervision (can be used for imagenet classification, depth
|
||||
evaluation, segmentation).
|
||||
- [VGG](./candle-examples/examples/vgg/),
|
||||
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
generate captions for an image.
|
||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||
@ -207,7 +204,7 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Image to text.
|
||||
- BLIP.
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
||||
- yolo-v3, yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { workspace = true }
|
||||
candle-datasets = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.3" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
@ -12,8 +12,8 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { workspace = true, optional = true }
|
||||
candle-metal-kernels = { workspace = true, optional = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.3", optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.3", optional = true }
|
||||
metal = { workspace = true, optional = true}
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
@ -46,5 +46,6 @@ accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||
|
||||
[[bench]]
|
||||
name = "bench_main"
|
||||
name = "matmul"
|
||||
harness = false
|
||||
|
||||
|
@ -1,11 +0,0 @@
|
||||
mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
|
||||
criterion_main!(
|
||||
//benchmarks::affine::benches,
|
||||
//benchmarks::matmul::benches,
|
||||
//benchmarks::random::benches,
|
||||
benchmarks::reduce::benches,
|
||||
//benchmarks::where_cond::benches
|
||||
);
|
@ -1,43 +0,0 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor) {
|
||||
a.affine(12.34, 56.78).unwrap();
|
||||
}
|
||||
|
||||
fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
||||
|
||||
let flops = b * m * k * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
|
||||
run_affine_benchmark(c, &device, DType::F16, "affine_f16");
|
||||
run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,67 +0,0 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod reduce;
|
||||
pub(crate) mod where_cond;
|
||||
|
||||
use candle_core::{Device, Result};
|
||||
|
||||
pub(crate) trait BenchDevice {
|
||||
fn sync(&self) -> Result<()>;
|
||||
|
||||
fn bench_name<S: Into<String>>(&self, name: S) -> String;
|
||||
}
|
||||
|
||||
impl BenchDevice for Device {
|
||||
fn sync(&self) -> Result<()> {
|
||||
match self {
|
||||
Device::Cpu => Ok(()),
|
||||
Device::Cuda(device) => {
|
||||
#[cfg(feature = "cuda")]
|
||||
return Ok(device.synchronize()?);
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
#[cfg(feature = "metal")]
|
||||
return Ok(device.wait_until_completed()?);
|
||||
#[cfg(not(feature = "metal"))]
|
||||
panic!("Metal device without metal feature enabled: {:?}", device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn bench_name<S: Into<String>>(&self, name: S) -> String {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let cpu_type = if cfg!(feature = "accelerate") {
|
||||
"accelerate"
|
||||
} else if cfg!(feature = "mkl") {
|
||||
"mkl"
|
||||
} else {
|
||||
"cpu"
|
||||
};
|
||||
format!("{}_{}", cpu_type, name.into())
|
||||
}
|
||||
Device::Cuda(_) => format!("cuda_{}", name.into()),
|
||||
Device::Metal(_) => format!("metal_{}", name.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct BenchDeviceHandler {
|
||||
devices: Vec<Device>,
|
||||
}
|
||||
|
||||
impl BenchDeviceHandler {
|
||||
pub fn new() -> Result<Self> {
|
||||
let mut devices = Vec::new();
|
||||
if cfg!(feature = "metal") {
|
||||
devices.push(Device::new_metal(0)?);
|
||||
} else if cfg!(feature = "cuda") {
|
||||
devices.push(Device::new_cuda(0)?);
|
||||
}
|
||||
devices.push(Device::Cpu);
|
||||
Ok(Self { devices })
|
||||
}
|
||||
}
|
@ -1,63 +0,0 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn rand_uniform(a: &Tensor) {
|
||||
a.rand_like(-1.0, 123.0).unwrap();
|
||||
}
|
||||
|
||||
fn rand_normal(a: &Tensor) {
|
||||
a.randn_like(100.0, 15.0).unwrap();
|
||||
}
|
||||
|
||||
fn run_random_bench(c: &mut Criterion, device: &Device) {
|
||||
let b = 1;
|
||||
|
||||
let rows = 2048;
|
||||
let cols = 2048;
|
||||
|
||||
let dtype = DType::F32;
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||
|
||||
let flops = b * rows * cols * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |benches| {
|
||||
benches.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
rand_uniform(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("random_normal"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |benches| {
|
||||
benches.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
rand_normal(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_random_bench(c, &device);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,239 +0,0 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Storage, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use half::{bf16, f16};
|
||||
use std::ops::Deref;
|
||||
use std::time::Instant;
|
||||
|
||||
fn run_sum(a: &Tensor) {
|
||||
a.sum(2).unwrap();
|
||||
}
|
||||
fn run_arg_min(a: &Tensor) {
|
||||
a.argmin(2).unwrap();
|
||||
}
|
||||
|
||||
// TODO: Remove before merging. Softmax impls live in candle-nn, so this is a temporary workaround.
|
||||
fn softmax(a: &Tensor) -> candle_core::Result<()> {
|
||||
use candle_core::{backend::BackendStorage, DType};
|
||||
let (storage, layout) = a.storage_and_layout();
|
||||
|
||||
let device = a.device();
|
||||
|
||||
if let (Device::Metal(device), Storage::Metal(storage)) = (device, storage.deref()) {
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let kernels = device.kernels();
|
||||
let name = match a.dtype() {
|
||||
DType::F32 => "softmax_f32",
|
||||
DType::F16 => "softmax_f16",
|
||||
DType::BF16 => "softmax_bf16",
|
||||
dtype => candle_core::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
||||
};
|
||||
|
||||
let n = layout.stride().len();
|
||||
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
|
||||
candle_core::bail!("Non contiguous softmax-last-dim is not implemented");
|
||||
}
|
||||
|
||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
|
||||
candle_metal_kernels::call_last_softmax(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
name,
|
||||
elem_count,
|
||||
last_dim,
|
||||
storage.buffer(),
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
let (lo, up) = (-1000.0f32, 1000.0f32);
|
||||
for device in handler.devices {
|
||||
run_softmax(c, &device, (lo, up));
|
||||
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||
|
||||
run_reduce(c, &device, (lo, up), false);
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up), false);
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||
|
||||
run_reduce(c, &device, (lo, up), true);
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up), true);
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
|
||||
if !device.is_metal() {
|
||||
return;
|
||||
}
|
||||
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => "softmax_f32",
|
||||
DType::F16 => "softmax_f16",
|
||||
DType::BF16 => "softmax_bf16",
|
||||
_ => "softmax",
|
||||
};
|
||||
softmax(&a).unwrap();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
softmax(black_box(&a)).unwrap();
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn run_reduce<T: candle_core::FloatDType>(
|
||||
c: &mut Criterion,
|
||||
device: &Device,
|
||||
(lo, up): (T, T),
|
||||
strided: bool,
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = if strided {
|
||||
Tensor::rand(lo, up, (b, m, k), &device)
|
||||
.unwrap()
|
||||
.transpose(0, 2)
|
||||
.unwrap()
|
||||
} else {
|
||||
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||
};
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => {
|
||||
if strided {
|
||||
"reduce_f32_strided"
|
||||
} else {
|
||||
"reduce_f32"
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
if strided {
|
||||
"reduce_f16_strided"
|
||||
} else {
|
||||
"reduce_f16"
|
||||
}
|
||||
}
|
||||
DType::BF16 => {
|
||||
if strided {
|
||||
"reduce_bf16_strided"
|
||||
} else {
|
||||
"reduce_bf16"
|
||||
}
|
||||
}
|
||||
_ => "reduce",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run_sum(black_box(&a));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn run_arg_reduce<T: candle_core::FloatDType>(
|
||||
c: &mut Criterion,
|
||||
device: &Device,
|
||||
(lo, up): (T, T),
|
||||
strided: bool,
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = if strided {
|
||||
Tensor::rand(lo, up, (b, m, k), &device)
|
||||
.unwrap()
|
||||
.transpose(0, 2)
|
||||
.unwrap()
|
||||
} else {
|
||||
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||
};
|
||||
|
||||
let flops = b * m * k * (DType::U32.size_in_bytes() + T::DTYPE.size_in_bytes());
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => {
|
||||
if strided {
|
||||
"arg_reduce_f32_strided"
|
||||
} else {
|
||||
"arg_reduce_f32"
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
if strided {
|
||||
"arg_reduce_f16_strided"
|
||||
} else {
|
||||
"arg_reduce_f16"
|
||||
}
|
||||
}
|
||||
DType::BF16 => {
|
||||
if strided {
|
||||
"arg_reduce_bf16_strided"
|
||||
} else {
|
||||
"arg_reduce_bf16"
|
||||
}
|
||||
}
|
||||
_ => "unknown",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run_arg_min(black_box(&a));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,64 +0,0 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
|
||||
a.where_cond(b, c).unwrap();
|
||||
}
|
||||
|
||||
const fn create_cond_arr<const N: usize>() -> [u8; N] {
|
||||
let mut arr = [0u8; N];
|
||||
let mut i = 0;
|
||||
while i < N {
|
||||
arr[i] = (i % 2) as u8;
|
||||
i += 1;
|
||||
}
|
||||
arr
|
||||
}
|
||||
|
||||
const B: usize = 1;
|
||||
const M: usize = 1024;
|
||||
const K: usize = 1024;
|
||||
const SIZE: usize = B * M * K;
|
||||
|
||||
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
||||
|
||||
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
|
||||
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
|
||||
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
|
||||
|
||||
let elements = B * M * K;
|
||||
// E.g. 2 f32 tensors + 1 u8 tensor
|
||||
let flops = (2 * elements * dtype.size_in_bytes()) + elements;
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(
|
||||
black_box(&tensor),
|
||||
black_box(&on_true),
|
||||
black_box(&on_false),
|
||||
);
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let device = BenchDeviceHandler::new().unwrap();
|
||||
for d in device.devices {
|
||||
run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
|
||||
run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
|
||||
run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,25 +1,25 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor, b: &Tensor) {
|
||||
a.matmul(&b.t().unwrap()).unwrap();
|
||||
}
|
||||
|
||||
fn run_bench(c: &mut Criterion, device: &Device) {
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let b = 1;
|
||||
let m = 1;
|
||||
let n = 2048;
|
||||
let k = 2048;
|
||||
|
||||
let device = Device::new_metal(0).unwrap();
|
||||
let dtype = DType::F32;
|
||||
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
||||
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
|
||||
let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
||||
let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap();
|
||||
|
||||
let flops = b * m * n * k;
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("matmul"));
|
||||
let mut group = c.benchmark_group("matmul_metal");
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
@ -27,18 +27,16 @@ fn run_bench(c: &mut Criterion, device: &Device) {
|
||||
for _i in 0..iters {
|
||||
run(black_box(&lhs), black_box(&rhs));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
if let Device::Metal(device) = &device {
|
||||
device.wait_until_completed().unwrap();
|
||||
} else {
|
||||
panic!("Expected metal device");
|
||||
}
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_bench(c, &device);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
||||
criterion_main!(benches);
|
@ -118,7 +118,7 @@ enum Command {
|
||||
},
|
||||
|
||||
Quantize {
|
||||
/// The input file(s), in safetensors format.
|
||||
/// The input file, in gguf format.
|
||||
in_file: Vec<std::path::PathBuf>,
|
||||
|
||||
/// The output file, in gguf format.
|
||||
@ -133,15 +133,6 @@ enum Command {
|
||||
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
|
||||
mode: QuantizationMode,
|
||||
},
|
||||
|
||||
Dequantize {
|
||||
/// The input file, in gguf format.
|
||||
in_file: std::path::PathBuf,
|
||||
|
||||
/// The output file, in safetensors format.
|
||||
#[arg(long)]
|
||||
out_file: std::path::PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
@ -286,23 +277,6 @@ fn run_quantize_safetensors(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_dequantize(
|
||||
in_file: std::path::PathBuf,
|
||||
out_file: std::path::PathBuf,
|
||||
device: &Device,
|
||||
) -> Result<()> {
|
||||
let mut in_file = std::fs::File::open(in_file)?;
|
||||
let content = gguf_file::Content::read(&mut in_file)?;
|
||||
let mut tensors = std::collections::HashMap::new();
|
||||
for (tensor_name, _) in content.tensor_infos.iter() {
|
||||
let tensor = content.tensor(&mut in_file, tensor_name, device)?;
|
||||
let tensor = tensor.dequantize(device)?;
|
||||
tensors.insert(tensor_name.to_string(), tensor);
|
||||
}
|
||||
candle_core::safetensors::save(&tensors, out_file)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize(
|
||||
in_files: &[std::path::PathBuf],
|
||||
out_file: std::path::PathBuf,
|
||||
@ -383,7 +357,6 @@ fn main() -> anyhow::Result<()> {
|
||||
quantization,
|
||||
mode,
|
||||
} => run_quantize(&in_file, out_file, quantization, mode, &device)?,
|
||||
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -72,7 +72,7 @@ pub mod utils;
|
||||
mod variable;
|
||||
|
||||
pub use cpu_backend::CpuStorage;
|
||||
pub use device::{Device, DeviceLocation, NdArray};
|
||||
pub use device::{Device, DeviceLocation};
|
||||
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use indexer::IndexOp;
|
||||
|
@ -7,9 +7,8 @@ use candle_metal_kernels::Kernels;
|
||||
use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||
use std::sync::{Arc, RwLock, TryLockError};
|
||||
|
||||
/// Simple way to catch lock error without
|
||||
/// depending on T
|
||||
@ -85,8 +84,13 @@ pub struct MetalDevice {
|
||||
command_buffer_index: Arc<RwLock<usize>>,
|
||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||
compute_per_buffer: usize,
|
||||
/// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the
|
||||
/// execution order to be linear.
|
||||
/// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
|
||||
/// compute graph.
|
||||
// fence: metal::Fence,
|
||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||
/// Heavily used by [`candle_metal_kernels`]
|
||||
/// Heavily used by [`candle_metal_kernels`], both fences need to match
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
/// Simple allocator struct.
|
||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||
@ -102,8 +106,6 @@ pub struct MetalDevice {
|
||||
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
|
||||
/// (strong_count = 1).
|
||||
buffers: AllocatedBuffers,
|
||||
/// Seed for random number generation.
|
||||
seed: Arc<Mutex<Buffer>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
@ -129,6 +131,10 @@ impl MetalDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
// pub(crate) fn fence(&self) -> &metal::Fence {
|
||||
// &self.fence
|
||||
// }
|
||||
|
||||
pub fn command_queue(&self) -> &CommandQueue {
|
||||
&self.command_queue
|
||||
}
|
||||
@ -219,8 +225,10 @@ impl MetalDevice {
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("with_data");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
// blit.wait_for_fence(&self.fence);
|
||||
blit.set_label("with_data_blit");
|
||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||
// blit.update_fence(&self.fence);
|
||||
blit.end_encoding();
|
||||
|
||||
// This is necessary, for mmaped safetensors
|
||||
@ -228,7 +236,7 @@ impl MetalDevice {
|
||||
// The slice might not live long enough for metal
|
||||
// To actually fill the GPU buffer.
|
||||
// Putting this wait forces the GPU buffer to be filled
|
||||
// with the actual data allowing the CPU storage to do
|
||||
// with the actual data allowing the CPU storage todo
|
||||
// deallocate properly.
|
||||
self.wait_until_completed()?;
|
||||
Ok(real)
|
||||
@ -243,6 +251,7 @@ impl MetalDevice {
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("zeros");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
// blit.wait_for_fence(&self.fence);
|
||||
blit.fill_buffer(
|
||||
&buffer,
|
||||
metal::NSRange {
|
||||
@ -251,6 +260,7 @@ impl MetalDevice {
|
||||
},
|
||||
0,
|
||||
);
|
||||
// blit.update_fence(&self.fence);
|
||||
blit.end_encoding();
|
||||
Ok(buffer)
|
||||
}
|
||||
@ -349,7 +359,6 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_f32",
|
||||
DType::F16 => "affine_f16",
|
||||
DType::BF16 => "affine_bf16",
|
||||
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_affine(
|
||||
@ -368,7 +377,6 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_f32_strided",
|
||||
DType::F16 => "affine_f16_strided",
|
||||
DType::BF16 => "affine_bf16_strided",
|
||||
dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_affine_strided(
|
||||
@ -489,7 +497,6 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
let device = self.device.clone();
|
||||
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
// Source dims and strides with the sum dims at the end.
|
||||
@ -503,69 +510,13 @@ impl BackendStorage for MetalStorage {
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
if layout.is_contiguous() {
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
|
||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
||||
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
|
||||
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
|
||||
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
|
||||
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
|
||||
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
|
||||
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
|
||||
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
|
||||
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
|
||||
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
|
||||
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
|
||||
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
|
||||
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
|
||||
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
|
||||
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
|
||||
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
|
||||
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
|
||||
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
|
||||
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
|
||||
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
|
||||
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
|
||||
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
|
||||
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
|
||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
|
||||
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
|
||||
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
|
||||
(k, dtype) => {
|
||||
crate::bail!("Metal contiguous reduce op {k:?} {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.shape().elem_count(),
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
return Ok(Self::new(buffer, device, self.dtype));
|
||||
}
|
||||
|
||||
for &dim_idx in sum_dims.iter() {
|
||||
dims.push(src_dims[dim_idx]);
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
|
||||
// The reduction loop requires the shared array to be properly initialized and for
|
||||
// this we want the number of threads to be a power of two.
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
|
||||
@ -597,7 +548,7 @@ impl BackendStorage for MetalStorage {
|
||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
|
||||
(k, dtype) => crate::bail!("Metal strided reduce op {k:?} {dtype:?} not implemented"),
|
||||
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
@ -645,26 +596,14 @@ impl BackendStorage for MetalStorage {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
(DType::U8, DType::F32) => "cast_u8_f32",
|
||||
(DType::U8, DType::I64) => "cast_u8_i64",
|
||||
(DType::U8, DType::BF16) => "cast_u8_bf16",
|
||||
|
||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||
(DType::F32, DType::BF16) => "cast_f32_bf16",
|
||||
|
||||
(DType::I64, DType::F32) => "cast_i64_f32",
|
||||
|
||||
(DType::F16, DType::BF16) => "cast_f16_bf16",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||
|
||||
(DType::BF16, DType::U8) => "cast_bf16_u8",
|
||||
(DType::BF16, DType::U32) => "cast_bf16_u32",
|
||||
(DType::BF16, DType::F16) => "cast_bf16_f16",
|
||||
(DType::I64, DType::F32) => "cast_i64_f32",
|
||||
(DType::F32, DType::BF16) => "cast_f32_bf16",
|
||||
(DType::BF16, DType::F32) => "cast_bf16_f32",
|
||||
|
||||
(left, right) => {
|
||||
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
|
||||
}
|
||||
@ -742,7 +681,6 @@ impl BackendStorage for MetalStorage {
|
||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||
("urecip", DType::F32) => contiguous::recip::FLOAT,
|
||||
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
||||
("urelu", DType::F32) => contiguous::relu::FLOAT,
|
||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||
("usin", DType::F16) => contiguous::sin::HALF,
|
||||
("usqr", DType::F16) => contiguous::sqr::HALF,
|
||||
@ -759,7 +697,6 @@ impl BackendStorage for MetalStorage {
|
||||
("uround", DType::F16) => contiguous::round::HALF,
|
||||
("urecip", DType::F16) => contiguous::recip::HALF,
|
||||
("utanh", DType::F16) => contiguous::tanh::HALF,
|
||||
("urelu", DType::F16) => contiguous::relu::HALF,
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
@ -790,7 +727,6 @@ impl BackendStorage for MetalStorage {
|
||||
("uabs", DType::F32) => strided::abs::FLOAT,
|
||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
("urelu", DType::F32) => strided::relu::FLOAT,
|
||||
("uround", DType::F32) => strided::round::FLOAT,
|
||||
("ucos", DType::F16) => strided::cos::HALF,
|
||||
("usin", DType::F16) => strided::sin::HALF,
|
||||
@ -805,7 +741,6 @@ impl BackendStorage for MetalStorage {
|
||||
("uabs", DType::F16) => strided::abs::HALF,
|
||||
("uceil", DType::F16) => strided::ceil::HALF,
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
("urelu", DType::F16) => strided::relu::HALF,
|
||||
("uround", DType::F16) => strided::round::HALF,
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||
@ -861,7 +796,6 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
let name = match (self.dtype, t.dtype()) {
|
||||
(DType::U8, DType::F32) => "where_u8_f32",
|
||||
(DType::U8, DType::BF16) => "where_u8_bf16",
|
||||
(DType::U8, DType::F16) => "where_u8_f16",
|
||||
(DType::U8, DType::I64) => "where_u8_i64",
|
||||
(DType::U8, DType::U32) => "where_u8_u32",
|
||||
@ -1199,12 +1133,8 @@ impl BackendStorage for MetalStorage {
|
||||
let device = self.device();
|
||||
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U8, DType::BF16) => "is_u8_bf16",
|
||||
|
||||
(DType::U32, DType::F32) => "is_u32_f32",
|
||||
(DType::U32, DType::F16) => "is_u32_f16",
|
||||
(DType::U32, DType::BF16) => "is_u32_bf16",
|
||||
|
||||
(left, right) => {
|
||||
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
|
||||
}
|
||||
@ -1394,7 +1324,6 @@ impl MetalStorage {
|
||||
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
|
||||
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
|
||||
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
|
||||
|
||||
("add", DType::F16) => (contiguous::add::HALF, self.dtype),
|
||||
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
|
||||
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
|
||||
@ -1405,18 +1334,6 @@ impl MetalStorage {
|
||||
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
|
||||
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
|
||||
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
|
||||
|
||||
("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype),
|
||||
("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype),
|
||||
("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype),
|
||||
("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype),
|
||||
("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8),
|
||||
("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8),
|
||||
("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8),
|
||||
("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8),
|
||||
("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8),
|
||||
("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8),
|
||||
|
||||
("add", DType::I64) => (contiguous::add::I64, self.dtype),
|
||||
("sub", DType::I64) => (contiguous::sub::I64, self.dtype),
|
||||
("mul", DType::I64) => (contiguous::mul::I64, self.dtype),
|
||||
@ -1427,7 +1344,6 @@ impl MetalStorage {
|
||||
("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
|
||||
("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
|
||||
("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
|
||||
|
||||
("add", DType::U32) => (contiguous::add::U32, self.dtype),
|
||||
("sub", DType::U32) => (contiguous::sub::U32, self.dtype),
|
||||
("mul", DType::U32) => (contiguous::mul::U32, self.dtype),
|
||||
@ -1438,7 +1354,6 @@ impl MetalStorage {
|
||||
("lt", DType::U32) => (contiguous::lt::U32, DType::U8),
|
||||
("ge", DType::U32) => (contiguous::ge::U32, DType::U8),
|
||||
("gt", DType::U32) => (contiguous::gt::U32, DType::U8),
|
||||
|
||||
("add", DType::U8) => (contiguous::add::U8, self.dtype),
|
||||
("sub", DType::U8) => (contiguous::sub::U8, self.dtype),
|
||||
("mul", DType::U8) => (contiguous::mul::U8, self.dtype),
|
||||
@ -1449,7 +1364,6 @@ impl MetalStorage {
|
||||
("lt", DType::U8) => (contiguous::lt::U8, DType::U8),
|
||||
("ge", DType::U8) => (contiguous::ge::U8, DType::U8),
|
||||
("gt", DType::U8) => (contiguous::gt::U8, DType::U8),
|
||||
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
|
||||
}
|
||||
@ -1483,7 +1397,6 @@ impl MetalStorage {
|
||||
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
|
||||
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
|
||||
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
|
||||
|
||||
("badd", DType::F16) => (strided::add::HALF, self.dtype),
|
||||
("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
|
||||
("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
|
||||
@ -1496,20 +1409,6 @@ impl MetalStorage {
|
||||
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
|
||||
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
|
||||
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
|
||||
|
||||
("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype),
|
||||
("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype),
|
||||
("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype),
|
||||
("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype),
|
||||
("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype),
|
||||
("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype),
|
||||
("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8),
|
||||
("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8),
|
||||
("le", DType::BF16) => (strided::le::BFLOAT, DType::U8),
|
||||
("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8),
|
||||
("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8),
|
||||
("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8),
|
||||
|
||||
("badd", DType::I64) => (strided::add::I64, self.dtype),
|
||||
("bsub", DType::I64) => (strided::sub::I64, self.dtype),
|
||||
("bmul", DType::I64) => (strided::mul::I64, self.dtype),
|
||||
@ -1522,7 +1421,6 @@ impl MetalStorage {
|
||||
("lt", DType::I64) => (strided::lt::I64, DType::U8),
|
||||
("ge", DType::I64) => (strided::ge::I64, DType::U8),
|
||||
("gt", DType::I64) => (strided::gt::I64, DType::U8),
|
||||
|
||||
("badd", DType::U32) => (strided::add::U32, self.dtype),
|
||||
("bsub", DType::U32) => (strided::sub::U32, self.dtype),
|
||||
("bmul", DType::U32) => (strided::mul::U32, self.dtype),
|
||||
@ -1535,7 +1433,6 @@ impl MetalStorage {
|
||||
("lt", DType::U32) => (strided::lt::U32, DType::U8),
|
||||
("ge", DType::U32) => (strided::ge::U32, DType::U8),
|
||||
("gt", DType::U32) => (strided::gt::U32, DType::U8),
|
||||
|
||||
("badd", DType::U8) => (strided::add::U8, self.dtype),
|
||||
("bsub", DType::U8) => (strided::sub::U8, self.dtype),
|
||||
("bmul", DType::U8) => (strided::mul::U8, self.dtype),
|
||||
@ -1548,7 +1445,6 @@ impl MetalStorage {
|
||||
("lt", DType::U8) => (strided::lt::U8, DType::U8),
|
||||
("ge", DType::U8) => (strided::ge::U8, DType::U8),
|
||||
("gt", DType::U8) => (strided::gt::U8, DType::U8),
|
||||
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
|
||||
}
|
||||
@ -1590,7 +1486,9 @@ impl MetalStorage {
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
// blit.wait_for_fence(&self.device.fence);
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
// blit.update_fence(&self.device.fence);
|
||||
blit.end_encoding();
|
||||
}
|
||||
self.device.wait_until_completed()?;
|
||||
@ -1608,29 +1506,29 @@ impl BackendDevice for MetalDevice {
|
||||
command_buffer.enqueue();
|
||||
let command_buffer = Arc::new(RwLock::new(command_buffer));
|
||||
let command_buffer_index = Arc::new(RwLock::new(0));
|
||||
// let fence = device.new_fence();
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
||||
Ok(val) => val.parse()?,
|
||||
_ => 10,
|
||||
};
|
||||
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
||||
[299792458].as_ptr() as *const c_void,
|
||||
4,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
)));
|
||||
Ok(Self {
|
||||
device,
|
||||
// fence,
|
||||
command_queue,
|
||||
command_buffer,
|
||||
command_buffer_index,
|
||||
compute_per_buffer,
|
||||
buffers,
|
||||
kernels,
|
||||
seed,
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||
crate::bail!("Metal set_seed not implemented")
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Metal {
|
||||
gpu_id: self.registry_id() as usize,
|
||||
@ -1670,31 +1568,12 @@ impl BackendDevice for MetalDevice {
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
min: f64,
|
||||
max: f64,
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
let name = match dtype {
|
||||
DType::F32 => "rand_uniform_f32",
|
||||
DType::F16 => "rand_uniform_f16",
|
||||
DType::BF16 => "rand_uniform_bf16",
|
||||
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
|
||||
};
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_uniform")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
candle_metal_kernels::call_random_uniform(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
min as f32,
|
||||
max as f32,
|
||||
shape.elem_count(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn rand_normal(
|
||||
@ -1704,43 +1583,9 @@ impl BackendDevice for MetalDevice {
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
let name = match dtype {
|
||||
DType::F32 => "rand_normal_f32",
|
||||
DType::F16 => "rand_normal_f16",
|
||||
DType::BF16 => "rand_normal_bf16",
|
||||
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
|
||||
};
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
candle_metal_kernels::call_random_normal(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
mean as f32,
|
||||
stddev as f32,
|
||||
shape.elem_count(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
let seed: u32 = seed.try_into().map_err(|_| {
|
||||
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())
|
||||
})?;
|
||||
|
||||
let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?;
|
||||
let contents = seed_buffer.contents();
|
||||
unsafe {
|
||||
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4);
|
||||
}
|
||||
seed_buffer.did_modify_range(metal::NSRange::new(0, 4));
|
||||
|
||||
Ok(())
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -703,7 +703,6 @@ impl PthTensors {
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||
use std::io::Read;
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
None => return Ok(None),
|
||||
Some(tensor_info) => tensor_info,
|
||||
@ -713,21 +712,14 @@ impl PthTensors {
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||
|
||||
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
||||
// case.
|
||||
if !tensor_info.layout.is_contiguous() {
|
||||
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
|
||||
// For now only support the basic case.
|
||||
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
|
||||
crate::bail!(
|
||||
"cannot retrieve non-contiguous tensors {:?}",
|
||||
tensor_info.layout
|
||||
)
|
||||
}
|
||||
let start_offset = tensor_info.layout.start_offset();
|
||||
if start_offset > 0 {
|
||||
std::io::copy(
|
||||
&mut reader.by_ref().take(start_offset as u64),
|
||||
&mut std::io::sink(),
|
||||
)?;
|
||||
}
|
||||
let tensor = Tensor::from_reader(
|
||||
tensor_info.layout.shape().clone(),
|
||||
tensor_info.dtype,
|
||||
|
@ -1545,13 +1545,13 @@ impl GgmlType for BlockQ5K {
|
||||
let d2 = d * sc as f32;
|
||||
let m2 = min * m as f32;
|
||||
for (ql, qh) in ql.iter().zip(qh) {
|
||||
let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 };
|
||||
y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1;
|
||||
let to_add = if qh & u1 != 0 { 16 } else { 1 };
|
||||
y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1;
|
||||
ys_index += 1;
|
||||
}
|
||||
for (ql, qh) in ql.iter().zip(qh) {
|
||||
let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 };
|
||||
y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2;
|
||||
let to_add = if qh & u2 != 0 { 16 } else { 1 };
|
||||
y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2;
|
||||
ys_index += 1;
|
||||
}
|
||||
is += 2;
|
||||
|
@ -32,7 +32,9 @@ impl QMetalStorage {
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
// blit.wait_for_fence(&self.device.fence());
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
// blit.update_fence(&self.device.fence());
|
||||
blit.end_encoding();
|
||||
self.device.wait_until_completed()?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
|
@ -12,14 +12,6 @@ use core::arch::arm::*;
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
use core::arch::aarch64::*;
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
|
||||
// TODO: dotprod
|
||||
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
|
||||
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
||||
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
@ -51,8 +43,15 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
||||
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
||||
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
|
||||
let pl0 = vdotq_s32(v0_0ls, v1_0l);
|
||||
let ph0 = vdotq_s32(v0_0hs, v1_0h);
|
||||
// TODO: Support dotprod when it's available outside of nightly.
|
||||
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
|
||||
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
||||
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
|
||||
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
||||
|
||||
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
||||
@ -83,8 +82,14 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
||||
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
|
||||
let p0 = vdotq_s32(x0_0, y0_0);
|
||||
let p1 = vdotq_s32(x0_1, y0_1);
|
||||
// TODO dotprod once this is the intrinsics are.
|
||||
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
||||
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
||||
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
||||
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
||||
|
||||
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
||||
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
@ -113,7 +118,10 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res
|
||||
for i in (0..QK_K).step_by(16) {
|
||||
let xs = vld1q_s8(xs.add(i));
|
||||
let ys = vld1q_s8(ys.add(i));
|
||||
let xy = vdotq_s32(xs, ys);
|
||||
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
|
||||
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
|
||||
|
||||
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
|
||||
sum_i = vaddq_s32(sum_i, xy)
|
||||
}
|
||||
sumf += vaddvq_s32(sum_i) as f32 * scale
|
||||
@ -183,16 +191,30 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
|
||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
|
||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
|
||||
|
||||
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
||||
// TODO: dotprod
|
||||
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
||||
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
||||
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let q8bytes = vld1q_s8_x4(q8);
|
||||
@ -212,16 +234,29 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
|
||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
|
||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
|
||||
|
||||
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
||||
// TODO: dotprod case.
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
||||
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
||||
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||
scale = scale.add(2);
|
||||
}
|
||||
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
|
||||
@ -298,14 +333,28 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
||||
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
|
||||
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
|
||||
|
||||
let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
|
||||
sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
|
||||
// TODO: dotprod
|
||||
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
|
||||
scales = scales.add(1);
|
||||
|
||||
let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
|
||||
sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
|
||||
scales = scales.add(1);
|
||||
}
|
||||
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
|
||||
@ -368,15 +417,22 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
||||
for j in 0..QK_K / 64 {
|
||||
let q4bits = vld1q_u8_x2(q4);
|
||||
q4 = q4.add(32);
|
||||
// TODO: dotprod
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
let q4bytes = int8x16x2_t(
|
||||
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
|
||||
);
|
||||
let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
|
||||
sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
@ -384,9 +440,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
|
||||
);
|
||||
let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
|
||||
let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
|
||||
sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
|
||||
}
|
||||
sumf += d * (sumi1 + sumi2) as f32;
|
||||
}
|
||||
@ -464,14 +526,27 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
||||
vreinterpretq_s8_u8(q3h_3),
|
||||
);
|
||||
|
||||
let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
|
||||
let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
|
||||
let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
|
||||
let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
|
||||
isum += vaddvq_s32(p0) * *scale as i32
|
||||
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
||||
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
||||
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
||||
// TODO: dotprod
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
|
||||
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
|
||||
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
|
||||
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
|
||||
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
|
||||
);
|
||||
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||
scale = scale.add(4);
|
||||
|
||||
let q3h_0 = vbicq_u8(m2, qhbits.0);
|
||||
@ -496,14 +571,27 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
||||
vreinterpretq_s8_u8(q3h_3),
|
||||
);
|
||||
|
||||
let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
|
||||
let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
|
||||
let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
|
||||
let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
|
||||
isum += vaddvq_s32(p0) * *scale as i32
|
||||
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
||||
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
||||
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
||||
// TODO: dotprod
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
|
||||
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
|
||||
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
|
||||
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
|
||||
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
|
||||
);
|
||||
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||
scale = scale.add(4);
|
||||
|
||||
if j == 0 {
|
||||
@ -561,6 +649,7 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
|
||||
let mut is = 0usize;
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
for _j in 0..QK_K / 128 {
|
||||
let q2bits = vld1q_u8_x2(q2);
|
||||
q2 = q2.add(32);
|
||||
@ -607,7 +696,14 @@ unsafe fn multiply_accum_with_scale(
|
||||
q2bytes: int8x16x2_t,
|
||||
q8bytes: int8x16x2_t,
|
||||
) -> i32 {
|
||||
let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
|
||||
let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
|
||||
vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
vaddvq_s16(p1) as i32 * aux[is + index] as i32
|
||||
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
|
||||
}
|
||||
|
@ -426,7 +426,9 @@ impl Tensor {
|
||||
if buffer_size != shape.elem_count() {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
// println!("from vec {buffer_size}");
|
||||
let storage = device.storage_owned(data)?;
|
||||
// println!("Created storage");
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
}
|
||||
@ -2578,21 +2580,11 @@ impl Tensor {
|
||||
}
|
||||
|
||||
/// Returns log(sum(exp(tensor), dim)).
|
||||
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||
pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||
let exp = self.exp()?;
|
||||
let sum = exp.sum(sum_dims)?;
|
||||
sum.log()
|
||||
}
|
||||
|
||||
/// Pointwise pow operation.
|
||||
pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||
rhs.mul(&self.log()?)?.exp()
|
||||
}
|
||||
|
||||
/// Broadcasting version of `pow`.
|
||||
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||
rhs.broadcast_mul(&self.log()?)?.exp()
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
@ -1,5 +1,4 @@
|
||||
use candle_core::{
|
||||
bail,
|
||||
quantized::{self, GgmlDType},
|
||||
test_device,
|
||||
test_utils::to_vec2_round,
|
||||
@ -47,10 +46,6 @@ fn test_matmul(
|
||||
}
|
||||
|
||||
fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||
@ -105,10 +100,6 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k))
|
||||
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
||||
@ -178,10 +169,6 @@ test_device!(
|
||||
);
|
||||
|
||||
fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
@ -209,10 +196,6 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
||||
@ -239,10 +222,6 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
||||
@ -269,10 +248,6 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q5_1(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
||||
@ -334,8 +309,7 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a vector similar to the ones used in GGML unit tests:
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
|
||||
(0..GGML_TEST_SIZE)
|
||||
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
|
||||
@ -354,8 +328,7 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
|
||||
sum / a.len() as f32
|
||||
}
|
||||
|
||||
/// Similar to the GGML quantization unit test:
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
||||
/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
||||
fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f32) -> Result<()> {
|
||||
let src = create_ggml_like_vector(0.0);
|
||||
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
|
||||
@ -363,7 +336,7 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
|
||||
let dst = quant.dequantize(device)?;
|
||||
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
|
||||
if error > max_error {
|
||||
bail!(
|
||||
candle_core::bail!(
|
||||
"Quantization error {} exceeds max error {}",
|
||||
error,
|
||||
max_error
|
||||
@ -373,10 +346,6 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
|
||||
}
|
||||
|
||||
fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q2K;
|
||||
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
@ -411,10 +380,6 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q3K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -448,10 +413,6 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q4K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -485,10 +446,6 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q5K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -506,7 +463,7 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
let dst = round_vector(&dst);
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.5, -0.373, -0.25, 0.0, 0.279, 0.499]
|
||||
[-0.499, -0.372, -0.249, 0.001, 0.279, 0.499]
|
||||
);
|
||||
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
@ -522,10 +479,6 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q6K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -559,10 +512,6 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q8k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q8K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -671,66 +620,54 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
|
||||
GgmlDType::Q5K => 0.000740,
|
||||
GgmlDType::Q6K => 0.000952,
|
||||
GgmlDType::Q4_0 => 0.001143,
|
||||
GgmlDType::Q4_1 => 0.008,
|
||||
GgmlDType::Q4_1 => 0.007784,
|
||||
GgmlDType::Q5_0 => 0.001353,
|
||||
GgmlDType::Q5_1 => 0.00149,
|
||||
GgmlDType::Q5_1 => 0.001363,
|
||||
GgmlDType::Q8_0 => 0.000092,
|
||||
|
||||
// Not from the ggml repo.
|
||||
GgmlDType::Q8K => 0.00065,
|
||||
_ => bail!("No GGML results for quantization type {dtype:?}",),
|
||||
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
||||
};
|
||||
Ok(err)
|
||||
}
|
||||
|
||||
/// Similar to the GGML matmul unit test:
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
||||
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
||||
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||
let a = create_ggml_like_vector(0.0);
|
||||
let b = create_ggml_like_vector(1.0);
|
||||
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 1.0)?;
|
||||
// Another example that is more likely to trigger the overflow reported in #1526
|
||||
let a = (0..GGML_TEST_SIZE)
|
||||
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
|
||||
.collect::<Vec<_>>();
|
||||
let b = (0..GGML_TEST_SIZE)
|
||||
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
|
||||
.collect::<Vec<_>>();
|
||||
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 2.0)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Result<()> {
|
||||
let length = a.len();
|
||||
|
||||
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
|
||||
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
|
||||
T::from_float(a, &mut a_quant)?;
|
||||
T::VecDotType::from_float(b, &mut b_quant)?;
|
||||
T::from_float(&a, &mut a_quant)?;
|
||||
T::VecDotType::from_float(&b, &mut b_quant)?;
|
||||
|
||||
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
||||
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
|
||||
let reference_result = vec_dot_reference(a, b);
|
||||
let reference_result = vec_dot_reference(&a, &b);
|
||||
|
||||
if (result - result_unopt).abs() / length as f32 > 1e-6 {
|
||||
bail!(
|
||||
candle_core::bail!(
|
||||
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
|
||||
)
|
||||
}
|
||||
|
||||
let error = (result - reference_result).abs() / length as f32;
|
||||
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m;
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
||||
|
||||
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",);
|
||||
candle_core::bail!(
|
||||
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
|
||||
);
|
||||
}
|
||||
|
||||
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
|
||||
// => we use a slightly higher error threshold
|
||||
const ERROR_LENIENCY: f32 = 0.00001;
|
||||
if error - ERROR_LENIENCY > ggml_error {
|
||||
bail!(
|
||||
candle_core::bail!(
|
||||
"Dot product error {} exceeds ggml reference error {}",
|
||||
error,
|
||||
ggml_error
|
||||
@ -739,16 +676,6 @@ fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Res
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_mm() -> Result<()> {
|
||||
ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ4_1>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ5_0>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ5_1>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ8_0>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
|
||||
fn get_random_tensors(
|
||||
m: usize,
|
||||
@ -778,10 +705,6 @@ macro_rules! quantized_matmul {
|
||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
|
||||
fn $fn_name(device: &Device) -> Result<()> {
|
||||
if device.is_cuda() {
|
||||
// TODO Enable Cuda GGML sometime maybe.
|
||||
return Ok(());
|
||||
}
|
||||
test_matmul(device, (1, 3, 4, 256), $dtype)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1245,23 +1245,11 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn log_sum_exp() -> Result<()> {
|
||||
fn logsumexp() -> Result<()> {
|
||||
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
let output = input.log_sum_exp(D::Minus1)?;
|
||||
let output = input.logsumexp(D::Minus1)?;
|
||||
// The expectations obtained from pytorch.
|
||||
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
||||
assert_close(&output, &expected, 0.00001)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pow() -> Result<()> {
|
||||
let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
let rhs = (&lhs - 2.)?;
|
||||
let res = lhs.pow(&rhs)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res, 4)?,
|
||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -11,8 +11,8 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
|
@ -11,12 +11,12 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { workspace = true }
|
||||
candle-datasets = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
candle-onnx = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.3" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
|
||||
candle-onnx = { path = "../candle-onnx", version = "0.3.3", optional = true }
|
||||
|
||||
csv = "1.3.0"
|
||||
cudarc = { workspace = true, optional = true }
|
||||
@ -49,12 +49,11 @@ tokio = "1.29.1"
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
bindgen_cuda = { version = "0.1.1", optional = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
|
@ -4,28 +4,251 @@ use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
struct KernelDirectories {
|
||||
kernel_glob: &'static str,
|
||||
kernel_dir: &'static str,
|
||||
rust_target: &'static str,
|
||||
include_dirs: &'static [&'static str],
|
||||
}
|
||||
|
||||
const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
||||
kernel_glob: "examples/custom-ops/kernels/*.cu",
|
||||
const DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
||||
kernel_dir: "examples/custom-ops/kernels/",
|
||||
rust_target: "examples/custom-ops/cuda_kernels.rs",
|
||||
include_dirs: &[],
|
||||
}];
|
||||
|
||||
impl KernelDirectories {
|
||||
fn maybe_build_ptx(
|
||||
&self,
|
||||
cu_file: &std::path::Path,
|
||||
ptx_file: &std::path::Path,
|
||||
compute_cap: usize,
|
||||
) -> Result<()> {
|
||||
let should_compile = if ptx_file.exists() {
|
||||
let ptx_modified = ptx_file.metadata()?.modified()?;
|
||||
let cu_modified = cu_file.metadata()?.modified()?;
|
||||
cu_modified.duration_since(ptx_modified).is_ok()
|
||||
} else {
|
||||
true
|
||||
};
|
||||
if should_compile {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
|
||||
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
|
||||
let include_dirs: Vec<String> =
|
||||
self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
|
||||
command
|
||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("--ptx")
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.args(["--output-directory", out_dir.to_str().unwrap()])
|
||||
.arg(format!("-I/{}", self.kernel_dir))
|
||||
.args(include_dirs)
|
||||
.arg(cu_file);
|
||||
if let Ok(ccbin_path) = &ccbin_env {
|
||||
command
|
||||
.arg("-allow-unsupported-compiler")
|
||||
.args(["-ccbin", ccbin_path]);
|
||||
}
|
||||
let output = command
|
||||
.spawn()
|
||||
.context("failed spawning nvcc")?
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.open(ptx_file)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> {
|
||||
println!("cargo:rerun-if-changed={}", self.kernel_dir);
|
||||
let kernel_dir = PathBuf::from(self.kernel_dir);
|
||||
let out_dir = out_dir.join(self.kernel_dir);
|
||||
if !out_dir.exists() {
|
||||
std::fs::create_dir_all(&out_dir)?;
|
||||
}
|
||||
let mut cu_files = vec![];
|
||||
let mut cuh_files = vec![];
|
||||
for file in std::fs::read_dir(kernel_dir)?.flatten() {
|
||||
let file = file.path();
|
||||
match file.extension().and_then(|v| v.to_str()) {
|
||||
Some("cu") => cu_files.push(file),
|
||||
Some("cuh") => cuh_files.push(file),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let mut ptx_paths = vec![];
|
||||
for cu_file in cu_files.iter() {
|
||||
let file_stem = cu_file
|
||||
.file_stem()
|
||||
.with_context(|| format!("no stem {cu_file:?}"))?;
|
||||
let file_stem = file_stem.to_string_lossy().into_owned();
|
||||
let ptx_file = out_dir.join(&format!("{file_stem}.ptx"));
|
||||
self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?;
|
||||
ptx_paths.push(ptx_file);
|
||||
}
|
||||
|
||||
let regenerate_rs_file = true;
|
||||
if regenerate_rs_file {
|
||||
let mut file = std::fs::File::create(self.rust_target)?;
|
||||
for ptx_path in ptx_paths {
|
||||
let name = ptx_path
|
||||
.file_stem()
|
||||
.context("empty stem")?
|
||||
.to_string_lossy();
|
||||
file.write_all(b"#[rustfmt::skip]\n")?;
|
||||
let const_definition = format!(
|
||||
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
|
||||
name.to_uppercase().replace('.', "_"),
|
||||
self.kernel_dir,
|
||||
);
|
||||
file.write_all(const_definition.as_bytes())?;
|
||||
file.write_all(b"\n")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
|
||||
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
|
||||
let out_dir = PathBuf::from(out_dir);
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
for kdir in KERNEL_DIRS.iter() {
|
||||
let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob);
|
||||
println!("cargo:info={builder:?}");
|
||||
let bindings = builder.build_ptx().unwrap();
|
||||
bindings.write(kdir.rust_target).unwrap()
|
||||
}
|
||||
set_cuda_include_dir()?;
|
||||
#[cfg(feature = "cuda")]
|
||||
let compute_cap = compute_cap()?;
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
let compute_cap = 0;
|
||||
for d in DIRS {
|
||||
d.process(&out_dir, compute_cap)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_cuda_include_dir() -> Result<()> {
|
||||
// NOTE: copied from cudarc build.rs.
|
||||
let env_vars = [
|
||||
"CUDA_PATH",
|
||||
"CUDA_ROOT",
|
||||
"CUDA_TOOLKIT_ROOT_DIR",
|
||||
"CUDNN_LIB",
|
||||
];
|
||||
let env_vars = env_vars
|
||||
.into_iter()
|
||||
.map(std::env::var)
|
||||
.filter_map(Result::ok)
|
||||
.map(Into::<PathBuf>::into);
|
||||
|
||||
let roots = [
|
||||
"/usr",
|
||||
"/usr/local/cuda",
|
||||
"/opt/cuda",
|
||||
"/usr/lib/cuda",
|
||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||
"C:/CUDA",
|
||||
];
|
||||
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||
let root = env_vars
|
||||
.chain(roots)
|
||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||
.context("cannot find include/cuda.h")?;
|
||||
println!(
|
||||
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||
root.join("include").display()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn compute_cap() -> Result<usize> {
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
|
||||
// Try to parse compute cap from env
|
||||
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
|
||||
compute_cap_str
|
||||
.parse::<usize>()
|
||||
.context("Could not parse code")?
|
||||
} else {
|
||||
// Grab compute cap from nvidia-smi
|
||||
let out = std::process::Command::new("nvidia-smi")
|
||||
.arg("--query-gpu=compute_cap")
|
||||
.arg("--format=csv")
|
||||
.output()
|
||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||
let mut lines = out.lines();
|
||||
assert_eq!(
|
||||
lines.next().context("missing line in stdout")?,
|
||||
"compute_cap"
|
||||
);
|
||||
let cap = lines
|
||||
.next()
|
||||
.context("missing line in stdout")?
|
||||
.replace('.', "");
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
|
||||
cap.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as int {cap}"))?
|
||||
};
|
||||
|
||||
// Grab available GPU codes from nvcc and select the highest one
|
||||
let max_nvcc_code = {
|
||||
let out = std::process::Command::new("nvcc")
|
||||
.arg("--list-gpu-code")
|
||||
.output()
|
||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||
|
||||
let out = out.lines().collect::<Vec<&str>>();
|
||||
let mut codes = Vec::with_capacity(out.len());
|
||||
for code in out {
|
||||
let code = code.split('_').collect::<Vec<&str>>();
|
||||
if !code.is_empty() && code.contains(&"sm") {
|
||||
if let Ok(num) = code[1].parse::<usize>() {
|
||||
codes.push(num);
|
||||
}
|
||||
}
|
||||
}
|
||||
codes.sort();
|
||||
if !codes.contains(&compute_cap) {
|
||||
anyhow::bail!(
|
||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
|
||||
);
|
||||
}
|
||||
*codes.last().unwrap()
|
||||
};
|
||||
|
||||
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
|
||||
// then choose the highest gpu code in nvcc
|
||||
if compute_cap > max_nvcc_code {
|
||||
println!(
|
||||
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
|
||||
);
|
||||
compute_cap = max_nvcc_code;
|
||||
}
|
||||
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
|
||||
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
compute_cap = compute_cap_str
|
||||
.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
|
||||
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
|
||||
}
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
||||
Ok(compute_cap)
|
||||
}
|
||||
|
@ -0,0 +1,2 @@
|
||||
#[rustfmt::skip]
|
||||
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));
|
||||
|
@ -6,8 +6,7 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(unused)]
|
||||
mod cuda_kernels;
|
||||
|
||||
use clap::Parser;
|
||||
|
@ -165,14 +165,14 @@ fn main() -> Result<()> {
|
||||
let mut index_pos = 0;
|
||||
let mut token_generated = 0;
|
||||
for index in 0..args.sample_len {
|
||||
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
|
||||
(1, index_pos)
|
||||
let context_size = if cache.use_kv_cache && index > 0 {
|
||||
1
|
||||
} else {
|
||||
(tokens.len(), 0)
|
||||
tokens.len()
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = llama.forward(&input, context_index)?;
|
||||
let logits = llama.forward(&input, index_pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
|
@ -1,22 +0,0 @@
|
||||
# candle-mobileone
|
||||
|
||||
[MobileOne: An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040).
|
||||
|
||||
This candle implementation uses a pre-trained MobileOne network for inference. The
|
||||
classification head has been trained on the ImageNet dataset and returns the
|
||||
probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example mobileone --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which s2
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
mountain bike, all-terrain bike, off-roader: 79.33%
|
||||
bicycle-built-for-two, tandem bicycle, tandem: 15.32%
|
||||
crash helmet : 2.58%
|
||||
unicycle, monocycle : 1.70%
|
||||
alp : 0.21%
|
||||
|
||||
```
|
@ -1,96 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::mobileone;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
S0,
|
||||
S1,
|
||||
S2,
|
||||
S3,
|
||||
S4,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::S0 => "s0",
|
||||
Self::S1 => "s1",
|
||||
Self::S2 => "s2",
|
||||
Self::S3 => "s3",
|
||||
Self::S4 => "s4",
|
||||
};
|
||||
format!("timm/mobileone_{}.apple_in1k", name)
|
||||
}
|
||||
|
||||
fn config(&self) -> mobileone::Config {
|
||||
match self {
|
||||
Self::S0 => mobileone::Config::s0(),
|
||||
Self::S1 => mobileone::Config::s1(),
|
||||
Self::S2 => mobileone::Config::s2(),
|
||||
Self::S3 => mobileone::Config::s3(),
|
||||
Self::S4 => mobileone::Config::s4(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::S0)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = mobileone::mobileone(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -8,7 +8,6 @@ use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||
use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
|
||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
@ -19,7 +18,6 @@ use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
MixFormer(MixFormer),
|
||||
Phi(Phi),
|
||||
Quantized(QMixFormer),
|
||||
}
|
||||
|
||||
@ -86,7 +84,6 @@ impl TextGeneration {
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::MixFormer(m) => m.forward(&input)?,
|
||||
Model::Phi(m) => m.forward(&input)?,
|
||||
Model::Quantized(m) => m.forward(&input)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
@ -120,7 +117,7 @@ impl TextGeneration {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum WhichModel {
|
||||
#[value(name = "1")]
|
||||
V1,
|
||||
@ -128,8 +125,6 @@ enum WhichModel {
|
||||
V1_5,
|
||||
#[value(name = "2")]
|
||||
V2,
|
||||
#[value(name = "2-old")]
|
||||
V2Old,
|
||||
PuffinPhiV2,
|
||||
PhiHermes,
|
||||
}
|
||||
@ -174,7 +169,7 @@ struct Args {
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "2")]
|
||||
#[arg(long, default_value = "1.5")]
|
||||
model: WhichModel,
|
||||
|
||||
#[arg(long)]
|
||||
@ -235,7 +230,7 @@ fn main() -> Result<()> {
|
||||
match args.model {
|
||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
|
||||
WhichModel::V2 => "microsoft/phi-2".to_string(),
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"lmz/candle-quantized-phi".to_string()
|
||||
}
|
||||
@ -250,9 +245,8 @@ fn main() -> Result<()> {
|
||||
"main".to_string()
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 => "refs/pr/8".to_string(),
|
||||
WhichModel::V1_5 => "refs/pr/73".to_string(),
|
||||
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
|
||||
WhichModel::V1 => "refs/pr/2".to_string(),
|
||||
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
||||
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"main".to_string()
|
||||
}
|
||||
@ -264,9 +258,7 @@ fn main() -> Result<()> {
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old => {
|
||||
repo.get("tokenizer.json")?
|
||||
}
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?,
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||
}
|
||||
@ -279,14 +271,14 @@ fn main() -> Result<()> {
|
||||
match args.model {
|
||||
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
|
||||
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
|
||||
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
|
||||
WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?],
|
||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
||||
}
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
||||
WhichModel::V2 | WhichModel::V2Old => candle_examples::hub_load_safetensors(
|
||||
WhichModel::V2 => candle_examples::hub_load_safetensors(
|
||||
&repo,
|
||||
"model.safetensors.index.json",
|
||||
)?,
|
||||
@ -300,44 +292,33 @@ fn main() -> Result<()> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = || match args.model {
|
||||
let config = match args.model {
|
||||
WhichModel::V1 => Config::v1(),
|
||||
WhichModel::V1_5 => Config::v1_5(),
|
||||
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
|
||||
WhichModel::V2 => Config::v2(),
|
||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = if args.quantized {
|
||||
let config = config();
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&filenames[0],
|
||||
&device,
|
||||
)?;
|
||||
println!("Loaded vb");
|
||||
let model = match args.model {
|
||||
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
|
||||
WhichModel::V2 => QMixFormer::new_v2(&config, vb)?,
|
||||
_ => QMixFormer::new(&config, vb)?,
|
||||
};
|
||||
println!("Loaded model");
|
||||
Model::Quantized(model)
|
||||
} else {
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: PhiConfig = serde_json::from_str(&config)?;
|
||||
let phi = Phi::new(&config, vb)?;
|
||||
Model::Phi(phi)
|
||||
}
|
||||
WhichModel::V2Old => {
|
||||
let config = config();
|
||||
Model::MixFormer(MixFormer::new_v2(&config, vb)?)
|
||||
}
|
||||
WhichModel::PhiHermes | WhichModel::PuffinPhiV2 => {
|
||||
let config = config();
|
||||
Model::MixFormer(MixFormer::new(&config, vb)?)
|
||||
}
|
||||
}
|
||||
let model = match args.model {
|
||||
WhichModel::V2 => MixFormer::new_v2(&config, vb)?,
|
||||
_ => MixFormer::new(&config, vb)?,
|
||||
};
|
||||
Model::MixFormer(model)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
@ -417,10 +398,6 @@ fn mmlu<P: AsRef<std::path::Path>>(
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
}
|
||||
Model::Phi(m) => {
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
}
|
||||
Model::Quantized(m) => {
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
|
@ -236,15 +236,18 @@ fn main() -> Result<()> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let device = Device::Cpu;
|
||||
let config = Config::replit_code_v1_5_3b();
|
||||
let model = if args.quantized {
|
||||
let (model, device) = if args.quantized {
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
|
||||
Model::Q(Q::new(&config, vb.pp("transformer"))?)
|
||||
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
|
||||
(model, Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||
Model::M(M::new(&config, vb.pp("transformer"))?)
|
||||
let model = Model::M(M::new(&config, vb.pp("transformer"))?);
|
||||
(model, device)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
@ -1,22 +0,0 @@
|
||||
# candle-repvgg
|
||||
|
||||
[RepVGG: Making VGG-style ConvNets Great Again](https://arxiv.org/abs/2101.03697).
|
||||
|
||||
This candle implementation uses a pre-trained RepVGG network for inference. The
|
||||
classification head has been trained on the ImageNet dataset and returns the
|
||||
probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example repvgg --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
mountain bike, all-terrain bike, off-roader: 61.70%
|
||||
bicycle-built-for-two, tandem bicycle, tandem: 33.14%
|
||||
unicycle, monocycle : 4.88%
|
||||
crash helmet : 0.15%
|
||||
moped : 0.04%
|
||||
|
||||
```
|
@ -1,111 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::repvgg;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
A0,
|
||||
A1,
|
||||
A2,
|
||||
B0,
|
||||
B1,
|
||||
B2,
|
||||
B3,
|
||||
B1G4,
|
||||
B2G4,
|
||||
B3G4,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::A0 => "a0",
|
||||
Self::A1 => "a1",
|
||||
Self::A2 => "a2",
|
||||
Self::B0 => "b0",
|
||||
Self::B1 => "b1",
|
||||
Self::B2 => "b2",
|
||||
Self::B3 => "b3",
|
||||
Self::B1G4 => "b1g4",
|
||||
Self::B2G4 => "b2g4",
|
||||
Self::B3G4 => "b3g4",
|
||||
};
|
||||
format!("timm/repvgg_{}.rvgg_in1k", name)
|
||||
}
|
||||
|
||||
fn config(&self) -> repvgg::Config {
|
||||
match self {
|
||||
Self::A0 => repvgg::Config::a0(),
|
||||
Self::A1 => repvgg::Config::a1(),
|
||||
Self::A2 => repvgg::Config::a2(),
|
||||
Self::B0 => repvgg::Config::b0(),
|
||||
Self::B1 => repvgg::Config::b1(),
|
||||
Self::B2 => repvgg::Config::b2(),
|
||||
Self::B3 => repvgg::Config::b3(),
|
||||
Self::B1G4 => repvgg::Config::b1g4(),
|
||||
Self::B2G4 => repvgg::Config::b2g4(),
|
||||
Self::B3G4 => repvgg::Config::b3g4(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::A0)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = repvgg::repvgg(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -11,14 +11,14 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.3", package = "candle-core" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
bindgen_cuda = "0.1.1"
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
|
||||
num_cpus = "1.15.0"
|
||||
rayon = "1.7.0"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
candle-nn = { path = "../candle-nn", features = ["cuda"] }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3", features = ["cuda"] }
|
||||
|
@ -2,32 +2,44 @@
|
||||
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
|
||||
// variable in order to cache the compiled artifacts and avoid recompiling too often.
|
||||
use anyhow::{Context, Result};
|
||||
use rayon::prelude::*;
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
|
||||
const KERNEL_FILES: [&str; 17] = [
|
||||
"kernels/flash_api.cu",
|
||||
"kernels/flash_fwd_hdim128_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim160_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim192_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim224_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim256_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim32_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim64_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim96_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim128_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim160_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim192_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim224_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim256_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim32_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim64_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim96_bf16_sm80.cu",
|
||||
"flash_api.cu",
|
||||
"flash_fwd_hdim128_fp16_sm80.cu",
|
||||
"flash_fwd_hdim160_fp16_sm80.cu",
|
||||
"flash_fwd_hdim192_fp16_sm80.cu",
|
||||
"flash_fwd_hdim224_fp16_sm80.cu",
|
||||
"flash_fwd_hdim256_fp16_sm80.cu",
|
||||
"flash_fwd_hdim32_fp16_sm80.cu",
|
||||
"flash_fwd_hdim64_fp16_sm80.cu",
|
||||
"flash_fwd_hdim96_fp16_sm80.cu",
|
||||
"flash_fwd_hdim128_bf16_sm80.cu",
|
||||
"flash_fwd_hdim160_bf16_sm80.cu",
|
||||
"flash_fwd_hdim192_bf16_sm80.cu",
|
||||
"flash_fwd_hdim224_bf16_sm80.cu",
|
||||
"flash_fwd_hdim256_bf16_sm80.cu",
|
||||
"flash_fwd_hdim32_bf16_sm80.cu",
|
||||
"flash_fwd_hdim64_bf16_sm80.cu",
|
||||
"flash_fwd_hdim96_bf16_sm80.cu",
|
||||
];
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
|
||||
|_| num_cpus::get_physical(),
|
||||
|s| usize::from_str(&s).unwrap(),
|
||||
);
|
||||
|
||||
rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(num_cpus)
|
||||
.build_global()
|
||||
.unwrap();
|
||||
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
for kernel_file in KERNEL_FILES.iter() {
|
||||
println!("cargo:rerun-if-changed={kernel_file}");
|
||||
println!("cargo:rerun-if-changed=kernels/{kernel_file}");
|
||||
}
|
||||
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
|
||||
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
|
||||
@ -54,30 +66,223 @@ fn main() -> Result<()> {
|
||||
))
|
||||
}
|
||||
};
|
||||
set_cuda_include_dir()?;
|
||||
|
||||
let kernels = KERNEL_FILES.iter().collect();
|
||||
let builder = bindgen_cuda::Builder::default()
|
||||
.kernel_paths(kernels)
|
||||
.out_dir(build_dir.clone())
|
||||
.arg("-std=c++17")
|
||||
.arg("-O3")
|
||||
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
||||
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
||||
.arg("-Icutlass/include")
|
||||
.arg("--expt-relaxed-constexpr")
|
||||
.arg("--expt-extended-lambda")
|
||||
.arg("--use_fast_math")
|
||||
.arg("--verbose");
|
||||
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
|
||||
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
|
||||
|
||||
let compute_cap = compute_cap()?;
|
||||
|
||||
let out_file = build_dir.join("libflashattention.a");
|
||||
builder.build_lib(out_file);
|
||||
|
||||
let kernel_dir = PathBuf::from("kernels");
|
||||
let cu_files: Vec<_> = KERNEL_FILES
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let mut obj_file = out_dir.join(f);
|
||||
obj_file.set_extension("o");
|
||||
(kernel_dir.join(f), obj_file)
|
||||
})
|
||||
.collect();
|
||||
let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
|
||||
let should_compile = if out_file.exists() {
|
||||
kernel_dir
|
||||
.read_dir()
|
||||
.expect("kernels folder should exist")
|
||||
.any(|entry| {
|
||||
if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
|
||||
let in_modified = entry.metadata().unwrap().modified().unwrap();
|
||||
in_modified.duration_since(*out_modified).is_ok()
|
||||
} else {
|
||||
true
|
||||
}
|
||||
})
|
||||
} else {
|
||||
true
|
||||
};
|
||||
if should_compile {
|
||||
cu_files
|
||||
.par_iter()
|
||||
.map(|(cu_file, obj_file)| {
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
command
|
||||
.arg("-std=c++17")
|
||||
.arg("-O3")
|
||||
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
||||
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("-c")
|
||||
.args(["-o", obj_file.to_str().unwrap()])
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.arg("-Icutlass/include")
|
||||
.arg("--expt-relaxed-constexpr")
|
||||
.arg("--expt-extended-lambda")
|
||||
.arg("--use_fast_math")
|
||||
.arg("--verbose");
|
||||
if let Ok(ccbin_path) = &ccbin_env {
|
||||
command
|
||||
.arg("-allow-unsupported-compiler")
|
||||
.args(["-ccbin", ccbin_path]);
|
||||
}
|
||||
command.arg(cu_file);
|
||||
let output = command
|
||||
.spawn()
|
||||
.context("failed spawning nvcc")?
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
&command,
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<()>>()?;
|
||||
let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
command
|
||||
.arg("--lib")
|
||||
.args(["-o", out_file.to_str().unwrap()])
|
||||
.args(obj_files);
|
||||
let output = command
|
||||
.spawn()
|
||||
.context("failed spawning nvcc")?
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
&command,
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
}
|
||||
}
|
||||
println!("cargo:rustc-link-search={}", build_dir.display());
|
||||
println!("cargo:rustc-link-lib=flashattention");
|
||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||
|
||||
/* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
|
||||
finishing to run for some reason. Calling nvcc manually worked fine.
|
||||
cc::Build::new()
|
||||
.cuda(true)
|
||||
.include("cutlass/include")
|
||||
.flag("--expt-relaxed-constexpr")
|
||||
.flag("--default-stream")
|
||||
.flag("per-thread")
|
||||
.flag(&format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.file("kernels/flash_fwd_hdim32_fp16_sm80.cu")
|
||||
.compile("flashattn");
|
||||
*/
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_cuda_include_dir() -> Result<()> {
|
||||
// NOTE: copied from cudarc build.rs.
|
||||
let env_vars = [
|
||||
"CUDA_PATH",
|
||||
"CUDA_ROOT",
|
||||
"CUDA_TOOLKIT_ROOT_DIR",
|
||||
"CUDNN_LIB",
|
||||
];
|
||||
let env_vars = env_vars
|
||||
.into_iter()
|
||||
.map(std::env::var)
|
||||
.filter_map(Result::ok)
|
||||
.map(Into::<PathBuf>::into);
|
||||
|
||||
let roots = [
|
||||
"/usr",
|
||||
"/usr/local/cuda",
|
||||
"/opt/cuda",
|
||||
"/usr/lib/cuda",
|
||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||
"C:/CUDA",
|
||||
];
|
||||
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||
let root = env_vars
|
||||
.chain(roots)
|
||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||
.context("cannot find include/cuda.h")?;
|
||||
println!(
|
||||
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||
root.join("include").display()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn compute_cap() -> Result<usize> {
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
|
||||
// Try to parse compute caps from env
|
||||
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
|
||||
compute_cap_str
|
||||
.parse::<usize>()
|
||||
.context("Could not parse compute cap")?
|
||||
} else {
|
||||
// Use nvidia-smi to get the current compute cap
|
||||
let out = std::process::Command::new("nvidia-smi")
|
||||
.arg("--query-gpu=compute_cap")
|
||||
.arg("--format=csv")
|
||||
.output()
|
||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||
let mut lines = out.lines();
|
||||
assert_eq!(
|
||||
lines.next().context("missing line in stdout")?,
|
||||
"compute_cap"
|
||||
);
|
||||
let cap = lines
|
||||
.next()
|
||||
.context("missing line in stdout")?
|
||||
.replace('.', "");
|
||||
let cap = cap
|
||||
.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as int {cap}"))?;
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
|
||||
cap
|
||||
};
|
||||
|
||||
// Grab available GPU codes from nvcc and select the highest one
|
||||
let (supported_nvcc_codes, max_nvcc_code) = {
|
||||
let out = std::process::Command::new("nvcc")
|
||||
.arg("--list-gpu-code")
|
||||
.output()
|
||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||
|
||||
let out = out.lines().collect::<Vec<&str>>();
|
||||
let mut codes = Vec::with_capacity(out.len());
|
||||
for code in out {
|
||||
let code = code.split('_').collect::<Vec<&str>>();
|
||||
if !code.is_empty() && code.contains(&"sm") {
|
||||
if let Ok(num) = code[1].parse::<usize>() {
|
||||
codes.push(num);
|
||||
}
|
||||
}
|
||||
}
|
||||
codes.sort();
|
||||
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
|
||||
(codes, max_nvcc_code)
|
||||
};
|
||||
|
||||
// Check that nvcc supports the asked compute caps
|
||||
if !supported_nvcc_codes.contains(&compute_cap) {
|
||||
anyhow::bail!(
|
||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
|
||||
);
|
||||
}
|
||||
if compute_cap > max_nvcc_code {
|
||||
anyhow::bail!(
|
||||
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(compute_cap)
|
||||
}
|
||||
|
@ -1,62 +0,0 @@
|
||||
#include <cmath>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_causal, typename Engine, typename Layout>
|
||||
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
|
||||
const int col_idx_offset_,
|
||||
const int max_seqlen_k,
|
||||
const int row_idx_offset,
|
||||
const int max_seqlen_q,
|
||||
const int warp_row_stride,
|
||||
const float alibi_slope) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else { // Bias depends on both row_idx and col_idx
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace flash
|
@ -14,12 +14,9 @@ struct BlockInfo {
|
||||
template<typename Params>
|
||||
__device__ BlockInfo(const Params ¶ms, const int bidb)
|
||||
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
|
||||
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
|
||||
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
|
||||
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
||||
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
||||
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
||||
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
|
||||
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
|
||||
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
|
||||
{
|
||||
}
|
||||
|
||||
@ -35,10 +32,8 @@ struct BlockInfo {
|
||||
|
||||
const int sum_s_q;
|
||||
const int sum_s_k;
|
||||
const int actual_seqlen_q;
|
||||
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
|
||||
const int seqlen_k_cache;
|
||||
const int actual_seqlen_k;
|
||||
const uint32_t actual_seqlen_q;
|
||||
const uint32_t actual_seqlen_k;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -7,6 +7,15 @@
|
||||
#include <cuda.h>
|
||||
#include <vector>
|
||||
|
||||
// #ifdef OLD_GENERATOR_PATH
|
||||
// #include <ATen/CUDAGeneratorImpl.h>
|
||||
// #else
|
||||
// #include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
// #endif
|
||||
//
|
||||
// #include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||
|
||||
|
||||
constexpr int TOTAL_DIM = 0;
|
||||
constexpr int H_DIM = 1;
|
||||
constexpr int D_DIM = 2;
|
||||
@ -44,7 +53,6 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
|
||||
// The O matrix (output).
|
||||
void * __restrict__ o_ptr;
|
||||
void * __restrict__ oaccum_ptr;
|
||||
|
||||
// The stride between rows of O.
|
||||
index_t o_batch_stride;
|
||||
@ -56,10 +64,9 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
|
||||
// The pointer to the softmax sum.
|
||||
void * __restrict__ softmax_lse_ptr;
|
||||
void * __restrict__ softmax_lseaccum_ptr;
|
||||
|
||||
// The dimensions.
|
||||
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
|
||||
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
|
||||
|
||||
// The scaling factors for the kernel.
|
||||
float scale_softmax;
|
||||
@ -69,30 +76,8 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
int * __restrict__ cu_seqlens_q;
|
||||
int * __restrict__ cu_seqlens_k;
|
||||
|
||||
// If provided, the actual length of each k sequence.
|
||||
int * __restrict__ seqused_k;
|
||||
|
||||
int *__restrict__ blockmask;
|
||||
|
||||
// The K_new and V_new matrices.
|
||||
void * __restrict__ knew_ptr;
|
||||
void * __restrict__ vnew_ptr;
|
||||
|
||||
// The stride between rows of the Q, K and V matrices.
|
||||
index_t knew_batch_stride;
|
||||
index_t vnew_batch_stride;
|
||||
index_t knew_row_stride;
|
||||
index_t vnew_row_stride;
|
||||
index_t knew_head_stride;
|
||||
index_t vnew_head_stride;
|
||||
|
||||
// The cos and sin matrices for rotary embedding.
|
||||
void * __restrict__ rotary_cos_ptr;
|
||||
void * __restrict__ rotary_sin_ptr;
|
||||
|
||||
// The indices to index into the KV cache.
|
||||
int *__restrict__ cache_batch_idx;
|
||||
|
||||
// The dropout probability (probability of keeping an activation).
|
||||
float p_dropout;
|
||||
// uint32_t p_dropout_in_uint;
|
||||
@ -103,22 +88,11 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
float rp_dropout;
|
||||
float scale_softmax_rp_dropout;
|
||||
|
||||
// Local window size
|
||||
int window_size_left, window_size_right;
|
||||
// Random state.
|
||||
// at::PhiloxCudaState philox_args;
|
||||
|
||||
bool is_bf16;
|
||||
bool is_causal;
|
||||
|
||||
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
||||
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
||||
bool is_seqlens_k_cumulative;
|
||||
|
||||
bool is_rotary_interleaved;
|
||||
|
||||
int num_splits; // For split-KV version
|
||||
|
||||
void * __restrict__ alibi_slopes_ptr;
|
||||
index_t alibi_slopes_batch_stride;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -158,14 +132,10 @@ struct Flash_bwd_params : public Flash_fwd_params {
|
||||
|
||||
// The pointer to the softmax d sum.
|
||||
void *__restrict__ dsoftmax_sum;
|
||||
|
||||
bool deterministic;
|
||||
index_t dq_accum_split_stride;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure);
|
||||
|
@ -1,15 +1,17 @@
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
// if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
|
||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
// } else {
|
||||
// run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
|
||||
// }
|
||||
});
|
||||
});
|
||||
// void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
|
||||
// });
|
||||
// }
|
||||
|
||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
extern "C" void run_mha(
|
||||
@ -18,7 +20,6 @@ extern "C" void run_mha(
|
||||
void *v_ptr,
|
||||
void *o_ptr,
|
||||
void *softmax_lse_ptr,
|
||||
void *alibi_slopes_ptr,
|
||||
|
||||
int32_t *cu_seqlens_q_ptr,
|
||||
int32_t *cu_seqlens_k_ptr,
|
||||
@ -27,7 +28,6 @@ extern "C" void run_mha(
|
||||
uint32_t k_batch_stride,
|
||||
uint32_t v_batch_stride,
|
||||
uint32_t o_batch_stride,
|
||||
uint32_t alibi_slopes_batch_stride,
|
||||
|
||||
uint32_t q_row_stride,
|
||||
uint32_t k_row_stride,
|
||||
@ -51,11 +51,8 @@ extern "C" void run_mha(
|
||||
uint32_t seqlen_q_rounded,
|
||||
uint32_t seqlen_k_rounded,
|
||||
|
||||
int is_bf16,
|
||||
int is_causal,
|
||||
|
||||
int window_size_left,
|
||||
int window_size_right
|
||||
int is_bf16
|
||||
) {
|
||||
Flash_fwd_params params;
|
||||
// Reset the parameters
|
||||
@ -68,14 +65,12 @@ extern "C" void run_mha(
|
||||
params.o_ptr = o_ptr;
|
||||
|
||||
params.softmax_lse_ptr = softmax_lse_ptr;
|
||||
params.alibi_slopes_ptr = alibi_slopes_ptr;
|
||||
|
||||
// All stride are in elements, not bytes.
|
||||
params.q_batch_stride = q_batch_stride;
|
||||
params.k_batch_stride = k_batch_stride;
|
||||
params.v_batch_stride = v_batch_stride;
|
||||
params.o_batch_stride = o_batch_stride;
|
||||
params.alibi_slopes_batch_stride = alibi_slopes_batch_stride;
|
||||
|
||||
params.q_row_stride = q_row_stride;
|
||||
params.k_row_stride = k_row_stride;
|
||||
@ -97,6 +92,7 @@ extern "C" void run_mha(
|
||||
params.seqlen_k_rounded = seqlen_k_rounded;
|
||||
params.d = d;
|
||||
params.d_rounded = d_rounded;
|
||||
params.is_causal = is_causal;
|
||||
|
||||
// Set the different scale values.
|
||||
params.scale_softmax = softmax_scale;
|
||||
@ -110,14 +106,6 @@ extern "C" void run_mha(
|
||||
params.cu_seqlens_q = cu_seqlens_q_ptr;
|
||||
params.cu_seqlens_k = cu_seqlens_k_ptr;
|
||||
params.p_ptr = nullptr; // used for `return_softmax`.
|
||||
params.seqused_k = nullptr;
|
||||
|
||||
params.is_causal = is_causal;
|
||||
params.window_size_left = window_size_left;
|
||||
params.window_size_right = window_size_right;
|
||||
|
||||
params.is_seqlens_k_cumulative = true;
|
||||
params.num_splits = 1;
|
||||
|
||||
cudaStream_t stream = 0; // Use the default stream.
|
||||
run_mha_fwd(params, stream);
|
||||
|
@ -1,10 +1,19 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// if (params.p_dropout == 1.f) {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
|
||||
// } else {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
|
||||
// }
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
}
|
@ -1,10 +1,32 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// if (params.p_dropout == 1.f) {
|
||||
// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, false, elem_type>, false>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, true, elem_type>, false>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, true, elem_type>, false>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, false>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 64, 4, false, false, elem_type>, false>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 128, 4, false, false, elem_type>, false>(params, stream);
|
||||
// // 1st ones are good for H100, A100
|
||||
// // 2nd one is good for A6000 bc we get slightly better occupancy
|
||||
// } else {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, false, elem_type>, true>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, true, elem_type>, true>(params, stream);
|
||||
// // 1st one is good for H100, A100, A6000
|
||||
// }
|
||||
// }
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
|
||||
}
|
||||
}
|
@ -1,10 +1,17 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// });
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
}
|
@ -1,10 +1,27 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, true, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 128, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 8, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 128, 8, false, elem_type>>(params, stream);
|
||||
// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest.
|
||||
// // For A100, H100, 1st is fastest.
|
||||
// });
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
|
||||
}
|
||||
}
|
@ -1,10 +1,16 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// });
|
||||
// }
|
||||
template<> void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
|
@ -1,10 +1,27 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// // This one is slightly faster for causal?
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 8, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 128, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 128, 8, false, elem_type>>(params, stream);
|
||||
// });
|
||||
// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout
|
||||
// // For A6000, 1st is faster when causal, 3rd is faster when not causal
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
|
||||
}
|
||||
}
|
@ -1,10 +1,9 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
template<> void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
|
@ -1,10 +1,9 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
template<> void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
|
||||
}
|
||||
|
@ -1,10 +1,9 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
template<> void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
|
@ -1,10 +1,9 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
template<> void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
|
||||
}
|
||||
|
@ -1,10 +1,10 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
}
|
@ -1,10 +1,23 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 128, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// // For dropout there might be a lot of register spilling?
|
||||
// // These two are very slow due to register spilling
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 128, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 256, 4, false, elem_type>>(params, stream);
|
||||
// // This one is slightly slower
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 64, 4, false, elem_type>>(params, stream);
|
||||
// });
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
|
||||
}
|
||||
}
|
@ -1,10 +1,19 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// if (params.p_dropout == 1.f) {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
|
||||
// } else {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
|
||||
// }
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
}
|
@ -1,10 +1,26 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// if (params.p_dropout == 1.f) {
|
||||
// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
|
||||
// // Using block size (64 x 256) is 27% slower for seqlen=2k
|
||||
// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, elem_type>, false>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, false>(params, stream);
|
||||
// } else {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, true>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, true>(params, stream);
|
||||
// }
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
|
||||
}
|
||||
}
|
@ -1,10 +1,17 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
|
||||
// });
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
}
|
@ -1,10 +1,23 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, true, elem_type>, Is_dropout>(params, stream);
|
||||
// // This 3rd one is good for H100, and A100, A6000
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, true, elem_type>, Is_dropout>(params, stream);
|
||||
// // These two are always slower
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, elem_type>>(params, stream);
|
||||
// });
|
||||
// }
|
||||
template<> void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
|
||||
}
|
||||
}
|
@ -4,18 +4,20 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
#include <cute/algorithm/gemm.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
|
||||
#include "block_info.h"
|
||||
#include "kernel_traits.h"
|
||||
#include "utils.h"
|
||||
#include "softmax.h"
|
||||
|
||||
#include "alibi.h"
|
||||
#include "philox.cuh"
|
||||
|
||||
namespace flash {
|
||||
|
||||
@ -23,6 +25,49 @@ using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int MMA_M,
|
||||
class... Args,
|
||||
class TiledMMA>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_A_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma) {
|
||||
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
|
||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
|
||||
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
|
||||
constexpr int MMAStride_M = MMA_M * AtomShape_M;
|
||||
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
|
||||
Stride<_1, Int<MMAStride_M>> >{},
|
||||
make_layout(size<2>(TileShape_MNK{})));
|
||||
// if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int MMA_M,
|
||||
class... Args,
|
||||
class TiledMMA>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_C_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma) {
|
||||
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
|
||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
|
||||
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
|
||||
constexpr int MMAStride_M = MMA_M * AtomShape_M;
|
||||
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
|
||||
Stride<_1, Int<MMAStride_M>> >{},
|
||||
// TODO: Shouldn't this be size<1>?
|
||||
make_layout(size<2>(TileShape_MNK{})));
|
||||
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
|
||||
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
|
||||
Tensor2 &acc_o, float softmax_scale_log2) {
|
||||
@ -32,7 +77,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
|
||||
flash::reduce_sum(scores, scores_sum);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(scores_max);
|
||||
cute::copy(scores_max, scores_max_prev);
|
||||
copy(scores_max, scores_max_prev);
|
||||
flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
@ -58,22 +103,23 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
|
||||
inline __device__ void write_softmax_to_gmem(
|
||||
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_tiled_copy_P
|
||||
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_thr_copy_P
|
||||
) {
|
||||
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
|
||||
Layout l = tOrP.layout();
|
||||
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
|
||||
// TODO(laurent): reactivate the following
|
||||
// CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<1>(tPrP); ++mi) {
|
||||
cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
|
||||
copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
|
||||
|
||||
using Element = typename Kernel_traits::Element;
|
||||
@ -92,65 +138,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
constexpr int kNWarps = Kernel_traits::kNWarps;
|
||||
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
|
||||
|
||||
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
|
||||
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
||||
const BlockInfo</*Varlen=*/!Is_even_N> binfo(params, bidb);
|
||||
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
|
||||
|
||||
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
|
||||
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
|
||||
if (Is_causal || Is_local) {
|
||||
n_block_max = std::min(n_block_max,
|
||||
cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
|
||||
if (Is_causal) {
|
||||
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN));
|
||||
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
|
||||
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
|
||||
// }
|
||||
}
|
||||
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
|
||||
// Otherwise we might read OOB elements from gK and gV.
|
||||
if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
|
||||
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
|
||||
// exit early and no one saves the rng state.
|
||||
// if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
|
||||
// auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
// params.rng_state[0] = std::get<0>(seeds);
|
||||
// params.rng_state[1] = std::get<1>(seeds);
|
||||
// params.rng_state[0] = 0;
|
||||
// params.rng_state[1] = 0;
|
||||
// }
|
||||
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
||||
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.o_row_stride, _1{}));
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
|
||||
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||
Tensor tOrO = make_tensor<Element>(shape(tOgO));
|
||||
clear(tOrO);
|
||||
// Construct identity layout for sO
|
||||
Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
// Repeat the partitioning with identity layouts
|
||||
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
|
||||
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
|
||||
if (!Is_even_K) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
||||
}
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(tOgO); ++m) {
|
||||
const int row = get<0>(tOcO(0, m, 0));
|
||||
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
|
||||
}
|
||||
return;
|
||||
}
|
||||
// if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }
|
||||
|
||||
// We iterate over the blocks in reverse order. This is because the last block is the only one
|
||||
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
|
||||
@ -188,10 +185,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
|
||||
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
|
||||
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
|
||||
typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
|
||||
auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
|
||||
auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx);
|
||||
auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx);
|
||||
|
||||
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
|
||||
@ -213,18 +208,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
// Copy Atom retiling
|
||||
//
|
||||
|
||||
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
|
||||
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
|
||||
auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
|
||||
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
|
||||
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
|
||||
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
|
||||
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
|
||||
|
||||
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
|
||||
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
|
||||
auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
|
||||
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
|
||||
|
||||
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
|
||||
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
|
||||
auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx);
|
||||
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
|
||||
|
||||
// TODO: this might need to change if we change the mma instruction in SM70
|
||||
@ -275,8 +268,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
Tensor tQrQ = make_fragment_like(tQgQ);
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
binfo.actual_seqlen_q - m_block * kBlockM);
|
||||
flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
binfo.actual_seqlen_q - m_block * kBlockM);
|
||||
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
|
||||
|
||||
// // Copy rmem to smem
|
||||
@ -292,14 +285,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
__syncthreads();
|
||||
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
||||
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
|
||||
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
|
||||
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
|
||||
binfo.actual_seqlen_k - n_block * kBlockN);
|
||||
flash::copy<Is_even_N, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
|
||||
binfo.actual_seqlen_k - n_block * kBlockN);
|
||||
cute::cp_async_fence();
|
||||
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
|
||||
// __syncthreads();
|
||||
@ -309,7 +302,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
__syncthreads();
|
||||
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
||||
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
|
||||
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
|
||||
}
|
||||
|
||||
// auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
@ -320,19 +313,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
clear(acc_o);
|
||||
|
||||
float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
|
||||
|
||||
// For performance reason, we separate out two kinds of iterations:
|
||||
// those that need masking on S, and those that don't.
|
||||
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
||||
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
|
||||
// We will have at least 1 "masking" iteration.
|
||||
|
||||
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
|
||||
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
|
||||
constexpr int n_masking_steps = (!Is_causal && !Is_local)
|
||||
? 1
|
||||
: ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
|
||||
constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1;
|
||||
#pragma unroll
|
||||
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
|
||||
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
||||
@ -343,42 +330,28 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
// Advance gV
|
||||
if (masking_step > 0) {
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||
} else {
|
||||
// Clear the smem tiles to account for predicated off loads
|
||||
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
}
|
||||
cute::cp_async_fence();
|
||||
|
||||
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
||||
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
|
||||
smem_thr_copy_Q, smem_thr_copy_K
|
||||
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
|
||||
);
|
||||
// if (cute::thread0()) { print(acc_s); }
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
// if (cute::thread0()) { print_tensor(scores); }
|
||||
// if (cute::thread0()) { print(scores); }
|
||||
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
|
||||
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
|
||||
// can produce Inf / NaN.
|
||||
|
||||
if (Has_alibi) {
|
||||
flash::apply_alibi<Is_causal>(
|
||||
scores,
|
||||
n_block * kBlockN,
|
||||
binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q,
|
||||
kNWarps * 16,
|
||||
alibi_slope
|
||||
);
|
||||
}
|
||||
|
||||
if (!Is_causal && !Is_local) {
|
||||
if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
|
||||
if (!Is_causal) {
|
||||
if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
|
||||
} else {
|
||||
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
|
||||
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
|
||||
@ -391,24 +364,20 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
// Idk why it's get<1> and not get<0> of the stride.
|
||||
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
|
||||
// I can't get the stride from idx_row
|
||||
flash::apply_mask_local</*HasWSLeft=*/Is_local>(
|
||||
scores, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||
// m_block * kBlockM + get<0>(idx_row(0)),
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q, kNWarps * 16,
|
||||
params.window_size_left, params.window_size_right
|
||||
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16
|
||||
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16
|
||||
);
|
||||
// if (cute::thread0()) { print_tensor(scores); }
|
||||
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||
// m_block * kBlockM + get<0>(idx_row(0)),
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
kNWarps * 16);
|
||||
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
|
||||
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
|
||||
}
|
||||
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if (n_block > n_block_min) {
|
||||
if (n_block > 0) {
|
||||
// Advance gK
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
||||
// isn't right and we get race conditions.
|
||||
cute::cp_async_fence();
|
||||
@ -416,24 +385,24 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
// TODO: when we have key_padding_mask we'll need to Check_inf
|
||||
masking_step == 0
|
||||
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
||||
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
||||
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
|
||||
// Convert scores from fp32 to fp16/bf16
|
||||
Tensor rP = flash::convert_type<Element>(scores);
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
int block_col_idx = n_block * (kBlockN / 32);
|
||||
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
uint32_t block_col_idx = n_block * (kBlockN / 32);
|
||||
if (Return_softmax) {
|
||||
Tensor tOrP_copy = make_fragment_like(tOrP);
|
||||
cute::copy(tOrP, tOrP_copy);
|
||||
copy(tOrP, tOrP_copy);
|
||||
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps
|
||||
);
|
||||
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
|
||||
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
|
||||
tPgP.data() = tPgP.data() + (-kBlockN);
|
||||
}
|
||||
if (Is_dropout) {
|
||||
@ -442,38 +411,37 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
}
|
||||
// if (cute::thread0()) { print(tOrP); }
|
||||
|
||||
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
|
||||
// if (cute::thread0()) { print(scores); }
|
||||
|
||||
// This check is at the end of the loop since we always have at least 1 iteration
|
||||
if (n_masking_steps > 1 && n_block <= n_block_min) {
|
||||
if (n_masking_steps > 1 && n_block <= 0) {
|
||||
--n_block;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These are the iterations where we don't need masking on S
|
||||
for (; n_block >= n_block_min; --n_block) {
|
||||
for (; n_block >= 0; --n_block) {
|
||||
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
||||
clear(acc_s);
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
// Advance gV
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||
cute::cp_async_fence();
|
||||
|
||||
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
||||
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
|
||||
smem_thr_copy_Q, smem_thr_copy_K
|
||||
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
|
||||
);
|
||||
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if (n_block > n_block_min) {
|
||||
if (n_block > 0) {
|
||||
// Advance gK
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
||||
// isn't right and we get race conditions.
|
||||
cute::cp_async_fence();
|
||||
@ -481,44 +449,22 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
|
||||
if (Has_alibi) {
|
||||
flash::apply_alibi<Is_causal>(
|
||||
scores,
|
||||
n_block * kBlockN,
|
||||
binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q,
|
||||
kNWarps * 16,
|
||||
alibi_slope
|
||||
);
|
||||
}
|
||||
|
||||
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
|
||||
flash::apply_mask_local(
|
||||
scores, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q, kNWarps * 16,
|
||||
params.window_size_left, params.window_size_right
|
||||
);
|
||||
}
|
||||
|
||||
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
|
||||
Tensor rP = flash::convert_type<Element>(scores);
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
int block_col_idx = n_block * (kBlockN / 32);
|
||||
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
uint32_t block_col_idx = n_block * (kBlockN / 32);
|
||||
if (Return_softmax) {
|
||||
Tensor tOrP_copy = make_fragment_like(tOrP);
|
||||
cute::copy(tOrP, tOrP_copy);
|
||||
copy(tOrP, tOrP_copy);
|
||||
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps
|
||||
);
|
||||
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
|
||||
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
|
||||
tPgP.data() = tPgP.data() + (-kBlockN);
|
||||
}
|
||||
if (Is_dropout) {
|
||||
@ -526,7 +472,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
block_row_idx, block_col_idx, kNWarps);
|
||||
}
|
||||
|
||||
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
|
||||
}
|
||||
|
||||
// Epilogue
|
||||
@ -550,15 +496,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
Tensor rO = flash::convert_type<Element>(acc_o);
|
||||
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
||||
// Partition sO to match the accumulator partitioning
|
||||
auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
|
||||
auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
|
||||
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
|
||||
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
|
||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
// sO has the same size as sQ, so we don't need to sync here.
|
||||
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
|
||||
|
||||
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
||||
copy(smem_thr_copy_O, taccOrO, taccOsO);
|
||||
|
||||
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
@ -569,15 +515,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
|
||||
auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx);
|
||||
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
Tensor tOrO = make_tensor<Element>(shape(tOgO));
|
||||
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
|
||||
copy(gmem_thr_copy_O, tOsO, tOrO);
|
||||
|
||||
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
|
||||
@ -603,15 +548,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
||||
}
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
inline __device__ void compute_attn(const Params ¶ms) {
|
||||
const int m_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
@ -627,7 +571,7 @@ inline __device__ void compute_attn(const Params ¶ms) {
|
||||
// the attention matrix. This way, as long as we have the batch, head, and the location of
|
||||
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
|
||||
|
||||
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
|
||||
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -4,14 +4,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// #include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "static_switch.h"
|
||||
#include "flash.h"
|
||||
#include "flash_fwd_kernel.h"
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax>
|
||||
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
|
||||
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
|
||||
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
|
||||
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
|
||||
@ -25,39 +26,35 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid(num_m_block, params.b, params.h);
|
||||
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
|
||||
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
|
||||
// for cu_seqlens_q as well.
|
||||
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
const bool return_softmax = params.p_ptr != nullptr;
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
||||
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
||||
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
||||
// int ctas_per_sm;
|
||||
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
});
|
||||
});
|
||||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
|
||||
// if (smem_size >= 48 * 1024) {
|
||||
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
// }
|
||||
int ctas_per_sm;
|
||||
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
// C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 32;
|
||||
constexpr int Headdim = 32;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
@ -67,7 +64,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 64;
|
||||
constexpr int Headdim = 64;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
@ -89,7 +86,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 96;
|
||||
constexpr int Headdim = 96;
|
||||
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
@ -115,7 +112,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 128;
|
||||
constexpr int Headdim = 128;
|
||||
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
@ -152,7 +149,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 160;
|
||||
constexpr int Headdim = 160;
|
||||
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
@ -182,7 +179,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 192;
|
||||
constexpr int Headdim = 192;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
@ -201,7 +198,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 224;
|
||||
constexpr int Headdim = 224;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
@ -227,7 +224,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 256;
|
||||
constexpr int Headdim = 256;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_sm, max_smem_per_block;
|
||||
|
@ -91,20 +91,17 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||
|
||||
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
|
||||
using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>;
|
||||
using SmemLayoutAtomVtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
|
||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposed{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||
// Maybe the VtransposeNoSwizzle just needs to have the right shape
|
||||
// And the strides don't matter?
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposedNoSwizzle{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
|
||||
|
||||
using SmemLayoutAtomO = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
@ -113,8 +110,7 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
|
||||
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
static constexpr int kSmemQCount = size(SmemLayoutQ{});
|
||||
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
||||
@ -142,11 +138,11 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
DefaultCopy
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopyO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
|
||||
@ -155,30 +151,10 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
Stride<Int<kGmemThreadsPerRowP>, _1>>;
|
||||
|
||||
using GmemTiledCopyP = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtomP{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
using GmemLayoutAtomOaccum = std::conditional_t<
|
||||
kBlockKSmem == 32,
|
||||
Layout<Shape <_16, _8>, // Thread layout, 8 threads per row
|
||||
Stride< _8, _1>>,
|
||||
Layout<Shape <_8, _16>, // Thread layout, 16 threads per row
|
||||
Stride< _16, _1>>
|
||||
>;
|
||||
using GmemTiledCopyOaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
|
||||
using GmemTiledCopyRotcossin = decltype(
|
||||
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
|
||||
GmemLayoutAtomRotcossin{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
|
||||
using GmemTiledCopyRotcossinCont = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
GmemLayoutAtomRotcossin{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
|
||||
};
|
||||
|
||||
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
|
||||
@ -247,19 +223,16 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
SmemLayoutAtomKV{},
|
||||
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||
|
||||
using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>;
|
||||
using SmemLayoutAtomKtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||
using SmemLayoutKtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomKtransposed{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
|
||||
// Maybe the KtransposeNoSwizzle just needs to have the right shape
|
||||
// And the strides don't matter?
|
||||
using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
|
||||
SmemLayoutAtomKtransposedNoSwizzle{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
|
||||
// using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
|
||||
using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
|
||||
|
||||
// TODO: generalize to other values of kBlockN
|
||||
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
|
||||
@ -277,30 +250,24 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
using SmemLayoutPdS = decltype(tile_to_shape(
|
||||
SmemLayoutAtomPdS{},
|
||||
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
|
||||
using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
|
||||
Stride<_1, Int<kPBlockN>>>;
|
||||
using SmemLayoutAtomPdStransposed = decltype(
|
||||
composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
|
||||
composition(Swizzle<kSwizzlePdS, 3, 3>{},
|
||||
Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
|
||||
Stride<_1, Int<kPBlockN>>>{}));
|
||||
using SmemLayoutPdStransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomPdStransposed{},
|
||||
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
|
||||
using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
|
||||
SmemLayoutAtomPdStransposedNoSwizzle{},
|
||||
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
|
||||
// using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
|
||||
using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
|
||||
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>;
|
||||
using SmemLayoutAtomQdOtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||
using SmemLayoutQdOtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQdOtransposed{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
|
||||
using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQdOtransposedNoSwizzle{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
|
||||
// using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
|
||||
using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
|
||||
|
||||
using SmemLayoutAtomdKV = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
@ -325,11 +292,13 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
|
||||
static constexpr int kSmemPCount = size(SmemLayoutPdS{});
|
||||
static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
|
||||
static constexpr int kSmemdPsumCount = kBlockM;
|
||||
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
||||
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
|
||||
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
|
||||
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
|
||||
static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
|
||||
static constexpr int kSmemSize = kSmemQdOSize
|
||||
+ (!Is_V_in_regs
|
||||
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
|
||||
|
@ -1,159 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
|
||||
struct Flash_kernel_traits_sm90 {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using Element = elem_type;
|
||||
static constexpr bool Has_cp_async = true;
|
||||
#else
|
||||
using Element = cutlass::half_t;
|
||||
static constexpr bool Has_cp_async = false;
|
||||
#endif
|
||||
|
||||
using ElementAccum = float;
|
||||
using index_t = uint32_t;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using MMA_Atom_Arch = std::conditional_t<
|
||||
std::is_same_v<elem_type, cutlass::half_t>,
|
||||
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
||||
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||
>;
|
||||
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
|
||||
#else
|
||||
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
|
||||
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
|
||||
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
|
||||
#else
|
||||
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
|
||||
#endif
|
||||
};
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
|
||||
typename Base=Flash_kernel_traits_sm90<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
|
||||
struct Flash_fwd_kernel_traits : public Base {
|
||||
using Element = typename Base::Element;
|
||||
using ElementAccum = typename Base::ElementAccum;
|
||||
using index_t = typename Base::index_t;
|
||||
static constexpr bool Has_cp_async = Base::Has_cp_async;
|
||||
using SmemCopyAtom = typename Base::SmemCopyAtom;
|
||||
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
|
||||
|
||||
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
|
||||
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
|
||||
|
||||
// The number of threads.
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
using TiledMma = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
|
||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||
|
||||
using SmemLayoutAtomQ = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutAtomVtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
|
||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposed{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||
// Maybe the VtransposeNoSwizzle just needs to have the right shape
|
||||
// And the strides don't matter?
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
|
||||
|
||||
using SmemLayoutAtomO = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
static constexpr int kSmemQCount = size(SmemLayoutQ{});
|
||||
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
||||
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
||||
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
|
||||
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
|
||||
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
|
||||
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
|
||||
// to the same banks.
|
||||
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
|
||||
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
|
||||
// from the same address by the same threadblock. This is slightly faster.
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
DefaultCopy
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopyO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
|
||||
using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
|
||||
Stride<Int<kGmemThreadsPerRowP>, _1>>;
|
||||
|
||||
using GmemTiledCopyP = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtomP{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
@ -8,7 +8,8 @@
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
|
||||
#include "philox.cuh"
|
||||
#include "utils.h"
|
||||
@ -116,18 +117,15 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
|
||||
const int col_idx_offset_ = 0) {
|
||||
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
const uint32_t lane_id = threadIdx.x % 32;
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2;
|
||||
if (col_idx >= max_seqlen_k) {
|
||||
// Without the "make_coord" we get wrong results
|
||||
#pragma unroll
|
||||
@ -139,30 +137,30 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasWSLeft=true, typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset,
|
||||
const int max_seqlen_q, const int warp_row_stride,
|
||||
const int window_size_left, const int window_size_right) {
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_,
|
||||
const uint32_t max_seqlen_k, const uint32_t row_idx_offset_,
|
||||
const uint32_t warp_row_stride) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
const uint32_t lane_id = threadIdx.x % 32;
|
||||
// const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4;
|
||||
const uint32_t row_idx_offset = row_idx_offset_;
|
||||
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
|
||||
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
|
||||
const uint32_t row_idx = row_idx_base + i * 8;
|
||||
const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1);
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
const uint32_t col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
|
||||
const uint32_t col_idx = col_idx_base + j;
|
||||
if (col_idx >= col_idx_limit) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
@ -176,19 +174,10 @@ inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const in
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset,
|
||||
const int max_seqlen_q, const int warp_row_stride) {
|
||||
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
|
||||
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
|
||||
max_seqlen_q, warp_row_stride, -1, 0);
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void apply_mask_causal_w_idx(
|
||||
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
|
||||
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
|
||||
const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_)
|
||||
{
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
@ -197,7 +186,7 @@ inline __device__ void apply_mask_causal_w_idx(
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
|
||||
const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0)));
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
|
||||
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
|
||||
@ -215,8 +204,8 @@ inline __device__ void apply_mask_causal_w_idx(
|
||||
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
|
||||
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
|
||||
unsigned long long seed, unsigned long long offset,
|
||||
int block_row_start, int block_col_start,
|
||||
int block_row_stride) {
|
||||
uint32_t block_row_start, uint32_t block_col_start,
|
||||
uint32_t block_row_stride) {
|
||||
// tensor has shape (8, MMA_M, MMA_N / 2)
|
||||
using T = typename Engine::value_type;
|
||||
auto encode_dropout = [](bool keep, T val) {
|
||||
|
@ -87,6 +87,46 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
inline __device__ float2 half2_unpack(uint32_t a);
|
||||
|
||||
template <>
|
||||
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
|
||||
return __half22float2(reinterpret_cast<__half2 (&)>(a));
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template <>
|
||||
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
|
||||
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert two half2's or bf162's into float, then take their dot product.
|
||||
template <typename T>
|
||||
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
|
||||
float2 af = flash::half2_unpack<T>(a);
|
||||
float2 bf = flash::half2_unpack<T>(b);
|
||||
return af.x * bf.x + af.y * bf.y;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
|
||||
template<typename T>
|
||||
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
|
||||
float sum;
|
||||
sum = flash::hfma2_to_float<T>(a.x, b.x);
|
||||
sum += flash::hfma2_to_float<T>(a.y, b.y);
|
||||
sum += flash::hfma2_to_float<T>(a.z, b.z);
|
||||
sum += flash::hfma2_to_float<T>(a.w, b.w);
|
||||
return sum;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct MaxOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||
@ -133,12 +173,10 @@ static __device__ inline T run(T x, Operator &op) {
|
||||
|
||||
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
|
||||
typename Tensor2, typename Tensor3, typename Tensor4,
|
||||
typename TiledMma, typename TiledCopyA, typename TiledCopyB,
|
||||
typename ThrCopyA, typename ThrCopyB>
|
||||
typename TiledMma, typename TiledCopy0, typename TiledCopy1>
|
||||
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
|
||||
Tensor4 const& tCsB, TiledMma tiled_mma,
|
||||
TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
|
||||
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
|
||||
TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
||||
@ -146,13 +184,13 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
|
||||
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
||||
if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
|
||||
if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
|
||||
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
|
||||
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||
if (i < size<2>(tCrA) - 1) {
|
||||
if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
|
||||
if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
|
||||
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
|
||||
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
|
||||
}
|
||||
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
|
||||
}
|
||||
@ -161,20 +199,19 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
|
||||
typename TiledMma, typename TiledCopy, typename ThrCopy>
|
||||
typename TiledMma, typename TiledCopy>
|
||||
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
|
||||
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
|
||||
ThrCopy smem_thr_copy_B) {
|
||||
TiledMma tiled_mma, TiledCopy smem_thr_copy_B) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
||||
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
||||
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
|
||||
copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||
if (i < size<2>(tCrA) - 1) {
|
||||
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
|
||||
copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
|
||||
}
|
||||
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
|
||||
}
|
||||
@ -188,10 +225,7 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
||||
// TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
|
||||
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
|
||||
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -207,13 +241,9 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
|
||||
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
|
||||
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
|
||||
// TD [2023-08-13]: Same error as above on Cutlass 3.2
|
||||
// return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
|
||||
// get<0, 1>(l),
|
||||
// get<1, 1, 1>(l));
|
||||
return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))),
|
||||
get<1>(get<0>(l)),
|
||||
get<1>(get<1>(get<1>(l))));
|
||||
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
|
||||
get<0, 1>(l),
|
||||
get<1, 1, 1>(l));
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -289,9 +319,9 @@ void cp_async_wait() {
|
||||
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
|
||||
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
|
||||
inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
||||
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
|
||||
Tensor<Engine3, Layout3> const &predicate_K, int max_MN=0) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
@ -305,13 +335,13 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || predicate_K(k)) {
|
||||
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
|
||||
copy(thr_copy, S(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
} else if (Clear_OOB_MN) {
|
||||
cute::clear(D(_, m, _));
|
||||
clear(D(_, m, _));
|
||||
}
|
||||
}
|
||||
// TD [2023-04-13]: Strange that the code below can cause race condition.
|
||||
@ -320,7 +350,7 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
|
||||
// #pragma unroll
|
||||
// for (int m = 0; m < size<1>(S); ++m) {
|
||||
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
// copy(tiled_copy, S(_, m, _), D(_, m, _));
|
||||
// copy(thr_copy, S(_, m, _), D(_, m, _));
|
||||
// } else if (Clear_OOB_MN) {
|
||||
// clear(D(_, m, _));
|
||||
// }
|
||||
@ -332,7 +362,7 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
|
||||
// #pragma unroll
|
||||
// for (int m = 0; m < size<1>(S); ++m) {
|
||||
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
// copy(tiled_copy, S(_, m, k), D(_, m, k));
|
||||
// copy(thr_copy, S(_, m, k), D(_, m, k));
|
||||
// } else if (Clear_OOB_MN) {
|
||||
// clear(D(_, m, k));
|
||||
// }
|
||||
|
@ -7,8 +7,6 @@ extern "C" {
|
||||
v_ptr: *const c_void,
|
||||
o_ptr: *const c_void,
|
||||
softmax_lse_ptr: *const c_void,
|
||||
alibi_slopes_ptr: *const c_void,
|
||||
|
||||
cu_seqlens_q_ptr: *const i32,
|
||||
cu_seqlens_k_ptr: *const i32,
|
||||
|
||||
@ -16,7 +14,6 @@ extern "C" {
|
||||
k_batch_stride: u32,
|
||||
v_batch_stride: u32,
|
||||
o_batch_stride: u32,
|
||||
alibi_slopes_batch_stride: u32,
|
||||
|
||||
q_row_stride: u32,
|
||||
k_row_stride: u32,
|
||||
@ -40,11 +37,8 @@ extern "C" {
|
||||
seqlen_q_rounded: u32,
|
||||
seqlen_k_rounded: u32,
|
||||
|
||||
is_bf16: c_int,
|
||||
is_causal: c_int,
|
||||
|
||||
window_size_left: c_int,
|
||||
window_size_right: c_int,
|
||||
is_bf16: c_int,
|
||||
);
|
||||
|
||||
}
|
||||
|
@ -3,14 +3,12 @@ mod ffi;
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||
use candle::cuda_backend::WrapErr;
|
||||
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
|
||||
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
|
||||
use half::{bf16, f16};
|
||||
|
||||
pub struct FlashAttn {
|
||||
pub softmax_scale: f32,
|
||||
pub alibi_slopes: Option<Tensor>,
|
||||
pub window_size_left: Option<usize>,
|
||||
pub window_size_right: Option<usize>,
|
||||
pub causal: bool,
|
||||
}
|
||||
|
||||
fn round_multiple(x: usize, m: usize) -> usize {
|
||||
@ -87,51 +85,6 @@ impl FlashAttn {
|
||||
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
|
||||
}
|
||||
|
||||
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||
if alibi_slopes.dtype() != DType::F32 {
|
||||
candle::bail!(
|
||||
"DType mismatch alibi_slopes {:?}, expected {:?}",
|
||||
alibi_slopes.dtype(),
|
||||
DType::F32
|
||||
);
|
||||
}
|
||||
|
||||
let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
|
||||
|
||||
if num_heads != alibi_slopes_layout.shape().dims1()? {
|
||||
candle::bail!(
|
||||
"shape mismatch alibi_slopes {:?}, expected {:?}",
|
||||
alibi_slopes_layout.shape(),
|
||||
(num_heads)
|
||||
);
|
||||
}
|
||||
|
||||
let alibi_slopes = match &*alibi_slopes {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("alibi_slopes must be a cuda tensor"),
|
||||
};
|
||||
|
||||
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||
|
||||
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
||||
} else {
|
||||
std::ptr::null()
|
||||
};
|
||||
|
||||
// if window_size_left > self.max_seqlen_k or None => -1
|
||||
let mut window_size_left = self
|
||||
.window_size_left
|
||||
.filter(|v| v <= &seqlen_k)
|
||||
.map(|v| v as i32)
|
||||
.unwrap_or(-1);
|
||||
|
||||
// if window_size_right > self.max_seqlen_k or None => -1
|
||||
let mut window_size_right = self
|
||||
.window_size_right
|
||||
.filter(|v| v <= &seqlen_k)
|
||||
.map(|v| v as i32)
|
||||
.unwrap_or(-1);
|
||||
|
||||
let head_size = round_multiple(head_size_og, 8);
|
||||
let head_size_rounded = round_multiple(head_size, 32);
|
||||
let seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
@ -141,22 +94,9 @@ impl FlashAttn {
|
||||
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
||||
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
|
||||
|
||||
let causal = if self.causal { 1 } else { 0 };
|
||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||
|
||||
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
|
||||
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
|
||||
let is_causal = if window_size_left < 0 && window_size_right == 0 {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
if window_size_left < 0 && window_size_right >= 0 {
|
||||
window_size_left = seqlen_k as i32;
|
||||
}
|
||||
if window_size_left >= 0 && window_size_right < 0 {
|
||||
window_size_right = seqlen_k as i32;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
||||
@ -169,14 +109,12 @@ impl FlashAttn {
|
||||
v_ptr,
|
||||
dst_ptr,
|
||||
softmax_lse_ptr,
|
||||
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
||||
/* cu_seqlens_q_ptr */ std::ptr::null(),
|
||||
/* cu_seqlens_k_ptr */ std::ptr::null(),
|
||||
/* q_batch_stride */ q_stride[0] as u32,
|
||||
/* k_batch_stride */ k_stride[0] as u32,
|
||||
/* v_batch_stride */ v_stride[0] as u32,
|
||||
/* o_batch_stride */ o_stride[0] as u32,
|
||||
/* alibi_slopes_batch_stride */ 0,
|
||||
/* q_row_stride */ q_stride[q_rank - 3] as u32,
|
||||
/* k_row_stride */ k_stride[k_rank - 3] as u32,
|
||||
/* v_row_stride */ v_stride[v_rank - 3] as u32,
|
||||
@ -195,10 +133,8 @@ impl FlashAttn {
|
||||
/* seqlen_k */ seqlen_k as u32,
|
||||
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||
/* is_causal */ causal,
|
||||
/* is_bf16 */ is_bf16,
|
||||
/* is_causal */ is_causal,
|
||||
/* window_size_left */ window_size_left,
|
||||
/* window_size_right */ window_size_right,
|
||||
)
|
||||
}
|
||||
|
||||
@ -261,137 +197,20 @@ pub fn flash_attn(
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
let window_size_left = None;
|
||||
let window_size_right = if causal { Some(0) } else { None };
|
||||
|
||||
let op = FlashAttn {
|
||||
softmax_scale,
|
||||
alibi_slopes: None,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
/// Flash-attention v2 layer.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `window_size_left` - Limit left attention to value tokens.
|
||||
/// * `window_size_right` - Limit right attention to value tokens.
|
||||
///
|
||||
/// # Causal mask
|
||||
///
|
||||
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
|
||||
/// of `Q @ K^T`
|
||||
///
|
||||
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
pub fn flash_attn_windowed(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
window_size_left: Option<usize>,
|
||||
window_size_right: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
let op = FlashAttn {
|
||||
softmax_scale,
|
||||
alibi_slopes: None,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
/// Flash-attention v2 layer.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
|
||||
///
|
||||
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
pub fn flash_attn_alibi(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
alibi_slopes: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
let window_size_left = None;
|
||||
let window_size_right = if causal { Some(0) } else { None };
|
||||
|
||||
let op = FlashAttn {
|
||||
softmax_scale,
|
||||
alibi_slopes: Some(alibi_slopes.clone()),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
/// Flash-attention v2 layer.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
|
||||
/// * `window_size_left` - Limit left attention to value tokens.
|
||||
/// * `window_size_right` - Limit right attention to value tokens.
|
||||
///
|
||||
/// # Causal mask
|
||||
///
|
||||
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
|
||||
/// of `Q @ K^T`
|
||||
///
|
||||
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
pub fn flash_attn_alibi_windowed(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
alibi_slopes: &Tensor,
|
||||
softmax_scale: f32,
|
||||
window_size_left: Option<usize>,
|
||||
window_size_right: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
let op = FlashAttn {
|
||||
softmax_scale,
|
||||
alibi_slopes: Some(alibi_slopes.clone()),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
causal,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
struct FlashAttnVarLen {
|
||||
pub softmax_scale: f32,
|
||||
pub max_seqlen_q: usize,
|
||||
pub max_seqlen_k: usize,
|
||||
pub seqlens_q: Tensor,
|
||||
pub seqlens_k: Tensor,
|
||||
pub alibi_slopes: Option<Tensor>,
|
||||
pub window_size_left: Option<usize>,
|
||||
pub window_size_right: Option<usize>,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
max_seqlen_q: usize,
|
||||
max_seqlen_k: usize,
|
||||
seqlens_q: Tensor,
|
||||
seqlens_k: Tensor,
|
||||
}
|
||||
|
||||
impl FlashAttnVarLen {
|
||||
@ -492,54 +311,7 @@ impl FlashAttnVarLen {
|
||||
if nseqlens_k != nseqlens_q {
|
||||
candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}")
|
||||
}
|
||||
|
||||
let batch_size = nseqlens_q - 1;
|
||||
|
||||
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||
if alibi_slopes.dtype() != DType::F32 {
|
||||
candle::bail!(
|
||||
"DType mismatch alibi_slopes {:?}, expected {:?}",
|
||||
alibi_slopes.dtype(),
|
||||
DType::F32
|
||||
);
|
||||
}
|
||||
|
||||
let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
|
||||
|
||||
if num_heads != alibi_slopes_layout.shape().dims1()? {
|
||||
candle::bail!(
|
||||
"shape mismatch alibi_slopes {:?}, expected {:?}",
|
||||
alibi_slopes_layout.shape(),
|
||||
(num_heads)
|
||||
);
|
||||
}
|
||||
|
||||
let alibi_slopes = match &*alibi_slopes {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("alibi_slopes must be a cuda tensor"),
|
||||
};
|
||||
|
||||
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||
|
||||
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
||||
} else {
|
||||
std::ptr::null()
|
||||
};
|
||||
|
||||
// if window_size_left > self.max_seqlen_k or None => -1
|
||||
let mut window_size_left = self
|
||||
.window_size_left
|
||||
.filter(|v| v <= &self.max_seqlen_k)
|
||||
.map(|v| v as i32)
|
||||
.unwrap_or(-1);
|
||||
|
||||
// if window_size_right > self.max_seqlen_k or None => -1
|
||||
let mut window_size_right = self
|
||||
.window_size_right
|
||||
.filter(|v| v <= &self.max_seqlen_k)
|
||||
.map(|v| v as i32)
|
||||
.unwrap_or(-1);
|
||||
|
||||
let head_size = round_multiple(head_size_og, 8);
|
||||
let head_size_rounded = round_multiple(head_size, 32);
|
||||
let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128);
|
||||
@ -551,22 +323,9 @@ impl FlashAttnVarLen {
|
||||
.alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
|
||||
.w()?;
|
||||
|
||||
let causal = if self.causal { 1 } else { 0 };
|
||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||
|
||||
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
|
||||
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
|
||||
let is_causal = if window_size_left < 0 && window_size_right == 0 {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
if window_size_left < 0 && window_size_right >= 0 {
|
||||
window_size_left = self.max_seqlen_k as i32;
|
||||
}
|
||||
if window_size_left >= 0 && window_size_right < 0 {
|
||||
window_size_right = self.max_seqlen_k as i32;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
||||
@ -581,14 +340,12 @@ impl FlashAttnVarLen {
|
||||
v_ptr,
|
||||
dst_ptr,
|
||||
softmax_lse_ptr,
|
||||
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
||||
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
|
||||
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
|
||||
/* q_batch_stride */ 0,
|
||||
/* k_batch_stride */ 0,
|
||||
/* v_batch_stride */ 0,
|
||||
/* o_batch_stride */ 0,
|
||||
/* alibi_slopes_batch_stride */ 0,
|
||||
/* q_row_stride */ q_stride[q_rank - 3] as u32,
|
||||
/* k_row_stride */ k_stride[k_rank - 3] as u32,
|
||||
/* v_row_stride */ v_stride[v_rank - 3] as u32,
|
||||
@ -607,10 +364,8 @@ impl FlashAttnVarLen {
|
||||
/* seqlen_k */ self.max_seqlen_k as u32,
|
||||
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||
/* is_causal */ causal,
|
||||
/* is_bf16 */ is_bf16,
|
||||
/* is_causal */ is_causal,
|
||||
/* window_size_left */ window_size_left,
|
||||
/* window_size_right */ window_size_right,
|
||||
)
|
||||
}
|
||||
|
||||
@ -685,176 +440,13 @@ pub fn flash_attn_varlen(
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
let window_size_left = None;
|
||||
let window_size_right = if causal { Some(0) } else { None };
|
||||
|
||||
let op = FlashAttnVarLen {
|
||||
softmax_scale,
|
||||
causal,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
seqlens_q: seqlens_q.clone(),
|
||||
seqlens_k: seqlens_k.clone(),
|
||||
alibi_slopes: None,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Flash-attention v2 layer with variable-length batching.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
|
||||
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
|
||||
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
|
||||
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
|
||||
/// * `window_size_left` - Limit left attention to value tokens.
|
||||
/// * `window_size_right` - Limit right attention to value tokens.
|
||||
///
|
||||
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
|
||||
/// `seqlen_1 + seqlen_2`, etc.
|
||||
///
|
||||
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
|
||||
///
|
||||
/// # Causal mask
|
||||
///
|
||||
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
|
||||
/// of `Q @ K^T`
|
||||
pub fn flash_attn_varlen_windowed(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
seqlens_q: &Tensor,
|
||||
seqlens_k: &Tensor,
|
||||
max_seqlen_q: usize,
|
||||
max_seqlen_k: usize,
|
||||
softmax_scale: f32,
|
||||
window_size_left: Option<usize>,
|
||||
window_size_right: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
let op = FlashAttnVarLen {
|
||||
softmax_scale,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
seqlens_q: seqlens_q.clone(),
|
||||
seqlens_k: seqlens_k.clone(),
|
||||
alibi_slopes: None,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Flash-attention v2 layer with variable-length batching.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
|
||||
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
|
||||
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
|
||||
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
|
||||
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
|
||||
///
|
||||
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
|
||||
/// `seqlen_1 + seqlen_2`, etc.
|
||||
///
|
||||
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
|
||||
pub fn flash_attn_varlen_alibi(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
alibi_slopes: &Tensor,
|
||||
seqlens_q: &Tensor,
|
||||
seqlens_k: &Tensor,
|
||||
max_seqlen_q: usize,
|
||||
max_seqlen_k: usize,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
let window_size_left = None;
|
||||
let window_size_right = if causal { Some(0) } else { None };
|
||||
|
||||
let op = FlashAttnVarLen {
|
||||
softmax_scale,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
seqlens_q: seqlens_q.clone(),
|
||||
seqlens_k: seqlens_k.clone(),
|
||||
alibi_slopes: Some(alibi_slopes.clone()),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Flash-attention v2 layer with variable-length batching.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
|
||||
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
|
||||
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
|
||||
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
|
||||
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
|
||||
/// * `window_size_left` - Limit left attention to value tokens.
|
||||
/// * `window_size_right` - Limit right attention to value tokens.
|
||||
///
|
||||
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
|
||||
/// `seqlen_1 + seqlen_2`, etc.
|
||||
///
|
||||
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
|
||||
///
|
||||
/// # Causal mask
|
||||
///
|
||||
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
|
||||
/// of `Q @ K^T`
|
||||
pub fn flash_attn_varlen_alibi_windowed(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
alibi_slopes: &Tensor,
|
||||
seqlens_q: &Tensor,
|
||||
seqlens_k: &Tensor,
|
||||
max_seqlen_q: usize,
|
||||
max_seqlen_k: usize,
|
||||
softmax_scale: f32,
|
||||
window_size_left: Option<usize>,
|
||||
window_size_right: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
let op = FlashAttnVarLen {
|
||||
softmax_scale,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
seqlens_q: seqlens_q.clone(),
|
||||
seqlens_k: seqlens_k.clone(),
|
||||
alibi_slopes: Some(alibi_slopes.clone()),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
@ -12,4 +12,6 @@ license = "MIT OR Apache-2.0"
|
||||
[dependencies]
|
||||
|
||||
[build-dependencies]
|
||||
bindgen_cuda = "0.1.1"
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
glob = "0.3.1"
|
||||
rayon = "1.7.0"
|
||||
|
@ -1,8 +1,243 @@
|
||||
use std::io::Write;
|
||||
|
||||
fn main() {
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
|
||||
let builder = bindgen_cuda::Builder::default();
|
||||
println!("cargo:info={builder:?}");
|
||||
let bindings = builder.build_ptx().unwrap();
|
||||
bindings.write("src/lib.rs").unwrap();
|
||||
cuda::set_include_dir();
|
||||
let (write, kernel_paths) = cuda::build_ptx();
|
||||
if write {
|
||||
let mut file = std::fs::File::create("src/lib.rs").unwrap();
|
||||
for kernel_path in kernel_paths {
|
||||
let name = kernel_path.file_stem().unwrap().to_str().unwrap();
|
||||
file.write_all(
|
||||
format!(
|
||||
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#,
|
||||
name.to_uppercase().replace('.', "_"),
|
||||
name
|
||||
)
|
||||
.as_bytes(),
|
||||
)
|
||||
.unwrap();
|
||||
file.write_all(&[b'\n']).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod cuda {
|
||||
use anyhow::{Context, Result};
|
||||
|
||||
pub fn set_include_dir() {
|
||||
use std::path::PathBuf;
|
||||
// NOTE: copied from cudarc build.rs.
|
||||
// We can't actually set a env!() value from another crate,
|
||||
// so we have to do that here.
|
||||
|
||||
// use PathBuf;
|
||||
|
||||
let env_vars = [
|
||||
"CUDA_PATH",
|
||||
"CUDA_ROOT",
|
||||
"CUDA_TOOLKIT_ROOT_DIR",
|
||||
"CUDNN_LIB",
|
||||
];
|
||||
#[allow(unused)]
|
||||
let env_vars = env_vars
|
||||
.into_iter()
|
||||
.map(std::env::var)
|
||||
.filter_map(Result::ok)
|
||||
.map(Into::<PathBuf>::into);
|
||||
|
||||
let roots = [
|
||||
"/usr",
|
||||
"/usr/local/cuda",
|
||||
"/opt/cuda",
|
||||
"/usr/lib/cuda",
|
||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||
"C:/CUDA",
|
||||
];
|
||||
#[allow(unused)]
|
||||
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||
|
||||
#[cfg(feature = "ci-check")]
|
||||
let root: PathBuf = "ci".into();
|
||||
|
||||
#[cfg(not(feature = "ci-check"))]
|
||||
let root = env_vars
|
||||
.chain(roots)
|
||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||
.unwrap();
|
||||
|
||||
println!(
|
||||
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||
root.join("include").display()
|
||||
);
|
||||
}
|
||||
|
||||
pub fn build_ptx() -> (bool, Vec<std::path::PathBuf>) {
|
||||
use rayon::prelude::*;
|
||||
use std::path::PathBuf;
|
||||
let out_dir = std::env::var("OUT_DIR").unwrap();
|
||||
let kernel_paths: Vec<PathBuf> = glob::glob("src/*.cu")
|
||||
.unwrap()
|
||||
.map(|p| p.unwrap())
|
||||
.collect();
|
||||
let mut include_directories: Vec<PathBuf> = glob::glob("src/**/*.cuh")
|
||||
.unwrap()
|
||||
.map(|p| p.unwrap())
|
||||
.collect();
|
||||
|
||||
println!("cargo:rerun-if-changed=src/");
|
||||
// for path in &kernel_paths {
|
||||
// println!("cargo:rerun-if-changed={}", path.display());
|
||||
// }
|
||||
|
||||
for path in &mut include_directories {
|
||||
// println!("cargo:rerun-if-changed={}", path.display());
|
||||
let destination =
|
||||
std::format!("{out_dir}/{}", path.file_name().unwrap().to_str().unwrap());
|
||||
std::fs::copy(path.clone(), destination).unwrap();
|
||||
// remove the filename from the path so it's just the directory
|
||||
path.pop();
|
||||
}
|
||||
|
||||
include_directories.sort();
|
||||
include_directories.dedup();
|
||||
|
||||
let compute_cap = compute_cap().expect("Could not get Cuda compute cap");
|
||||
|
||||
#[allow(unused)]
|
||||
let include_options: Vec<String> = include_directories
|
||||
.into_iter()
|
||||
.map(|s| "-I".to_string() + &s.into_os_string().into_string().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
|
||||
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
|
||||
let children = kernel_paths
|
||||
.par_iter()
|
||||
.flat_map(|p| {
|
||||
let mut output = p.clone();
|
||||
output.set_extension("ptx");
|
||||
let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap());
|
||||
|
||||
let ignore = if output_filename.exists() {
|
||||
let out_modified = output_filename.metadata().unwrap().modified().unwrap();
|
||||
let in_modified = p.metadata().unwrap().modified().unwrap();
|
||||
out_modified.duration_since(in_modified).is_ok()
|
||||
} else {
|
||||
false
|
||||
};
|
||||
if ignore {
|
||||
None
|
||||
} else {
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("--ptx")
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.args(["--output-directory", &out_dir])
|
||||
// Flash attention only
|
||||
// .arg("--expt-relaxed-constexpr")
|
||||
.args(&include_options);
|
||||
if let Ok(ccbin_path) = &ccbin_env {
|
||||
command
|
||||
.arg("-allow-unsupported-compiler")
|
||||
.args(["-ccbin", ccbin_path]);
|
||||
}
|
||||
command.arg(p);
|
||||
Some((p, command.spawn()
|
||||
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let ptx_paths: Vec<PathBuf> = glob::glob(&format!("{out_dir}/**/*.ptx"))
|
||||
.unwrap()
|
||||
.map(|p| p.unwrap())
|
||||
.collect();
|
||||
// We should rewrite `src/lib.rs` only if there are some newly compiled kernels, or removed
|
||||
// some old ones
|
||||
let write = !children.is_empty() || kernel_paths.len() < ptx_paths.len();
|
||||
for (kernel_path, child) in children {
|
||||
let output = child.expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
}
|
||||
(write, kernel_paths)
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn compute_cap() -> Result<usize> {
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
|
||||
// Try to parse compute caps from env
|
||||
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
|
||||
compute_cap_str
|
||||
.parse::<usize>()
|
||||
.context("Could not parse code")?
|
||||
} else {
|
||||
// Use nvidia-smi to get the current compute cap
|
||||
let out = std::process::Command::new("nvidia-smi")
|
||||
.arg("--query-gpu=compute_cap")
|
||||
.arg("--format=csv")
|
||||
.output()
|
||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||
let mut lines = out.lines();
|
||||
assert_eq!(
|
||||
lines.next().context("missing line in stdout")?,
|
||||
"compute_cap"
|
||||
);
|
||||
let cap = lines
|
||||
.next()
|
||||
.context("missing line in stdout")?
|
||||
.replace('.', "");
|
||||
let cap = cap
|
||||
.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as int {cap}"))?;
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
|
||||
cap
|
||||
};
|
||||
|
||||
// Grab available GPU codes from nvcc and select the highest one
|
||||
let (supported_nvcc_codes, max_nvcc_code) = {
|
||||
let out = std::process::Command::new("nvcc")
|
||||
.arg("--list-gpu-code")
|
||||
.output()
|
||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||
|
||||
let out = out.lines().collect::<Vec<&str>>();
|
||||
let mut codes = Vec::with_capacity(out.len());
|
||||
for code in out {
|
||||
let code = code.split('_').collect::<Vec<&str>>();
|
||||
if !code.is_empty() && code.contains(&"sm") {
|
||||
if let Ok(num) = code[1].parse::<usize>() {
|
||||
codes.push(num);
|
||||
}
|
||||
}
|
||||
}
|
||||
codes.sort();
|
||||
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
|
||||
(codes, max_nvcc_code)
|
||||
};
|
||||
|
||||
// Check that nvcc supports the asked compute caps
|
||||
if !supported_nvcc_codes.contains(&compute_cap) {
|
||||
anyhow::bail!(
|
||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
|
||||
);
|
||||
}
|
||||
if compute_cap > max_nvcc_code {
|
||||
anyhow::bail!(
|
||||
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(compute_cap)
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1 @@
|
||||
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
||||
|
||||
|
@ -9,17 +9,12 @@ keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
|
||||
[dependencies]
|
||||
metal = { version = "0.27.0", features = ["mps"] }
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
||||
[dev-dependencies]
|
||||
half = { version = "2.3.1", features = [
|
||||
"num-traits",
|
||||
"use-intrinsics",
|
||||
"rand_distr",
|
||||
] }
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
rand = "0.8.5"
|
||||
|
@ -17,19 +17,19 @@ METAL_FUNC uint get_strided_index(
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define AFFINE(FN_NAME, T) \
|
||||
#define AFFINE(FN_NAME, TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
constant float &mul, \
|
||||
constant float &add, \
|
||||
device const T *input, \
|
||||
device T *output, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[id] = T(fma(float(input[id]), mul, add)); \
|
||||
output[id] = TYPENAME(float(input[id]) * mul + add); \
|
||||
} \
|
||||
kernel void FN_NAME##_strided( \
|
||||
constant size_t &dim, \
|
||||
@ -38,14 +38,14 @@ kernel void FN_NAME##_strided( \
|
||||
constant size_t *strides, \
|
||||
constant float &mul, \
|
||||
constant float &add, \
|
||||
device const T *input, \
|
||||
device T *output, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[id] = T(fma(float(input[get_strided_index(id, num_dims, dims, strides)]), mul, add)); \
|
||||
output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \
|
||||
}
|
||||
|
||||
#define POWF(FN_NAME, TYPENAME) \
|
||||
@ -117,7 +117,7 @@ ELU(elu_f32, float)
|
||||
ELU(elu_f16, half)
|
||||
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
AFFINE(affine_bf16, bfloat);
|
||||
POWF(powf_bf16, bfloat);
|
||||
ELU(elu_bf16, bfloat);
|
||||
|
@ -73,7 +73,7 @@ BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
|
||||
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
||||
|
||||
#define INT64_BINARY_OP_OUT(NAME, FN) \
|
||||
BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided);
|
||||
BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided);
|
||||
|
||||
BINARY_OP(x + y, add)
|
||||
BINARY_OP(x - y, sub)
|
||||
@ -105,7 +105,7 @@ INT64_BINARY_OP_OUT(ge, x >= y)
|
||||
INT64_BINARY_OP_OUT(gt, x > y)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_BINARY_OP(x + y, add)
|
||||
BFLOAT_BINARY_OP(x - y, sub)
|
||||
BFLOAT_BINARY_OP(x * y, mul)
|
||||
|
@ -28,7 +28,7 @@ kernel void FN_NAME( \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(input[tid]); \
|
||||
output[tid] = RIGHT_TYPENAME(input[tid]); \
|
||||
} \
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
@ -42,34 +42,7 @@ kernel void FN_NAME_STRIDED( \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)]); \
|
||||
} \
|
||||
|
||||
#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[tid])); \
|
||||
} \
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
|
||||
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
|
||||
} \
|
||||
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||
@ -85,14 +58,7 @@ CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
|
||||
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
|
||||
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
|
||||
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
|
||||
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
|
||||
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
|
||||
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
|
||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||
#endif
|
||||
#endif
|
||||
|
@ -173,10 +173,7 @@ SCATTER_ADD_OP(sa_u32_f32, uint, float)
|
||||
SCATTER_ADD_OP(sa_u32_f16, uint, half)
|
||||
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
|
||||
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
||||
|
@ -12,9 +12,8 @@ const UNARY: &str = include_str!("unary.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
|
||||
@ -63,12 +62,10 @@ macro_rules! primitive {
|
||||
}
|
||||
};
|
||||
}
|
||||
primitive!(bool);
|
||||
primitive!(usize);
|
||||
primitive!(i32);
|
||||
primitive!(i64);
|
||||
primitive!(i32);
|
||||
primitive!(u32);
|
||||
primitive!(u64);
|
||||
primitive!(f32);
|
||||
|
||||
impl<T> EncoderParam for &[T] {
|
||||
@ -123,7 +120,6 @@ pub enum Source {
|
||||
Reduce,
|
||||
Mfa,
|
||||
Conv,
|
||||
Random,
|
||||
Quantized,
|
||||
}
|
||||
|
||||
@ -182,8 +178,8 @@ macro_rules! ops{
|
||||
|
||||
pub mod unary {
|
||||
ops!(
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||
tanh, recip
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
|
||||
recip
|
||||
);
|
||||
}
|
||||
pub mod binary {
|
||||
@ -223,6 +219,7 @@ type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipeline
|
||||
pub struct Kernels {
|
||||
libraries: RwLock<Libraries>,
|
||||
pipelines: RwLock<Pipelines>,
|
||||
// fence: metal::Fence,
|
||||
}
|
||||
|
||||
impl Kernels {
|
||||
@ -232,6 +229,7 @@ impl Kernels {
|
||||
Self {
|
||||
libraries,
|
||||
pipelines,
|
||||
// fence,
|
||||
}
|
||||
}
|
||||
|
||||
@ -245,7 +243,6 @@ impl Kernels {
|
||||
Source::Cast => CAST,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Conv => CONV,
|
||||
Source::Random => RANDOM,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
@ -353,6 +350,7 @@ pub fn call_unary_contiguous(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, output));
|
||||
@ -361,6 +359,7 @@ pub fn call_unary_contiguous(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -382,6 +381,7 @@ pub fn call_unary_strided(
|
||||
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
@ -403,6 +403,7 @@ pub fn call_unary_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -421,6 +422,7 @@ pub fn call_binary_contiguous(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, left, right, output));
|
||||
@ -431,6 +433,7 @@ pub fn call_binary_contiguous(
|
||||
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -455,6 +458,7 @@ pub fn call_binary_strided(
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let width: usize = shape.iter().product();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
@ -479,6 +483,7 @@ pub fn call_binary_strided(
|
||||
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -497,6 +502,7 @@ pub fn call_cast_contiguous(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, (input, input_offset), output));
|
||||
@ -505,6 +511,7 @@ pub fn call_cast_contiguous(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -524,6 +531,7 @@ pub fn call_cast_strided(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
@ -545,6 +553,7 @@ pub fn call_cast_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -564,6 +573,7 @@ pub fn call_reduce_contiguous(
|
||||
let elements_to_sum = length / out_length;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -592,6 +602,7 @@ pub fn call_reduce_contiguous(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -613,6 +624,7 @@ pub fn call_reduce_strided(
|
||||
let elements_to_sum = length / out_length;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -623,8 +635,7 @@ pub fn call_reduce_strided(
|
||||
strides,
|
||||
elements_to_sum,
|
||||
(input, input_offset),
|
||||
output,
|
||||
out_length
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
@ -649,6 +660,7 @@ pub fn call_reduce_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -667,6 +679,7 @@ pub fn call_last_softmax(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -697,6 +710,7 @@ pub fn call_last_softmax(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -716,6 +730,7 @@ pub fn call_affine(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, add, input, output));
|
||||
@ -724,6 +739,7 @@ pub fn call_affine(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -746,6 +762,7 @@ pub fn call_affine_strided(
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -766,6 +783,7 @@ pub fn call_affine_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -784,6 +802,7 @@ pub fn call_powf(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, input, output));
|
||||
@ -792,6 +811,7 @@ pub fn call_powf(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -813,6 +833,7 @@ pub fn call_powf_strided(
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -832,6 +853,7 @@ pub fn call_powf_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -850,6 +872,7 @@ pub fn call_elu(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, input, output));
|
||||
@ -858,6 +881,7 @@ pub fn call_elu(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -879,6 +903,7 @@ pub fn call_elu_strided(
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -898,6 +923,7 @@ pub fn call_elu_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -919,6 +945,7 @@ pub fn call_where_cond_strided(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let size: usize = shape.iter().product();
|
||||
@ -947,6 +974,7 @@ pub fn call_where_cond_strided(
|
||||
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -973,6 +1001,7 @@ pub fn call_index_select(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -995,6 +1024,7 @@ pub fn call_index_select(
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -1023,6 +1053,7 @@ pub fn call_gather(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -1045,6 +1076,7 @@ pub fn call_gather(
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -1073,6 +1105,7 @@ pub fn call_scatter_add(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -1095,6 +1128,7 @@ pub fn call_scatter_add(
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -1124,6 +1158,7 @@ pub fn call_index_add(
|
||||
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -1147,6 +1182,7 @@ pub fn call_index_add(
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -1350,6 +1386,7 @@ pub fn call_gemm(
|
||||
let block_bytes = block_elements * bytes;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||
@ -1393,6 +1430,7 @@ pub fn call_gemm(
|
||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
@ -1417,6 +1455,7 @@ pub fn call_im2col1d_strided(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
@ -1436,6 +1475,7 @@ pub fn call_im2col1d_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
@ -1465,6 +1505,7 @@ pub fn call_im2col_strided(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
@ -1486,6 +1527,7 @@ pub fn call_im2col_strided(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
@ -1511,6 +1553,7 @@ pub fn call_upsample_nearest_2d(
|
||||
let scale_h = shape[3] as f32 / out_h as f32;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
@ -1528,73 +1571,7 @@ pub fn call_upsample_nearest_2d(
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_random_uniform(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
min: f32,
|
||||
max: f32,
|
||||
length: usize,
|
||||
seed: &Buffer,
|
||||
buffer: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
if min >= max {
|
||||
return Err(MetalKernelError::LoadLibraryError(
|
||||
"min must be less than max".to_string(),
|
||||
));
|
||||
}
|
||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let odd = (length % 2 != 0) as usize;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, min, max, seed, buffer));
|
||||
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_random_normal(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
mean: f32,
|
||||
stddev: f32,
|
||||
length: usize,
|
||||
seed: &Buffer,
|
||||
buffer: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let odd = (length % 2 != 0) as usize;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, mean, stddev, seed, buffer));
|
||||
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
@ -1733,6 +1710,7 @@ pub fn call_quantized_matmul_t(
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
//encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
@ -1765,6 +1743,7 @@ pub fn call_quantized_matmul_t(
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
|
||||
//encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
|
@ -1,206 +0,0 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_integer>
|
||||
#include <metal_atomic>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// Constants
|
||||
// 2^32 and 1/2^32. Useful for converting between float and uint.
|
||||
static constexpr constant ulong UNIF01_NORM32 = 4294967296;
|
||||
static constexpr constant float UNIF01_INV32 = 2.328306436538696289e-10;
|
||||
// 2 * pi
|
||||
static constexpr constant float TWO_PI = 2.0 * M_PI_F;
|
||||
static constexpr constant int3 S1 = {13, 19, 12};
|
||||
static constexpr constant int3 S2 = {2, 25, 4};
|
||||
static constexpr constant int3 S3 = {3, 11, 17};
|
||||
|
||||
// Used to prevent bad seeds.
|
||||
static constexpr constant uint64_t PHI[16] = {
|
||||
0x9E3779B97F4A7C15,
|
||||
0xF39CC0605CEDC834,
|
||||
0x1082276BF3A27251,
|
||||
0xF86C6A11D0C18E95,
|
||||
0x2767F0B153D27B7F,
|
||||
0x0347045B5BF1827F,
|
||||
0x01886F0928403002,
|
||||
0xC1D64BA40F335E36,
|
||||
0xF06AD7AE9717877E,
|
||||
0x85839D6EFFBD7DC6,
|
||||
0x64D325D1C5371682,
|
||||
0xCADD0CCCFDFFBBE1,
|
||||
0x626E33B8D04B4331,
|
||||
0xBBF73C790D94F79D,
|
||||
0x471C4AB3ED3D82A5,
|
||||
0xFEC507705E4AE6E5,
|
||||
};
|
||||
|
||||
// Combined Tausworthe and LCG Random Number Generator.
|
||||
// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application
|
||||
// https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf
|
||||
struct HybridTaus {
|
||||
|
||||
float state;
|
||||
|
||||
HybridTaus() thread = default;
|
||||
HybridTaus() threadgroup = default;
|
||||
HybridTaus() device = default;
|
||||
HybridTaus() constant = default;
|
||||
|
||||
// Generate seeds for each thread.
|
||||
METAL_FUNC static uint4 seed_per_thread(const ulong4 seeds) {
|
||||
return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL));
|
||||
}
|
||||
|
||||
// Tausworthe generator.
|
||||
METAL_FUNC static uint taus(const uint z, const int3 s, const uint M) {
|
||||
uint b = (((z << s.x) ^ z) >> s.y);
|
||||
return (((z & M) << s.z) ^ b);
|
||||
}
|
||||
|
||||
// LCG generator.
|
||||
METAL_FUNC static uint lcg(const uint z) {
|
||||
return (1664525 * z + 1013904223UL);
|
||||
}
|
||||
|
||||
// Initialize the RNG state.
|
||||
METAL_FUNC static HybridTaus init(const ulong4 seeds) {
|
||||
uint4 seed = seed_per_thread(seeds);
|
||||
|
||||
// Seed #1
|
||||
uint z1 = taus(seed.x, S1, 4294967294UL);
|
||||
uint z2 = taus(seed.y, S2, 4294967288UL);
|
||||
uint z3 = taus(seed.z, S3, 4294967280UL);
|
||||
uint z4 = lcg(seed.x);
|
||||
|
||||
// Seed #2
|
||||
uint r1 = (z1^z2^z3^z4^seed.y);
|
||||
z1 = taus(r1, S1, 429496729UL);
|
||||
z2 = taus(r1, S2, 4294967288UL);
|
||||
z3 = taus(r1, S3, 429496280UL);
|
||||
z4 = lcg(r1);
|
||||
|
||||
// Seed #3
|
||||
r1 = (z1^z2^z3^z4^seed.z);
|
||||
z1 = taus(r1, S1, 429496729UL);
|
||||
z2 = taus(r1, S2, 4294967288UL);
|
||||
z3 = taus(r1, S3, 429496280UL);
|
||||
z4 = lcg(r1);
|
||||
|
||||
// Seed #4
|
||||
r1 = (z1^z2^z3^z4^seed.w);
|
||||
z1 = taus(r1, S1, 429496729UL);
|
||||
z2 = taus(r1, S2, 4294967288UL);
|
||||
z3 = taus(r1, S3, 429496280UL);
|
||||
z4 = lcg(r1);
|
||||
|
||||
HybridTaus rng;
|
||||
rng.state = (z1^z2^z3^z4) * UNIF01_INV32;
|
||||
return rng;
|
||||
}
|
||||
|
||||
METAL_FUNC float rand() {
|
||||
uint seed = this->state * UNIF01_NORM32;
|
||||
uint z1 = taus(seed, S1, 429496729UL);
|
||||
uint z2 = taus(seed, S2, 4294967288UL);
|
||||
uint z3 = taus(seed, S3, 429496280UL);
|
||||
uint z4 = lcg(seed);
|
||||
|
||||
thread float result = this->state;
|
||||
this->state = (z1^z2^z3^z4) * UNIF01_INV32;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T> METAL_FUNC void rand_uniform(
|
||||
constant size_t &size,
|
||||
constant float &min,
|
||||
constant float &max,
|
||||
device atomic_uint *seed,
|
||||
device T *out,
|
||||
uint tid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (tid >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
float diff = abs(min - max);
|
||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||
out[tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
if (tid == 0) {
|
||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||
// Return early if tid == 0, otherwise we will write to out[size].
|
||||
return;
|
||||
}
|
||||
// Use symmetry to fill the other half of the array.
|
||||
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
}
|
||||
|
||||
// Create Gaussian normal distribution using Box-Muller transform:
|
||||
// https://en.wikipedia.org/wiki/Box–Muller_transform
|
||||
template<typename T> METAL_FUNC void normal(
|
||||
constant size_t &size,
|
||||
constant float &mean,
|
||||
constant float &stddev,
|
||||
device atomic_uint *seed,
|
||||
device T *out,
|
||||
uint tid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (tid >= size) {
|
||||
return;
|
||||
}
|
||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||
float u1 = rng.rand();
|
||||
float u2 = rng.rand();
|
||||
|
||||
float cosval;
|
||||
float sinval = sincos(TWO_PI * u2, cosval);
|
||||
float mag = stddev * sqrt(-2.0 * log(u1));
|
||||
float z0 = mag * cosval + mean;
|
||||
float z1 = mag * sinval + mean;
|
||||
|
||||
out[tid] = static_cast<T>(z0);
|
||||
|
||||
if (tid == 0) {
|
||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||
// Return early if tid == 0, otherwise we will write to out[size].
|
||||
return;
|
||||
}
|
||||
// Use symmetry to fill the other half of the array.
|
||||
out[size - tid] = static_cast<T>(z1);
|
||||
}
|
||||
|
||||
#define UNIFORM_OP(NAME, T) \
|
||||
kernel void rand_uniform_##NAME( \
|
||||
constant size_t &size, \
|
||||
constant float &min, \
|
||||
constant float &max, \
|
||||
device atomic_uint *seed, \
|
||||
device T *out, \
|
||||
uint tid [[thread_position_in_grid]] \
|
||||
) { \
|
||||
rand_uniform<T>(size, min, max, seed, out, tid); \
|
||||
} \
|
||||
|
||||
#define NORMAL_OP(NAME, T) \
|
||||
kernel void rand_normal_##NAME( \
|
||||
constant size_t &size, \
|
||||
constant float &mean, \
|
||||
constant float &stddev, \
|
||||
device atomic_uint *seed, \
|
||||
device T *out, \
|
||||
uint tid [[thread_position_in_grid]] \
|
||||
) { \
|
||||
normal<T>(size, mean, stddev, seed, out, tid); \
|
||||
} \
|
||||
|
||||
|
||||
#define RANDOM_OPS(NAME, T) \
|
||||
UNIFORM_OP(NAME, T) \
|
||||
NORMAL_OP(NAME, T) \
|
||||
|
||||
RANDOM_OPS(f32, float)
|
||||
RANDOM_OPS(f16, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
RANDOM_OPS(bf16, bfloat)
|
||||
#endif
|
@ -1,18 +1,16 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_limits>
|
||||
using namespace metal;
|
||||
|
||||
// TODO: Load multiple values per thread to improve memory bandwidth utilization
|
||||
// static constant constexpr uint VALUES_PER_THREAD = 1;
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant const size_t &num_dims,
|
||||
constant const size_t *dims,
|
||||
constant const size_t *strides
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
@ -21,637 +19,288 @@ METAL_FUNC uint get_strided_index(
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
template <typename V>
|
||||
struct Indexed {
|
||||
uint i;
|
||||
V val;
|
||||
typedef V type;
|
||||
|
||||
constexpr Indexed<V>() thread = default;
|
||||
constexpr Indexed<V>() threadgroup = default;
|
||||
constexpr Indexed<V>() device = default;
|
||||
constexpr Indexed<V>() constant = default;
|
||||
|
||||
constexpr Indexed<V>(uint _i, V _val) : i(_i), val(_val) {}
|
||||
|
||||
template <typename U, typename = typename enable_if<is_convertible_v<U, V>>::type>
|
||||
constexpr Indexed<V>(uint _i, U _val) : i(_i), val(static_cast<U>(_val)) {}
|
||||
|
||||
template <typename U>
|
||||
constexpr Indexed<V>(const thread Indexed<U> &iv): Indexed<V>(iv.i, iv.val) {}
|
||||
|
||||
template <typename U>
|
||||
constexpr Indexed<V>(const threadgroup Indexed<V> &iv): Indexed<V>(iv.i, iv.val) {}
|
||||
|
||||
Indexed<V> operator=(const thread Indexed<V> &iv) thread {
|
||||
this->i = iv.i;
|
||||
this->val = iv.val;
|
||||
return *this;
|
||||
}
|
||||
Indexed<V> operator=(const thread Indexed<V> &iv) threadgroup {
|
||||
this->i = iv.i;
|
||||
this->val = iv.val;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename V>
|
||||
constexpr METAL_FUNC bool operator<(Indexed<V> lhs, Indexed<V> rhs) {
|
||||
return lhs.val < rhs.val || (lhs.val == rhs.val && lhs.i < rhs.i);
|
||||
}
|
||||
|
||||
template<typename V>
|
||||
constexpr METAL_FUNC bool operator>(Indexed<V> lhs, Indexed<V> rhs) {
|
||||
return lhs.val > rhs.val || (lhs.val == rhs.val && lhs.i > rhs.i);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
struct _numeric_limits_impl<Indexed<T>> {
|
||||
static constexpr Indexed<T> lowest() {
|
||||
return Indexed<T>(0, numeric_limits<T>::lowest());
|
||||
}
|
||||
|
||||
static constexpr Indexed<T> max() {
|
||||
return Indexed<T>(0, numeric_limits<T>::max());
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
// Metal does not have simd_shuffle_down for bfloat16
|
||||
// TODO: Check if volatile threadgroup memory reduction is faster than simd_shuffle_down for bfloat
|
||||
bfloat simd_shuffle_down(bfloat value, ushort delta) {
|
||||
return static_cast<bfloat>(__metal_simd_shuffle_down(static_cast<float>(value), delta));
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename V>
|
||||
Indexed<V> simd_shuffle_down(Indexed<V> iv, ushort delta) {
|
||||
return Indexed<V>(
|
||||
simd_shuffle_down(iv.i, delta),
|
||||
simd_shuffle_down(iv.val, delta)
|
||||
);
|
||||
}
|
||||
|
||||
#define impl_reduction_op_helper(name, op, init_val, __result_type__) \
|
||||
template<typename T, typename R = __result_type__> \
|
||||
struct name { \
|
||||
static constexpr T init() { \
|
||||
return init_val; \
|
||||
} \
|
||||
METAL_FUNC R operator()(T a, T b) { \
|
||||
return op; \
|
||||
} \
|
||||
METAL_FUNC R operator()(thread const T& a, thread const T& b) const { \
|
||||
return op; \
|
||||
} \
|
||||
METAL_FUNC R operator()(threadgroup const T& a, threadgroup const T& b) const { \
|
||||
return op; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define impl_reduction_op(name, op, init_val) \
|
||||
impl_reduction_op_helper(name, op, init_val, T);
|
||||
|
||||
#define impl_arg_reduction_op(name, op, init_val) \
|
||||
impl_reduction_op_helper(name, op, init_val, tuple<bool, Indexed<T>>);
|
||||
|
||||
impl_reduction_op(Sum, a + b, 0);
|
||||
impl_reduction_op(Mul, a * b, 1);
|
||||
impl_reduction_op(Min, a < b ? a : b, numeric_limits<T>::max());
|
||||
impl_reduction_op(Max, a > b ? a : b, numeric_limits<T>::lowest());
|
||||
#undef impl_reduction_op
|
||||
|
||||
// These are used when loading elements from global memory into shared memory.
|
||||
// They let us use the same code for both indexed and non-indexed types.
|
||||
template<typename Op, typename T, typename U>
|
||||
METAL_FUNC T apply_operator(Op op, size_t _idx, T a, U b) {
|
||||
return op(a, static_cast<T>(b));
|
||||
}
|
||||
|
||||
template<typename Op, typename T, typename U>
|
||||
METAL_FUNC Indexed<T> apply_operator(Op op, size_t idx, Indexed<T> a, U b) {
|
||||
return op(a, Indexed<T>(idx, b));
|
||||
}
|
||||
|
||||
// Load elements from global memory into shared memory.
|
||||
// Handles both indexed and non-indexed types by using apply_operator.
|
||||
template<
|
||||
typename T,
|
||||
typename R,
|
||||
typename ReductionOp,
|
||||
ushort BLOCKSIZE,
|
||||
bool STRIDED = false
|
||||
>
|
||||
METAL_FUNC R load_from_global(
|
||||
R value,
|
||||
constant size_t &num_elements,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
const device T *src,
|
||||
const ushort offset,
|
||||
threadgroup R shared[BLOCKSIZE],
|
||||
const ushort tid
|
||||
) {
|
||||
ReductionOp op;
|
||||
|
||||
size_t stop_idx = offset + el_to_sum_per_block;
|
||||
size_t idx = offset + tid;
|
||||
|
||||
while (idx < stop_idx) {
|
||||
if (STRIDED) {
|
||||
idx = get_strided_index(idx, num_dims, dims, strides);
|
||||
}
|
||||
value = apply_operator(op, idx, value, src[idx]);
|
||||
idx += BLOCKSIZE;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
constant int THREADGROUP_SIZE = 2048;
|
||||
|
||||
|
||||
// Convenience function for when we don't need to sum over multiple dimensions.
|
||||
template<
|
||||
typename T,
|
||||
typename R,
|
||||
typename ReductionOp,
|
||||
ushort BLOCKSIZE
|
||||
>
|
||||
METAL_FUNC R load_from_global(
|
||||
R value,
|
||||
constant size_t &num_elements,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
const device T *src,
|
||||
const size_t offset,
|
||||
threadgroup R shared[BLOCKSIZE],
|
||||
const ushort tid
|
||||
) {
|
||||
return load_from_global<T, R, ReductionOp, BLOCKSIZE, false>(
|
||||
value,
|
||||
num_elements,
|
||||
// Dummy values for num_dims, dims, and strides
|
||||
num_elements,
|
||||
nullptr,
|
||||
nullptr,
|
||||
// end dummy values
|
||||
el_to_sum_per_block,
|
||||
src,
|
||||
offset,
|
||||
shared,
|
||||
tid
|
||||
);
|
||||
}
|
||||
|
||||
// Since we are using simd_shuffle_down with a BLOCKSIZE guard we don't need any barriers.
|
||||
template<typename ReductionOp, ushort BLOCKSIZE, typename T>
|
||||
METAL_FUNC T simdgroup_reduce(T value) {
|
||||
ReductionOp op;
|
||||
if (BLOCKSIZE >= 32) value = op(value, simd_shuffle_down(value, 16));
|
||||
if (BLOCKSIZE >= 16) value = op(value, simd_shuffle_down(value, 8));
|
||||
if (BLOCKSIZE >= 8) value = op(value, simd_shuffle_down(value, 4));
|
||||
if (BLOCKSIZE >= 4) value = op(value, simd_shuffle_down(value, 2));
|
||||
if (BLOCKSIZE >= 2) value = op(value, simd_shuffle_down(value, 1));
|
||||
return value;
|
||||
}
|
||||
|
||||
template<
|
||||
typename ReductionOp,
|
||||
ushort BLOCKSIZE,
|
||||
typename T
|
||||
>
|
||||
METAL_FUNC T threadgroup_reduce(
|
||||
threadgroup T shared[BLOCKSIZE],
|
||||
ushort tid [[ thread_index_in_threadgroup ]]
|
||||
) {
|
||||
ReductionOp op;
|
||||
|
||||
// Fully unrolled reduction loop from BLOCKSIZE down to 64.
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint s = BLOCKSIZE / 2; s >= 64; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared[tid] = op(shared[tid], shared[tid + s]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
if (tid < 32) {
|
||||
// Last shared memory reduce can be done without tid < s check.
|
||||
if (BLOCKSIZE >= 64) {
|
||||
shared[tid] = op(shared[tid], shared[tid + 32]);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
// Remaining 32 threads can be reduced with simdgroup_reduce.
|
||||
shared[tid] = simdgroup_reduce<ReductionOp, BLOCKSIZE>(shared[tid]);
|
||||
}
|
||||
|
||||
return shared[tid];
|
||||
}
|
||||
|
||||
// Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris
|
||||
template<
|
||||
typename T,
|
||||
typename R,
|
||||
typename ReductionOp,
|
||||
ushort BLOCKSIZE,
|
||||
bool STRIDED = false
|
||||
>
|
||||
METAL_FUNC void reduce(
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const T *src,
|
||||
device R *dst,
|
||||
constant size_t &num_elements,
|
||||
threadgroup T shared[BLOCKSIZE],
|
||||
ushort tid [[ thread_index_in_threadgroup ]],
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]]
|
||||
) {
|
||||
// Initialize shared memory for current thread to correct value for reduction operation
|
||||
shared[tid] = ReductionOp::init();
|
||||
|
||||
// Calcluate offset for the threadgroup of current thread
|
||||
ushort offset = dst_id * el_to_sum_per_block;
|
||||
R initial = ReductionOp::init();
|
||||
// Load with reduction from global memory into shared memory
|
||||
shared[tid] = load_from_global<T, R, ReductionOp, BLOCKSIZE, STRIDED>(
|
||||
initial,
|
||||
num_elements,
|
||||
num_dims,
|
||||
dims,
|
||||
strides,
|
||||
el_to_sum_per_block,
|
||||
src,
|
||||
offset,
|
||||
shared,
|
||||
tid
|
||||
);
|
||||
// Threadgroup barrier is needed to ensure that all threads have written to shared memory
|
||||
// Memory space is not shared between threadgroups so we can use the mem_none flag for all threadgroup barriers.
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Complete reduction
|
||||
R value = threadgroup_reduce<ReductionOp, BLOCKSIZE>(shared, tid);
|
||||
|
||||
if (tid == 0) dst[dst_id] = value;
|
||||
}
|
||||
#define ARGMIN(NAME, T, MAXVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MAXVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
bool notset = true; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || src[strided_i] < shared_memory[tid]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
|
||||
|
||||
#define reduce_case(OP, T, R, N) \
|
||||
case N: { \
|
||||
threadgroup R shared[N]; \
|
||||
reduce<T, R, OP<R>, N, STRIDED>( \
|
||||
num_dims, \
|
||||
dims, \
|
||||
strides, \
|
||||
el_to_sum_per_block, \
|
||||
src, \
|
||||
dst, \
|
||||
num_elements, \
|
||||
shared, \
|
||||
tid, \
|
||||
dst_id); \
|
||||
break; \
|
||||
}
|
||||
#define ARGMAX(NAME, T, MINVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MINVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
bool notset = true; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || shared_memory[tid] < src[strided_i]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define impl_reduce(OP, NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
constant size_t &num_elements, \
|
||||
ushort tid [[ thread_index_in_threadgroup ]], \
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]], \
|
||||
ushort block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
constant size_t *dims = {}; \
|
||||
constant size_t *strides = {}; \
|
||||
const bool STRIDED = false; \
|
||||
switch (block_dim) { \
|
||||
reduce_case(OP, T, T, 2048); \
|
||||
reduce_case(OP, T, T, 1024); \
|
||||
reduce_case(OP, T, T, 512); \
|
||||
reduce_case(OP, T, T, 256); \
|
||||
reduce_case(OP, T, T, 128); \
|
||||
reduce_case(OP, T, T, 64); \
|
||||
reduce_case(OP, T, T, 32); \
|
||||
reduce_case(OP, T, T, 16); \
|
||||
reduce_case(OP, T, T, 8); \
|
||||
reduce_case(OP, T, T, 4); \
|
||||
reduce_case(OP, T, T, 2); \
|
||||
reduce_case(OP, T, T, 1); \
|
||||
} \
|
||||
} \
|
||||
kernel void NAME##_strided( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
constant size_t &num_elements, \
|
||||
ushort tid [[ thread_index_in_threadgroup ]], \
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]], \
|
||||
ushort block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
const bool STRIDED = true; \
|
||||
switch (block_dim) { \
|
||||
reduce_case(OP, T, T, 2048); \
|
||||
reduce_case(OP, T, T, 1024); \
|
||||
reduce_case(OP, T, T, 512); \
|
||||
reduce_case(OP, T, T, 256); \
|
||||
reduce_case(OP, T, T, 128); \
|
||||
reduce_case(OP, T, T, 64); \
|
||||
reduce_case(OP, T, T, 32); \
|
||||
reduce_case(OP, T, T, 16); \
|
||||
reduce_case(OP, T, T, 8); \
|
||||
reduce_case(OP, T, T, 4); \
|
||||
reduce_case(OP, T, T, 2); \
|
||||
reduce_case(OP, T, T, 1); \
|
||||
} \
|
||||
}
|
||||
#define REDUCE(FN, NAME, T, START) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = START; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = src[strided_i]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = shared_memory[tid + s]; \
|
||||
shared_memory[tid] = FN; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
} \
|
||||
|
||||
template<
|
||||
typename T,
|
||||
typename ReductionOp,
|
||||
ushort BLOCKSIZE,
|
||||
bool STRIDED
|
||||
>
|
||||
METAL_FUNC void reduce(
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const T *src,
|
||||
device uint *dst,
|
||||
constant size_t &num_elements,
|
||||
threadgroup Indexed<T> shared[BLOCKSIZE],
|
||||
ushort tid [[ thread_index_in_threadgroup ]],
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]]
|
||||
) {
|
||||
// Initialize shared memory for current thread to correct value for reduction operation
|
||||
shared[tid] = ReductionOp::init();
|
||||
|
||||
// Calcluate offset for the threadgroup of current thread
|
||||
ushort offset = dst_id * el_to_sum_per_block;
|
||||
Indexed<T> initial = ReductionOp::init();
|
||||
// Load with reduction from global memory into shared memory
|
||||
shared[tid] = load_from_global<T, Indexed<T>, ReductionOp, BLOCKSIZE, STRIDED>(
|
||||
initial,
|
||||
num_elements,
|
||||
num_dims,
|
||||
dims,
|
||||
strides,
|
||||
el_to_sum_per_block,
|
||||
src,
|
||||
offset,
|
||||
shared,
|
||||
tid
|
||||
);
|
||||
// Threadgroup barrier is needed to ensure that all threads have written to shared memory
|
||||
// Memory space is not shared between threadgroups so we can use the mem_none flag for all threadgroup barriers.
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
\
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = -INFINITY; \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
\
|
||||
float tmp = -INFINITY; \
|
||||
while (idx < stop_idx) { \
|
||||
tmp = MAX(tmp, float(src[idx])); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
shared_memory[tid] = tmp; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
/* wait for shared_memory[0] to be filled */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float _max = shared_memory[0]; \
|
||||
\
|
||||
/* prevent tid=0 from overwriting _max before other threads have written it */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
shared_memory[tid] = 0; \
|
||||
\
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
const float val = exp(float(src[idx]) - _max); \
|
||||
dst[idx] = T(val); \
|
||||
shared_memory[tid] += val; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] += shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
const T inv_acc = T(1.0/shared_memory[0]); \
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
dst[idx] *= inv_acc; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
} \
|
||||
|
||||
// Complete reduction
|
||||
Indexed<T> value = threadgroup_reduce<ReductionOp, BLOCKSIZE, Indexed<T>>(shared, tid);
|
||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
||||
REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
|
||||
REDUCE(x * y, fast_mul_f32_strided, float, 1)
|
||||
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
|
||||
REDUCE(x * y, fast_mul_f16_strided, half, 1)
|
||||
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
|
||||
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
|
||||
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
|
||||
REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
|
||||
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
|
||||
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
|
||||
REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
|
||||
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
||||
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
||||
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
||||
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
|
||||
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
||||
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
||||
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
||||
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
|
||||
|
||||
// Return index of reduce result
|
||||
if (tid == 0) dst[dst_id] = value.i;
|
||||
}
|
||||
|
||||
#define arg_reduce_case(OP, T, N) \
|
||||
case N: { \
|
||||
threadgroup Indexed<T> shared[N]; \
|
||||
reduce<T, OP<Indexed<T>>, N, STRIDED>( \
|
||||
num_dims, \
|
||||
dims, \
|
||||
strides, \
|
||||
el_to_sum_per_block, \
|
||||
src, \
|
||||
dst, \
|
||||
num_elements, \
|
||||
shared, \
|
||||
tid, \
|
||||
dst_id); \
|
||||
break; \
|
||||
}
|
||||
|
||||
#define impl_arg_reduce(OP, NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
constant size_t &num_elements, \
|
||||
ushort tid [[ thread_index_in_threadgroup ]], \
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]], \
|
||||
ushort block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
constant size_t *dims = {}; \
|
||||
constant size_t *strides = {}; \
|
||||
const bool STRIDED = false; \
|
||||
switch (block_dim) { \
|
||||
arg_reduce_case(OP, T, 2048); \
|
||||
arg_reduce_case(OP, T, 1024); \
|
||||
arg_reduce_case(OP, T, 512); \
|
||||
arg_reduce_case(OP, T, 256); \
|
||||
arg_reduce_case(OP, T, 128); \
|
||||
arg_reduce_case(OP, T, 64); \
|
||||
arg_reduce_case(OP, T, 32); \
|
||||
arg_reduce_case(OP, T, 16); \
|
||||
arg_reduce_case(OP, T, 8); \
|
||||
arg_reduce_case(OP, T, 4); \
|
||||
arg_reduce_case(OP, T, 2); \
|
||||
arg_reduce_case(OP, T, 1); \
|
||||
} \
|
||||
} \
|
||||
kernel void NAME##_strided( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
constant size_t &num_elements, \
|
||||
ushort tid [[ thread_index_in_threadgroup ]], \
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]], \
|
||||
ushort block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
const bool STRIDED = true; \
|
||||
switch (block_dim) { \
|
||||
arg_reduce_case(OP, T, 2048); \
|
||||
arg_reduce_case(OP, T, 1024); \
|
||||
arg_reduce_case(OP, T, 512); \
|
||||
arg_reduce_case(OP, T, 256); \
|
||||
arg_reduce_case(OP, T, 128); \
|
||||
arg_reduce_case(OP, T, 64); \
|
||||
arg_reduce_case(OP, T, 32); \
|
||||
arg_reduce_case(OP, T, 16); \
|
||||
arg_reduce_case(OP, T, 8); \
|
||||
arg_reduce_case(OP, T, 4); \
|
||||
arg_reduce_case(OP, T, 2); \
|
||||
arg_reduce_case(OP, T, 1); \
|
||||
} \
|
||||
}
|
||||
|
||||
template<
|
||||
typename T,
|
||||
typename ACC = float,
|
||||
ushort BLOCKSIZE
|
||||
>
|
||||
METAL_FUNC void softmax(
|
||||
constant size_t &src_numel,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
const device T *src,
|
||||
device T *dst,
|
||||
threadgroup ACC shared[BLOCKSIZE],
|
||||
|
||||
ushort tid [[ thread_index_in_threadgroup ]],
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]]
|
||||
) {
|
||||
// Initialize shared memory for current thread to lowest value
|
||||
shared[tid] = numeric_limits<ACC>::lowest();
|
||||
|
||||
// Calcluate offset for the threadgroup of current thread
|
||||
size_t offset = dst_id * el_to_sum_per_block;
|
||||
ACC initial = numeric_limits<ACC>::lowest();
|
||||
// Load with reduction from global memory into shared memory
|
||||
shared[tid] = load_from_global<T, ACC, Max<ACC>, BLOCKSIZE>(
|
||||
initial,
|
||||
src_numel,
|
||||
el_to_sum_per_block,
|
||||
src,
|
||||
offset,
|
||||
shared,
|
||||
tid
|
||||
);
|
||||
// Threadgroup barrier is needed to ensure that all threads have written to shared memory
|
||||
// Memory space is not shared between threadgroups so we can use the mem_none flag for all threadgroup barriers.
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Reduce shared memory to find max value
|
||||
threadgroup_reduce<Max<ACC>, BLOCKSIZE>(shared, tid);
|
||||
ACC max_result = shared[0];
|
||||
|
||||
// Ensure all threads have max_result = shared[0] before we set shared[0] = 0.
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
shared[tid] = 0;
|
||||
|
||||
// Calculate softmax values
|
||||
size_t stop_idx = min(offset + el_to_sum_per_block, src_numel);
|
||||
size_t idx = offset + tid;
|
||||
while (idx < stop_idx) {
|
||||
const ACC val = exp(ACC(src[idx]) - max_result);
|
||||
dst[idx] = T(val);
|
||||
shared[tid] += val;
|
||||
idx += BLOCKSIZE;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
threadgroup_reduce<Sum<ACC>, BLOCKSIZE>(shared, tid);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
const T inv_acc = T(1.0/shared[0]);
|
||||
idx = offset + tid;
|
||||
while (idx < stop_idx) {
|
||||
dst[idx] *= inv_acc;
|
||||
idx += BLOCKSIZE;
|
||||
}
|
||||
}
|
||||
|
||||
#define softmax_case(T, ACC, N) \
|
||||
case N: { \
|
||||
threadgroup ACC shared[N]; \
|
||||
softmax<T, ACC, N>( \
|
||||
src_numel, \
|
||||
el_to_sum_per_block, \
|
||||
src, \
|
||||
dst, \
|
||||
shared, \
|
||||
tid, \
|
||||
dst_id); \
|
||||
break; \
|
||||
}
|
||||
|
||||
#define impl_softmax(NAME, T, ACC) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
\
|
||||
ushort tid [[ thread_index_in_threadgroup ]], \
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]], \
|
||||
ushort block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
switch (block_dim) { \
|
||||
softmax_case(T, ACC, 2048); \
|
||||
softmax_case(T, ACC, 1024); \
|
||||
softmax_case(T, ACC, 512); \
|
||||
softmax_case(T, ACC, 256); \
|
||||
softmax_case(T, ACC, 128); \
|
||||
softmax_case(T, ACC, 64); \
|
||||
softmax_case(T, ACC, 32); \
|
||||
softmax_case(T, ACC, 16); \
|
||||
softmax_case(T, ACC, 8); \
|
||||
softmax_case(T, ACC, 4); \
|
||||
softmax_case(T, ACC, 2); \
|
||||
softmax_case(T, ACC, 1); \
|
||||
} \
|
||||
}
|
||||
|
||||
impl_reduce(Sum, fast_sum_f32, float)
|
||||
impl_reduce(Sum, fast_sum_u32, uint)
|
||||
impl_reduce(Sum, fast_sum_f16, half)
|
||||
impl_reduce(Sum, fast_sum_u8, uint8_t)
|
||||
|
||||
impl_reduce(Mul, fast_mul_f32, float)
|
||||
impl_reduce(Mul, fast_mul_u32, uint)
|
||||
impl_reduce(Mul, fast_mul_f16, half)
|
||||
impl_reduce(Mul, fast_mul_u8, uint8_t)
|
||||
|
||||
impl_reduce(Max, fast_max_f32, float)
|
||||
impl_reduce(Max, fast_max_u32, uint)
|
||||
impl_reduce(Max, fast_max_f16, half)
|
||||
impl_reduce(Max, fast_max_u8, uint8_t)
|
||||
|
||||
impl_reduce(Min, fast_min_f32, float)
|
||||
impl_reduce(Min, fast_min_u32, uint)
|
||||
impl_reduce(Min, fast_min_f16, half)
|
||||
impl_reduce(Min, fast_min_u8, uint8_t)
|
||||
|
||||
impl_arg_reduce(Min, fast_argmin_f32, float)
|
||||
impl_arg_reduce(Min, fast_argmin_f16, half)
|
||||
impl_arg_reduce(Min, fast_argmin_u32, uint)
|
||||
impl_arg_reduce(Min, fast_argmin_u8, uint8_t)
|
||||
|
||||
impl_arg_reduce(Max, fast_argmax_f32, float)
|
||||
impl_arg_reduce(Max, fast_argmax_f16, half)
|
||||
impl_arg_reduce(Max, fast_argmax_u32, uint)
|
||||
impl_arg_reduce(Max, fast_argmax_u8, uint8_t)
|
||||
|
||||
impl_softmax(softmax_f32, float, float)
|
||||
impl_softmax(softmax_f16, half, float)
|
||||
SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
impl_reduce(Sum, fast_sum_i64, int64_t)
|
||||
impl_reduce(Mul, fast_mul_i64, int64_t)
|
||||
impl_reduce(Min, fast_min_i64, int64_t)
|
||||
impl_reduce(Max, fast_max_i64, int64_t)
|
||||
|
||||
impl_arg_reduce(Min, fast_argmin_i64, int64_t)
|
||||
impl_arg_reduce(Max, fast_argmax_i64, int64_t)
|
||||
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
|
||||
REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
|
||||
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
|
||||
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
impl_reduce(Sum, fast_sum_bf16, bfloat)
|
||||
impl_reduce(Mul, fast_mul_bf16, bfloat)
|
||||
impl_reduce(Max, fast_max_bf16, bfloat)
|
||||
impl_reduce(Min, fast_min_bf16, bfloat)
|
||||
|
||||
impl_arg_reduce(Min, fast_argmin_bf16, bfloat)
|
||||
impl_arg_reduce(Max, fast_argmax_bf16, bfloat)
|
||||
|
||||
impl_softmax(softmax_bf16, bfloat, float)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
||||
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
#endif
|
||||
|
@ -1,346 +0,0 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
constant int THREADGROUP_SIZE = 2048;
|
||||
|
||||
|
||||
#define ARGMIN(NAME, T, MAXVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MAXVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
bool notset = true; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || src[strided_i] < shared_memory[tid]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
|
||||
|
||||
#define ARGMAX(NAME, T, MINVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MINVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
bool notset = true; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || shared_memory[tid] < src[strided_i]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define REDUCE(FN, NAME, T, START) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = START; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = src[strided_i]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = shared_memory[tid + s]; \
|
||||
shared_memory[tid] = FN; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
} \
|
||||
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
\
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = -INFINITY; \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
\
|
||||
float tmp = -INFINITY; \
|
||||
while (idx < stop_idx) { \
|
||||
tmp = MAX(tmp, float(src[idx])); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
shared_memory[tid] = tmp; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
/* wait for shared_memory[0] to be filled */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float _max = shared_memory[0]; \
|
||||
\
|
||||
/* prevent tid=0 from overwriting _max before other threads have written it */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
shared_memory[tid] = 0; \
|
||||
\
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
const float val = exp(float(src[idx]) - _max); \
|
||||
dst[idx] = T(val); \
|
||||
shared_memory[tid] += val; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] += shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
const T inv_acc = T(1.0/shared_memory[0]); \
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
dst[idx] *= inv_acc; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
} \
|
||||
|
||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
||||
REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
|
||||
REDUCE(x * y, fast_mul_f32_strided, float, 1)
|
||||
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
|
||||
REDUCE(x * y, fast_mul_f16_strided, half, 1)
|
||||
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
|
||||
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
|
||||
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
|
||||
REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
|
||||
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
|
||||
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
|
||||
REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
|
||||
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
||||
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
||||
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
||||
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
|
||||
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
||||
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
||||
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
||||
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_f32, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32, uint, 0)
|
||||
REDUCE(x + y, fast_sum_f16, half, 0)
|
||||
REDUCE(x + y, fast_sum_u8, uint8_t, 0)
|
||||
REDUCE(x * y, fast_mul_f32, float, 1)
|
||||
REDUCE(x * y, fast_mul_u32, uint, 1)
|
||||
REDUCE(x * y, fast_mul_f16, half, 1)
|
||||
REDUCE(MAX(x, y), fast_max_f32, float, -HUGE_VALF)
|
||||
REDUCE(MAX(x, y), fast_max_u32, uint, 0)
|
||||
REDUCE(MAX(x, y), fast_max_f16, half, -HUGE_VALH)
|
||||
REDUCE(MAX(x, y), fast_max_u8, uint8_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_f32, float, HUGE_VALF)
|
||||
REDUCE(MIN(x, y), fast_min_u32, uint, 0xFFFFFFFF)
|
||||
REDUCE(MIN(x, y), fast_min_f16, half, HUGE_VALH)
|
||||
REDUCE(MIN(x, y), fast_min_u8, uint8_t, 0xFF)
|
||||
ARGMIN(fast_argmin_f32, float, HUGE_VALF)
|
||||
ARGMIN(fast_argmin_f16, half, HUGE_VALH)
|
||||
ARGMIN(fast_argmin_u32, uint, 0xFFFFFFFF)
|
||||
ARGMIN(fast_argmin_u8, uint8_t, 0xFF)
|
||||
ARGMAX(fast_argmax_f32, float, -HUGE_VALF)
|
||||
ARGMAX(fast_argmax_f16, half, -HUGE_VALH)
|
||||
ARGMAX(fast_argmax_u32, uint, 0)
|
||||
ARGMAX(fast_argmax_u8, uint8_t, 0)
|
||||
|
||||
SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
|
||||
REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
|
||||
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
|
||||
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_i64, int64_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_i64, int64_t, INT_MAX)
|
||||
REDUCE(MAX(x, y), fast_max_i64, int64_t, INT_MIN)
|
||||
ARGMIN(fast_argmin_i64, int64_t, INT_MAX)
|
||||
ARGMAX(fast_argmax_i64, int64_t, INT_MIN)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
REDUCE(x + y, fast_sum_bf16_strided, bfloat, 0)
|
||||
REDUCE(x * y, fast_mul_bf16_strided, bfloat, 1)
|
||||
REDUCE(MAX(x, y), fast_max_bf16_strided, bfloat, -HUGE_VALBF)
|
||||
REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF)
|
||||
ARGMIN(fast_argmin_bf16_strided, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16_strided, bfloat, -HUGE_VALBF)
|
||||
|
||||
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
||||
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
#endif
|
@ -17,45 +17,29 @@ METAL_FUNC uint get_strided_index(
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
template<typename T, typename ID>
|
||||
METAL_FUNC void where_cond(
|
||||
constant size_t &numel,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t *strides_t,
|
||||
constant size_t *strides_f,
|
||||
device const ID *ids,
|
||||
device const T *t,
|
||||
device const T *f,
|
||||
device T *out,
|
||||
uint i [[ thread_position_in_grid ]]
|
||||
) {
|
||||
if (i >= numel){
|
||||
return;
|
||||
}
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t);
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f);
|
||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f];
|
||||
}
|
||||
|
||||
#define WHERE_OP(T, ID, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &numel, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t *strides_t, \
|
||||
constant size_t *strides_f, \
|
||||
device const ID *ids, \
|
||||
device const T *t, \
|
||||
device const T *f, \
|
||||
device T *out, \
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \
|
||||
} \
|
||||
#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &numel, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t *strides_t, \
|
||||
constant size_t *strides_f, \
|
||||
device const ID_TYPENAME *ids, \
|
||||
device const TYPENAME *t, \
|
||||
device const TYPENAME *f, \
|
||||
device TYPENAME *out ,\
|
||||
uint i [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (i >= numel){ \
|
||||
return; \
|
||||
} \
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
|
||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
|
||||
} \
|
||||
|
||||
// WHERE_OP(float, int64_t, where_i64_f32)
|
||||
// WHERE_OP(double, int64_t, where_i64_f64)
|
||||
@ -70,14 +54,10 @@ kernel void FN_NAME(
|
||||
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
||||
|
||||
WHERE_OP(float, uint8_t, where_u8_f32)
|
||||
WHERE_OP(half, uint8_t, where_u8_f16)
|
||||
// WHERE_OP(double, uint8_t, where_u8_f64)
|
||||
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
||||
WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
WHERE_OP(bfloat, uint8_t, where_u8_bf16)
|
||||
#endif
|
@ -1,6 +1,6 @@
|
||||
use super::*;
|
||||
use half::{bf16, f16};
|
||||
use metal::{Buffer, Device, MTLResourceOptions};
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let ptr = buffer.contents() as *const T;
|
||||
@ -11,7 +11,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
|
||||
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let ptr = data.as_ptr() as *const c_void;
|
||||
let ptr = data.as_ptr() as *const core::ffi::c_void;
|
||||
let size = (data.len() * std::mem::size_of::<T>()) as u64;
|
||||
device.new_buffer_with_data(ptr, size, options)
|
||||
}
|
||||
@ -37,7 +37,8 @@ fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
|
||||
|
||||
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
@ -59,7 +60,8 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
|
||||
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
@ -94,7 +96,8 @@ fn run_strided<T: Clone>(
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
let kernels = Kernels::new();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
call_unary_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
@ -245,37 +248,10 @@ fn binary_add_f32() {
|
||||
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_ops_bf16() {
|
||||
let lhs: Vec<bf16> = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect();
|
||||
let rhs: Vec<bf16> = [4.2f32, 5.5f32, 6.91f32]
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect();
|
||||
|
||||
macro_rules! binary_op {
|
||||
($opname:ident, $opexpr:expr) => {{
|
||||
let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT);
|
||||
let expected: Vec<bf16> = lhs
|
||||
.iter()
|
||||
.zip(rhs.iter())
|
||||
.map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y))
|
||||
.collect();
|
||||
assert_eq!(results, expected);
|
||||
}};
|
||||
}
|
||||
|
||||
binary_op!(add, |x, y| x + y);
|
||||
binary_op!(sub, |x, y| x - y);
|
||||
binary_op!(mul, |x, y| x * y);
|
||||
binary_op!(div, |x, y| x / y);
|
||||
binary_op!(min, |x: bf16, y| x.min(y));
|
||||
binary_op!(max, |x: bf16, y| x.max(y));
|
||||
}
|
||||
|
||||
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
@ -320,92 +296,10 @@ fn cast_u32_f32() {
|
||||
assert_eq!(results, vec![1.0f32; 10_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_u32() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<u32> = cast(&input, "cast_bf16_u32");
|
||||
let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_f32() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<f32> = cast(&input, "cast_bf16_f32");
|
||||
let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_u8_bf16() {
|
||||
let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
|
||||
let expected: Vec<bf16> = input
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v as f32))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_u32_bf16() {
|
||||
let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_f32_bf16() {
|
||||
let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_u8() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<u8> = cast(&input, "cast_bf16_u8");
|
||||
let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_f16() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<f16> = cast(&input, "cast_bf16_f16");
|
||||
let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_f16_bf16() {
|
||||
let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
@ -440,7 +334,8 @@ fn run_affine_strided<T: Clone>(
|
||||
add: f64,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
@ -501,14 +396,14 @@ fn index_select() {
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
|
||||
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [2, 5];
|
||||
let ids = [0u32, 1, 0];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||
@ -524,46 +419,20 @@ fn index_select_f16() {
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16");
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(
|
||||
approx_f16(result, 4),
|
||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_is_u32_bf16() {
|
||||
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16");
|
||||
assert_eq!(
|
||||
approx_bf16(result, 4),
|
||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_is_u8_bf16() {
|
||||
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
|
||||
let shape = [5, 2];
|
||||
let ids = [0u8, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16");
|
||||
assert_eq!(
|
||||
approx_bf16(result, 4),
|
||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_dim1() {
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 1, 0];
|
||||
let dim = 1;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
||||
@ -575,7 +444,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
shape: &[usize],
|
||||
ids: &[I],
|
||||
dim: usize,
|
||||
name: &'static str,
|
||||
) -> Vec<T> {
|
||||
let device = Device::system_default().expect("no device found");
|
||||
|
||||
@ -589,7 +457,14 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
let dst_el = ids.len() * left_size * right_size;
|
||||
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
|
||||
let kernels = Kernels::new();
|
||||
let name = match core::mem::size_of::<T>() {
|
||||
4 => "is_u32_f32",
|
||||
2 => "is_u32_f16",
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
call_index_select(
|
||||
&device,
|
||||
&command_buffer,
|
||||
@ -622,18 +497,19 @@ fn cos_f16() {
|
||||
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
|
||||
}
|
||||
|
||||
fn run_reduce<T, U: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<U> {
|
||||
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let output = device.new_buffer((out_length * core::mem::size_of::<U>()) as u64, options);
|
||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||
let dims = vec![v.len()];
|
||||
let strides = vec![1];
|
||||
match call_reduce_strided(
|
||||
call_reduce_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
@ -644,13 +520,8 @@ fn run_reduce<T, U: Clone>(v: &[T], out_length: usize, name: &'static str) -> Ve
|
||||
&input,
|
||||
0,
|
||||
&output,
|
||||
) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
println!("Error: {}", e);
|
||||
panic!();
|
||||
}
|
||||
}
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
@ -659,7 +530,8 @@ fn run_reduce<T, U: Clone>(v: &[T], out_length: usize, name: &'static str) -> Ve
|
||||
|
||||
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
@ -682,114 +554,22 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
|
||||
read_to_vec(&output, v.len())
|
||||
}
|
||||
|
||||
const fn create_array<const N: usize>() -> [f32; N] {
|
||||
let mut array: [f32; N] = [0.0; N];
|
||||
let mut i = 1;
|
||||
while i <= N {
|
||||
array[i - 1] = i as f32;
|
||||
i += 1;
|
||||
}
|
||||
array
|
||||
}
|
||||
|
||||
const fn correct_sum<const N: usize, const D: usize>() -> [f32; D] {
|
||||
let mut sum = 0;
|
||||
let mut results: [f32; D] = [0.0; D];
|
||||
let mut i = 1;
|
||||
let mut j = 1;
|
||||
while i <= N {
|
||||
sum += i;
|
||||
i += 1;
|
||||
if i > j * N / D {
|
||||
results[j - 1] = sum as f32;
|
||||
j += 1;
|
||||
sum = 0;
|
||||
}
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
fn correct_argmax<const N: usize, const D: usize>(arr: [f32; N]) -> [u32; D] {
|
||||
let mut max = 0.0;
|
||||
let mut max_index: u32 = 0;
|
||||
let mut results: [u32; D] = [0; D];
|
||||
let mut i = 0;
|
||||
let mut j = 1;
|
||||
while i <= N {
|
||||
if i >= (j * N / D) {
|
||||
results[j - 1] = max_index;
|
||||
max = 0.0;
|
||||
max_index = 0;
|
||||
j += 1;
|
||||
}
|
||||
if i == N {
|
||||
break;
|
||||
}
|
||||
if arr[i] > max {
|
||||
max = arr[i];
|
||||
max_index = i as u32;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
fn reduce_sum_case<const N: usize, const D: usize>() {
|
||||
let v = create_array::<N>();
|
||||
let results = run_reduce(&v, D, "fast_sum_f32_strided");
|
||||
assert_eq!(approx(results, 4), correct_sum::<N, D>());
|
||||
}
|
||||
|
||||
fn reduce_argmax_case<const N: usize, const D: usize>() {
|
||||
let v = create_array::<N>();
|
||||
let results: Vec<u32> = run_reduce(&v, D, "fast_argmax_f32_strided");
|
||||
assert_eq!(results, correct_argmax::<N, D>(v));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_sum() {
|
||||
reduce_sum_case::<6, 1>();
|
||||
reduce_sum_case::<10, 1>();
|
||||
reduce_sum_case::<64, 1>();
|
||||
reduce_sum_case::<128, 1>();
|
||||
reduce_sum_case::<256, 1>();
|
||||
reduce_sum_case::<512, 1>();
|
||||
reduce_sum_case::<1024, 1>();
|
||||
reduce_sum_case::<2048, 1>();
|
||||
reduce_sum_case::<4096, 1>();
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let out_length = 1;
|
||||
|
||||
reduce_sum_case::<6, 2>();
|
||||
reduce_sum_case::<10, 2>();
|
||||
reduce_sum_case::<64, 2>();
|
||||
reduce_sum_case::<128, 2>();
|
||||
reduce_sum_case::<256, 2>();
|
||||
reduce_sum_case::<512, 2>();
|
||||
reduce_sum_case::<1024, 2>();
|
||||
reduce_sum_case::<2048, 2>();
|
||||
reduce_sum_case::<4096, 2>();
|
||||
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
||||
assert_eq!(approx(results, 4), vec![21.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_argmax() {
|
||||
reduce_argmax_case::<6, 1>();
|
||||
reduce_argmax_case::<10, 1>();
|
||||
reduce_argmax_case::<64, 1>();
|
||||
reduce_argmax_case::<128, 1>();
|
||||
reduce_argmax_case::<256, 1>();
|
||||
reduce_argmax_case::<512, 1>();
|
||||
reduce_argmax_case::<1024, 1>();
|
||||
reduce_argmax_case::<2048, 1>();
|
||||
reduce_argmax_case::<4096, 1>();
|
||||
fn reduce_sum2() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let out_length = 2;
|
||||
|
||||
reduce_argmax_case::<6, 2>();
|
||||
reduce_argmax_case::<10, 2>();
|
||||
reduce_argmax_case::<64, 2>();
|
||||
reduce_argmax_case::<128, 2>();
|
||||
reduce_argmax_case::<256, 2>();
|
||||
reduce_argmax_case::<512, 2>();
|
||||
reduce_argmax_case::<1024, 2>();
|
||||
reduce_argmax_case::<2048, 2>();
|
||||
reduce_argmax_case::<4096, 2>();
|
||||
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
||||
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -810,6 +590,7 @@ fn softmax() {
|
||||
}
|
||||
let results = run_softmax(&v, last_dim, "softmax_f32");
|
||||
let results = approx(results, 4);
|
||||
println!("{results:?}");
|
||||
assert_eq!(
|
||||
results.iter().map(|&s| s.round() as usize).sum::<usize>(),
|
||||
n
|
||||
@ -869,7 +650,8 @@ fn run_where_cond<I: Clone, T: Clone>(
|
||||
name: &'static str,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
@ -945,7 +727,8 @@ fn run_gemm<T: Clone>(
|
||||
rhs_offset: usize,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
@ -1023,124 +806,3 @@ fn gemm() {
|
||||
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
|
||||
);
|
||||
}
|
||||
|
||||
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options);
|
||||
|
||||
let seed = device.new_buffer_with_data(
|
||||
&seed as *const u32 as *const core::ffi::c_void,
|
||||
std::mem::size_of::<u32>() as NSUInteger,
|
||||
options,
|
||||
);
|
||||
|
||||
if name.starts_with("rand_uniform") {
|
||||
call_random_uniform(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
a,
|
||||
b,
|
||||
length,
|
||||
&seed,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
} else {
|
||||
call_random_normal(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
a,
|
||||
b,
|
||||
length,
|
||||
&seed,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
read_to_vec(&output, length)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn random() {
|
||||
fn calc_mean(data: &[f32]) -> f32 {
|
||||
let sum = data.iter().sum::<f32>() as f32;
|
||||
let count = data.len();
|
||||
assert!(count > 0);
|
||||
sum / count as f32
|
||||
}
|
||||
|
||||
fn calc_stddev(data: &[f32]) -> f32 {
|
||||
let mean = calc_mean(data);
|
||||
let count = data.len();
|
||||
assert!(count > 0);
|
||||
|
||||
let variance = data
|
||||
.iter()
|
||||
.map(|value| {
|
||||
let diff = mean - (*value as f32);
|
||||
diff * diff
|
||||
})
|
||||
.sum::<f32>()
|
||||
/ count as f32;
|
||||
|
||||
variance.sqrt()
|
||||
}
|
||||
|
||||
let shape = vec![1024, 10];
|
||||
|
||||
let length = shape.iter().product::<usize>();
|
||||
let seed = 299792458;
|
||||
|
||||
let min = -30.0;
|
||||
let max = 30.0;
|
||||
let mean = 100.0;
|
||||
let stddev = 50.0;
|
||||
|
||||
macro_rules! validate_random {
|
||||
($type:ty) => {
|
||||
let results: Vec<f32> = run_random::<$type>(
|
||||
concat!("rand_uniform_", stringify!($type)),
|
||||
seed,
|
||||
length,
|
||||
min,
|
||||
max,
|
||||
)
|
||||
.into_iter()
|
||||
.map(f32::from)
|
||||
.collect();
|
||||
results.iter().for_each(|v| {
|
||||
assert!(*v >= min && *v <= max);
|
||||
});
|
||||
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
|
||||
|
||||
let results: Vec<f32> = run_random::<$type>(
|
||||
concat!("rand_normal_", stringify!($type)),
|
||||
seed,
|
||||
length,
|
||||
mean,
|
||||
stddev,
|
||||
)
|
||||
.into_iter()
|
||||
.map(f32::from)
|
||||
.collect();
|
||||
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
|
||||
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
|
||||
};
|
||||
}
|
||||
|
||||
validate_random!(f32);
|
||||
validate_random!(f16);
|
||||
validate_random!(bf16);
|
||||
}
|
||||
|
@ -58,12 +58,6 @@ template <typename T> METAL_FUNC T gelu(T x) {
|
||||
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
||||
}
|
||||
template <typename T> METAL_FUNC T relu(T in){
|
||||
if (in < 0) {
|
||||
return 0;
|
||||
}
|
||||
return in;
|
||||
}
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
@ -116,7 +110,6 @@ UNARY_OP(gelu_erf)
|
||||
UNARY_OP(erf)
|
||||
UNARY_OP(tanh)
|
||||
UNARY_OP(recip)
|
||||
UNARY_OP(relu)
|
||||
UNARY(id, float, copy_f32, copy_f32_strided)
|
||||
UNARY(id, half, copy_f16, copy_f16_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
@ -126,7 +119,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_UNARY_OP(cos)
|
||||
BFLOAT_UNARY_OP(sin)
|
||||
BFLOAT_UNARY_OP(sqr)
|
||||
@ -143,7 +136,6 @@ BFLOAT_UNARY_OP(gelu_erf)
|
||||
BFLOAT_UNARY_OP(erf)
|
||||
BFLOAT_UNARY_OP(tanh)
|
||||
BFLOAT_UNARY_OP(recip)
|
||||
BFLOAT_UNARY_OP(relu)
|
||||
|
||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
#endif
|
||||
|
@ -11,7 +11,7 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
half = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
@ -20,7 +20,7 @@ rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
metal = { workspace = true, optional = true }
|
||||
candle-metal-kernels = { workspace = true, optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
@ -6,7 +6,6 @@ use serde::Deserialize;
|
||||
pub enum Activation {
|
||||
#[default]
|
||||
Gelu,
|
||||
#[serde(alias = "gelu_new")]
|
||||
NewGelu,
|
||||
Relu,
|
||||
Relu2,
|
||||
|
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn" }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
@ -20,3 +20,4 @@ prost-build = "0.12.1"
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
|
||||
|
@ -254,12 +254,6 @@ pub fn simple_eval(
|
||||
let output = input0.broadcast_div(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Pow" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
let output = input0.broadcast_pow(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Equal" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
|
@ -15,9 +15,9 @@ crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-onnx = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
candle-onnx = {path= "../candle-onnx", version = "0.3.3", optional = true}
|
||||
half = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
||||
|
@ -12,9 +12,9 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
|
@ -1,6 +1,6 @@
|
||||
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const DTYPE: DType = DType::F32;
|
||||
@ -112,6 +112,11 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
struct Dropout {
|
||||
#[allow(dead_code)]
|
||||
pr: f64,
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
@ -11,6 +11,11 @@ fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Line
|
||||
Ok(Linear::new(weight, bias))
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let weight = vb.get(size, "weight")?;
|
||||
let bias = vb.get(size, "bias")?;
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
|
||||
@ -27,6 +27,11 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
Ok(LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
|
||||
#[derive(Debug)]
|
||||
pub struct Config {
|
||||
|
@ -1,6 +1,6 @@
|
||||
use super::with_tracing::{linear_no_bias as linear, Linear};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
@ -136,6 +136,11 @@ impl Cache {
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, cfg.hidden_size))
|
||||
}
|
||||
|
||||
struct RmsNorm {
|
||||
inner: candle_nn::RmsNorm,
|
||||
span: tracing::Span,
|
||||
@ -404,7 +409,7 @@ impl Llama {
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
||||
|
@ -1,333 +0,0 @@
|
||||
//! MobileOne inference implementation based on timm and candle-repvgg
|
||||
//!
|
||||
//! See "MobileOne: An Improved One millisecond Mobile Backbone"
|
||||
//! https://arxiv.org/abs/2206.04040
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, BatchNorm, Conv2d, Conv2dConfig,
|
||||
Func, VarBuilder,
|
||||
};
|
||||
|
||||
struct StageConfig {
|
||||
blocks: usize,
|
||||
channels: usize,
|
||||
}
|
||||
|
||||
// The architecture in the paper has 6 stages. The timm implementation uses an equivalent form
|
||||
// by concatenating the 5th stage (starts with stride 1) to the previous one.
|
||||
const STAGES: [StageConfig; 5] = [
|
||||
StageConfig {
|
||||
blocks: 1,
|
||||
channels: 64,
|
||||
},
|
||||
StageConfig {
|
||||
blocks: 2,
|
||||
channels: 64,
|
||||
},
|
||||
StageConfig {
|
||||
blocks: 8,
|
||||
channels: 128,
|
||||
},
|
||||
StageConfig {
|
||||
blocks: 10,
|
||||
channels: 256,
|
||||
},
|
||||
StageConfig {
|
||||
blocks: 1,
|
||||
channels: 512,
|
||||
},
|
||||
];
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
/// overparameterization factor
|
||||
k: usize,
|
||||
/// per-stage channel number multipliers
|
||||
alphas: [f32; 5],
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn s0() -> Self {
|
||||
Self {
|
||||
k: 4,
|
||||
alphas: [0.75, 0.75, 1.0, 1.0, 2.0],
|
||||
}
|
||||
}
|
||||
pub fn s1() -> Self {
|
||||
Self {
|
||||
k: 1,
|
||||
alphas: [1.5, 1.5, 1.5, 2.0, 2.5],
|
||||
}
|
||||
}
|
||||
pub fn s2() -> Self {
|
||||
Self {
|
||||
k: 1,
|
||||
alphas: [1.5, 1.5, 2.0, 2.5, 4.0],
|
||||
}
|
||||
}
|
||||
pub fn s3() -> Self {
|
||||
Self {
|
||||
k: 1,
|
||||
alphas: [2.0, 2.0, 2.5, 3.0, 4.0],
|
||||
}
|
||||
}
|
||||
pub fn s4() -> Self {
|
||||
Self {
|
||||
k: 1,
|
||||
alphas: [3.0, 3.0, 3.5, 3.5, 4.0],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SE blocks are used in the last stages of the s4 variant.
|
||||
fn squeeze_and_excitation(
|
||||
in_channels: usize,
|
||||
squeeze_channels: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
..Default::default()
|
||||
};
|
||||
let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
|
||||
let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs;
|
||||
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
|
||||
let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
|
||||
|
||||
residual.broadcast_mul(&xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
|
||||
// based on the _fuse_bn_tensor method in timm
|
||||
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
|
||||
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
|
||||
let (gamma, beta) = bn.weight_and_bias().unwrap();
|
||||
let mu = bn.running_mean();
|
||||
let sigma = (bn.running_var() + bn.eps())?.sqrt();
|
||||
let gps = (gamma / sigma)?;
|
||||
let bias = (beta - mu * &gps)?;
|
||||
let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
|
||||
|
||||
Ok((weights, bias))
|
||||
}
|
||||
|
||||
// A mobileone block has a different training time and inference time architecture.
|
||||
// The latter is a simple and efficient equivalent transformation of the former
|
||||
// realized by a structural reparameterization technique, where convolutions
|
||||
// along with identity branches and batchnorm layers are fused into a single convolution.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn mobileone_block(
|
||||
has_identity: bool,
|
||||
k: usize,
|
||||
dim: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
groups: usize,
|
||||
kernel: usize,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride,
|
||||
padding,
|
||||
groups,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut w = Tensor::zeros(
|
||||
(out_channels, in_channels / groups, kernel, kernel),
|
||||
DType::F32,
|
||||
vb.device(),
|
||||
)?;
|
||||
let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;
|
||||
|
||||
// k is the training-time overparameterization factor, larger than 1 only in the s0 variant
|
||||
for i in 0..k {
|
||||
let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp(format!("conv_kxk.{i}.bn")))?;
|
||||
let conv_kxk = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel,
|
||||
conv2d_cfg,
|
||||
vb.pp(format!("conv_kxk.{i}.conv")),
|
||||
)?;
|
||||
let (wk, bk) = fuse_conv_bn(conv_kxk.weight(), conv_kxk_bn)?;
|
||||
w = (w + wk)?;
|
||||
b = (b + bk)?;
|
||||
}
|
||||
|
||||
if kernel > 1 {
|
||||
let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"))?;
|
||||
let conv_scale = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv2d_cfg,
|
||||
vb.pp("conv_scale.conv"),
|
||||
)?;
|
||||
|
||||
let (mut ws, bs) = fuse_conv_bn(conv_scale.weight(), conv_scale_bn)?;
|
||||
// resize to 3x3
|
||||
ws = ws.pad_with_zeros(D::Minus1, 1, 1)?;
|
||||
ws = ws.pad_with_zeros(D::Minus2, 1, 1)?;
|
||||
|
||||
w = (w + ws)?;
|
||||
b = (b + bs)?;
|
||||
}
|
||||
|
||||
// Use SE blocks if present (last layers of the s4 variant)
|
||||
let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("attn"));
|
||||
|
||||
// read and reparameterize the identity bn into wi and bi
|
||||
if has_identity {
|
||||
let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
|
||||
|
||||
let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
|
||||
|
||||
let id = in_channels / groups;
|
||||
// See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809
|
||||
for i in 0..in_channels {
|
||||
if kernel > 1 {
|
||||
weights[i * kernel * kernel + 4] = 1.0;
|
||||
} else {
|
||||
weights[i * (id + 1)] = 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
|
||||
let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
|
||||
|
||||
w = (w + wi)?;
|
||||
b = (b + bi)?;
|
||||
}
|
||||
|
||||
let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.apply(&reparam_conv)?;
|
||||
if let Ok(f) = &se {
|
||||
xs = xs.apply(f)?;
|
||||
}
|
||||
xs = xs.relu()?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Get the number of output channels per stage taking into account the multipliers
|
||||
fn output_channels_per_stage(cfg: &Config, stage: usize) -> usize {
|
||||
let channels = STAGES[stage].channels as f32;
|
||||
let alpha = cfg.alphas[stage];
|
||||
|
||||
match stage {
|
||||
0 => std::cmp::min(64, (channels * alpha) as usize),
|
||||
_ => (channels * alpha) as usize,
|
||||
}
|
||||
}
|
||||
|
||||
// Each stage is made of blocks. The first layer always downsamples with stride 2.
|
||||
// All but the first block have a residual connection.
|
||||
fn mobileone_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let nblocks = STAGES[idx].blocks;
|
||||
let mut blocks = Vec::with_capacity(nblocks);
|
||||
|
||||
let mut in_channels = output_channels_per_stage(cfg, idx - 1);
|
||||
|
||||
for block_idx in 0..nblocks {
|
||||
let out_channels = output_channels_per_stage(cfg, idx);
|
||||
let (has_identity, stride) = if block_idx == 0 {
|
||||
(false, 2)
|
||||
} else {
|
||||
(true, 1)
|
||||
};
|
||||
|
||||
// depthwise convolution layer
|
||||
blocks.push(mobileone_block(
|
||||
has_identity,
|
||||
cfg.k,
|
||||
in_channels,
|
||||
stride,
|
||||
1,
|
||||
in_channels,
|
||||
3,
|
||||
in_channels,
|
||||
in_channels,
|
||||
vb.pp(block_idx * 2),
|
||||
)?);
|
||||
|
||||
// pointwise convolution layer
|
||||
blocks.push(mobileone_block(
|
||||
has_identity,
|
||||
cfg.k,
|
||||
out_channels,
|
||||
1, // stride
|
||||
0, // padding
|
||||
1, // groups
|
||||
1, // kernel
|
||||
in_channels,
|
||||
out_channels,
|
||||
vb.pp(block_idx * 2 + 1),
|
||||
)?);
|
||||
|
||||
in_channels = out_channels;
|
||||
}
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.clone();
|
||||
for block in blocks.iter() {
|
||||
xs = xs.apply(block)?
|
||||
}
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Build a mobileone model for a given configuration.
|
||||
fn mobileone_model(
|
||||
config: &Config,
|
||||
nclasses: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let cls = match nclasses {
|
||||
None => None,
|
||||
Some(nclasses) => {
|
||||
let outputs = output_channels_per_stage(config, 4);
|
||||
let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
|
||||
Some(linear)
|
||||
}
|
||||
};
|
||||
|
||||
let stem_dim = output_channels_per_stage(config, 0);
|
||||
let stem = mobileone_block(false, 1, stem_dim, 2, 1, 1, 3, 3, stem_dim, vb.pp("stem"))?;
|
||||
let vb = vb.pp("stages");
|
||||
let stage1 = mobileone_stage(config, 1, vb.pp(0))?;
|
||||
let stage2 = mobileone_stage(config, 2, vb.pp(1))?;
|
||||
let stage3 = mobileone_stage(config, 3, vb.pp(2))?;
|
||||
let stage4 = mobileone_stage(config, 4, vb.pp(3))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&stem)?
|
||||
.apply(&stage1)?
|
||||
.apply(&stage2)?
|
||||
.apply(&stage3)?
|
||||
.apply(&stage4)?
|
||||
.mean(D::Minus2)?
|
||||
.mean(D::Minus1)?;
|
||||
match &cls {
|
||||
None => Ok(xs),
|
||||
Some(cls) => xs.apply(cls),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn mobileone(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
mobileone_model(cfg, Some(nclasses), vb)
|
||||
}
|
||||
|
||||
pub fn mobileone_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
mobileone_model(cfg, None, vb)
|
||||
}
|
@ -15,10 +15,8 @@ pub mod marian;
|
||||
pub mod mistral;
|
||||
pub mod mixformer;
|
||||
pub mod mixtral;
|
||||
pub mod mobileone;
|
||||
pub mod mpt;
|
||||
pub mod persimmon;
|
||||
pub mod phi;
|
||||
pub mod quantized_blip;
|
||||
pub mod quantized_blip_text;
|
||||
pub mod quantized_llama;
|
||||
@ -28,7 +26,6 @@ pub mod quantized_mixformer;
|
||||
pub mod quantized_mpt;
|
||||
pub mod quantized_stable_lm;
|
||||
pub mod quantized_t5;
|
||||
pub mod repvgg;
|
||||
pub mod resnet;
|
||||
pub mod segment_anything;
|
||||
pub mod stable_diffusion;
|
||||
|
@ -1,363 +0,0 @@
|
||||
use crate::models::with_tracing::{layer_norm, linear, Embedding, LayerNorm, Linear};
|
||||
/// Phi model.
|
||||
/// https://huggingface.co/microsoft/phi-2
|
||||
/// There is an alternative implementation of the phi model in mixformers.rs.
|
||||
/// This corresponds to the model update made with the following commit:
|
||||
/// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
// https://huggingface.co/microsoft/phi-2/blob/main/configuration_phi.py
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub(crate) vocab_size: usize,
|
||||
pub(crate) hidden_size: usize,
|
||||
pub(crate) intermediate_size: usize,
|
||||
pub(crate) num_hidden_layers: usize,
|
||||
pub(crate) num_attention_heads: usize,
|
||||
pub(crate) num_key_value_heads: Option<usize>,
|
||||
pub(crate) hidden_act: Activation,
|
||||
pub(crate) max_position_embeddings: usize,
|
||||
pub(crate) layer_norm_eps: f64,
|
||||
pub(crate) tie_word_embeddings: bool,
|
||||
pub(crate) rope_theta: f32,
|
||||
pub(crate) partial_rotary_factor: f64,
|
||||
pub(crate) qk_layernorm: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn num_key_value_heads(&self) -> usize {
|
||||
self.num_key_value_heads.unwrap_or(self.num_attention_heads)
|
||||
}
|
||||
|
||||
fn head_dim(&self) -> usize {
|
||||
self.hidden_size / self.num_attention_heads
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
dim: usize,
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
||||
let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((cfg.max_position_embeddings, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
dim,
|
||||
sin: emb.sin()?,
|
||||
cos: emb.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
|
||||
let xs_rot = xs.i((.., .., .., ..self.dim))?;
|
||||
let xs_pass = xs.i((.., .., .., self.dim..))?;
|
||||
let xs12 = xs_rot.chunk(2, D::Minus1)?;
|
||||
let (xs1, xs2) = (&xs12[0], &xs12[1]);
|
||||
let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let rotate_half = Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)?;
|
||||
let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
|
||||
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
act: Activation,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?;
|
||||
let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
|
||||
Ok(Self {
|
||||
fc1,
|
||||
fc2,
|
||||
// This does not match the mixformers implementation where Gelu is used rather than
|
||||
// GeluNew.
|
||||
act: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
dense: Linear,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
q_layernorm: Option<LayerNorm>,
|
||||
k_layernorm: Option<LayerNorm>,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
softmax_scale: f64,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size, size), device)
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads();
|
||||
let head_dim = cfg.head_dim();
|
||||
let q_proj = linear(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let dense = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("dense"))?;
|
||||
// Alternative rope scalings are not supported.
|
||||
let rotary_emb = RotaryEmbedding::new(cfg, vb.device())?;
|
||||
let (q_layernorm, k_layernorm) = if cfg.qk_layernorm {
|
||||
let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?;
|
||||
let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?;
|
||||
(Some(q_layernorm), Some(k_layernorm))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
dense,
|
||||
kv_cache: None,
|
||||
q_layernorm,
|
||||
k_layernorm,
|
||||
rotary_emb,
|
||||
softmax_scale,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
span: tracing::span!(tracing::Level::TRACE, "attention"),
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_heads / self.num_kv_heads;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_size, seq_len, _n_embd) = xs.dims3()?;
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
let key_states = self.k_proj.forward(xs)?;
|
||||
let value_states = self.v_proj.forward(xs)?;
|
||||
|
||||
let query_states = match &self.q_layernorm {
|
||||
None => query_states,
|
||||
Some(ln) => query_states.apply(ln)?,
|
||||
};
|
||||
let key_states = match &self.k_layernorm {
|
||||
None => key_states,
|
||||
Some(ln) => key_states.apply(ln)?,
|
||||
};
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_size, seq_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let key_states = key_states
|
||||
.reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let value_states = value_states
|
||||
.reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
// Rotary embeddings.
|
||||
let seqlen_offset = match &self.kv_cache {
|
||||
None => 0,
|
||||
Some((prev_k, _)) => prev_k.dim(2)?,
|
||||
};
|
||||
let query_states = self
|
||||
.rotary_emb
|
||||
.apply_rotary_emb(&query_states, seqlen_offset)?;
|
||||
let key_states = self
|
||||
.rotary_emb
|
||||
.apply_rotary_emb(&key_states, seqlen_offset)?;
|
||||
|
||||
// KV cache.
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let k = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let v = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(k, v)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
// Repeat kv.
|
||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
||||
|
||||
let attn_weights = (query_states
|
||||
.to_dtype(DType::F32)?
|
||||
.contiguous()?
|
||||
.matmul(&key_states.to_dtype(DType::F32)?.t()?)?
|
||||
* self.softmax_scale)?;
|
||||
let attn_weights = match mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => masked_fill(
|
||||
&attn_weights,
|
||||
&mask.broadcast_left((b_size, self.num_heads))?,
|
||||
f32::NEG_INFINITY,
|
||||
)?,
|
||||
};
|
||||
let attn_weights =
|
||||
candle_nn::ops::softmax_last_dim(&attn_weights)?.to_dtype(value_states.dtype())?;
|
||||
let attn_output = attn_weights.matmul(&value_states)?;
|
||||
let attn_output = attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_size, seq_len, ()))?;
|
||||
attn_output.apply(&self.dense)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
mlp: MLP,
|
||||
input_layernorm: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let input_layernorm = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_eps,
|
||||
vb.pp("input_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
input_layernorm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "block"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.input_layernorm)?;
|
||||
let attn_outputs = self.self_attn.forward(&xs, mask)?;
|
||||
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
|
||||
attn_outputs + feed_forward_hidden_states + residual
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
final_layernorm: LayerNorm,
|
||||
lm_head: Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let final_layernorm = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_eps,
|
||||
vb_m.pp("final_layernorm"),
|
||||
)?;
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_m = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(cfg, vb_m.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
final_layernorm,
|
||||
lm_head,
|
||||
span: tracing::span!(tracing::Level::TRACE, "model"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_b_size, seq_len) = xs.dims2()?;
|
||||
let mut xs = xs.apply(&self.embed_tokens)?;
|
||||
let mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(get_mask(seq_len, xs.device())?)
|
||||
};
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, mask.as_ref())?;
|
||||
}
|
||||
xs.apply(&self.final_layernorm)?
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.lm_head)?
|
||||
.squeeze(1)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.layers.iter_mut().for_each(|b| b.clear_kv_cache())
|
||||
}
|
||||
}
|
@ -165,9 +165,13 @@ impl Attention {
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
// let (query_states1, key_states1) =
|
||||
// self.rotary_emb
|
||||
// .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
// println!("{query_states:?} {query_states1:?}");
|
||||
// println!("{key_states:?} {key_states1:?}");
|
||||
let query_states = query_states.contiguous()?;
|
||||
let key_states = key_states.contiguous()?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
|
@ -199,7 +199,10 @@ impl MHA {
|
||||
Some((prev_k, _)) => prev_k.dim(1)?,
|
||||
};
|
||||
// In the python implementation, a single tensor is returned with the third axis of size 3.
|
||||
let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?;
|
||||
// let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?;
|
||||
let q = qkv.i((.., .., 0))?;
|
||||
let k = qkv.i((.., .., 1))?;
|
||||
let v = qkv.i((.., .., 2))?;
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
Some((prev_k, prev_v)) => {
|
||||
|
@ -1,306 +0,0 @@
|
||||
//! RepVGG inference implementation
|
||||
//!
|
||||
//! See "RepVGG: Making VGG-style ConvNets Great Again" Ding et al. 2021
|
||||
//! https://arxiv.org/abs/2101.03697
|
||||
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d_no_bias, linear, BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
|
||||
};
|
||||
|
||||
const CHANNELS_PER_STAGE: [usize; 5] = [64, 64, 128, 256, 512];
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
a: f32,
|
||||
b: f32,
|
||||
groups: usize,
|
||||
stages: [usize; 4],
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn a0() -> Self {
|
||||
Self {
|
||||
a: 0.75,
|
||||
b: 2.5,
|
||||
groups: 1,
|
||||
stages: [2, 4, 14, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn a1() -> Self {
|
||||
Self {
|
||||
a: 1.0,
|
||||
b: 2.5,
|
||||
groups: 1,
|
||||
stages: [2, 4, 14, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn a2() -> Self {
|
||||
Self {
|
||||
a: 1.5,
|
||||
b: 2.75,
|
||||
groups: 1,
|
||||
stages: [2, 4, 14, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b0() -> Self {
|
||||
Self {
|
||||
a: 1.0,
|
||||
b: 2.5,
|
||||
groups: 1,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b1() -> Self {
|
||||
Self {
|
||||
a: 2.0,
|
||||
b: 4.0,
|
||||
groups: 1,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b2() -> Self {
|
||||
Self {
|
||||
a: 2.5,
|
||||
b: 5.0,
|
||||
groups: 1,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b3() -> Self {
|
||||
Self {
|
||||
a: 3.0,
|
||||
b: 5.0,
|
||||
groups: 1,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b1g4() -> Self {
|
||||
Self {
|
||||
a: 2.0,
|
||||
b: 4.0,
|
||||
groups: 4,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b2g4() -> Self {
|
||||
Self {
|
||||
a: 2.5,
|
||||
b: 5.0,
|
||||
groups: 4,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b3g4() -> Self {
|
||||
Self {
|
||||
a: 3.0,
|
||||
b: 5.0,
|
||||
groups: 4,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
|
||||
// based on the _fuse_bn_tensor method in timm
|
||||
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
|
||||
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
|
||||
let (gamma, beta) = bn.weight_and_bias().unwrap();
|
||||
let mu = bn.running_mean();
|
||||
let sigma = (bn.running_var() + bn.eps())?.sqrt();
|
||||
let gps = (gamma / sigma)?;
|
||||
let bias = (beta - mu * &gps)?;
|
||||
let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
|
||||
|
||||
Ok((weights, bias))
|
||||
}
|
||||
|
||||
// A RepVGG layer has a different training time and inference time architecture.
|
||||
// The latter is a simple and efficient equivalent transformation of the former
|
||||
// realized by a structural reparameterization technique, where 3x3 and 1x1 convolutions
|
||||
// along with identity branches and batchnorm layers are fused into a single 3x3 convolution.
|
||||
fn repvgg_layer(
|
||||
has_identity: bool,
|
||||
dim: usize,
|
||||
stride: usize,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
groups: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride,
|
||||
groups,
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// read and reparameterize the 1x1 conv and bn into w1 and b1
|
||||
// based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L543
|
||||
|
||||
let conv1x1_bn = batch_norm(dim, 1e-5, vb.pp("conv_1x1.bn"))?;
|
||||
let conv1x1 = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv2d_cfg,
|
||||
vb.pp("conv_1x1.conv"),
|
||||
)?;
|
||||
|
||||
let (mut w1, b1) = fuse_conv_bn(conv1x1.weight(), conv1x1_bn)?;
|
||||
|
||||
// resize to 3x3
|
||||
w1 = w1.pad_with_zeros(D::Minus1, 1, 1)?;
|
||||
w1 = w1.pad_with_zeros(D::Minus2, 1, 1)?;
|
||||
|
||||
// read and reparameterize the 3x3 conv and bn into w3 and b3
|
||||
let convkxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.bn"))?;
|
||||
let conv3x3 = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
conv2d_cfg,
|
||||
vb.pp("conv_kxk.conv"),
|
||||
)?;
|
||||
|
||||
let (w3, b3) = fuse_conv_bn(conv3x3.weight(), convkxk_bn)?;
|
||||
|
||||
let mut w = (w1 + w3)?;
|
||||
let mut b = (b1 + b3)?;
|
||||
|
||||
// read and reparameterize the identity bn into wi and bi
|
||||
if has_identity {
|
||||
let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
|
||||
|
||||
// create a 3x3 convolution equivalent to the identity branch
|
||||
let mut weights: Vec<f32> = vec![0.0; conv3x3.weight().elem_count()];
|
||||
|
||||
// https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L620
|
||||
let in_dim = in_channels / groups;
|
||||
for i in 0..in_channels {
|
||||
weights[i * in_dim * 3 * 3 + (i % in_dim) * 3 * 3 + 4] = 1.0;
|
||||
}
|
||||
|
||||
let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
|
||||
let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
|
||||
|
||||
w = (w + wi)?;
|
||||
b = (b + bi)?;
|
||||
}
|
||||
|
||||
// create the 3x3 conv equivalent to the sum of 3x3, 1x1 and identity branches
|
||||
let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.apply(&reparam_conv)?.relu()?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Get the number of output channels per stage taking into account the multipliers
|
||||
fn output_channels_per_stage(a: f32, b: f32, stage: usize) -> usize {
|
||||
let channels = CHANNELS_PER_STAGE[stage] as f32;
|
||||
|
||||
match stage {
|
||||
0 => std::cmp::min(64, (channels * a) as usize),
|
||||
4 => (channels * b) as usize,
|
||||
_ => (channels * a) as usize,
|
||||
}
|
||||
}
|
||||
|
||||
// Each stage is made of layers. The first layer always downsamples with stride 2.
|
||||
// All but the first layer have a residual connection.
|
||||
// The G4 variants have a groupwise convolution instead of a dense one on odd layers
|
||||
// counted across stage boundaries, so we keep track of which layer we are in the
|
||||
// full model.
|
||||
fn repvgg_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let nlayers = cfg.stages[idx - 1];
|
||||
let mut layers = Vec::with_capacity(nlayers);
|
||||
let prev_layers: usize = cfg.stages[..idx - 1].iter().sum();
|
||||
let out_channels_prev = output_channels_per_stage(cfg.a, cfg.b, idx - 1);
|
||||
let out_channels = output_channels_per_stage(cfg.a, cfg.b, idx);
|
||||
|
||||
for layer_idx in 0..nlayers {
|
||||
let (has_identity, stride, in_channels) = if layer_idx == 0 {
|
||||
(false, 2, out_channels_prev)
|
||||
} else {
|
||||
(true, 1, out_channels)
|
||||
};
|
||||
|
||||
let groups = if (prev_layers + layer_idx) % 2 == 1 {
|
||||
cfg.groups
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
layers.push(repvgg_layer(
|
||||
has_identity,
|
||||
out_channels,
|
||||
stride,
|
||||
in_channels,
|
||||
out_channels,
|
||||
groups,
|
||||
vb.pp(layer_idx),
|
||||
)?)
|
||||
}
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.clone();
|
||||
for layer in layers.iter() {
|
||||
xs = xs.apply(layer)?
|
||||
}
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Build a RepVGG model for a given configuration.
|
||||
fn repvgg_model(config: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let cls = match nclasses {
|
||||
None => None,
|
||||
Some(nclasses) => {
|
||||
let outputs = output_channels_per_stage(config.a, config.b, 4);
|
||||
let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
|
||||
Some(linear)
|
||||
}
|
||||
};
|
||||
|
||||
let stem_dim = output_channels_per_stage(config.a, config.b, 0);
|
||||
let stem = repvgg_layer(false, stem_dim, 2, 3, stem_dim, 1, vb.pp("stem"))?;
|
||||
let vb = vb.pp("stages");
|
||||
let stage1 = repvgg_stage(config, 1, vb.pp(0))?;
|
||||
let stage2 = repvgg_stage(config, 2, vb.pp(1))?;
|
||||
let stage3 = repvgg_stage(config, 3, vb.pp(2))?;
|
||||
let stage4 = repvgg_stage(config, 4, vb.pp(3))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&stem)?
|
||||
.apply(&stage1)?
|
||||
.apply(&stage2)?
|
||||
.apply(&stage3)?
|
||||
.apply(&stage4)?
|
||||
.mean(D::Minus1)?
|
||||
.mean(D::Minus1)?;
|
||||
match &cls {
|
||||
None => Ok(xs),
|
||||
Some(cls) => xs.apply(cls),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn repvgg(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
repvgg_model(cfg, Some(nclasses), vb)
|
||||
}
|
||||
|
||||
pub fn repvgg_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
repvgg_model(cfg, None, vb)
|
||||
}
|
@ -1,7 +1,12 @@
|
||||
use super::Config;
|
||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{embedding, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
in_channels: usize,
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
@ -27,7 +27,7 @@ safetensors = { workspace = true }
|
||||
# Wasm specific crates.
|
||||
console_error_panic_hook = "0.1.7"
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
gloo = "0.11"
|
||||
gloo = "0.8"
|
||||
js-sys = "0.3.64"
|
||||
wasm-bindgen = "0.2.87"
|
||||
serde-wasm-bindgen = "0.6.0"
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
@ -26,7 +26,7 @@ serde_json = { workspace = true }
|
||||
# Wasm specific crates.
|
||||
console_error_panic_hook = "0.1.7"
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
gloo = "0.11"
|
||||
gloo = "0.8"
|
||||
js-sys = "0.3.64"
|
||||
wasm-bindgen = "0.2.87"
|
||||
wasm-bindgen-futures = "0.4.37"
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user