mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Merge branch 'main' into ivarflakstad/metal-prng
This commit is contained in:
7
.github/dependabot.yml
vendored
Normal file
7
.github/dependabot.yml
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "cargo"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 5
|
14
Cargo.toml
14
Cargo.toml
@ -31,9 +31,17 @@ license = "MIT OR Apache-2.0"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core" }
|
||||
candle-datasets = { path = "./candle-datasets" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn" }
|
||||
candle-kernels = { path = "./candle-kernels" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels" }
|
||||
candle-nn = { path = "./candle-nn" }
|
||||
candle-onnx = { path = "./candle-onnx" }
|
||||
candle-transformers = { path = "./candle-transformers" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.9.14", features = ["f16"] }
|
||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.3.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
@ -42,7 +50,7 @@ imageproc = { version = "0.23.0", default-features = false }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
|
||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
parquet = { version = "45.0.0" }
|
||||
@ -55,7 +63,7 @@ serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_plain = "1.0.2"
|
||||
serde_json = "1.0.99"
|
||||
thiserror = "1"
|
||||
tokenizers = { version = "0.13.4", default-features = false }
|
||||
tokenizers = { version = "0.15.0", default-features = false }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
|
@ -11,11 +11,11 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.3" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
|
||||
candle = { workspace = true }
|
||||
candle-datasets = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
@ -12,8 +12,8 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.3", optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.3", optional = true }
|
||||
candle-kernels = { workspace = true, optional = true }
|
||||
candle-metal-kernels = { workspace = true, optional = true }
|
||||
metal = { workspace = true, optional = true}
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::benchmarks::{bench_name, device, BenchDevice};
|
||||
use candle_core::{DType, Tensor};
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
@ -7,20 +7,19 @@ fn run(a: &Tensor, b: &Tensor) {
|
||||
a.matmul(&b.t().unwrap()).unwrap();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
fn run_bench(c: &mut Criterion, device: &Device) {
|
||||
let b = 1;
|
||||
let m = 1;
|
||||
let n = 2048;
|
||||
let k = 2048;
|
||||
|
||||
let device = device().unwrap();
|
||||
let dtype = DType::F32;
|
||||
let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
||||
let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap();
|
||||
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
||||
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
|
||||
|
||||
let flops = b * m * n * k;
|
||||
|
||||
let mut group = c.benchmark_group(bench_name("matmul"));
|
||||
let mut group = c.benchmark_group(device.bench_name("matmul"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
@ -35,4 +34,11 @@ fn criterion_benchmark(c: &mut Criterion) {
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_bench(c, &device);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
||||
|
@ -5,6 +5,8 @@ use candle_core::{Device, Result};
|
||||
|
||||
pub(crate) trait BenchDevice {
|
||||
fn sync(&self) -> Result<()>;
|
||||
|
||||
fn bench_name<S: Into<String>>(&self, name: S) -> String;
|
||||
}
|
||||
|
||||
impl BenchDevice for Device {
|
||||
@ -25,32 +27,38 @@ impl BenchDevice for Device {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn device() -> Result<Device> {
|
||||
if cfg!(feature = "metal") {
|
||||
Device::new_metal(0)
|
||||
} else if cfg!(feature = "cuda") {
|
||||
Device::new_cuda(0)
|
||||
} else {
|
||||
Ok(Device::Cpu)
|
||||
fn bench_name<S: Into<String>>(&self, name: S) -> String {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let cpu_type = if cfg!(feature = "accelerate") {
|
||||
"accelerate"
|
||||
} else if cfg!(feature = "mkl") {
|
||||
"mkl"
|
||||
} else {
|
||||
"cpu"
|
||||
};
|
||||
format!("{}_{}", cpu_type, name.into())
|
||||
}
|
||||
Device::Cuda(_) => format!("cuda_{}", name.into()),
|
||||
Device::Metal(_) => format!("metal_{}", name.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn bench_name<S: Into<String>>(name: S) -> String {
|
||||
format!("{}_{}", device_variant(), name.into())
|
||||
struct BenchDeviceHandler {
|
||||
devices: Vec<Device>,
|
||||
}
|
||||
|
||||
const fn device_variant() -> &'static str {
|
||||
if cfg!(feature = "metal") {
|
||||
"metal"
|
||||
} else if cfg!(feature = "cuda") {
|
||||
"cuda"
|
||||
} else if cfg!(feature = "accelerate") {
|
||||
"accelerate"
|
||||
} else if cfg!(feature = "mkl") {
|
||||
"mkl"
|
||||
} else {
|
||||
"cpu"
|
||||
impl BenchDeviceHandler {
|
||||
pub fn new() -> Result<Self> {
|
||||
let mut devices = Vec::new();
|
||||
if cfg!(feature = "metal") {
|
||||
devices.push(Device::new_metal(0)?);
|
||||
} else if cfg!(feature = "cuda") {
|
||||
devices.push(Device::new_cuda(0)?);
|
||||
}
|
||||
devices.push(Device::Cpu);
|
||||
Ok(Self { devices })
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::benchmarks::{bench_name, device, BenchDevice};
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
@ -11,19 +11,18 @@ fn rand_normal(a: &Tensor) {
|
||||
a.randn_like(100.0, 15.0).unwrap();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
fn run_random_bench(c: &mut Criterion, device: &Device) {
|
||||
let b = 1;
|
||||
|
||||
let rows = 2048;
|
||||
let cols = 2048;
|
||||
|
||||
let d = device().unwrap();
|
||||
let dtype = DType::F32;
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, &d).unwrap();
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||
|
||||
let flops = b * rows * cols * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(bench_name("random_uniform"));
|
||||
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |benches| {
|
||||
benches.iter_custom(|iters| {
|
||||
@ -31,16 +30,15 @@ fn criterion_benchmark(c: &mut Criterion) {
|
||||
for _i in 0..iters {
|
||||
rand_uniform(black_box(&tensor));
|
||||
}
|
||||
d.sync().unwrap();
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
|
||||
let d = device().unwrap();
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, &d).unwrap();
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||
|
||||
let mut group = c.benchmark_group(bench_name("random_normal"));
|
||||
let mut group = c.benchmark_group(device.bench_name("random_normal"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |benches| {
|
||||
benches.iter_custom(|iters| {
|
||||
@ -48,11 +46,18 @@ fn criterion_benchmark(c: &mut Criterion) {
|
||||
for _i in 0..iters {
|
||||
rand_normal(black_box(&tensor));
|
||||
}
|
||||
d.sync().unwrap();
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_random_bench(c, &device);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
||||
|
@ -102,7 +102,7 @@ enum Command {
|
||||
},
|
||||
|
||||
Quantize {
|
||||
/// The input file, in gguf format.
|
||||
/// The input file(s), in safetensors format.
|
||||
in_file: Vec<std::path::PathBuf>,
|
||||
|
||||
/// The output file, in gguf format.
|
||||
@ -117,6 +117,15 @@ enum Command {
|
||||
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
|
||||
mode: QuantizationMode,
|
||||
},
|
||||
|
||||
Dequantize {
|
||||
/// The input file, in gguf format.
|
||||
in_file: std::path::PathBuf,
|
||||
|
||||
/// The output file, in safetensors format.
|
||||
#[arg(long)]
|
||||
out_file: std::path::PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
@ -285,6 +294,19 @@ fn run_quantize_safetensors(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_dequantize(in_file: std::path::PathBuf, out_file: std::path::PathBuf) -> Result<()> {
|
||||
let mut in_file = std::fs::File::open(in_file)?;
|
||||
let content = gguf_file::Content::read(&mut in_file)?;
|
||||
let mut tensors = std::collections::HashMap::new();
|
||||
for (tensor_name, _) in content.tensor_infos.iter() {
|
||||
let tensor = content.tensor(&mut in_file, tensor_name)?;
|
||||
let tensor = tensor.dequantize(&Device::Cpu)?;
|
||||
tensors.insert(tensor_name.to_string(), tensor);
|
||||
}
|
||||
candle_core::safetensors::save(&tensors, out_file)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_quantize(
|
||||
in_files: &[std::path::PathBuf],
|
||||
out_file: std::path::PathBuf,
|
||||
@ -379,6 +401,7 @@ fn main() -> anyhow::Result<()> {
|
||||
quantization,
|
||||
mode,
|
||||
} => run_quantize(&in_file, out_file, quantization, mode)?,
|
||||
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -592,14 +592,26 @@ impl BackendStorage for MetalStorage {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
(DType::U8, DType::F32) => "cast_u8_f32",
|
||||
(DType::U8, DType::I64) => "cast_u8_i64",
|
||||
(DType::U8, DType::BF16) => "cast_u8_bf16",
|
||||
|
||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||
(DType::I64, DType::F32) => "cast_i64_f32",
|
||||
(DType::F32, DType::BF16) => "cast_f32_bf16",
|
||||
|
||||
(DType::I64, DType::F32) => "cast_i64_f32",
|
||||
|
||||
(DType::F16, DType::BF16) => "cast_f16_bf16",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||
|
||||
(DType::BF16, DType::U8) => "cast_bf16_u8",
|
||||
(DType::BF16, DType::U32) => "cast_bf16_u32",
|
||||
(DType::BF16, DType::F16) => "cast_bf16_f16",
|
||||
(DType::BF16, DType::F32) => "cast_bf16_f32",
|
||||
|
||||
(left, right) => {
|
||||
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
|
||||
}
|
||||
@ -677,6 +689,7 @@ impl BackendStorage for MetalStorage {
|
||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||
("urecip", DType::F32) => contiguous::recip::FLOAT,
|
||||
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
||||
("urelu", DType::F32) => contiguous::relu::FLOAT,
|
||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||
("usin", DType::F16) => contiguous::sin::HALF,
|
||||
("usqr", DType::F16) => contiguous::sqr::HALF,
|
||||
@ -693,6 +706,7 @@ impl BackendStorage for MetalStorage {
|
||||
("uround", DType::F16) => contiguous::round::HALF,
|
||||
("urecip", DType::F16) => contiguous::recip::HALF,
|
||||
("utanh", DType::F16) => contiguous::tanh::HALF,
|
||||
("urelu", DType::F16) => contiguous::relu::HALF,
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
@ -723,6 +737,7 @@ impl BackendStorage for MetalStorage {
|
||||
("uabs", DType::F32) => strided::abs::FLOAT,
|
||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
("urelu", DType::F32) => strided::relu::FLOAT,
|
||||
("uround", DType::F32) => strided::round::FLOAT,
|
||||
("ucos", DType::F16) => strided::cos::HALF,
|
||||
("usin", DType::F16) => strided::sin::HALF,
|
||||
@ -737,6 +752,7 @@ impl BackendStorage for MetalStorage {
|
||||
("uabs", DType::F16) => strided::abs::HALF,
|
||||
("uceil", DType::F16) => strided::ceil::HALF,
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
("urelu", DType::F16) => strided::relu::HALF,
|
||||
("uround", DType::F16) => strided::round::HALF,
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||
@ -1129,8 +1145,12 @@ impl BackendStorage for MetalStorage {
|
||||
let device = self.device();
|
||||
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U8, DType::BF16) => "is_u8_bf16",
|
||||
|
||||
(DType::U32, DType::F32) => "is_u32_f32",
|
||||
(DType::U32, DType::F16) => "is_u32_f16",
|
||||
(DType::U32, DType::BF16) => "is_u32_bf16",
|
||||
|
||||
(left, right) => {
|
||||
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
|
||||
}
|
||||
@ -1320,6 +1340,7 @@ impl MetalStorage {
|
||||
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
|
||||
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
|
||||
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
|
||||
|
||||
("add", DType::F16) => (contiguous::add::HALF, self.dtype),
|
||||
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
|
||||
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
|
||||
@ -1330,6 +1351,18 @@ impl MetalStorage {
|
||||
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
|
||||
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
|
||||
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
|
||||
|
||||
("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype),
|
||||
("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype),
|
||||
("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype),
|
||||
("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype),
|
||||
("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8),
|
||||
("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8),
|
||||
("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8),
|
||||
("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8),
|
||||
("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8),
|
||||
("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8),
|
||||
|
||||
("add", DType::I64) => (contiguous::add::I64, self.dtype),
|
||||
("sub", DType::I64) => (contiguous::sub::I64, self.dtype),
|
||||
("mul", DType::I64) => (contiguous::mul::I64, self.dtype),
|
||||
@ -1340,6 +1373,7 @@ impl MetalStorage {
|
||||
("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
|
||||
("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
|
||||
("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
|
||||
|
||||
("add", DType::U32) => (contiguous::add::U32, self.dtype),
|
||||
("sub", DType::U32) => (contiguous::sub::U32, self.dtype),
|
||||
("mul", DType::U32) => (contiguous::mul::U32, self.dtype),
|
||||
@ -1350,6 +1384,7 @@ impl MetalStorage {
|
||||
("lt", DType::U32) => (contiguous::lt::U32, DType::U8),
|
||||
("ge", DType::U32) => (contiguous::ge::U32, DType::U8),
|
||||
("gt", DType::U32) => (contiguous::gt::U32, DType::U8),
|
||||
|
||||
("add", DType::U8) => (contiguous::add::U8, self.dtype),
|
||||
("sub", DType::U8) => (contiguous::sub::U8, self.dtype),
|
||||
("mul", DType::U8) => (contiguous::mul::U8, self.dtype),
|
||||
@ -1360,6 +1395,7 @@ impl MetalStorage {
|
||||
("lt", DType::U8) => (contiguous::lt::U8, DType::U8),
|
||||
("ge", DType::U8) => (contiguous::ge::U8, DType::U8),
|
||||
("gt", DType::U8) => (contiguous::gt::U8, DType::U8),
|
||||
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
|
||||
}
|
||||
@ -1393,6 +1429,7 @@ impl MetalStorage {
|
||||
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
|
||||
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
|
||||
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
|
||||
|
||||
("badd", DType::F16) => (strided::add::HALF, self.dtype),
|
||||
("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
|
||||
("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
|
||||
@ -1405,6 +1442,20 @@ impl MetalStorage {
|
||||
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
|
||||
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
|
||||
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
|
||||
|
||||
("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype),
|
||||
("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype),
|
||||
("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype),
|
||||
("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype),
|
||||
("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype),
|
||||
("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype),
|
||||
("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8),
|
||||
("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8),
|
||||
("le", DType::BF16) => (strided::le::BFLOAT, DType::U8),
|
||||
("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8),
|
||||
("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8),
|
||||
("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8),
|
||||
|
||||
("badd", DType::I64) => (strided::add::I64, self.dtype),
|
||||
("bsub", DType::I64) => (strided::sub::I64, self.dtype),
|
||||
("bmul", DType::I64) => (strided::mul::I64, self.dtype),
|
||||
@ -1417,6 +1468,7 @@ impl MetalStorage {
|
||||
("lt", DType::I64) => (strided::lt::I64, DType::U8),
|
||||
("ge", DType::I64) => (strided::ge::I64, DType::U8),
|
||||
("gt", DType::I64) => (strided::gt::I64, DType::U8),
|
||||
|
||||
("badd", DType::U32) => (strided::add::U32, self.dtype),
|
||||
("bsub", DType::U32) => (strided::sub::U32, self.dtype),
|
||||
("bmul", DType::U32) => (strided::mul::U32, self.dtype),
|
||||
@ -1429,6 +1481,7 @@ impl MetalStorage {
|
||||
("lt", DType::U32) => (strided::lt::U32, DType::U8),
|
||||
("ge", DType::U32) => (strided::ge::U32, DType::U8),
|
||||
("gt", DType::U32) => (strided::gt::U32, DType::U8),
|
||||
|
||||
("badd", DType::U8) => (strided::add::U8, self.dtype),
|
||||
("bsub", DType::U8) => (strided::sub::U8, self.dtype),
|
||||
("bmul", DType::U8) => (strided::mul::U8, self.dtype),
|
||||
@ -1441,6 +1494,7 @@ impl MetalStorage {
|
||||
("lt", DType::U8) => (strided::lt::U8, DType::U8),
|
||||
("ge", DType::U8) => (strided::ge::U8, DType::U8),
|
||||
("gt", DType::U8) => (strided::gt::U8, DType::U8),
|
||||
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
|
||||
}
|
||||
|
@ -703,6 +703,7 @@ impl PthTensors {
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||
use std::io::Read;
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
None => return Ok(None),
|
||||
Some(tensor_info) => tensor_info,
|
||||
@ -712,14 +713,21 @@ impl PthTensors {
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||
|
||||
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
|
||||
// For now only support the basic case.
|
||||
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
|
||||
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
||||
// case.
|
||||
if !tensor_info.layout.is_contiguous() {
|
||||
crate::bail!(
|
||||
"cannot retrieve non-contiguous tensors {:?}",
|
||||
tensor_info.layout
|
||||
)
|
||||
}
|
||||
let start_offset = tensor_info.layout.start_offset();
|
||||
if start_offset > 0 {
|
||||
std::io::copy(
|
||||
&mut reader.by_ref().take(start_offset as u64),
|
||||
&mut std::io::sink(),
|
||||
)?;
|
||||
}
|
||||
let tensor = Tensor::from_reader(
|
||||
tensor_info.layout.shape().clone(),
|
||||
tensor_info.dtype,
|
||||
|
@ -1545,13 +1545,13 @@ impl GgmlType for BlockQ5K {
|
||||
let d2 = d * sc as f32;
|
||||
let m2 = min * m as f32;
|
||||
for (ql, qh) in ql.iter().zip(qh) {
|
||||
let to_add = if qh & u1 != 0 { 16 } else { 1 };
|
||||
y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1;
|
||||
let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 };
|
||||
y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1;
|
||||
ys_index += 1;
|
||||
}
|
||||
for (ql, qh) in ql.iter().zip(qh) {
|
||||
let to_add = if qh & u2 != 0 { 16 } else { 1 };
|
||||
y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2;
|
||||
let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 };
|
||||
y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2;
|
||||
ys_index += 1;
|
||||
}
|
||||
is += 2;
|
||||
|
@ -12,6 +12,14 @@ use core::arch::arm::*;
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
use core::arch::aarch64::*;
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
|
||||
// TODO: dotprod
|
||||
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
|
||||
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
||||
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
@ -43,15 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
||||
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
||||
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
|
||||
// TODO: Support dotprod when it's available outside of nightly.
|
||||
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
|
||||
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
||||
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
|
||||
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
||||
|
||||
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||
|
||||
let pl0 = vdotq_s32(v0_0ls, v1_0l);
|
||||
let ph0 = vdotq_s32(v0_0hs, v1_0h);
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
||||
@ -82,14 +83,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
||||
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
|
||||
// TODO dotprod once this is the intrinsics are.
|
||||
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
||||
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
||||
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
||||
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
||||
|
||||
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
||||
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
||||
let p0 = vdotq_s32(x0_0, y0_0);
|
||||
let p1 = vdotq_s32(x0_1, y0_1);
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
@ -118,10 +113,7 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res
|
||||
for i in (0..QK_K).step_by(16) {
|
||||
let xs = vld1q_s8(xs.add(i));
|
||||
let ys = vld1q_s8(ys.add(i));
|
||||
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
|
||||
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
|
||||
|
||||
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
|
||||
let xy = vdotq_s32(xs, ys);
|
||||
sum_i = vaddq_s32(sum_i, xy)
|
||||
}
|
||||
sumf += vaddvq_s32(sum_i) as f32 * scale
|
||||
@ -191,30 +183,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
|
||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
|
||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let q8bytes = vld1q_s8_x4(q8);
|
||||
@ -234,29 +212,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
|
||||
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
|
||||
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
|
||||
|
||||
// TODO: dotprod case.
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
|
||||
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
|
||||
scale = scale.add(2);
|
||||
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
|
||||
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
|
||||
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
|
||||
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
|
||||
scale = scale.add(2);
|
||||
}
|
||||
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
|
||||
@ -333,28 +298,14 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
|
||||
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
|
||||
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
|
||||
let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
|
||||
sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
|
||||
scales = scales.add(1);
|
||||
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
|
||||
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
|
||||
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
|
||||
);
|
||||
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
|
||||
let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
|
||||
let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
|
||||
sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
|
||||
scales = scales.add(1);
|
||||
}
|
||||
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
|
||||
@ -417,22 +368,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
||||
for j in 0..QK_K / 64 {
|
||||
let q4bits = vld1q_u8_x2(q4);
|
||||
q4 = q4.add(32);
|
||||
// TODO: dotprod
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
let q4bytes = int8x16x2_t(
|
||||
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
|
||||
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
|
||||
);
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
|
||||
let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
|
||||
let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
|
||||
sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
|
||||
|
||||
let q8bytes = vld1q_s8_x2(q8);
|
||||
q8 = q8.add(32);
|
||||
@ -440,15 +384,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
|
||||
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
|
||||
let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
|
||||
let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
|
||||
sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
|
||||
}
|
||||
sumf += d * (sumi1 + sumi2) as f32;
|
||||
}
|
||||
@ -526,27 +464,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
||||
vreinterpretq_s8_u8(q3h_3),
|
||||
);
|
||||
|
||||
// TODO: dotprod
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
|
||||
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
|
||||
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
|
||||
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
|
||||
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
|
||||
);
|
||||
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||
let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
|
||||
let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
|
||||
let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
|
||||
let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
|
||||
isum += vaddvq_s32(p0) * *scale as i32
|
||||
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
||||
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
||||
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
||||
scale = scale.add(4);
|
||||
|
||||
let q3h_0 = vbicq_u8(m2, qhbits.0);
|
||||
@ -571,27 +496,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
|
||||
vreinterpretq_s8_u8(q3h_3),
|
||||
);
|
||||
|
||||
// TODO: dotprod
|
||||
let p0 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
|
||||
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
|
||||
);
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
|
||||
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
|
||||
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
|
||||
);
|
||||
let p3 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
|
||||
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
|
||||
);
|
||||
isum += vaddvq_s16(p0) as i32 * *scale as i32
|
||||
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
|
||||
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
|
||||
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
|
||||
let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
|
||||
let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
|
||||
let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
|
||||
let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
|
||||
isum += vaddvq_s32(p0) * *scale as i32
|
||||
+ vaddvq_s32(p1) * *scale.add(1) as i32
|
||||
+ vaddvq_s32(p2) * *scale.add(2) as i32
|
||||
+ vaddvq_s32(p3) * *scale.add(3) as i32;
|
||||
scale = scale.add(4);
|
||||
|
||||
if j == 0 {
|
||||
@ -649,7 +561,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
|
||||
let mut is = 0usize;
|
||||
|
||||
// TODO: dotprod
|
||||
|
||||
for _j in 0..QK_K / 128 {
|
||||
let q2bits = vld1q_u8_x2(q2);
|
||||
q2 = q2.add(32);
|
||||
@ -696,14 +607,7 @@ unsafe fn multiply_accum_with_scale(
|
||||
q2bytes: int8x16x2_t,
|
||||
q8bytes: int8x16x2_t,
|
||||
) -> i32 {
|
||||
let p1 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
|
||||
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
|
||||
);
|
||||
let p2 = vaddq_s16(
|
||||
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
|
||||
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
|
||||
);
|
||||
vaddvq_s16(p1) as i32 * aux[is + index] as i32
|
||||
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
|
||||
let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
|
||||
let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
|
||||
vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
use candle_core::{
|
||||
bail,
|
||||
quantized::{self, GgmlDType},
|
||||
test_utils::to_vec2_round,
|
||||
Device, Module, Result, Tensor,
|
||||
@ -265,7 +266,8 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||
/// Creates a vector similar to the ones used in GGML unit tests:
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
|
||||
(0..GGML_TEST_SIZE)
|
||||
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
|
||||
@ -284,14 +286,15 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
|
||||
sum / a.len() as f32
|
||||
}
|
||||
|
||||
/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
||||
/// Similar to the GGML quantization unit test:
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
|
||||
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
|
||||
let src = create_ggml_like_vector(0.0);
|
||||
let mut dst = vec![0.0; GGML_TEST_SIZE];
|
||||
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
|
||||
let error = calculate_rmse(src.as_slice(), dst.as_slice());
|
||||
if error > max_error {
|
||||
candle_core::bail!(
|
||||
bail!(
|
||||
"Quantization error {} exceeds max error {}",
|
||||
error,
|
||||
max_error
|
||||
@ -404,7 +407,7 @@ fn quantize_q5k() -> Result<()> {
|
||||
let dst = round_vector(&dst);
|
||||
assert_eq!(
|
||||
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
|
||||
[-0.499, -0.372, -0.249, 0.001, 0.279, 0.499]
|
||||
[-0.5, -0.373, -0.25, 0.0, 0.279, 0.499]
|
||||
);
|
||||
|
||||
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
|
||||
@ -487,54 +490,66 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
|
||||
GgmlDType::Q5K => 0.000740,
|
||||
GgmlDType::Q6K => 0.000952,
|
||||
GgmlDType::Q4_0 => 0.001143,
|
||||
GgmlDType::Q4_1 => 0.007784,
|
||||
GgmlDType::Q4_1 => 0.008,
|
||||
GgmlDType::Q5_0 => 0.001353,
|
||||
GgmlDType::Q5_1 => 0.001363,
|
||||
GgmlDType::Q5_1 => 0.00149,
|
||||
GgmlDType::Q8_0 => 0.000092,
|
||||
|
||||
// Not from the ggml repo.
|
||||
GgmlDType::Q8K => 0.00065,
|
||||
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
|
||||
_ => bail!("No GGML results for quantization type {dtype:?}",),
|
||||
};
|
||||
Ok(err)
|
||||
}
|
||||
|
||||
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
||||
/// Similar to the GGML matmul unit test:
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
||||
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||
let a = create_ggml_like_vector(0.0);
|
||||
let b = create_ggml_like_vector(1.0);
|
||||
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 1.0)?;
|
||||
// Another example that is more likely to trigger the overflow reported in #1526
|
||||
let a = (0..GGML_TEST_SIZE)
|
||||
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
|
||||
.collect::<Vec<_>>();
|
||||
let b = (0..GGML_TEST_SIZE)
|
||||
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
|
||||
.collect::<Vec<_>>();
|
||||
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 2.0)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Result<()> {
|
||||
let length = a.len();
|
||||
|
||||
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
|
||||
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
|
||||
T::from_float(&a, &mut a_quant)?;
|
||||
T::VecDotType::from_float(&b, &mut b_quant)?;
|
||||
T::from_float(a, &mut a_quant)?;
|
||||
T::VecDotType::from_float(b, &mut b_quant)?;
|
||||
|
||||
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
||||
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
|
||||
let reference_result = vec_dot_reference(&a, &b);
|
||||
let reference_result = vec_dot_reference(a, b);
|
||||
|
||||
if (result - result_unopt).abs() / length as f32 > 1e-6 {
|
||||
candle_core::bail!(
|
||||
bail!(
|
||||
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
|
||||
)
|
||||
}
|
||||
|
||||
let error = (result - reference_result).abs() / length as f32;
|
||||
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m;
|
||||
|
||||
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
candle_core::bail!(
|
||||
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
|
||||
);
|
||||
bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",);
|
||||
}
|
||||
|
||||
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
|
||||
// => we use a slightly higher error threshold
|
||||
const ERROR_LENIENCY: f32 = 0.00001;
|
||||
if error - ERROR_LENIENCY > ggml_error {
|
||||
candle_core::bail!(
|
||||
bail!(
|
||||
"Dot product error {} exceeds ggml reference error {}",
|
||||
error,
|
||||
ggml_error
|
||||
@ -543,6 +558,16 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quantized_mm() -> Result<()> {
|
||||
ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ4_1>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ5_0>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ5_1>()?;
|
||||
ggml_matmul_error_test::<k_quants::BlockQ8_0>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
|
||||
fn get_random_tensors(
|
||||
m: usize,
|
||||
|
@ -11,8 +11,8 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
hf-hub = { workspace = true}
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
|
@ -11,12 +11,12 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.3.3" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.3.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
|
||||
candle-onnx = { path = "../candle-onnx", version = "0.3.3", optional = true }
|
||||
candle = { workspace = true }
|
||||
candle-datasets = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
candle-onnx = { workspace = true, optional = true }
|
||||
|
||||
csv = "1.3.0"
|
||||
cudarc = { workspace = true, optional = true }
|
||||
@ -49,11 +49,12 @@ tokio = "1.29.1"
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
bindgen_cuda = { version = "0.1.1", optional = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
|
@ -4,251 +4,34 @@ use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
struct KernelDirectories {
|
||||
kernel_dir: &'static str,
|
||||
kernel_glob: &'static str,
|
||||
rust_target: &'static str,
|
||||
include_dirs: &'static [&'static str],
|
||||
}
|
||||
|
||||
const DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
||||
kernel_dir: "examples/custom-ops/kernels/",
|
||||
const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories {
|
||||
kernel_glob: "examples/custom-ops/kernels/*.cu",
|
||||
rust_target: "examples/custom-ops/cuda_kernels.rs",
|
||||
include_dirs: &[],
|
||||
}];
|
||||
|
||||
impl KernelDirectories {
|
||||
fn maybe_build_ptx(
|
||||
&self,
|
||||
cu_file: &std::path::Path,
|
||||
ptx_file: &std::path::Path,
|
||||
compute_cap: usize,
|
||||
) -> Result<()> {
|
||||
let should_compile = if ptx_file.exists() {
|
||||
let ptx_modified = ptx_file.metadata()?.modified()?;
|
||||
let cu_modified = cu_file.metadata()?.modified()?;
|
||||
cu_modified.duration_since(ptx_modified).is_ok()
|
||||
} else {
|
||||
true
|
||||
};
|
||||
if should_compile {
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
|
||||
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
|
||||
let include_dirs: Vec<String> =
|
||||
self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
|
||||
command
|
||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("--ptx")
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.args(["--output-directory", out_dir.to_str().unwrap()])
|
||||
.arg(format!("-I/{}", self.kernel_dir))
|
||||
.args(include_dirs)
|
||||
.arg(cu_file);
|
||||
if let Ok(ccbin_path) = &ccbin_env {
|
||||
command
|
||||
.arg("-allow-unsupported-compiler")
|
||||
.args(["-ccbin", ccbin_path]);
|
||||
}
|
||||
let output = command
|
||||
.spawn()
|
||||
.context("failed spawning nvcc")?
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.open(ptx_file)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> {
|
||||
println!("cargo:rerun-if-changed={}", self.kernel_dir);
|
||||
let kernel_dir = PathBuf::from(self.kernel_dir);
|
||||
let out_dir = out_dir.join(self.kernel_dir);
|
||||
if !out_dir.exists() {
|
||||
std::fs::create_dir_all(&out_dir)?;
|
||||
}
|
||||
let mut cu_files = vec![];
|
||||
let mut cuh_files = vec![];
|
||||
for file in std::fs::read_dir(kernel_dir)?.flatten() {
|
||||
let file = file.path();
|
||||
match file.extension().and_then(|v| v.to_str()) {
|
||||
Some("cu") => cu_files.push(file),
|
||||
Some("cuh") => cuh_files.push(file),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let mut ptx_paths = vec![];
|
||||
for cu_file in cu_files.iter() {
|
||||
let file_stem = cu_file
|
||||
.file_stem()
|
||||
.with_context(|| format!("no stem {cu_file:?}"))?;
|
||||
let file_stem = file_stem.to_string_lossy().into_owned();
|
||||
let ptx_file = out_dir.join(&format!("{file_stem}.ptx"));
|
||||
self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?;
|
||||
ptx_paths.push(ptx_file);
|
||||
}
|
||||
|
||||
let regenerate_rs_file = true;
|
||||
if regenerate_rs_file {
|
||||
let mut file = std::fs::File::create(self.rust_target)?;
|
||||
for ptx_path in ptx_paths {
|
||||
let name = ptx_path
|
||||
.file_stem()
|
||||
.context("empty stem")?
|
||||
.to_string_lossy();
|
||||
file.write_all(b"#[rustfmt::skip]\n")?;
|
||||
let const_definition = format!(
|
||||
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
|
||||
name.to_uppercase().replace('.', "_"),
|
||||
self.kernel_dir,
|
||||
);
|
||||
file.write_all(const_definition.as_bytes())?;
|
||||
file.write_all(b"\n")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
|
||||
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
|
||||
let out_dir = PathBuf::from(out_dir);
|
||||
#[cfg(feature = "cuda")]
|
||||
set_cuda_include_dir()?;
|
||||
#[cfg(feature = "cuda")]
|
||||
let compute_cap = compute_cap()?;
|
||||
{
|
||||
for kdir in KERNEL_DIRS.iter() {
|
||||
let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob);
|
||||
println!("cargo:info={builder:?}");
|
||||
let bindings = builder.build_ptx().unwrap();
|
||||
bindings.write(kdir.rust_target).unwrap()
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
let compute_cap = 0;
|
||||
for d in DIRS {
|
||||
d.process(&out_dir, compute_cap)?
|
||||
{
|
||||
for kdir in KERNEL_DIRS.iter() {
|
||||
let _file = std::fs::File::create(kdir.rust_target)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_cuda_include_dir() -> Result<()> {
|
||||
// NOTE: copied from cudarc build.rs.
|
||||
let env_vars = [
|
||||
"CUDA_PATH",
|
||||
"CUDA_ROOT",
|
||||
"CUDA_TOOLKIT_ROOT_DIR",
|
||||
"CUDNN_LIB",
|
||||
];
|
||||
let env_vars = env_vars
|
||||
.into_iter()
|
||||
.map(std::env::var)
|
||||
.filter_map(Result::ok)
|
||||
.map(Into::<PathBuf>::into);
|
||||
|
||||
let roots = [
|
||||
"/usr",
|
||||
"/usr/local/cuda",
|
||||
"/opt/cuda",
|
||||
"/usr/lib/cuda",
|
||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||
"C:/CUDA",
|
||||
];
|
||||
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||
let root = env_vars
|
||||
.chain(roots)
|
||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||
.context("cannot find include/cuda.h")?;
|
||||
println!(
|
||||
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||
root.join("include").display()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn compute_cap() -> Result<usize> {
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
|
||||
// Try to parse compute cap from env
|
||||
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
|
||||
compute_cap_str
|
||||
.parse::<usize>()
|
||||
.context("Could not parse code")?
|
||||
} else {
|
||||
// Grab compute cap from nvidia-smi
|
||||
let out = std::process::Command::new("nvidia-smi")
|
||||
.arg("--query-gpu=compute_cap")
|
||||
.arg("--format=csv")
|
||||
.output()
|
||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||
let mut lines = out.lines();
|
||||
assert_eq!(
|
||||
lines.next().context("missing line in stdout")?,
|
||||
"compute_cap"
|
||||
);
|
||||
let cap = lines
|
||||
.next()
|
||||
.context("missing line in stdout")?
|
||||
.replace('.', "");
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
|
||||
cap.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as int {cap}"))?
|
||||
};
|
||||
|
||||
// Grab available GPU codes from nvcc and select the highest one
|
||||
let max_nvcc_code = {
|
||||
let out = std::process::Command::new("nvcc")
|
||||
.arg("--list-gpu-code")
|
||||
.output()
|
||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||
|
||||
let out = out.lines().collect::<Vec<&str>>();
|
||||
let mut codes = Vec::with_capacity(out.len());
|
||||
for code in out {
|
||||
let code = code.split('_').collect::<Vec<&str>>();
|
||||
if !code.is_empty() && code.contains(&"sm") {
|
||||
if let Ok(num) = code[1].parse::<usize>() {
|
||||
codes.push(num);
|
||||
}
|
||||
}
|
||||
}
|
||||
codes.sort();
|
||||
if !codes.contains(&compute_cap) {
|
||||
anyhow::bail!(
|
||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
|
||||
);
|
||||
}
|
||||
*codes.last().unwrap()
|
||||
};
|
||||
|
||||
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
|
||||
// then choose the highest gpu code in nvcc
|
||||
if compute_cap > max_nvcc_code {
|
||||
println!(
|
||||
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
|
||||
);
|
||||
compute_cap = max_nvcc_code;
|
||||
}
|
||||
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
|
||||
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
compute_cap = compute_cap_str
|
||||
.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
|
||||
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
|
||||
}
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
||||
Ok(compute_cap)
|
||||
}
|
||||
|
@ -1,2 +1 @@
|
||||
#[rustfmt::skip]
|
||||
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));
|
||||
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));
|
||||
|
@ -6,7 +6,8 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[allow(unused)]
|
||||
#[rustfmt::skip]
|
||||
#[cfg(feature = "cuda")]
|
||||
mod cuda_kernels;
|
||||
|
||||
use clap::Parser;
|
||||
|
20
candle-examples/examples/repvgg/README.md
Normal file
20
candle-examples/examples/repvgg/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# candle-repvgg
|
||||
|
||||
A candle implementation of inference using a pre-trained [repvgg](https://arxiv.org/abs/2101.03697).
|
||||
This uses a classification head trained on the ImageNet dataset and returns the
|
||||
probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example repvgg --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
mountain bike, all-terrain bike, off-roader: 61.70%
|
||||
bicycle-built-for-two, tandem bicycle, tandem: 33.14%
|
||||
unicycle, monocycle : 4.88%
|
||||
crash helmet : 0.15%
|
||||
moped : 0.04%
|
||||
|
||||
```
|
111
candle-examples/examples/repvgg/main.rs
Normal file
111
candle-examples/examples/repvgg/main.rs
Normal file
@ -0,0 +1,111 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::repvgg;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
A0,
|
||||
A1,
|
||||
A2,
|
||||
B0,
|
||||
B1,
|
||||
B2,
|
||||
B3,
|
||||
B1G4,
|
||||
B2G4,
|
||||
B3G4,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::A0 => "a0",
|
||||
Self::A1 => "a1",
|
||||
Self::A2 => "a2",
|
||||
Self::B0 => "b0",
|
||||
Self::B1 => "b1",
|
||||
Self::B2 => "b2",
|
||||
Self::B3 => "b3",
|
||||
Self::B1G4 => "b1g4",
|
||||
Self::B2G4 => "b2g4",
|
||||
Self::B3G4 => "b3g4",
|
||||
};
|
||||
format!("timm/repvgg_{}.rvgg_in1k", name)
|
||||
}
|
||||
|
||||
fn config(&self) -> repvgg::Config {
|
||||
match self {
|
||||
Self::A0 => repvgg::Config::a0(),
|
||||
Self::A1 => repvgg::Config::a1(),
|
||||
Self::A2 => repvgg::Config::a2(),
|
||||
Self::B0 => repvgg::Config::b0(),
|
||||
Self::B1 => repvgg::Config::b1(),
|
||||
Self::B2 => repvgg::Config::b2(),
|
||||
Self::B3 => repvgg::Config::b3(),
|
||||
Self::B1G4 => repvgg::Config::b1g4(),
|
||||
Self::B2G4 => repvgg::Config::b2g4(),
|
||||
Self::B3G4 => repvgg::Config::b3g4(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::A0)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = repvgg::repvgg(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -11,14 +11,14 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.3", package = "candle-core" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
bindgen_cuda = "0.1.1"
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
num_cpus = "1.15.0"
|
||||
rayon = "1.7.0"
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3", features = ["cuda"] }
|
||||
candle-nn = { path = "../candle-nn", features = ["cuda"] }
|
||||
|
@ -2,44 +2,32 @@
|
||||
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
|
||||
// variable in order to cache the compiled artifacts and avoid recompiling too often.
|
||||
use anyhow::{Context, Result};
|
||||
use rayon::prelude::*;
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
|
||||
const KERNEL_FILES: [&str; 17] = [
|
||||
"flash_api.cu",
|
||||
"flash_fwd_hdim128_fp16_sm80.cu",
|
||||
"flash_fwd_hdim160_fp16_sm80.cu",
|
||||
"flash_fwd_hdim192_fp16_sm80.cu",
|
||||
"flash_fwd_hdim224_fp16_sm80.cu",
|
||||
"flash_fwd_hdim256_fp16_sm80.cu",
|
||||
"flash_fwd_hdim32_fp16_sm80.cu",
|
||||
"flash_fwd_hdim64_fp16_sm80.cu",
|
||||
"flash_fwd_hdim96_fp16_sm80.cu",
|
||||
"flash_fwd_hdim128_bf16_sm80.cu",
|
||||
"flash_fwd_hdim160_bf16_sm80.cu",
|
||||
"flash_fwd_hdim192_bf16_sm80.cu",
|
||||
"flash_fwd_hdim224_bf16_sm80.cu",
|
||||
"flash_fwd_hdim256_bf16_sm80.cu",
|
||||
"flash_fwd_hdim32_bf16_sm80.cu",
|
||||
"flash_fwd_hdim64_bf16_sm80.cu",
|
||||
"flash_fwd_hdim96_bf16_sm80.cu",
|
||||
"kernels/flash_api.cu",
|
||||
"kernels/flash_fwd_hdim128_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim160_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim192_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim224_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim256_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim32_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim64_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim96_fp16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim128_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim160_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim192_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim224_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim256_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim32_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim64_bf16_sm80.cu",
|
||||
"kernels/flash_fwd_hdim96_bf16_sm80.cu",
|
||||
];
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
|
||||
|_| num_cpus::get_physical(),
|
||||
|s| usize::from_str(&s).unwrap(),
|
||||
);
|
||||
|
||||
rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(num_cpus)
|
||||
.build_global()
|
||||
.unwrap();
|
||||
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
for kernel_file in KERNEL_FILES.iter() {
|
||||
println!("cargo:rerun-if-changed=kernels/{kernel_file}");
|
||||
println!("cargo:rerun-if-changed={kernel_file}");
|
||||
}
|
||||
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
|
||||
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
|
||||
@ -66,223 +54,30 @@ fn main() -> Result<()> {
|
||||
))
|
||||
}
|
||||
};
|
||||
set_cuda_include_dir()?;
|
||||
|
||||
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
|
||||
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
|
||||
|
||||
let compute_cap = compute_cap()?;
|
||||
let kernels = KERNEL_FILES.iter().collect();
|
||||
let builder = bindgen_cuda::Builder::default()
|
||||
.kernel_paths(kernels)
|
||||
.out_dir(build_dir.clone())
|
||||
.arg("-std=c++17")
|
||||
.arg("-O3")
|
||||
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
||||
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
||||
.arg("-Icutlass/include")
|
||||
.arg("--expt-relaxed-constexpr")
|
||||
.arg("--expt-extended-lambda")
|
||||
.arg("--use_fast_math")
|
||||
.arg("--verbose");
|
||||
|
||||
let out_file = build_dir.join("libflashattention.a");
|
||||
builder.build_lib(out_file);
|
||||
|
||||
let kernel_dir = PathBuf::from("kernels");
|
||||
let cu_files: Vec<_> = KERNEL_FILES
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let mut obj_file = out_dir.join(f);
|
||||
obj_file.set_extension("o");
|
||||
(kernel_dir.join(f), obj_file)
|
||||
})
|
||||
.collect();
|
||||
let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
|
||||
let should_compile = if out_file.exists() {
|
||||
kernel_dir
|
||||
.read_dir()
|
||||
.expect("kernels folder should exist")
|
||||
.any(|entry| {
|
||||
if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
|
||||
let in_modified = entry.metadata().unwrap().modified().unwrap();
|
||||
in_modified.duration_since(*out_modified).is_ok()
|
||||
} else {
|
||||
true
|
||||
}
|
||||
})
|
||||
} else {
|
||||
true
|
||||
};
|
||||
if should_compile {
|
||||
cu_files
|
||||
.par_iter()
|
||||
.map(|(cu_file, obj_file)| {
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
command
|
||||
.arg("-std=c++17")
|
||||
.arg("-O3")
|
||||
.arg("-U__CUDA_NO_HALF_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
|
||||
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
|
||||
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
||||
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("-c")
|
||||
.args(["-o", obj_file.to_str().unwrap()])
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.arg("-Icutlass/include")
|
||||
.arg("--expt-relaxed-constexpr")
|
||||
.arg("--expt-extended-lambda")
|
||||
.arg("--use_fast_math")
|
||||
.arg("--verbose");
|
||||
if let Ok(ccbin_path) = &ccbin_env {
|
||||
command
|
||||
.arg("-allow-unsupported-compiler")
|
||||
.args(["-ccbin", ccbin_path]);
|
||||
}
|
||||
command.arg(cu_file);
|
||||
let output = command
|
||||
.spawn()
|
||||
.context("failed spawning nvcc")?
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
&command,
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<()>>()?;
|
||||
let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
command
|
||||
.arg("--lib")
|
||||
.args(["-o", out_file.to_str().unwrap()])
|
||||
.args(obj_files);
|
||||
let output = command
|
||||
.spawn()
|
||||
.context("failed spawning nvcc")?
|
||||
.wait_with_output()?;
|
||||
if !output.status.success() {
|
||||
anyhow::bail!(
|
||||
"nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
&command,
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
}
|
||||
}
|
||||
println!("cargo:rustc-link-search={}", build_dir.display());
|
||||
println!("cargo:rustc-link-lib=flashattention");
|
||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||
|
||||
/* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
|
||||
finishing to run for some reason. Calling nvcc manually worked fine.
|
||||
cc::Build::new()
|
||||
.cuda(true)
|
||||
.include("cutlass/include")
|
||||
.flag("--expt-relaxed-constexpr")
|
||||
.flag("--default-stream")
|
||||
.flag("per-thread")
|
||||
.flag(&format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.file("kernels/flash_fwd_hdim32_fp16_sm80.cu")
|
||||
.compile("flashattn");
|
||||
*/
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_cuda_include_dir() -> Result<()> {
|
||||
// NOTE: copied from cudarc build.rs.
|
||||
let env_vars = [
|
||||
"CUDA_PATH",
|
||||
"CUDA_ROOT",
|
||||
"CUDA_TOOLKIT_ROOT_DIR",
|
||||
"CUDNN_LIB",
|
||||
];
|
||||
let env_vars = env_vars
|
||||
.into_iter()
|
||||
.map(std::env::var)
|
||||
.filter_map(Result::ok)
|
||||
.map(Into::<PathBuf>::into);
|
||||
|
||||
let roots = [
|
||||
"/usr",
|
||||
"/usr/local/cuda",
|
||||
"/opt/cuda",
|
||||
"/usr/lib/cuda",
|
||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||
"C:/CUDA",
|
||||
];
|
||||
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||
let root = env_vars
|
||||
.chain(roots)
|
||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||
.context("cannot find include/cuda.h")?;
|
||||
println!(
|
||||
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||
root.join("include").display()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn compute_cap() -> Result<usize> {
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
|
||||
// Try to parse compute caps from env
|
||||
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
|
||||
compute_cap_str
|
||||
.parse::<usize>()
|
||||
.context("Could not parse compute cap")?
|
||||
} else {
|
||||
// Use nvidia-smi to get the current compute cap
|
||||
let out = std::process::Command::new("nvidia-smi")
|
||||
.arg("--query-gpu=compute_cap")
|
||||
.arg("--format=csv")
|
||||
.output()
|
||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||
let mut lines = out.lines();
|
||||
assert_eq!(
|
||||
lines.next().context("missing line in stdout")?,
|
||||
"compute_cap"
|
||||
);
|
||||
let cap = lines
|
||||
.next()
|
||||
.context("missing line in stdout")?
|
||||
.replace('.', "");
|
||||
let cap = cap
|
||||
.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as int {cap}"))?;
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
|
||||
cap
|
||||
};
|
||||
|
||||
// Grab available GPU codes from nvcc and select the highest one
|
||||
let (supported_nvcc_codes, max_nvcc_code) = {
|
||||
let out = std::process::Command::new("nvcc")
|
||||
.arg("--list-gpu-code")
|
||||
.output()
|
||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||
|
||||
let out = out.lines().collect::<Vec<&str>>();
|
||||
let mut codes = Vec::with_capacity(out.len());
|
||||
for code in out {
|
||||
let code = code.split('_').collect::<Vec<&str>>();
|
||||
if !code.is_empty() && code.contains(&"sm") {
|
||||
if let Ok(num) = code[1].parse::<usize>() {
|
||||
codes.push(num);
|
||||
}
|
||||
}
|
||||
}
|
||||
codes.sort();
|
||||
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
|
||||
(codes, max_nvcc_code)
|
||||
};
|
||||
|
||||
// Check that nvcc supports the asked compute caps
|
||||
if !supported_nvcc_codes.contains(&compute_cap) {
|
||||
anyhow::bail!(
|
||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
|
||||
);
|
||||
}
|
||||
if compute_cap > max_nvcc_code {
|
||||
anyhow::bail!(
|
||||
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(compute_cap)
|
||||
}
|
||||
|
@ -12,6 +12,4 @@ license = "MIT OR Apache-2.0"
|
||||
[dependencies]
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
glob = "0.3.1"
|
||||
rayon = "1.7.0"
|
||||
bindgen_cuda = "0.1.1"
|
||||
|
@ -1,243 +1,8 @@
|
||||
use std::io::Write;
|
||||
|
||||
fn main() {
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
|
||||
cuda::set_include_dir();
|
||||
let (write, kernel_paths) = cuda::build_ptx();
|
||||
if write {
|
||||
let mut file = std::fs::File::create("src/lib.rs").unwrap();
|
||||
for kernel_path in kernel_paths {
|
||||
let name = kernel_path.file_stem().unwrap().to_str().unwrap();
|
||||
file.write_all(
|
||||
format!(
|
||||
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#,
|
||||
name.to_uppercase().replace('.', "_"),
|
||||
name
|
||||
)
|
||||
.as_bytes(),
|
||||
)
|
||||
.unwrap();
|
||||
file.write_all(&[b'\n']).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod cuda {
|
||||
use anyhow::{Context, Result};
|
||||
|
||||
pub fn set_include_dir() {
|
||||
use std::path::PathBuf;
|
||||
// NOTE: copied from cudarc build.rs.
|
||||
// We can't actually set a env!() value from another crate,
|
||||
// so we have to do that here.
|
||||
|
||||
// use PathBuf;
|
||||
|
||||
let env_vars = [
|
||||
"CUDA_PATH",
|
||||
"CUDA_ROOT",
|
||||
"CUDA_TOOLKIT_ROOT_DIR",
|
||||
"CUDNN_LIB",
|
||||
];
|
||||
#[allow(unused)]
|
||||
let env_vars = env_vars
|
||||
.into_iter()
|
||||
.map(std::env::var)
|
||||
.filter_map(Result::ok)
|
||||
.map(Into::<PathBuf>::into);
|
||||
|
||||
let roots = [
|
||||
"/usr",
|
||||
"/usr/local/cuda",
|
||||
"/opt/cuda",
|
||||
"/usr/lib/cuda",
|
||||
"C:/Program Files/NVIDIA GPU Computing Toolkit",
|
||||
"C:/CUDA",
|
||||
];
|
||||
#[allow(unused)]
|
||||
let roots = roots.into_iter().map(Into::<PathBuf>::into);
|
||||
|
||||
#[cfg(feature = "ci-check")]
|
||||
let root: PathBuf = "ci".into();
|
||||
|
||||
#[cfg(not(feature = "ci-check"))]
|
||||
let root = env_vars
|
||||
.chain(roots)
|
||||
.find(|path| path.join("include").join("cuda.h").is_file())
|
||||
.unwrap();
|
||||
|
||||
println!(
|
||||
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
|
||||
root.join("include").display()
|
||||
);
|
||||
}
|
||||
|
||||
pub fn build_ptx() -> (bool, Vec<std::path::PathBuf>) {
|
||||
use rayon::prelude::*;
|
||||
use std::path::PathBuf;
|
||||
let out_dir = std::env::var("OUT_DIR").unwrap();
|
||||
let kernel_paths: Vec<PathBuf> = glob::glob("src/*.cu")
|
||||
.unwrap()
|
||||
.map(|p| p.unwrap())
|
||||
.collect();
|
||||
let mut include_directories: Vec<PathBuf> = glob::glob("src/**/*.cuh")
|
||||
.unwrap()
|
||||
.map(|p| p.unwrap())
|
||||
.collect();
|
||||
|
||||
println!("cargo:rerun-if-changed=src/");
|
||||
// for path in &kernel_paths {
|
||||
// println!("cargo:rerun-if-changed={}", path.display());
|
||||
// }
|
||||
|
||||
for path in &mut include_directories {
|
||||
// println!("cargo:rerun-if-changed={}", path.display());
|
||||
let destination =
|
||||
std::format!("{out_dir}/{}", path.file_name().unwrap().to_str().unwrap());
|
||||
std::fs::copy(path.clone(), destination).unwrap();
|
||||
// remove the filename from the path so it's just the directory
|
||||
path.pop();
|
||||
}
|
||||
|
||||
include_directories.sort();
|
||||
include_directories.dedup();
|
||||
|
||||
let compute_cap = compute_cap().expect("Could not get Cuda compute cap");
|
||||
|
||||
#[allow(unused)]
|
||||
let include_options: Vec<String> = include_directories
|
||||
.into_iter()
|
||||
.map(|s| "-I".to_string() + &s.into_os_string().into_string().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
|
||||
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
|
||||
let children = kernel_paths
|
||||
.par_iter()
|
||||
.flat_map(|p| {
|
||||
let mut output = p.clone();
|
||||
output.set_extension("ptx");
|
||||
let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap());
|
||||
|
||||
let ignore = if output_filename.exists() {
|
||||
let out_modified = output_filename.metadata().unwrap().modified().unwrap();
|
||||
let in_modified = p.metadata().unwrap().modified().unwrap();
|
||||
out_modified.duration_since(in_modified).is_ok()
|
||||
} else {
|
||||
false
|
||||
};
|
||||
if ignore {
|
||||
None
|
||||
} else {
|
||||
let mut command = std::process::Command::new("nvcc");
|
||||
command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
|
||||
.arg("--ptx")
|
||||
.args(["--default-stream", "per-thread"])
|
||||
.args(["--output-directory", &out_dir])
|
||||
// Flash attention only
|
||||
// .arg("--expt-relaxed-constexpr")
|
||||
.args(&include_options);
|
||||
if let Ok(ccbin_path) = &ccbin_env {
|
||||
command
|
||||
.arg("-allow-unsupported-compiler")
|
||||
.args(["-ccbin", ccbin_path]);
|
||||
}
|
||||
command.arg(p);
|
||||
Some((p, command.spawn()
|
||||
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let ptx_paths: Vec<PathBuf> = glob::glob(&format!("{out_dir}/**/*.ptx"))
|
||||
.unwrap()
|
||||
.map(|p| p.unwrap())
|
||||
.collect();
|
||||
// We should rewrite `src/lib.rs` only if there are some newly compiled kernels, or removed
|
||||
// some old ones
|
||||
let write = !children.is_empty() || kernel_paths.len() < ptx_paths.len();
|
||||
for (kernel_path, child) in children {
|
||||
let output = child.expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
}
|
||||
(write, kernel_paths)
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn compute_cap() -> Result<usize> {
|
||||
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
|
||||
|
||||
// Try to parse compute caps from env
|
||||
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
|
||||
compute_cap_str
|
||||
.parse::<usize>()
|
||||
.context("Could not parse code")?
|
||||
} else {
|
||||
// Use nvidia-smi to get the current compute cap
|
||||
let out = std::process::Command::new("nvidia-smi")
|
||||
.arg("--query-gpu=compute_cap")
|
||||
.arg("--format=csv")
|
||||
.output()
|
||||
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
|
||||
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
|
||||
let mut lines = out.lines();
|
||||
assert_eq!(
|
||||
lines.next().context("missing line in stdout")?,
|
||||
"compute_cap"
|
||||
);
|
||||
let cap = lines
|
||||
.next()
|
||||
.context("missing line in stdout")?
|
||||
.replace('.', "");
|
||||
let cap = cap
|
||||
.parse::<usize>()
|
||||
.with_context(|| format!("cannot parse as int {cap}"))?;
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
|
||||
cap
|
||||
};
|
||||
|
||||
// Grab available GPU codes from nvcc and select the highest one
|
||||
let (supported_nvcc_codes, max_nvcc_code) = {
|
||||
let out = std::process::Command::new("nvcc")
|
||||
.arg("--list-gpu-code")
|
||||
.output()
|
||||
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
|
||||
let out = std::str::from_utf8(&out.stdout).unwrap();
|
||||
|
||||
let out = out.lines().collect::<Vec<&str>>();
|
||||
let mut codes = Vec::with_capacity(out.len());
|
||||
for code in out {
|
||||
let code = code.split('_').collect::<Vec<&str>>();
|
||||
if !code.is_empty() && code.contains(&"sm") {
|
||||
if let Ok(num) = code[1].parse::<usize>() {
|
||||
codes.push(num);
|
||||
}
|
||||
}
|
||||
}
|
||||
codes.sort();
|
||||
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
|
||||
(codes, max_nvcc_code)
|
||||
};
|
||||
|
||||
// Check that nvcc supports the asked compute caps
|
||||
if !supported_nvcc_codes.contains(&compute_cap) {
|
||||
anyhow::bail!(
|
||||
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
|
||||
);
|
||||
}
|
||||
if compute_cap > max_nvcc_code {
|
||||
anyhow::bail!(
|
||||
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(compute_cap)
|
||||
}
|
||||
let builder = bindgen_cuda::Builder::default();
|
||||
println!("cargo:info={builder:?}");
|
||||
let bindings = builder.build_ptx().unwrap();
|
||||
bindings.write("src/lib.rs").unwrap();
|
||||
}
|
||||
|
@ -9,12 +9,17 @@ keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
|
||||
[dependencies]
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
metal = { version = "0.27.0", features = ["mps"] }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
||||
[dev-dependencies]
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
half = { version = "2.3.1", features = [
|
||||
"num-traits",
|
||||
"use-intrinsics",
|
||||
"rand_distr",
|
||||
] }
|
||||
rand = "0.8.5"
|
||||
|
@ -117,7 +117,7 @@ ELU(elu_f32, float)
|
||||
ELU(elu_f16, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
AFFINE(affine_bf16, bfloat);
|
||||
POWF(powf_bf16, bfloat);
|
||||
ELU(elu_bf16, bfloat);
|
||||
|
@ -105,7 +105,7 @@ INT64_BINARY_OP_OUT(ge, x >= y)
|
||||
INT64_BINARY_OP_OUT(gt, x > y)
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
BFLOAT_BINARY_OP(x + y, add)
|
||||
BFLOAT_BINARY_OP(x - y, sub)
|
||||
BFLOAT_BINARY_OP(x * y, mul)
|
||||
|
@ -28,7 +28,7 @@ kernel void FN_NAME( \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = RIGHT_TYPENAME(input[tid]); \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(input[tid]); \
|
||||
} \
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
@ -42,7 +42,34 @@ kernel void FN_NAME_STRIDED( \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)]); \
|
||||
} \
|
||||
|
||||
#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[tid])); \
|
||||
} \
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
device const LEFT_TYPENAME *input, \
|
||||
device RIGHT_TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
|
||||
} \
|
||||
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||
@ -58,7 +85,14 @@ CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
|
||||
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
|
||||
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
|
||||
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
|
||||
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
|
||||
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
|
||||
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
|
||||
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
|
||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||
#endif
|
@ -173,7 +173,10 @@ SCATTER_ADD_OP(sa_u32_f32, uint, float)
|
||||
SCATTER_ADD_OP(sa_u32_f16, uint, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
|
||||
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
||||
|
||||
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
||||
|
@ -178,8 +178,8 @@ macro_rules! ops{
|
||||
|
||||
pub mod unary {
|
||||
ops!(
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
|
||||
recip
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||
tanh, recip
|
||||
);
|
||||
}
|
||||
pub mod binary {
|
||||
|
@ -295,7 +295,7 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
|
||||
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
||||
|
@ -1,6 +1,6 @@
|
||||
use super::*;
|
||||
use half::{bf16, f16};
|
||||
use metal::{Device, MTLResourceOptions};
|
||||
use metal::{Buffer, Device, MTLResourceOptions};
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let ptr = buffer.contents() as *const T;
|
||||
@ -248,6 +248,34 @@ fn binary_add_f32() {
|
||||
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_ops_bf16() {
|
||||
let lhs: Vec<bf16> = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect();
|
||||
let rhs: Vec<bf16> = [4.2f32, 5.5f32, 6.91f32]
|
||||
.into_iter()
|
||||
.map(bf16::from_f32)
|
||||
.collect();
|
||||
|
||||
macro_rules! binary_op {
|
||||
($opname:ident, $opexpr:expr) => {{
|
||||
let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT);
|
||||
let expected: Vec<bf16> = lhs
|
||||
.iter()
|
||||
.zip(rhs.iter())
|
||||
.map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y))
|
||||
.collect();
|
||||
assert_eq!(results, expected);
|
||||
}};
|
||||
}
|
||||
|
||||
binary_op!(add, |x, y| x + y);
|
||||
binary_op!(sub, |x, y| x - y);
|
||||
binary_op!(mul, |x, y| x * y);
|
||||
binary_op!(div, |x, y| x / y);
|
||||
binary_op!(min, |x: bf16, y| x.min(y));
|
||||
binary_op!(max, |x: bf16, y| x.max(y));
|
||||
}
|
||||
|
||||
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
@ -296,6 +324,89 @@ fn cast_u32_f32() {
|
||||
assert_eq!(results, vec![1.0f32; 10_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_u32() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<u32> = cast(&input, "cast_bf16_u32");
|
||||
let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_f32() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<f32> = cast(&input, "cast_bf16_f32");
|
||||
let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_u8_bf16() {
|
||||
let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
|
||||
let expected: Vec<bf16> = input
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v as f32))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_u32_bf16() {
|
||||
let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_f32_bf16() {
|
||||
let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_u8() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<u8> = cast(&input, "cast_bf16_u8");
|
||||
let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_f16() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<f16> = cast(&input, "cast_bf16_f16");
|
||||
let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_f16_bf16() {
|
||||
let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
let device = device();
|
||||
let fence = device.new_fence();
|
||||
@ -396,14 +507,14 @@ fn index_select() {
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
|
||||
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
|
||||
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [2, 5];
|
||||
let ids = [0u32, 1, 0];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||
@ -419,20 +530,46 @@ fn index_select_f16() {
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16");
|
||||
assert_eq!(
|
||||
approx_f16(result, 4),
|
||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_is_u32_bf16() {
|
||||
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16");
|
||||
assert_eq!(
|
||||
approx_bf16(result, 4),
|
||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_is_u8_bf16() {
|
||||
let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
|
||||
let shape = [5, 2];
|
||||
let ids = [0u8, 4, 2];
|
||||
let dim = 0;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16");
|
||||
assert_eq!(
|
||||
approx_bf16(result, 4),
|
||||
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_dim1() {
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 1, 0];
|
||||
let dim = 1;
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
||||
@ -444,6 +581,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
shape: &[usize],
|
||||
ids: &[I],
|
||||
dim: usize,
|
||||
name: &'static str,
|
||||
) -> Vec<T> {
|
||||
let device = Device::system_default().expect("no device found");
|
||||
|
||||
@ -457,12 +595,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
let dst_el = ids.len() * left_size * right_size;
|
||||
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
|
||||
let name = match core::mem::size_of::<T>() {
|
||||
4 => "is_u32_f32",
|
||||
2 => "is_u32_f16",
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
|
||||
let fence = device.new_fence();
|
||||
let kernels = Kernels::new(fence);
|
||||
call_index_select(
|
||||
|
@ -58,6 +58,12 @@ template <typename T> METAL_FUNC T gelu(T x) {
|
||||
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
||||
}
|
||||
template <typename T> METAL_FUNC T relu(T in){
|
||||
if (in < 0) {
|
||||
return 0;
|
||||
}
|
||||
return in;
|
||||
}
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
@ -110,6 +116,7 @@ UNARY_OP(gelu_erf)
|
||||
UNARY_OP(erf)
|
||||
UNARY_OP(tanh)
|
||||
UNARY_OP(recip)
|
||||
UNARY_OP(relu)
|
||||
|
||||
UNARY(id, float, copy_f32, copy_f32_strided)
|
||||
UNARY(id, half, copy_f16, copy_f16_strided)
|
||||
@ -120,7 +127,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
||||
#endif
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
BFLOAT_UNARY_OP(cos)
|
||||
BFLOAT_UNARY_OP(sin)
|
||||
BFLOAT_UNARY_OP(sqr)
|
||||
@ -136,6 +143,7 @@ BFLOAT_UNARY_OP(gelu_erf)
|
||||
BFLOAT_UNARY_OP(erf)
|
||||
BFLOAT_UNARY_OP(tanh)
|
||||
BFLOAT_UNARY_OP(recip)
|
||||
BFLOAT_UNARY_OP(relu)
|
||||
|
||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
#endif
|
||||
|
@ -11,7 +11,7 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle = { workspace = true }
|
||||
half = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
@ -20,7 +20,7 @@ rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
metal = { workspace = true, optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
candle-metal-kernels = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
candle = { path = "../candle-core", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
@ -20,4 +20,3 @@ prost-build = "0.12.1"
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
|
||||
|
@ -15,9 +15,9 @@ crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
candle-onnx = {path= "../candle-onnx", version = "0.3.3", optional = true}
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-onnx = { workspace = true, optional = true }
|
||||
half = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
||||
|
@ -12,9 +12,9 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
candle-nn = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
|
@ -1,6 +1,6 @@
|
||||
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const DTYPE: DType = DType::F32;
|
||||
@ -112,11 +112,6 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
struct Dropout {
|
||||
#[allow(dead_code)]
|
||||
pr: f64,
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
@ -11,11 +11,6 @@ fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Line
|
||||
Ok(Linear::new(weight, bias))
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let weight = vb.get(size, "weight")?;
|
||||
let bias = vb.get(size, "bias")?;
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
|
||||
@ -27,11 +27,6 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
Ok(LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
|
||||
#[derive(Debug)]
|
||||
pub struct Config {
|
||||
|
@ -1,6 +1,6 @@
|
||||
use super::with_tracing::{linear_no_bias as linear, Linear};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
@ -136,11 +136,6 @@ impl Cache {
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, cfg.hidden_size))
|
||||
}
|
||||
|
||||
struct RmsNorm {
|
||||
inner: candle_nn::RmsNorm,
|
||||
span: tracing::Span,
|
||||
@ -409,7 +404,7 @@ impl Llama {
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
||||
|
@ -26,6 +26,7 @@ pub mod quantized_mixformer;
|
||||
pub mod quantized_mpt;
|
||||
pub mod quantized_stable_lm;
|
||||
pub mod quantized_t5;
|
||||
pub mod repvgg;
|
||||
pub mod resnet;
|
||||
pub mod segment_anything;
|
||||
pub mod stable_diffusion;
|
||||
|
306
candle-transformers/src/models/repvgg.rs
Normal file
306
candle-transformers/src/models/repvgg.rs
Normal file
@ -0,0 +1,306 @@
|
||||
//! RepVGG inference implementation
|
||||
//!
|
||||
//! See "RepVGG: Making VGG-style ConvNets Great Again" Ding et al. 2021
|
||||
//! https://arxiv.org/abs/2101.03697
|
||||
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d_no_bias, linear, BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
|
||||
};
|
||||
|
||||
const CHANNELS_PER_STAGE: [usize; 5] = [64, 64, 128, 256, 512];
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
a: f32,
|
||||
b: f32,
|
||||
groups: usize,
|
||||
stages: [usize; 4],
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn a0() -> Self {
|
||||
Self {
|
||||
a: 0.75,
|
||||
b: 2.5,
|
||||
groups: 1,
|
||||
stages: [2, 4, 14, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn a1() -> Self {
|
||||
Self {
|
||||
a: 1.0,
|
||||
b: 2.5,
|
||||
groups: 1,
|
||||
stages: [2, 4, 14, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn a2() -> Self {
|
||||
Self {
|
||||
a: 1.5,
|
||||
b: 2.75,
|
||||
groups: 1,
|
||||
stages: [2, 4, 14, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b0() -> Self {
|
||||
Self {
|
||||
a: 1.0,
|
||||
b: 2.5,
|
||||
groups: 1,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b1() -> Self {
|
||||
Self {
|
||||
a: 2.0,
|
||||
b: 4.0,
|
||||
groups: 1,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b2() -> Self {
|
||||
Self {
|
||||
a: 2.5,
|
||||
b: 5.0,
|
||||
groups: 1,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b3() -> Self {
|
||||
Self {
|
||||
a: 3.0,
|
||||
b: 5.0,
|
||||
groups: 1,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b1g4() -> Self {
|
||||
Self {
|
||||
a: 2.0,
|
||||
b: 4.0,
|
||||
groups: 4,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b2g4() -> Self {
|
||||
Self {
|
||||
a: 2.5,
|
||||
b: 5.0,
|
||||
groups: 4,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn b3g4() -> Self {
|
||||
Self {
|
||||
a: 3.0,
|
||||
b: 5.0,
|
||||
groups: 4,
|
||||
stages: [4, 6, 16, 1],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
|
||||
// based on the _fuse_bn_tensor method in timm
|
||||
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
|
||||
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
|
||||
let (gamma, beta) = bn.weight_and_bias().unwrap();
|
||||
let mu = bn.running_mean();
|
||||
let sigma = (bn.running_var() + bn.eps())?.sqrt();
|
||||
let gps = (gamma / sigma)?;
|
||||
let bias = (beta - mu * &gps)?;
|
||||
let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
|
||||
|
||||
Ok((weights, bias))
|
||||
}
|
||||
|
||||
// A RepVGG layer has a different training time and inference time architecture.
|
||||
// The latter is a simple and efficient equivalent transformation of the former
|
||||
// realized by a structural reparameterization technique, where 3x3 and 1x1 convolutions
|
||||
// along with identity branches and batchnorm layers are fused into a single 3x3 convolution.
|
||||
fn repvgg_layer(
|
||||
has_identity: bool,
|
||||
dim: usize,
|
||||
stride: usize,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
groups: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride,
|
||||
groups,
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// read and reparameterize the 1x1 conv and bn into w1 and b1
|
||||
// based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L543
|
||||
|
||||
let conv1x1_bn = batch_norm(dim, 1e-5, vb.pp("conv_1x1.bn"))?;
|
||||
let conv1x1 = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv2d_cfg,
|
||||
vb.pp("conv_1x1.conv"),
|
||||
)?;
|
||||
|
||||
let (mut w1, b1) = fuse_conv_bn(conv1x1.weight(), conv1x1_bn)?;
|
||||
|
||||
// resize to 3x3
|
||||
w1 = w1.pad_with_zeros(D::Minus1, 1, 1)?;
|
||||
w1 = w1.pad_with_zeros(D::Minus2, 1, 1)?;
|
||||
|
||||
// read and reparameterize the 3x3 conv and bn into w3 and b3
|
||||
let convkxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.bn"))?;
|
||||
let conv3x3 = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
conv2d_cfg,
|
||||
vb.pp("conv_kxk.conv"),
|
||||
)?;
|
||||
|
||||
let (w3, b3) = fuse_conv_bn(conv3x3.weight(), convkxk_bn)?;
|
||||
|
||||
let mut w = (w1 + w3)?;
|
||||
let mut b = (b1 + b3)?;
|
||||
|
||||
// read and reparameterize the identity bn into wi and bi
|
||||
if has_identity {
|
||||
let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
|
||||
|
||||
// create a 3x3 convolution equivalent to the identity branch
|
||||
let mut weights: Vec<f32> = vec![0.0; conv3x3.weight().elem_count()];
|
||||
|
||||
// https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L620
|
||||
let in_dim = in_channels / groups;
|
||||
for i in 0..in_channels {
|
||||
weights[i * in_dim * 3 * 3 + (i % in_dim) * 3 * 3 + 4] = 1.0;
|
||||
}
|
||||
|
||||
let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
|
||||
let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
|
||||
|
||||
w = (w + wi)?;
|
||||
b = (b + bi)?;
|
||||
}
|
||||
|
||||
// create the 3x3 conv equivalent to the sum of 3x3, 1x1 and identity branches
|
||||
let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.apply(&reparam_conv)?.relu()?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Get the number of output channels per stage taking into account the multipliers
|
||||
fn output_channels_per_stage(a: f32, b: f32, stage: usize) -> usize {
|
||||
let channels = CHANNELS_PER_STAGE[stage] as f32;
|
||||
|
||||
match stage {
|
||||
0 => std::cmp::min(64, (channels * a) as usize),
|
||||
4 => (channels * b) as usize,
|
||||
_ => (channels * a) as usize,
|
||||
}
|
||||
}
|
||||
|
||||
// Each stage is made of layers. The first layer always downsamples with stride 2.
|
||||
// All but the first layer have a residual connection.
|
||||
// The G4 variants have a groupwise convolution instead of a dense one on odd layers
|
||||
// counted across stage boundaries, so we keep track of which layer we are in the
|
||||
// full model.
|
||||
fn repvgg_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let nlayers = cfg.stages[idx - 1];
|
||||
let mut layers = Vec::with_capacity(nlayers);
|
||||
let prev_layers: usize = cfg.stages[..idx - 1].iter().sum();
|
||||
let out_channels_prev = output_channels_per_stage(cfg.a, cfg.b, idx - 1);
|
||||
let out_channels = output_channels_per_stage(cfg.a, cfg.b, idx);
|
||||
|
||||
for layer_idx in 0..nlayers {
|
||||
let (has_identity, stride, in_channels) = if layer_idx == 0 {
|
||||
(false, 2, out_channels_prev)
|
||||
} else {
|
||||
(true, 1, out_channels)
|
||||
};
|
||||
|
||||
let groups = if (prev_layers + layer_idx) % 2 == 1 {
|
||||
cfg.groups
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
layers.push(repvgg_layer(
|
||||
has_identity,
|
||||
out_channels,
|
||||
stride,
|
||||
in_channels,
|
||||
out_channels,
|
||||
groups,
|
||||
vb.pp(layer_idx),
|
||||
)?)
|
||||
}
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.clone();
|
||||
for layer in layers.iter() {
|
||||
xs = xs.apply(layer)?
|
||||
}
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Build a RepVGG model for a given configuration.
|
||||
fn repvgg_model(config: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let cls = match nclasses {
|
||||
None => None,
|
||||
Some(nclasses) => {
|
||||
let outputs = output_channels_per_stage(config.a, config.b, 4);
|
||||
let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
|
||||
Some(linear)
|
||||
}
|
||||
};
|
||||
|
||||
let stem_dim = output_channels_per_stage(config.a, config.b, 0);
|
||||
let stem = repvgg_layer(false, stem_dim, 2, 3, stem_dim, 1, vb.pp("stem"))?;
|
||||
let vb = vb.pp("stages");
|
||||
let stage1 = repvgg_stage(config, 1, vb.pp(0))?;
|
||||
let stage2 = repvgg_stage(config, 2, vb.pp(1))?;
|
||||
let stage3 = repvgg_stage(config, 3, vb.pp(2))?;
|
||||
let stage4 = repvgg_stage(config, 4, vb.pp(3))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&stem)?
|
||||
.apply(&stage1)?
|
||||
.apply(&stage2)?
|
||||
.apply(&stage3)?
|
||||
.apply(&stage4)?
|
||||
.mean(D::Minus1)?
|
||||
.mean(D::Minus1)?;
|
||||
match &cls {
|
||||
None => Ok(xs),
|
||||
Some(cls) => xs.apply(cls),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn repvgg(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
repvgg_model(cfg, Some(nclasses), vb)
|
||||
}
|
||||
|
||||
pub fn repvgg_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
repvgg_model(cfg, None, vb)
|
||||
}
|
@ -1,12 +1,7 @@
|
||||
use super::Config;
|
||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
use candle_nn::{embedding, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||
|
||||
fn conv1d(
|
||||
in_channels: usize,
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
@ -27,7 +27,7 @@ safetensors = { workspace = true }
|
||||
# Wasm specific crates.
|
||||
console_error_panic_hook = "0.1.7"
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
gloo = "0.8"
|
||||
gloo = "0.11"
|
||||
js-sys = "0.3.64"
|
||||
wasm-bindgen = "0.2.87"
|
||||
serde-wasm-bindgen = "0.6.0"
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
@ -26,7 +26,7 @@ serde_json = { workspace = true }
|
||||
# Wasm specific crates.
|
||||
console_error_panic_hook = "0.1.7"
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
gloo = "0.8"
|
||||
gloo = "0.11"
|
||||
js-sys = "0.3.64"
|
||||
wasm-bindgen = "0.2.87"
|
||||
wasm-bindgen-futures = "0.4.37"
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
# App crates.
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
@ -27,7 +27,7 @@ safetensors = { workspace = true }
|
||||
# Wasm specific crates.
|
||||
console_error_panic_hook = "0.1.7"
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
gloo = "0.8"
|
||||
gloo = "0.11"
|
||||
js-sys = "0.3.64"
|
||||
wasm-bindgen = "0.2.87"
|
||||
serde-wasm-bindgen = "0.6.0"
|
||||
|
@ -9,9 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
@ -26,7 +26,7 @@ safetensors = { workspace = true }
|
||||
|
||||
# Wasm specific crates.
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
gloo = "0.8"
|
||||
gloo = "0.11"
|
||||
js-sys = "0.3.64"
|
||||
wasm-bindgen = "0.2.87"
|
||||
wasm-bindgen-futures = "0.4.37"
|
||||
|
@ -9,8 +9,8 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
|
||||
candle = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
@ -26,7 +26,7 @@ safetensors = { workspace = true }
|
||||
# Wasm specific crates.
|
||||
console_error_panic_hook = "0.1.7"
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
gloo = "0.8"
|
||||
gloo = "0.11"
|
||||
js-sys = "0.3.64"
|
||||
wasm-bindgen = "0.2.87"
|
||||
wasm-bindgen-futures = "0.4.37"
|
||||
|
@ -7,7 +7,7 @@ keywords.workspace = true
|
||||
categories.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
|
||||
candle = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
|
||||
|
Reference in New Issue
Block a user