Merge branch 'main' into ivarflakstad/metal-prng

This commit is contained in:
Ivar Flakstad
2024-01-12 07:19:58 +01:00
53 changed files with 1035 additions and 1051 deletions

7
.github/dependabot.yml vendored Normal file
View File

@ -0,0 +1,7 @@
version: 2
updates:
- package-ecosystem: "cargo"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 5

View File

@ -31,9 +31,17 @@ license = "MIT OR Apache-2.0"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core" }
candle-datasets = { path = "./candle-datasets" }
candle-flash-attn = { path = "./candle-flash-attn" }
candle-kernels = { path = "./candle-kernels" }
candle-metal-kernels = { path = "./candle-metal-kernels" }
candle-nn = { path = "./candle-nn" }
candle-onnx = { path = "./candle-onnx" }
candle-transformers = { path = "./candle-transformers" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.9.14", features = ["f16"] }
cudarc = { version = "0.10.0", features = ["f16"] }
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
@ -42,7 +50,7 @@ imageproc = { version = "0.23.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
log = "0.4"
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
num_cpus = "1.15.0"
num-traits = "0.2.15"
parquet = { version = "45.0.0" }
@ -55,7 +63,7 @@ serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"
serde_json = "1.0.99"
thiserror = "1"
tokenizers = { version = "0.13.4", default-features = false }
tokenizers = { version = "0.15.0", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"

View File

@ -11,11 +11,11 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.3" }
candle-nn = { path = "../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../candle-transformers", version = "0.3.3" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
candle = { workspace = true }
candle-datasets = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
candle-flash-attn = { workspace = true, optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }

View File

@ -12,8 +12,8 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.3", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.3", optional = true }
candle-kernels = { workspace = true, optional = true }
candle-metal-kernels = { workspace = true, optional = true }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }

View File

@ -1,5 +1,5 @@
use crate::benchmarks::{bench_name, device, BenchDevice};
use candle_core::{DType, Tensor};
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
@ -7,20 +7,19 @@ fn run(a: &Tensor, b: &Tensor) {
a.matmul(&b.t().unwrap()).unwrap();
}
fn criterion_benchmark(c: &mut Criterion) {
fn run_bench(c: &mut Criterion, device: &Device) {
let b = 1;
let m = 1;
let n = 2048;
let k = 2048;
let device = device().unwrap();
let dtype = DType::F32;
let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap();
let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap();
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
let flops = b * m * n * k;
let mut group = c.benchmark_group(bench_name("matmul"));
let mut group = c.benchmark_group(device.bench_name("matmul"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
@ -35,4 +34,11 @@ fn criterion_benchmark(c: &mut Criterion) {
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_bench(c, &device);
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -5,6 +5,8 @@ use candle_core::{Device, Result};
pub(crate) trait BenchDevice {
fn sync(&self) -> Result<()>;
fn bench_name<S: Into<String>>(&self, name: S) -> String;
}
impl BenchDevice for Device {
@ -25,32 +27,38 @@ impl BenchDevice for Device {
}
}
}
}
pub(crate) fn device() -> Result<Device> {
if cfg!(feature = "metal") {
Device::new_metal(0)
} else if cfg!(feature = "cuda") {
Device::new_cuda(0)
} else {
Ok(Device::Cpu)
fn bench_name<S: Into<String>>(&self, name: S) -> String {
match self {
Device::Cpu => {
let cpu_type = if cfg!(feature = "accelerate") {
"accelerate"
} else if cfg!(feature = "mkl") {
"mkl"
} else {
"cpu"
};
format!("{}_{}", cpu_type, name.into())
}
Device::Cuda(_) => format!("cuda_{}", name.into()),
Device::Metal(_) => format!("metal_{}", name.into()),
}
}
}
pub(crate) fn bench_name<S: Into<String>>(name: S) -> String {
format!("{}_{}", device_variant(), name.into())
struct BenchDeviceHandler {
devices: Vec<Device>,
}
const fn device_variant() -> &'static str {
if cfg!(feature = "metal") {
"metal"
} else if cfg!(feature = "cuda") {
"cuda"
} else if cfg!(feature = "accelerate") {
"accelerate"
} else if cfg!(feature = "mkl") {
"mkl"
} else {
"cpu"
impl BenchDeviceHandler {
pub fn new() -> Result<Self> {
let mut devices = Vec::new();
if cfg!(feature = "metal") {
devices.push(Device::new_metal(0)?);
} else if cfg!(feature = "cuda") {
devices.push(Device::new_cuda(0)?);
}
devices.push(Device::Cpu);
Ok(Self { devices })
}
}

View File

@ -1,4 +1,4 @@
use crate::benchmarks::{bench_name, device, BenchDevice};
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
@ -11,19 +11,18 @@ fn rand_normal(a: &Tensor) {
a.randn_like(100.0, 15.0).unwrap();
}
fn criterion_benchmark(c: &mut Criterion) {
fn run_random_bench(c: &mut Criterion, device: &Device) {
let b = 1;
let rows = 2048;
let cols = 2048;
let d = device().unwrap();
let dtype = DType::F32;
let tensor = Tensor::zeros((b, rows, cols), dtype, &d).unwrap();
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let flops = b * rows * cols * dtype.size_in_bytes();
let mut group = c.benchmark_group(bench_name("random_uniform"));
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
@ -31,16 +30,15 @@ fn criterion_benchmark(c: &mut Criterion) {
for _i in 0..iters {
rand_uniform(black_box(&tensor));
}
d.sync().unwrap();
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
let d = device().unwrap();
let tensor = Tensor::zeros((b, rows, cols), dtype, &d).unwrap();
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let mut group = c.benchmark_group(bench_name("random_normal"));
let mut group = c.benchmark_group(device.bench_name("random_normal"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
@ -48,11 +46,18 @@ fn criterion_benchmark(c: &mut Criterion) {
for _i in 0..iters {
rand_normal(black_box(&tensor));
}
d.sync().unwrap();
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_random_bench(c, &device);
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -102,7 +102,7 @@ enum Command {
},
Quantize {
/// The input file, in gguf format.
/// The input file(s), in safetensors format.
in_file: Vec<std::path::PathBuf>,
/// The output file, in gguf format.
@ -117,6 +117,15 @@ enum Command {
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
mode: QuantizationMode,
},
Dequantize {
/// The input file, in gguf format.
in_file: std::path::PathBuf,
/// The output file, in safetensors format.
#[arg(long)]
out_file: std::path::PathBuf,
},
}
#[derive(Parser, Debug, Clone)]
@ -285,6 +294,19 @@ fn run_quantize_safetensors(
Ok(())
}
fn run_dequantize(in_file: std::path::PathBuf, out_file: std::path::PathBuf) -> Result<()> {
let mut in_file = std::fs::File::open(in_file)?;
let content = gguf_file::Content::read(&mut in_file)?;
let mut tensors = std::collections::HashMap::new();
for (tensor_name, _) in content.tensor_infos.iter() {
let tensor = content.tensor(&mut in_file, tensor_name)?;
let tensor = tensor.dequantize(&Device::Cpu)?;
tensors.insert(tensor_name.to_string(), tensor);
}
candle_core::safetensors::save(&tensors, out_file)?;
Ok(())
}
fn run_quantize(
in_files: &[std::path::PathBuf],
out_file: std::path::PathBuf,
@ -379,6 +401,7 @@ fn main() -> anyhow::Result<()> {
quantization,
mode,
} => run_quantize(&in_file, out_file, quantization, mode)?,
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file)?,
}
Ok(())
}

View File

@ -592,14 +592,26 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::U8) => "cast_u32_u8",
(DType::U32, DType::I64) => "cast_u32_i64",
(DType::U32, DType::BF16) => "cast_u32_bf16",
(DType::U8, DType::U32) => "cast_u8_u32",
(DType::U8, DType::F32) => "cast_u8_f32",
(DType::U8, DType::I64) => "cast_u8_i64",
(DType::U8, DType::BF16) => "cast_u8_bf16",
(DType::F32, DType::F16) => "cast_f32_f16",
(DType::F16, DType::F32) => "cast_f16_f32",
(DType::I64, DType::F32) => "cast_i64_f32",
(DType::F32, DType::BF16) => "cast_f32_bf16",
(DType::I64, DType::F32) => "cast_i64_f32",
(DType::F16, DType::BF16) => "cast_f16_bf16",
(DType::F16, DType::F32) => "cast_f16_f32",
(DType::BF16, DType::U8) => "cast_bf16_u8",
(DType::BF16, DType::U32) => "cast_bf16_u32",
(DType::BF16, DType::F16) => "cast_bf16_f16",
(DType::BF16, DType::F32) => "cast_bf16_f32",
(left, right) => {
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
}
@ -677,6 +689,7 @@ impl BackendStorage for MetalStorage {
("uround", DType::F32) => contiguous::round::FLOAT,
("urecip", DType::F32) => contiguous::recip::FLOAT,
("utanh", DType::F32) => contiguous::tanh::FLOAT,
("urelu", DType::F32) => contiguous::relu::FLOAT,
("ucos", DType::F16) => contiguous::cos::HALF,
("usin", DType::F16) => contiguous::sin::HALF,
("usqr", DType::F16) => contiguous::sqr::HALF,
@ -693,6 +706,7 @@ impl BackendStorage for MetalStorage {
("uround", DType::F16) => contiguous::round::HALF,
("urecip", DType::F16) => contiguous::recip::HALF,
("utanh", DType::F16) => contiguous::tanh::HALF,
("urelu", DType::F16) => contiguous::relu::HALF,
(name, dtype) => {
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
}
@ -723,6 +737,7 @@ impl BackendStorage for MetalStorage {
("uabs", DType::F32) => strided::abs::FLOAT,
("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT,
("urelu", DType::F32) => strided::relu::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF,
@ -737,6 +752,7 @@ impl BackendStorage for MetalStorage {
("uabs", DType::F16) => strided::abs::HALF,
("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF,
("urelu", DType::F16) => strided::relu::HALF,
("uround", DType::F16) => strided::round::HALF,
(name, dtype) => {
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
@ -1129,8 +1145,12 @@ impl BackendStorage for MetalStorage {
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::BF16) => "is_u8_bf16",
(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
(DType::U32, DType::BF16) => "is_u32_bf16",
(left, right) => {
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
}
@ -1320,6 +1340,7 @@ impl MetalStorage {
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
("add", DType::F16) => (contiguous::add::HALF, self.dtype),
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
@ -1330,6 +1351,18 @@ impl MetalStorage {
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype),
("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype),
("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype),
("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype),
("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8),
("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8),
("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8),
("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8),
("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8),
("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8),
("add", DType::I64) => (contiguous::add::I64, self.dtype),
("sub", DType::I64) => (contiguous::sub::I64, self.dtype),
("mul", DType::I64) => (contiguous::mul::I64, self.dtype),
@ -1340,6 +1373,7 @@ impl MetalStorage {
("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
("add", DType::U32) => (contiguous::add::U32, self.dtype),
("sub", DType::U32) => (contiguous::sub::U32, self.dtype),
("mul", DType::U32) => (contiguous::mul::U32, self.dtype),
@ -1350,6 +1384,7 @@ impl MetalStorage {
("lt", DType::U32) => (contiguous::lt::U32, DType::U8),
("ge", DType::U32) => (contiguous::ge::U32, DType::U8),
("gt", DType::U32) => (contiguous::gt::U32, DType::U8),
("add", DType::U8) => (contiguous::add::U8, self.dtype),
("sub", DType::U8) => (contiguous::sub::U8, self.dtype),
("mul", DType::U8) => (contiguous::mul::U8, self.dtype),
@ -1360,6 +1395,7 @@ impl MetalStorage {
("lt", DType::U8) => (contiguous::lt::U8, DType::U8),
("ge", DType::U8) => (contiguous::ge::U8, DType::U8),
("gt", DType::U8) => (contiguous::gt::U8, DType::U8),
(name, dtype) => {
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
}
@ -1393,6 +1429,7 @@ impl MetalStorage {
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
("badd", DType::F16) => (strided::add::HALF, self.dtype),
("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
@ -1405,6 +1442,20 @@ impl MetalStorage {
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype),
("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype),
("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype),
("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype),
("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype),
("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype),
("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8),
("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8),
("le", DType::BF16) => (strided::le::BFLOAT, DType::U8),
("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8),
("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8),
("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8),
("badd", DType::I64) => (strided::add::I64, self.dtype),
("bsub", DType::I64) => (strided::sub::I64, self.dtype),
("bmul", DType::I64) => (strided::mul::I64, self.dtype),
@ -1417,6 +1468,7 @@ impl MetalStorage {
("lt", DType::I64) => (strided::lt::I64, DType::U8),
("ge", DType::I64) => (strided::ge::I64, DType::U8),
("gt", DType::I64) => (strided::gt::I64, DType::U8),
("badd", DType::U32) => (strided::add::U32, self.dtype),
("bsub", DType::U32) => (strided::sub::U32, self.dtype),
("bmul", DType::U32) => (strided::mul::U32, self.dtype),
@ -1429,6 +1481,7 @@ impl MetalStorage {
("lt", DType::U32) => (strided::lt::U32, DType::U8),
("ge", DType::U32) => (strided::ge::U32, DType::U8),
("gt", DType::U32) => (strided::gt::U32, DType::U8),
("badd", DType::U8) => (strided::add::U8, self.dtype),
("bsub", DType::U8) => (strided::sub::U8, self.dtype),
("bmul", DType::U8) => (strided::mul::U8, self.dtype),
@ -1441,6 +1494,7 @@ impl MetalStorage {
("lt", DType::U8) => (strided::lt::U8, DType::U8),
("ge", DType::U8) => (strided::ge::U8, DType::U8),
("gt", DType::U8) => (strided::gt::U8, DType::U8),
(name, dtype) => {
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
}

View File

@ -703,6 +703,7 @@ impl PthTensors {
}
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
use std::io::Read;
let tensor_info = match self.tensor_infos.get(name) {
None => return Ok(None),
Some(tensor_info) => tensor_info,
@ -712,14 +713,21 @@ impl PthTensors {
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_name(&tensor_info.path)?;
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
// For now only support the basic case.
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
// Reading the data is a bit tricky as it can be strided, for now only support the basic
// case.
if !tensor_info.layout.is_contiguous() {
crate::bail!(
"cannot retrieve non-contiguous tensors {:?}",
tensor_info.layout
)
}
let start_offset = tensor_info.layout.start_offset();
if start_offset > 0 {
std::io::copy(
&mut reader.by_ref().take(start_offset as u64),
&mut std::io::sink(),
)?;
}
let tensor = Tensor::from_reader(
tensor_info.layout.shape().clone(),
tensor_info.dtype,

View File

@ -1545,13 +1545,13 @@ impl GgmlType for BlockQ5K {
let d2 = d * sc as f32;
let m2 = min * m as f32;
for (ql, qh) in ql.iter().zip(qh) {
let to_add = if qh & u1 != 0 { 16 } else { 1 };
y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1;
let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 };
y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1;
ys_index += 1;
}
for (ql, qh) in ql.iter().zip(qh) {
let to_add = if qh & u2 != 0 { 16 } else { 1 };
y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2;
let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 };
y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2;
ys_index += 1;
}
is += 2;

View File

@ -12,6 +12,14 @@ use core::arch::arm::*;
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
#[inline(always)]
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
// TODO: dotprod
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
}
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
@ -43,15 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
let v1_0l = vld1q_s8(y0.qs.as_ptr());
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
// TODO: Support dotprod when it's available outside of nightly.
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
let pl0 = vdotq_s32(v0_0ls, v1_0l);
let ph0 = vdotq_s32(v0_0hs, v1_0h);
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
@ -82,14 +83,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
let y0_0 = vld1q_s8(y0.qs.as_ptr());
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
// TODO dotprod once this is the intrinsics are.
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
let p0 = vdotq_s32(x0_0, y0_0);
let p1 = vdotq_s32(x0_1, y0_1);
sumv0 = vmlaq_n_f32(
sumv0,
@ -118,10 +113,7 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res
for i in (0..QK_K).step_by(16) {
let xs = vld1q_s8(xs.add(i));
let ys = vld1q_s8(ys.add(i));
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
let xy = vdotq_s32(xs, ys);
sum_i = vaddq_s32(sum_i, xy)
}
sumf += vaddvq_s32(sum_i) as f32 * scale
@ -191,30 +183,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
// TODO: dotprod
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
);
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
);
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2);
let q8bytes = vld1q_s8_x4(q8);
@ -234,29 +212,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
// TODO: dotprod case.
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
);
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
);
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2);
}
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
@ -333,28 +298,14 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
// TODO: dotprod
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
);
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
scales = scales.add(1);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
);
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
scales = scales.add(1);
}
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
@ -417,22 +368,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
for j in 0..QK_K / 64 {
let q4bits = vld1q_u8_x2(q4);
q4 = q4.add(32);
// TODO: dotprod
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
let q4bytes = int8x16x2_t(
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
);
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
);
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
@ -440,15 +384,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
);
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
}
sumf += d * (sumi1 + sumi2) as f32;
}
@ -526,27 +464,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(q3h_3),
);
// TODO: dotprod
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
);
isum += vaddvq_s16(p0) as i32 * *scale as i32
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
isum += vaddvq_s32(p0) * *scale as i32
+ vaddvq_s32(p1) * *scale.add(1) as i32
+ vaddvq_s32(p2) * *scale.add(2) as i32
+ vaddvq_s32(p3) * *scale.add(3) as i32;
scale = scale.add(4);
let q3h_0 = vbicq_u8(m2, qhbits.0);
@ -571,27 +496,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(q3h_3),
);
// TODO: dotprod
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
);
isum += vaddvq_s16(p0) as i32 * *scale as i32
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
isum += vaddvq_s32(p0) * *scale as i32
+ vaddvq_s32(p1) * *scale.add(1) as i32
+ vaddvq_s32(p2) * *scale.add(2) as i32
+ vaddvq_s32(p3) * *scale.add(3) as i32;
scale = scale.add(4);
if j == 0 {
@ -649,7 +561,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
let mut is = 0usize;
// TODO: dotprod
for _j in 0..QK_K / 128 {
let q2bits = vld1q_u8_x2(q2);
q2 = q2.add(32);
@ -696,14 +607,7 @@ unsafe fn multiply_accum_with_scale(
q2bytes: int8x16x2_t,
q8bytes: int8x16x2_t,
) -> i32 {
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
);
vaddvq_s16(p1) as i32 * aux[is + index] as i32
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
}

View File

@ -1,4 +1,5 @@
use candle_core::{
bail,
quantized::{self, GgmlDType},
test_utils::to_vec2_round,
Device, Module, Result, Tensor,
@ -265,7 +266,8 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
}
}
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
/// Creates a vector similar to the ones used in GGML unit tests:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
(0..GGML_TEST_SIZE)
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
@ -284,14 +286,15 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
sum / a.len() as f32
}
/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
/// Similar to the GGML quantization unit test:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
let src = create_ggml_like_vector(0.0);
let mut dst = vec![0.0; GGML_TEST_SIZE];
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
let error = calculate_rmse(src.as_slice(), dst.as_slice());
if error > max_error {
candle_core::bail!(
bail!(
"Quantization error {} exceeds max error {}",
error,
max_error
@ -404,7 +407,7 @@ fn quantize_q5k() -> Result<()> {
let dst = round_vector(&dst);
assert_eq!(
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
[-0.499, -0.372, -0.249, 0.001, 0.279, 0.499]
[-0.5, -0.373, -0.25, 0.0, 0.279, 0.499]
);
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
@ -487,54 +490,66 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
GgmlDType::Q5K => 0.000740,
GgmlDType::Q6K => 0.000952,
GgmlDType::Q4_0 => 0.001143,
GgmlDType::Q4_1 => 0.007784,
GgmlDType::Q4_1 => 0.008,
GgmlDType::Q5_0 => 0.001353,
GgmlDType::Q5_1 => 0.001363,
GgmlDType::Q5_1 => 0.00149,
GgmlDType::Q8_0 => 0.000092,
// Not from the ggml repo.
GgmlDType::Q8K => 0.00065,
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
_ => bail!("No GGML results for quantization type {dtype:?}",),
};
Ok(err)
}
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
/// Similar to the GGML matmul unit test:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
let a = create_ggml_like_vector(0.0);
let b = create_ggml_like_vector(1.0);
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 1.0)?;
// Another example that is more likely to trigger the overflow reported in #1526
let a = (0..GGML_TEST_SIZE)
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
.collect::<Vec<_>>();
let b = (0..GGML_TEST_SIZE)
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
.collect::<Vec<_>>();
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 2.0)?;
Ok(())
}
fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Result<()> {
let length = a.len();
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
T::from_float(&a, &mut a_quant)?;
T::VecDotType::from_float(&b, &mut b_quant)?;
T::from_float(a, &mut a_quant)?;
T::VecDotType::from_float(b, &mut b_quant)?;
let result = T::vec_dot(length, &a_quant, &b_quant)?;
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
let reference_result = vec_dot_reference(&a, &b);
let reference_result = vec_dot_reference(a, b);
if (result - result_unopt).abs() / length as f32 > 1e-6 {
candle_core::bail!(
bail!(
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
)
}
let error = (result - reference_result).abs() / length as f32;
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m;
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
candle_core::bail!(
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
);
bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",);
}
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
// => we use a slightly higher error threshold
const ERROR_LENIENCY: f32 = 0.00001;
if error - ERROR_LENIENCY > ggml_error {
candle_core::bail!(
bail!(
"Dot product error {} exceeds ggml reference error {}",
error,
ggml_error
@ -543,6 +558,16 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
Ok(())
}
#[test]
fn quantized_mm() -> Result<()> {
ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
ggml_matmul_error_test::<k_quants::BlockQ4_1>()?;
ggml_matmul_error_test::<k_quants::BlockQ5_0>()?;
ggml_matmul_error_test::<k_quants::BlockQ5_1>()?;
ggml_matmul_error_test::<k_quants::BlockQ8_0>()?;
Ok(())
}
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
fn get_random_tensors(
m: usize,

View File

@ -11,8 +11,8 @@ readme = "README.md"
[dependencies]
byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.3" }
candle = { workspace = true }
candle-nn = { workspace = true }
hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true }

View File

@ -11,12 +11,12 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.3" }
candle-nn = { path = "../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../candle-transformers", version = "0.3.3" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
candle-onnx = { path = "../candle-onnx", version = "0.3.3", optional = true }
candle = { workspace = true }
candle-datasets = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
candle-flash-attn = { workspace = true, optional = true }
candle-onnx = { workspace = true, optional = true }
csv = "1.3.0"
cudarc = { workspace = true, optional = true }
@ -49,11 +49,12 @@ tokio = "1.29.1"
[build-dependencies]
anyhow = { workspace = true }
bindgen_cuda = { version = "0.1.1", optional = true }
[features]
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

View File

@ -4,251 +4,34 @@ use std::io::Write;
use std::path::PathBuf;
struct KernelDirectories {
kernel_dir: &'static str,
kernel_glob: &'static str,
rust_target: &'static str,
include_dirs: &'static [&'static str],
}
const DIRS: [KernelDirectories; 1] = [KernelDirectories {
kernel_dir: "examples/custom-ops/kernels/",
const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories {
kernel_glob: "examples/custom-ops/kernels/*.cu",
rust_target: "examples/custom-ops/cuda_kernels.rs",
include_dirs: &[],
}];
impl KernelDirectories {
fn maybe_build_ptx(
&self,
cu_file: &std::path::Path,
ptx_file: &std::path::Path,
compute_cap: usize,
) -> Result<()> {
let should_compile = if ptx_file.exists() {
let ptx_modified = ptx_file.metadata()?.modified()?;
let cu_modified = cu_file.metadata()?.modified()?;
cu_modified.duration_since(ptx_modified).is_ok()
} else {
true
};
if should_compile {
#[cfg(feature = "cuda")]
{
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let mut command = std::process::Command::new("nvcc");
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
let include_dirs: Vec<String> =
self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
command
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--ptx")
.args(["--default-stream", "per-thread"])
.args(["--output-directory", out_dir.to_str().unwrap()])
.arg(format!("-I/{}", self.kernel_dir))
.args(include_dirs)
.arg(cu_file);
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
}
#[cfg(not(feature = "cuda"))]
std::fs::OpenOptions::new()
.create(true)
.write(true)
.open(ptx_file)?;
}
Ok(())
}
fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> {
println!("cargo:rerun-if-changed={}", self.kernel_dir);
let kernel_dir = PathBuf::from(self.kernel_dir);
let out_dir = out_dir.join(self.kernel_dir);
if !out_dir.exists() {
std::fs::create_dir_all(&out_dir)?;
}
let mut cu_files = vec![];
let mut cuh_files = vec![];
for file in std::fs::read_dir(kernel_dir)?.flatten() {
let file = file.path();
match file.extension().and_then(|v| v.to_str()) {
Some("cu") => cu_files.push(file),
Some("cuh") => cuh_files.push(file),
_ => {}
}
}
let mut ptx_paths = vec![];
for cu_file in cu_files.iter() {
let file_stem = cu_file
.file_stem()
.with_context(|| format!("no stem {cu_file:?}"))?;
let file_stem = file_stem.to_string_lossy().into_owned();
let ptx_file = out_dir.join(&format!("{file_stem}.ptx"));
self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?;
ptx_paths.push(ptx_file);
}
let regenerate_rs_file = true;
if regenerate_rs_file {
let mut file = std::fs::File::create(self.rust_target)?;
for ptx_path in ptx_paths {
let name = ptx_path
.file_stem()
.context("empty stem")?
.to_string_lossy();
file.write_all(b"#[rustfmt::skip]\n")?;
let const_definition = format!(
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
name.to_uppercase().replace('.', "_"),
self.kernel_dir,
);
file.write_all(const_definition.as_bytes())?;
file.write_all(b"\n")?;
}
}
Ok(())
}
}
fn main() -> Result<()> {
println!("cargo:rerun-if-changed=build.rs");
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
let out_dir = PathBuf::from(out_dir);
#[cfg(feature = "cuda")]
set_cuda_include_dir()?;
#[cfg(feature = "cuda")]
let compute_cap = compute_cap()?;
{
for kdir in KERNEL_DIRS.iter() {
let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob);
println!("cargo:info={builder:?}");
let bindings = builder.build_ptx().unwrap();
bindings.write(kdir.rust_target).unwrap()
}
}
#[cfg(not(feature = "cuda"))]
let compute_cap = 0;
for d in DIRS {
d.process(&out_dir, compute_cap)?
{
for kdir in KERNEL_DIRS.iter() {
let _file = std::fs::File::create(kdir.rust_target)?;
}
}
Ok(())
}
fn set_cuda_include_dir() -> Result<()> {
// NOTE: copied from cudarc build.rs.
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
let roots = roots.into_iter().map(Into::<PathBuf>::into);
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.context("cannot find include/cuda.h")?;
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
Ok(())
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute cap from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse code")?
} else {
// Grab compute cap from nvidia-smi
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?
};
// Grab available GPU codes from nvcc and select the highest one
let max_nvcc_code = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
if !codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
);
}
*codes.last().unwrap()
};
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
// then choose the highest gpu code in nvcc
if compute_cap > max_nvcc_code {
println!(
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
);
compute_cap = max_nvcc_code;
}
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
compute_cap = compute_cap_str
.parse::<usize>()
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
}
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
Ok(compute_cap)
}

View File

@ -1,2 +1 @@
#[rustfmt::skip]
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));

View File

@ -6,7 +6,8 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[allow(unused)]
#[rustfmt::skip]
#[cfg(feature = "cuda")]
mod cuda_kernels;
use clap::Parser;

View File

@ -0,0 +1,20 @@
# candle-repvgg
A candle implementation of inference using a pre-trained [repvgg](https://arxiv.org/abs/2101.03697).
This uses a classification head trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example repvgg --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 61.70%
bicycle-built-for-two, tandem bicycle, tandem: 33.14%
unicycle, monocycle : 4.88%
crash helmet : 0.15%
moped : 0.04%
```

View File

@ -0,0 +1,111 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::repvgg;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
A0,
A1,
A2,
B0,
B1,
B2,
B3,
B1G4,
B2G4,
B3G4,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::A0 => "a0",
Self::A1 => "a1",
Self::A2 => "a2",
Self::B0 => "b0",
Self::B1 => "b1",
Self::B2 => "b2",
Self::B3 => "b3",
Self::B1G4 => "b1g4",
Self::B2G4 => "b2g4",
Self::B3G4 => "b3g4",
};
format!("timm/repvgg_{}.rvgg_in1k", name)
}
fn config(&self) -> repvgg::Config {
match self {
Self::A0 => repvgg::Config::a0(),
Self::A1 => repvgg::Config::a1(),
Self::A2 => repvgg::Config::a2(),
Self::B0 => repvgg::Config::b0(),
Self::B1 => repvgg::Config::b1(),
Self::B2 => repvgg::Config::b2(),
Self::B3 => repvgg::Config::b3(),
Self::B1G4 => repvgg::Config::b1g4(),
Self::B2G4 => repvgg::Config::b2g4(),
Self::B3G4 => repvgg::Config::b3g4(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::A0)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = repvgg::repvgg(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -11,14 +11,14 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.3", package = "candle-core" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
bindgen_cuda = "0.1.1"
anyhow = { version = "1", features = ["backtrace"] }
num_cpus = "1.15.0"
rayon = "1.7.0"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { path = "../candle-nn", version = "0.3.3", features = ["cuda"] }
candle-nn = { path = "../candle-nn", features = ["cuda"] }

View File

@ -2,44 +2,32 @@
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
// variable in order to cache the compiled artifacts and avoid recompiling too often.
use anyhow::{Context, Result};
use rayon::prelude::*;
use std::path::PathBuf;
use std::str::FromStr;
const KERNEL_FILES: [&str; 17] = [
"flash_api.cu",
"flash_fwd_hdim128_fp16_sm80.cu",
"flash_fwd_hdim160_fp16_sm80.cu",
"flash_fwd_hdim192_fp16_sm80.cu",
"flash_fwd_hdim224_fp16_sm80.cu",
"flash_fwd_hdim256_fp16_sm80.cu",
"flash_fwd_hdim32_fp16_sm80.cu",
"flash_fwd_hdim64_fp16_sm80.cu",
"flash_fwd_hdim96_fp16_sm80.cu",
"flash_fwd_hdim128_bf16_sm80.cu",
"flash_fwd_hdim160_bf16_sm80.cu",
"flash_fwd_hdim192_bf16_sm80.cu",
"flash_fwd_hdim224_bf16_sm80.cu",
"flash_fwd_hdim256_bf16_sm80.cu",
"flash_fwd_hdim32_bf16_sm80.cu",
"flash_fwd_hdim64_bf16_sm80.cu",
"flash_fwd_hdim96_bf16_sm80.cu",
"kernels/flash_api.cu",
"kernels/flash_fwd_hdim128_fp16_sm80.cu",
"kernels/flash_fwd_hdim160_fp16_sm80.cu",
"kernels/flash_fwd_hdim192_fp16_sm80.cu",
"kernels/flash_fwd_hdim224_fp16_sm80.cu",
"kernels/flash_fwd_hdim256_fp16_sm80.cu",
"kernels/flash_fwd_hdim32_fp16_sm80.cu",
"kernels/flash_fwd_hdim64_fp16_sm80.cu",
"kernels/flash_fwd_hdim96_fp16_sm80.cu",
"kernels/flash_fwd_hdim128_bf16_sm80.cu",
"kernels/flash_fwd_hdim160_bf16_sm80.cu",
"kernels/flash_fwd_hdim192_bf16_sm80.cu",
"kernels/flash_fwd_hdim224_bf16_sm80.cu",
"kernels/flash_fwd_hdim256_bf16_sm80.cu",
"kernels/flash_fwd_hdim32_bf16_sm80.cu",
"kernels/flash_fwd_hdim64_bf16_sm80.cu",
"kernels/flash_fwd_hdim96_bf16_sm80.cu",
];
fn main() -> Result<()> {
let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
|_| num_cpus::get_physical(),
|s| usize::from_str(&s).unwrap(),
);
rayon::ThreadPoolBuilder::new()
.num_threads(num_cpus)
.build_global()
.unwrap();
println!("cargo:rerun-if-changed=build.rs");
for kernel_file in KERNEL_FILES.iter() {
println!("cargo:rerun-if-changed=kernels/{kernel_file}");
println!("cargo:rerun-if-changed={kernel_file}");
}
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
@ -66,223 +54,30 @@ fn main() -> Result<()> {
))
}
};
set_cuda_include_dir()?;
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let compute_cap = compute_cap()?;
let kernels = KERNEL_FILES.iter().collect();
let builder = bindgen_cuda::Builder::default()
.kernel_paths(kernels)
.out_dir(build_dir.clone())
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--verbose");
let out_file = build_dir.join("libflashattention.a");
builder.build_lib(out_file);
let kernel_dir = PathBuf::from("kernels");
let cu_files: Vec<_> = KERNEL_FILES
.iter()
.map(|f| {
let mut obj_file = out_dir.join(f);
obj_file.set_extension("o");
(kernel_dir.join(f), obj_file)
})
.collect();
let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
let should_compile = if out_file.exists() {
kernel_dir
.read_dir()
.expect("kernels folder should exist")
.any(|entry| {
if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
let in_modified = entry.metadata().unwrap().modified().unwrap();
in_modified.duration_since(*out_modified).is_ok()
} else {
true
}
})
} else {
true
};
if should_compile {
cu_files
.par_iter()
.map(|(cu_file, obj_file)| {
let mut command = std::process::Command::new("nvcc");
command
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("-c")
.args(["-o", obj_file.to_str().unwrap()])
.args(["--default-stream", "per-thread"])
.arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--verbose");
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
command.arg(cu_file);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
&command,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
Ok(())
})
.collect::<Result<()>>()?;
let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
let mut command = std::process::Command::new("nvcc");
command
.arg("--lib")
.args(["-o", out_file.to_str().unwrap()])
.args(obj_files);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
&command,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
}
println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=flashattention");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=stdc++");
/* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
finishing to run for some reason. Calling nvcc manually worked fine.
cc::Build::new()
.cuda(true)
.include("cutlass/include")
.flag("--expt-relaxed-constexpr")
.flag("--default-stream")
.flag("per-thread")
.flag(&format!("--gpu-architecture=sm_{compute_cap}"))
.file("kernels/flash_fwd_hdim32_fp16_sm80.cu")
.compile("flashattn");
*/
Ok(())
}
fn set_cuda_include_dir() -> Result<()> {
// NOTE: copied from cudarc build.rs.
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
let roots = roots.into_iter().map(Into::<PathBuf>::into);
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.context("cannot find include/cuda.h")?;
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
Ok(())
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute caps from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse compute cap")?
} else {
// Use nvidia-smi to get the current compute cap
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
let cap = cap
.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?;
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap
};
// Grab available GPU codes from nvcc and select the highest one
let (supported_nvcc_codes, max_nvcc_code) = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
(codes, max_nvcc_code)
};
// Check that nvcc supports the asked compute caps
if !supported_nvcc_codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
);
}
if compute_cap > max_nvcc_code {
anyhow::bail!(
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
);
}
Ok(compute_cap)
}

View File

@ -12,6 +12,4 @@ license = "MIT OR Apache-2.0"
[dependencies]
[build-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
glob = "0.3.1"
rayon = "1.7.0"
bindgen_cuda = "0.1.1"

View File

@ -1,243 +1,8 @@
use std::io::Write;
fn main() {
println!("cargo:rerun-if-changed=build.rs");
cuda::set_include_dir();
let (write, kernel_paths) = cuda::build_ptx();
if write {
let mut file = std::fs::File::create("src/lib.rs").unwrap();
for kernel_path in kernel_paths {
let name = kernel_path.file_stem().unwrap().to_str().unwrap();
file.write_all(
format!(
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#,
name.to_uppercase().replace('.', "_"),
name
)
.as_bytes(),
)
.unwrap();
file.write_all(&[b'\n']).unwrap();
}
}
}
mod cuda {
use anyhow::{Context, Result};
pub fn set_include_dir() {
use std::path::PathBuf;
// NOTE: copied from cudarc build.rs.
// We can't actually set a env!() value from another crate,
// so we have to do that here.
// use PathBuf;
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
#[allow(unused)]
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
#[allow(unused)]
let roots = roots.into_iter().map(Into::<PathBuf>::into);
#[cfg(feature = "ci-check")]
let root: PathBuf = "ci".into();
#[cfg(not(feature = "ci-check"))]
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.unwrap();
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
}
pub fn build_ptx() -> (bool, Vec<std::path::PathBuf>) {
use rayon::prelude::*;
use std::path::PathBuf;
let out_dir = std::env::var("OUT_DIR").unwrap();
let kernel_paths: Vec<PathBuf> = glob::glob("src/*.cu")
.unwrap()
.map(|p| p.unwrap())
.collect();
let mut include_directories: Vec<PathBuf> = glob::glob("src/**/*.cuh")
.unwrap()
.map(|p| p.unwrap())
.collect();
println!("cargo:rerun-if-changed=src/");
// for path in &kernel_paths {
// println!("cargo:rerun-if-changed={}", path.display());
// }
for path in &mut include_directories {
// println!("cargo:rerun-if-changed={}", path.display());
let destination =
std::format!("{out_dir}/{}", path.file_name().unwrap().to_str().unwrap());
std::fs::copy(path.clone(), destination).unwrap();
// remove the filename from the path so it's just the directory
path.pop();
}
include_directories.sort();
include_directories.dedup();
let compute_cap = compute_cap().expect("Could not get Cuda compute cap");
#[allow(unused)]
let include_options: Vec<String> = include_directories
.into_iter()
.map(|s| "-I".to_string() + &s.into_os_string().into_string().unwrap())
.collect::<Vec<_>>();
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let children = kernel_paths
.par_iter()
.flat_map(|p| {
let mut output = p.clone();
output.set_extension("ptx");
let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap());
let ignore = if output_filename.exists() {
let out_modified = output_filename.metadata().unwrap().modified().unwrap();
let in_modified = p.metadata().unwrap().modified().unwrap();
out_modified.duration_since(in_modified).is_ok()
} else {
false
};
if ignore {
None
} else {
let mut command = std::process::Command::new("nvcc");
command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--ptx")
.args(["--default-stream", "per-thread"])
.args(["--output-directory", &out_dir])
// Flash attention only
// .arg("--expt-relaxed-constexpr")
.args(&include_options);
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
command.arg(p);
Some((p, command.spawn()
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
}
})
.collect::<Vec<_>>();
let ptx_paths: Vec<PathBuf> = glob::glob(&format!("{out_dir}/**/*.ptx"))
.unwrap()
.map(|p| p.unwrap())
.collect();
// We should rewrite `src/lib.rs` only if there are some newly compiled kernels, or removed
// some old ones
let write = !children.is_empty() || kernel_paths.len() < ptx_paths.len();
for (kernel_path, child) in children {
let output = child.expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
assert!(
output.status.success(),
"nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
}
(write, kernel_paths)
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute caps from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse code")?
} else {
// Use nvidia-smi to get the current compute cap
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
let cap = cap
.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?;
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap
};
// Grab available GPU codes from nvcc and select the highest one
let (supported_nvcc_codes, max_nvcc_code) = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
(codes, max_nvcc_code)
};
// Check that nvcc supports the asked compute caps
if !supported_nvcc_codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
);
}
if compute_cap > max_nvcc_code {
anyhow::bail!(
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
);
}
Ok(compute_cap)
}
let builder = bindgen_cuda::Builder::default();
println!("cargo:info={builder:?}");
let bindings = builder.build_ptx().unwrap();
bindings.write("src/lib.rs").unwrap();
}

View File

@ -9,12 +9,17 @@ keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
metal = { version = "0.27.0", features = ["mps"]}
metal = { version = "0.27.0", features = ["mps"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"
[dev-dependencies]
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
half = { version = "2.3.1", features = [
"num-traits",
"use-intrinsics",
"rand_distr",
] }
rand = "0.8.5"

View File

@ -117,7 +117,7 @@ ELU(elu_f32, float)
ELU(elu_f16, half)
#if __METAL_VERSION__ >= 310
#if defined(__HAVE_BFLOAT__)
AFFINE(affine_bf16, bfloat);
POWF(powf_bf16, bfloat);
ELU(elu_bf16, bfloat);

View File

@ -105,7 +105,7 @@ INT64_BINARY_OP_OUT(ge, x >= y)
INT64_BINARY_OP_OUT(gt, x > y)
#endif
#if __METAL_VERSION__ >= 310
#if defined(__HAVE_BFLOAT__)
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul)

View File

@ -28,7 +28,7 @@ kernel void FN_NAME( \
if (tid >= dim) { \
return; \
} \
output[tid] = RIGHT_TYPENAME(input[tid]); \
output[tid] = static_cast<RIGHT_TYPENAME>(input[tid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -42,7 +42,34 @@ kernel void FN_NAME_STRIDED( \
if (tid >= dim) { \
return; \
} \
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
output[tid] = static_cast<RIGHT_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)]); \
} \
#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[tid])); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
} \
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
@ -58,7 +85,14 @@ CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
#endif
#if __METAL_VERSION__ >= 310
#if defined(__HAVE_BFLOAT__)
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
#endif

View File

@ -173,7 +173,10 @@ SCATTER_ADD_OP(sa_u32_f32, uint, float)
SCATTER_ADD_OP(sa_u32_f16, uint, half)
#if __METAL_VERSION__ >= 310
#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)

View File

@ -178,8 +178,8 @@ macro_rules! ops{
pub mod unary {
ops!(
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
recip
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
tanh, recip
);
}
pub mod binary {

View File

@ -295,7 +295,7 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
#endif
#if __METAL_VERSION__ >= 310
#if defined(__HAVE_BFLOAT__)
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)

View File

@ -1,6 +1,6 @@
use super::*;
use half::{bf16, f16};
use metal::{Device, MTLResourceOptions};
use metal::{Buffer, Device, MTLResourceOptions};
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
@ -248,6 +248,34 @@ fn binary_add_f32() {
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
}
#[test]
fn binary_ops_bf16() {
let lhs: Vec<bf16> = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect();
let rhs: Vec<bf16> = [4.2f32, 5.5f32, 6.91f32]
.into_iter()
.map(bf16::from_f32)
.collect();
macro_rules! binary_op {
($opname:ident, $opexpr:expr) => {{
let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT);
let expected: Vec<bf16> = lhs
.iter()
.zip(rhs.iter())
.map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y))
.collect();
assert_eq!(results, expected);
}};
}
binary_op!(add, |x, y| x + y);
binary_op!(sub, |x, y| x - y);
binary_op!(mul, |x, y| x * y);
binary_op!(div, |x, y| x / y);
binary_op!(min, |x: bf16, y| x.min(y));
binary_op!(max, |x: bf16, y| x.max(y));
}
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
let fence = device.new_fence();
@ -296,6 +324,89 @@ fn cast_u32_f32() {
assert_eq!(results, vec![1.0f32; 10_000]);
}
#[test]
fn it_cast_bf16_u32() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
let output: Vec<u32> = cast(&input, "cast_bf16_u32");
let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_bf16_f32() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
let output: Vec<f32> = cast(&input, "cast_bf16_f32");
let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_u8_bf16() {
let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
let expected: Vec<bf16> = input
.iter()
.map(|v| bf16::from_f32(*v as f32))
.collect::<Vec<_>>();
assert_eq!(output, expected);
}
#[test]
fn it_cast_u32_bf16() {
let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_f32_bf16() {
let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_bf16_u8() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
let output: Vec<u8> = cast(&input, "cast_bf16_u8");
let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_bf16_f16() {
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
let output: Vec<f16> = cast(&input, "cast_bf16_f16");
let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
assert_eq!(output, expected);
}
#[test]
fn it_cast_f16_bf16() {
let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
assert_eq!(output, expected);
}
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
let fence = device.new_fence();
@ -396,14 +507,14 @@ fn index_select() {
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim);
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [2, 5];
let ids = [0u32, 1, 0];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim);
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
assert_eq!(
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
@ -419,20 +530,46 @@ fn index_select_f16() {
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim);
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16");
assert_eq!(
approx_f16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
);
}
#[test]
fn index_select_is_u32_bf16() {
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16");
assert_eq!(
approx_bf16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
);
}
#[test]
fn index_select_is_u8_bf16() {
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
let shape = [5, 2];
let ids = [0u8, 4, 2];
let dim = 0;
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16");
assert_eq!(
approx_bf16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
);
}
#[test]
fn index_select_dim1() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 1, 0];
let dim = 1;
let result = run_index_select(&embedding, &shape, &ids, dim);
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
assert_eq!(
result,
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
@ -444,6 +581,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
shape: &[usize],
ids: &[I],
dim: usize,
name: &'static str,
) -> Vec<T> {
let device = Device::system_default().expect("no device found");
@ -457,12 +595,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let dst_el = ids.len() * left_size * right_size;
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
let name = match core::mem::size_of::<T>() {
4 => "is_u32_f32",
2 => "is_u32_f16",
_ => unimplemented!(),
};
let fence = device.new_fence();
let kernels = Kernels::new(fence);
call_index_select(

View File

@ -58,6 +58,12 @@ template <typename T> METAL_FUNC T gelu(T x) {
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
}
template <typename T> METAL_FUNC T relu(T in){
if (in < 0) {
return 0;
}
return in;
}
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
@ -110,6 +116,7 @@ UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY_OP(tanh)
UNARY_OP(recip)
UNARY_OP(relu)
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
@ -120,7 +127,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided)
UNARY(id, int64_t, copy_i64, copy_i64_strided)
#endif
#if __METAL_VERSION__ >= 310
#if defined(__HAVE_BFLOAT__)
BFLOAT_UNARY_OP(cos)
BFLOAT_UNARY_OP(sin)
BFLOAT_UNARY_OP(sqr)
@ -136,6 +143,7 @@ BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip)
BFLOAT_UNARY_OP(relu)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
#endif

View File

@ -11,7 +11,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
candle = { workspace = true }
half = { workspace = true }
thiserror = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
@ -20,7 +20,7 @@ rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
metal = { workspace = true, optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
candle-metal-kernels = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }

View File

@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.3" }
candle = { path = "../candle-core", package = "candle-core" }
candle-nn = { path = "../candle-nn" }
prost = "0.12.1"
[build-dependencies]
@ -20,4 +20,3 @@ prost-build = "0.12.1"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
clap = { version = "4.2.4", features = ["derive"] }

View File

@ -15,9 +15,9 @@ crate-type = ["cdylib"]
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.3" }
candle-onnx = {path= "../candle-onnx", version = "0.3.3", optional = true}
candle = { workspace = true }
candle-nn = { workspace = true }
candle-onnx = { workspace = true, optional = true }
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }

View File

@ -12,9 +12,9 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
candle-nn = { path = "../candle-nn", version = "0.3.3" }
candle = { workspace = true }
candle-flash-attn = { workspace = true, optional = true }
candle-nn = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rand = { workspace = true }

View File

@ -1,6 +1,6 @@
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
use candle::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use serde::Deserialize;
pub const DTYPE: DType = DType::F32;
@ -112,11 +112,6 @@ impl Config {
}
}
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
struct Dropout {
#[allow(dead_code)]
pr: f64,

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
@ -11,11 +11,6 @@ fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Line
Ok(Linear::new(weight, bias))
}
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?;

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
const MAX_SEQ_LEN: usize = 5000;
@ -27,11 +27,6 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
Ok(LayerNorm::new(weight, bias, eps))
}
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
#[derive(Debug)]
pub struct Config {

View File

@ -1,6 +1,6 @@
use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@ -136,11 +136,6 @@ impl Cache {
}
}
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
Ok(Embedding::new(embeddings, cfg.hidden_size))
}
struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
@ -409,7 +404,7 @@ impl Llama {
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.num_hidden_layers)

View File

@ -26,6 +26,7 @@ pub mod quantized_mixformer;
pub mod quantized_mpt;
pub mod quantized_stable_lm;
pub mod quantized_t5;
pub mod repvgg;
pub mod resnet;
pub mod segment_anything;
pub mod stable_diffusion;

View File

@ -0,0 +1,306 @@
//! RepVGG inference implementation
//!
//! See "RepVGG: Making VGG-style ConvNets Great Again" Ding et al. 2021
//! https://arxiv.org/abs/2101.03697
use candle::{Result, Tensor, D};
use candle_nn::{
batch_norm, conv2d_no_bias, linear, BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
};
const CHANNELS_PER_STAGE: [usize; 5] = [64, 64, 128, 256, 512];
#[derive(Clone)]
pub struct Config {
a: f32,
b: f32,
groups: usize,
stages: [usize; 4],
}
impl Config {
pub fn a0() -> Self {
Self {
a: 0.75,
b: 2.5,
groups: 1,
stages: [2, 4, 14, 1],
}
}
pub fn a1() -> Self {
Self {
a: 1.0,
b: 2.5,
groups: 1,
stages: [2, 4, 14, 1],
}
}
pub fn a2() -> Self {
Self {
a: 1.5,
b: 2.75,
groups: 1,
stages: [2, 4, 14, 1],
}
}
pub fn b0() -> Self {
Self {
a: 1.0,
b: 2.5,
groups: 1,
stages: [4, 6, 16, 1],
}
}
pub fn b1() -> Self {
Self {
a: 2.0,
b: 4.0,
groups: 1,
stages: [4, 6, 16, 1],
}
}
pub fn b2() -> Self {
Self {
a: 2.5,
b: 5.0,
groups: 1,
stages: [4, 6, 16, 1],
}
}
pub fn b3() -> Self {
Self {
a: 3.0,
b: 5.0,
groups: 1,
stages: [4, 6, 16, 1],
}
}
pub fn b1g4() -> Self {
Self {
a: 2.0,
b: 4.0,
groups: 4,
stages: [4, 6, 16, 1],
}
}
pub fn b2g4() -> Self {
Self {
a: 2.5,
b: 5.0,
groups: 4,
stages: [4, 6, 16, 1],
}
}
pub fn b3g4() -> Self {
Self {
a: 3.0,
b: 5.0,
groups: 4,
stages: [4, 6, 16, 1],
}
}
}
// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
// based on the _fuse_bn_tensor method in timm
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
let (gamma, beta) = bn.weight_and_bias().unwrap();
let mu = bn.running_mean();
let sigma = (bn.running_var() + bn.eps())?.sqrt();
let gps = (gamma / sigma)?;
let bias = (beta - mu * &gps)?;
let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
Ok((weights, bias))
}
// A RepVGG layer has a different training time and inference time architecture.
// The latter is a simple and efficient equivalent transformation of the former
// realized by a structural reparameterization technique, where 3x3 and 1x1 convolutions
// along with identity branches and batchnorm layers are fused into a single 3x3 convolution.
fn repvgg_layer(
has_identity: bool,
dim: usize,
stride: usize,
in_channels: usize,
out_channels: usize,
groups: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
stride,
groups,
padding: 1,
..Default::default()
};
// read and reparameterize the 1x1 conv and bn into w1 and b1
// based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L543
let conv1x1_bn = batch_norm(dim, 1e-5, vb.pp("conv_1x1.bn"))?;
let conv1x1 = conv2d_no_bias(
in_channels,
out_channels,
1,
conv2d_cfg,
vb.pp("conv_1x1.conv"),
)?;
let (mut w1, b1) = fuse_conv_bn(conv1x1.weight(), conv1x1_bn)?;
// resize to 3x3
w1 = w1.pad_with_zeros(D::Minus1, 1, 1)?;
w1 = w1.pad_with_zeros(D::Minus2, 1, 1)?;
// read and reparameterize the 3x3 conv and bn into w3 and b3
let convkxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.bn"))?;
let conv3x3 = conv2d_no_bias(
in_channels,
out_channels,
3,
conv2d_cfg,
vb.pp("conv_kxk.conv"),
)?;
let (w3, b3) = fuse_conv_bn(conv3x3.weight(), convkxk_bn)?;
let mut w = (w1 + w3)?;
let mut b = (b1 + b3)?;
// read and reparameterize the identity bn into wi and bi
if has_identity {
let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
// create a 3x3 convolution equivalent to the identity branch
let mut weights: Vec<f32> = vec![0.0; conv3x3.weight().elem_count()];
// https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L620
let in_dim = in_channels / groups;
for i in 0..in_channels {
weights[i * in_dim * 3 * 3 + (i % in_dim) * 3 * 3 + 4] = 1.0;
}
let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
w = (w + wi)?;
b = (b + bi)?;
}
// create the 3x3 conv equivalent to the sum of 3x3, 1x1 and identity branches
let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
Ok(Func::new(move |xs| {
let xs = xs.apply(&reparam_conv)?.relu()?;
Ok(xs)
}))
}
// Get the number of output channels per stage taking into account the multipliers
fn output_channels_per_stage(a: f32, b: f32, stage: usize) -> usize {
let channels = CHANNELS_PER_STAGE[stage] as f32;
match stage {
0 => std::cmp::min(64, (channels * a) as usize),
4 => (channels * b) as usize,
_ => (channels * a) as usize,
}
}
// Each stage is made of layers. The first layer always downsamples with stride 2.
// All but the first layer have a residual connection.
// The G4 variants have a groupwise convolution instead of a dense one on odd layers
// counted across stage boundaries, so we keep track of which layer we are in the
// full model.
fn repvgg_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
let nlayers = cfg.stages[idx - 1];
let mut layers = Vec::with_capacity(nlayers);
let prev_layers: usize = cfg.stages[..idx - 1].iter().sum();
let out_channels_prev = output_channels_per_stage(cfg.a, cfg.b, idx - 1);
let out_channels = output_channels_per_stage(cfg.a, cfg.b, idx);
for layer_idx in 0..nlayers {
let (has_identity, stride, in_channels) = if layer_idx == 0 {
(false, 2, out_channels_prev)
} else {
(true, 1, out_channels)
};
let groups = if (prev_layers + layer_idx) % 2 == 1 {
cfg.groups
} else {
1
};
layers.push(repvgg_layer(
has_identity,
out_channels,
stride,
in_channels,
out_channels,
groups,
vb.pp(layer_idx),
)?)
}
Ok(Func::new(move |xs| {
let mut xs = xs.clone();
for layer in layers.iter() {
xs = xs.apply(layer)?
}
Ok(xs)
}))
}
// Build a RepVGG model for a given configuration.
fn repvgg_model(config: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
let cls = match nclasses {
None => None,
Some(nclasses) => {
let outputs = output_channels_per_stage(config.a, config.b, 4);
let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
Some(linear)
}
};
let stem_dim = output_channels_per_stage(config.a, config.b, 0);
let stem = repvgg_layer(false, stem_dim, 2, 3, stem_dim, 1, vb.pp("stem"))?;
let vb = vb.pp("stages");
let stage1 = repvgg_stage(config, 1, vb.pp(0))?;
let stage2 = repvgg_stage(config, 2, vb.pp(1))?;
let stage3 = repvgg_stage(config, 3, vb.pp(2))?;
let stage4 = repvgg_stage(config, 4, vb.pp(3))?;
Ok(Func::new(move |xs| {
let xs = xs
.apply(&stem)?
.apply(&stage1)?
.apply(&stage2)?
.apply(&stage3)?
.apply(&stage4)?
.mean(D::Minus1)?
.mean(D::Minus1)?;
match &cls {
None => Ok(xs),
Some(cls) => xs.apply(cls),
}
}))
}
pub fn repvgg(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
repvgg_model(cfg, Some(nclasses), vb)
}
pub fn repvgg_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
repvgg_model(cfg, None, vb)
}

View File

@ -1,12 +1,7 @@
use super::Config;
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
use candle_nn::{embedding, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
fn conv1d(
in_channels: usize,

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
candle = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
@ -27,7 +27,7 @@ safetensors = { workspace = true }
# Wasm specific crates.
console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] }
gloo = "0.8"
gloo = "0.11"
js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
serde-wasm-bindgen = "0.6.0"

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
candle = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
num-traits = { workspace = true }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
candle = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
@ -26,7 +26,7 @@ serde_json = { workspace = true }
# Wasm specific crates.
console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] }
gloo = "0.8"
gloo = "0.11"
js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37"

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
candle = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
num-traits = { workspace = true }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
candle = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
num-traits = { workspace = true }
# App crates.

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
candle = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
@ -27,7 +27,7 @@ safetensors = { workspace = true }
# Wasm specific crates.
console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] }
gloo = "0.8"
gloo = "0.11"
js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
serde-wasm-bindgen = "0.6.0"

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
candle = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
@ -26,7 +26,7 @@ safetensors = { workspace = true }
# Wasm specific crates.
getrandom = { version = "0.2", features = ["js"] }
gloo = "0.8"
gloo = "0.11"
js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37"

View File

@ -9,8 +9,8 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
candle = { workspace = true }
candle-nn = { workspace = true }
num-traits = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
@ -26,7 +26,7 @@ safetensors = { workspace = true }
# Wasm specific crates.
console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] }
gloo = "0.8"
gloo = "0.11"
js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37"

View File

@ -7,7 +7,7 @@ keywords.workspace = true
categories.workspace = true
[dependencies]
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
candle = { workspace = true }
rand = { workspace = true }
getrandom = { version = "0.2", features = ["js"] }