From 955e63c8033af247c51b7ada1ab2c12fa7170cf5 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 5 Jan 2024 13:27:59 +0100 Subject: [PATCH 01/46] Implement hybrid Tausworthe + LCG psuedo random number generator in metal --- candle-core/Cargo.toml | 4 + candle-core/benches/random.rs | 41 ++++++++ candle-core/src/metal_backend.rs | 40 ++++++-- candle-metal-kernels/src/lib.rs | 73 +++++++++++++- candle-metal-kernels/src/random.metal | 139 ++++++++++++++++++++++++++ candle-metal-kernels/src/tests.rs | 56 ++++++++++- 6 files changed, 341 insertions(+), 12 deletions(-) create mode 100644 candle-core/benches/random.rs create mode 100644 candle-metal-kernels/src/random.metal diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 91655f57..8edfef5a 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -49,3 +49,7 @@ metal = ["dep:metal", "dep:candle-metal-kernels"] name = "matmul" harness = false +[[bench]] +name = "random" +harness = false + diff --git a/candle-core/benches/random.rs b/candle-core/benches/random.rs new file mode 100644 index 00000000..ce42fd4e --- /dev/null +++ b/candle-core/benches/random.rs @@ -0,0 +1,41 @@ +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use std::time::Instant; + +fn run(a: &Tensor) { + a.rand_like(0.0, 1.0).unwrap(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let b = 1; + + let rows = 2048; + let cols = 2048; + + let device = Device::new_metal(0).unwrap(); + let dtype = DType::F32; + let tensor = Tensor::zeros((b, rows, cols), dtype, &device).unwrap(); + + let flops = b * rows * cols; + + let mut group = c.benchmark_group("random_metal"); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |benches| { + benches.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&tensor)); + } + if let Device::Metal(device) = &device { + device.wait_until_completed().unwrap(); + } else { + panic!("Expected metal device"); + } + start.elapsed() + }) + }); + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 6d8afab1..059cf24b 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -8,7 +8,7 @@ use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; use std::path::Path; -use std::sync::{Arc, RwLock, TryLockError}; +use std::sync::{Arc, Mutex, RwLock, TryLockError}; /// Simple way to catch lock error without /// depending on T @@ -106,6 +106,8 @@ pub struct MetalDevice { /// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers /// (strong_count = 1). buffers: AllocatedBuffers, + + seed: Arc>, } impl std::fmt::Debug for MetalDevice { @@ -1373,6 +1375,7 @@ impl BackendDevice for MetalDevice { Ok(val) => val.parse()?, _ => 20, }; + let seed = Arc::new(Mutex::new(299792458)); Ok(Self { device, fence, @@ -1382,11 +1385,14 @@ impl BackendDevice for MetalDevice { compute_per_buffer, buffers, kernels, + seed }) } - fn set_seed(&self, _seed: u64) -> Result<()> { - crate::bail!("set_seed") + fn set_seed(&self, seed: u64) -> Result<()> { + let mut s = self.seed.try_lock().map_err(MetalError::from)?; + *s = seed; + Ok(()) } fn location(&self) -> crate::DeviceLocation { @@ -1441,12 +1447,30 @@ impl BackendDevice for MetalDevice { &self, shape: &Shape, dtype: DType, - mean: f64, - stddev: f64, + min: f64, + max: f64, ) -> Result { - // TODO is there a better way ? - let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?; - self.storage_from_cpu_storage(&cpu_storage) + let name = match dtype { + DType::F32 => "rand_uniform_f32", + DType::F16 => "rand_uniform_f16", + DType::BF16 => "rand_uniform_bf16", + dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"), + }; + let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_uniform")?; + let command_buffer = self.command_buffer()?; + candle_metal_kernels::call_random_uniform( + &self.device, + &command_buffer, + &self.kernels, + name, + *self.seed.lock().unwrap(), + min as f32, + max as f32, + shape.elem_count(), + &buffer + ).map_err(MetalError::from)?; + + Ok(Self::Storage::new(buffer, self.clone(), dtype)) } fn rand_normal( diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index dd97a86d..04442c8a 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -12,8 +12,9 @@ const UNARY: &str = include_str!("unary.metal"); const BINARY: &str = include_str!("binary.metal"); const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); -const REDUCE: &str = include_str!("reduce.metal"); const CONV: &str = include_str!("conv.metal"); +const REDUCE: &str = include_str!("reduce.metal"); +const RANDOM: &str = include_str!("random.metal"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); /// Most kernels apply similarly across the tensors @@ -45,7 +46,7 @@ fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, /// Helper functions to create the various objects on the compute command encoder /// on a single line. /// Prevents getting wrong some arguments number and mixing length and size in bytes. -trait EncoderParam { +pub trait EncoderParam { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); } macro_rules! primitive { @@ -61,8 +62,10 @@ macro_rules! primitive { } }; } +primitive!(bool); primitive!(usize); primitive!(u32); +primitive!(u64); primitive!(f32); impl EncoderParam for &[T] { @@ -117,6 +120,7 @@ pub enum Source { Reduce, Mfa, Conv, + Random, } macro_rules! ops{ @@ -228,6 +232,7 @@ impl Kernels { Source::Cast => CAST, Source::Reduce => REDUCE, Source::Conv => CONV, + Source::Random => RANDOM, Source::Mfa => panic!("Invalid lib"), } } @@ -1566,5 +1571,69 @@ fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger } +#[allow(clippy::too_many_arguments)] +pub fn call_random_uniform( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + seed: u64, + min: f32, + max: f32, + length: usize, + buffer: &Buffer, +) -> Result<(), MetalKernelError> { + if min >= max { + return Err(MetalKernelError::LoadLibraryError( + "min must be less than max".to_string(), + )); + } + + let size: usize = match name { + "rand_uniform_f32" => 4, + "rand_uniform_f16" | "rand_uniform_bf16" => 2, + _ => Err(MetalKernelError::LoadLibraryError(format!( + "{name} is not a valid kernel for random" + )))?, + }; + + let elems_per_key = length; + let bytes_per_key = size * elems_per_key; + + let out_per_key = (bytes_per_key + 4 - 1) / 4; + let half_size = out_per_key / 2; + let odd = length % 2 != 0; + + let pipeline = kernels.load_pipeline(device, Source::Random, name)?; + let encoder = command_buffer.new_compute_command_encoder(); + + let thread_group_count = MTLSize { + width: length as u64, + height: half_size as u64 + odd as u64, + depth: 1, + }; + let threads = std::cmp::min( + (half_size + odd as usize) as NSUInteger, + pipeline.max_total_threads_per_threadgroup(), + ); + let thread_group_size = MTLSize { + width: threads, + height: 1, + depth: 1, + }; + + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, seed, min, max, buffer)); + + encoder.use_resource(buffer, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/random.metal b/candle-metal-kernels/src/random.metal new file mode 100644 index 00000000..1604123d --- /dev/null +++ b/candle-metal-kernels/src/random.metal @@ -0,0 +1,139 @@ +#include +using namespace metal; + +// Constants +// 2^32 and 1/2^32. Useful for converting between float and uint. +static constexpr constant ulong UNIF01_NORM32 = 4294967296; +static constexpr constant float UNIF01_INV32 = 2.328306436538696289e-10; +// 2 * pi +static constexpr constant float TWO_PI = 2.0 * M_PI_F; +static constexpr constant int3 S1 = {13, 19, 12}; +static constexpr constant int3 S2 = {2, 25, 4}; +static constexpr constant int3 S3 = {3, 11, 17}; + +static constexpr constant uint64_t PHI[16] = { + 0x9E3779B97F4A7C15, + 0xF39CC0605CEDC834, + 0x1082276BF3A27251, + 0xF86C6A11D0C18E95, + 0x2767F0B153D27B7F, + 0x0347045B5BF1827F, + 0x01886F0928403002, + 0xC1D64BA40F335E36, + 0xF06AD7AE9717877E, + 0x85839D6EFFBD7DC6, + 0x64D325D1C5371682, + 0xCADD0CCCFDFFBBE1, + 0x626E33B8D04B4331, + 0xBBF73C790D94F79D, + 0x471C4AB3ED3D82A5, + 0xFEC507705E4AE6E5, +}; + +// Combined Tausworthe and LCG Random Number Generator. +// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application +// https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf +class HybridTaus { +private: + thread float seed; + + // Generate seeds for each thread. + thread uint4 seed_per_thread(const ulong4 seeds) { + return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL)); + } + + // Tausworthe generator. + thread uint taus(const uint z, const int3 s, const uint M) { + uint b = (((z << s.x) ^ z) >> s.y); + return (((z & M) << s.z) ^ b); + } + + // LCG generator. + thread uint lcg(const uint z) { + return (1664525 * z + 1013904223UL); + } + +public: + thread HybridTaus(const ulong4 seeds) { + uint4 seed = this->seed_per_thread(seeds); + + // Seed #1 + uint z1 = taus(seed.x, S1, 4294967294UL); + uint z2 = taus(seed.y, S2, 4294967288UL); + uint z3 = taus(seed.z, S3, 4294967280UL); + uint z4 = lcg(seed.x); + + // Seed #2 + uint r1 = (z1^z2^z3^z4^seed.y); + z1 = taus(r1, S1, 429496729UL); + z2 = taus(r1, S2, 4294967288UL); + z3 = taus(r1, S3, 429496280UL); + z4 = lcg(r1); + + // Seed #3 + r1 = (z1^z2^z3^z4^seed.z); + z1 = taus(r1, S1, 429496729UL); + z2 = taus(r1, S2, 4294967288UL); + z3 = taus(r1, S3, 429496280UL); + z4 = lcg(r1); + + // Seed #4 + r1 = (z1^z2^z3^z4^seed.w); + z1 = taus(r1, S1, 429496729UL); + z2 = taus(r1, S2, 4294967288UL); + z3 = taus(r1, S3, 429496280UL); + z4 = lcg(r1); + + this->seed = (z1^z2^z3^z4) * UNIF01_INV32; + } + + thread float rand() { + uint seed = this->seed * UNIF01_NORM32; + uint z1 = taus(seed, S1, 429496729UL); + uint z2 = taus(seed, S2, 4294967288UL); + uint z3 = taus(seed, S3, 429496280UL); + uint z4 = lcg(seed); + + thread float old_seed = this->seed; + this->seed = (z1^z2^z3^z4) * UNIF01_INV32; + return old_seed; + } +}; + +template METAL_FUNC void rand_uniform( + constant size_t &elem_count, + constant ulong &seed, + constant float &min, + constant float &max, + device T *out, + uint tid [[thread_position_in_grid]] +) { + if (tid >= elem_count) { + return; + } + float diff = max - min; + HybridTaus rng = HybridTaus({seed, tid, 1, 1}); + out[tid] = static_cast(rng.rand() * diff + min); +} + +#define UNIFORM_OP(NAME, T) \ +kernel void rand_uniform_##NAME( \ + constant size_t &elem_count, \ + constant ulong &seed, \ + constant float &min, \ + constant float &max, \ + device T *out, \ + uint tid [[thread_position_in_grid]] \ +) { \ + rand_uniform(elem_count, seed, min, max, out, tid); \ +} \ + +#define RANDOM_OPS(NAME, T) \ +UNIFORM_OP(NAME, T) \ + +RANDOM_OPS(f32, float) +RANDOM_OPS(f16, half) + +#if __METAL_VERSION__ >= 310 +RANDOM_OPS(bf16, bfloat) +#endif diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index c955abca..d0ca8330 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -11,7 +11,7 @@ fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { fn new_buffer(device: &Device, data: &[T]) -> Buffer { let options = MTLResourceOptions::StorageModeManaged; - let ptr = data.as_ptr() as *const core::ffi::c_void; + let ptr = data.as_ptr() as *const c_void; let size = (data.len() * std::mem::size_of::()) as u64; device.new_buffer_with_data(ptr, size, options) } @@ -590,7 +590,6 @@ fn softmax() { } let results = run_softmax(&v, last_dim, "softmax_f32"); let results = approx(results, 4); - println!("{results:?}"); assert_eq!( results.iter().map(|&s| s.round() as usize).sum::(), n @@ -806,3 +805,56 @@ fn gemm() { vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0] ); } + +fn run_random(seed: u64, shape: &[usize], name: &'static str, min: f32, max: f32) -> Vec { + let device = device(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let length = shape.iter().product::(); + let output = device.new_buffer((length * core::mem::size_of::()) as u64, options); + + call_random_uniform( + &device, + command_buffer, + &kernels, + name, + seed, + min, + max, + length, + &output, + ) + .unwrap(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, length) +} + +#[test] +fn random() { + use std::fs::File; + use std::io::prelude::*; + + let shape = vec![1024, 4]; + let seed = 299792458; + let min = -30.0; + let max = 30.0; + let results = run_random::(seed, &shape, "rand_uniform_f32", min, max); + for &v in &results { + assert!(v >= min && v <= max); + } + + // Writing bytes to file for testing with ENT + // https://www.fourmilab.ch/random/ + // TODO: Remove before merge + let (head, body, tail) = unsafe { results.align_to::() }; + assert!(head.is_empty()); + assert!(tail.is_empty()); + let mut file = File::create("test").unwrap(); + file.write_all(body).unwrap(); +} From 6bf52b9fdf82ad775611e82924d73172660a605e Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 5 Jan 2024 21:18:12 +0100 Subject: [PATCH 02/46] Gaussian normal distribution of PRNG via Box-Muller transform --- candle-core/benches/random.rs | 31 +++++++- candle-core/src/metal_backend.rs | 32 ++++++-- candle-metal-kernels/src/lib.rs | 64 +++++++-------- candle-metal-kernels/src/random.metal | 107 +++++++++++++++++++------- candle-metal-kernels/src/tests.rs | 105 +++++++++++++++++-------- 5 files changed, 238 insertions(+), 101 deletions(-) diff --git a/candle-core/benches/random.rs b/candle-core/benches/random.rs index ce42fd4e..781d8b39 100644 --- a/candle-core/benches/random.rs +++ b/candle-core/benches/random.rs @@ -2,10 +2,14 @@ use candle_core::{DType, Device, Tensor}; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; use std::time::Instant; -fn run(a: &Tensor) { +fn rand_uniform(a: &Tensor) { a.rand_like(0.0, 1.0).unwrap(); } +fn rand_normal(a: &Tensor) { + a.randn_like(100.0, 15.0).unwrap(); +} + fn criterion_benchmark(c: &mut Criterion) { let b = 1; @@ -13,18 +17,19 @@ fn criterion_benchmark(c: &mut Criterion) { let cols = 2048; let device = Device::new_metal(0).unwrap(); + let device2 = device.clone(); let dtype = DType::F32; let tensor = Tensor::zeros((b, rows, cols), dtype, &device).unwrap(); let flops = b * rows * cols; - let mut group = c.benchmark_group("random_metal"); + let mut group = c.benchmark_group("metal_random_uniform"); group.throughput(Throughput::Bytes(flops as u64)); group.bench_function("iter", move |benches| { benches.iter_custom(|iters| { let start = Instant::now(); for _i in 0..iters { - run(black_box(&tensor)); + rand_uniform(black_box(&tensor)); } if let Device::Metal(device) = &device { device.wait_until_completed().unwrap(); @@ -35,6 +40,26 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); group.finish(); + + let tensor = Tensor::zeros((b, rows, cols), dtype, &device2).unwrap(); + + let mut group = c.benchmark_group("metal_random_normal"); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |benches| { + benches.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + rand_normal(black_box(&tensor)); + } + if let Device::Metal(device) = &device2 { + device.wait_until_completed().unwrap(); + } else { + panic!("Expected metal device"); + } + start.elapsed() + }) + }); + group.finish(); } criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 059cf24b..73a532e6 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -1385,7 +1385,7 @@ impl BackendDevice for MetalDevice { compute_per_buffer, buffers, kernels, - seed + seed, }) } @@ -1467,8 +1467,9 @@ impl BackendDevice for MetalDevice { min as f32, max as f32, shape.elem_count(), - &buffer - ).map_err(MetalError::from)?; + &buffer, + ) + .map_err(MetalError::from)?; Ok(Self::Storage::new(buffer, self.clone(), dtype)) } @@ -1480,9 +1481,28 @@ impl BackendDevice for MetalDevice { mean: f64, stddev: f64, ) -> Result { - // TODO is there a better way ? - let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?; - self.storage_from_cpu_storage(&cpu_storage) + let name = match dtype { + DType::F32 => "rand_normal_f32", + DType::F16 => "rand_normal_f16", + DType::BF16 => "rand_normal_bf16", + dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"), + }; + let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?; + let command_buffer = self.command_buffer()?; + candle_metal_kernels::call_random_normal( + &self.device, + &command_buffer, + &self.kernels, + name, + *self.seed.lock().unwrap(), + mean as f32, + stddev as f32, + shape.elem_count(), + &buffer, + ) + .map_err(MetalError::from)?; + + Ok(Self::Storage::new(buffer, self.clone(), dtype)) } } diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 04442c8a..e2603b3b 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1415,7 +1415,6 @@ pub fn call_gemm( height: 1, depth: 1, }; - // println!("grid size {grid_size:?} group size {group_size:?}"); encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); @@ -1588,39 +1587,11 @@ pub fn call_random_uniform( "min must be less than max".to_string(), )); } - - let size: usize = match name { - "rand_uniform_f32" => 4, - "rand_uniform_f16" | "rand_uniform_bf16" => 2, - _ => Err(MetalKernelError::LoadLibraryError(format!( - "{name} is not a valid kernel for random" - )))?, - }; - - let elems_per_key = length; - let bytes_per_key = size * elems_per_key; - - let out_per_key = (bytes_per_key + 4 - 1) / 4; - let half_size = out_per_key / 2; - let odd = length % 2 != 0; - let pipeline = kernels.load_pipeline(device, Source::Random, name)?; let encoder = command_buffer.new_compute_command_encoder(); - let thread_group_count = MTLSize { - width: length as u64, - height: half_size as u64 + odd as u64, - depth: 1, - }; - let threads = std::cmp::min( - (half_size + odd as usize) as NSUInteger, - pipeline.max_total_threads_per_threadgroup(), - ); - let thread_group_size = MTLSize { - width: threads, - height: 1, - depth: 1, - }; + let odd = (length % 2 != 0) as usize; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); @@ -1635,5 +1606,36 @@ pub fn call_random_uniform( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_random_normal( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + seed: u64, + mean: f32, + stddev: f32, + length: usize, + buffer: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Random, name)?; + let encoder = command_buffer.new_compute_command_encoder(); + + let odd = (length % 2 != 0) as usize; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); + + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, seed, mean, stddev, buffer)); + + encoder.use_resource(buffer, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/random.metal b/candle-metal-kernels/src/random.metal index 1604123d..5369e8e2 100644 --- a/candle-metal-kernels/src/random.metal +++ b/candle-metal-kernels/src/random.metal @@ -33,29 +33,34 @@ static constexpr constant uint64_t PHI[16] = { // Combined Tausworthe and LCG Random Number Generator. // https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application // https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf -class HybridTaus { -private: - thread float seed; +struct HybridTaus { + + float state; + + HybridTaus() thread = default; + HybridTaus() threadgroup = default; + HybridTaus() device = default; + HybridTaus() constant = default; // Generate seeds for each thread. - thread uint4 seed_per_thread(const ulong4 seeds) { + METAL_FUNC static uint4 seed_per_thread(const ulong4 seeds) { return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL)); } // Tausworthe generator. - thread uint taus(const uint z, const int3 s, const uint M) { + METAL_FUNC static uint taus(const uint z, const int3 s, const uint M) { uint b = (((z << s.x) ^ z) >> s.y); return (((z & M) << s.z) ^ b); } // LCG generator. - thread uint lcg(const uint z) { + METAL_FUNC static uint lcg(const uint z) { return (1664525 * z + 1013904223UL); } -public: - thread HybridTaus(const ulong4 seeds) { - uint4 seed = this->seed_per_thread(seeds); + // Initialize the RNG state. + METAL_FUNC static HybridTaus init(const ulong4 seeds) { + uint4 seed = seed_per_thread(seeds); // Seed #1 uint z1 = taus(seed.x, S1, 4294967294UL); @@ -84,52 +89,96 @@ public: z3 = taus(r1, S3, 429496280UL); z4 = lcg(r1); - this->seed = (z1^z2^z3^z4) * UNIF01_INV32; + HybridTaus rng; + rng.state = (z1^z2^z3^z4) * UNIF01_INV32; + return rng; } - thread float rand() { - uint seed = this->seed * UNIF01_NORM32; + METAL_FUNC float rand() { + uint seed = this->state * UNIF01_NORM32; uint z1 = taus(seed, S1, 429496729UL); uint z2 = taus(seed, S2, 4294967288UL); uint z3 = taus(seed, S3, 429496280UL); uint z4 = lcg(seed); - thread float old_seed = this->seed; - this->seed = (z1^z2^z3^z4) * UNIF01_INV32; - return old_seed; + thread float result = this->state; + this->state = (z1^z2^z3^z4) * UNIF01_INV32; + return result; } }; template METAL_FUNC void rand_uniform( - constant size_t &elem_count, + constant size_t &size, constant ulong &seed, constant float &min, constant float &max, device T *out, uint tid [[thread_position_in_grid]] ) { - if (tid >= elem_count) { + if (tid >= size) { return; } float diff = max - min; - HybridTaus rng = HybridTaus({seed, tid, 1, 1}); + HybridTaus rng = HybridTaus::init({seed, tid, 1, 1}); out[tid] = static_cast(rng.rand() * diff + min); + out[size - tid] = static_cast(rng.rand() * diff + min); } -#define UNIFORM_OP(NAME, T) \ -kernel void rand_uniform_##NAME( \ - constant size_t &elem_count, \ - constant ulong &seed, \ - constant float &min, \ - constant float &max, \ - device T *out, \ - uint tid [[thread_position_in_grid]] \ -) { \ - rand_uniform(elem_count, seed, min, max, out, tid); \ -} \ +// Create Gaussian normal distribution using Box-Muller transform: +// https://en.wikipedia.org/wiki/Box–Muller_transform +template METAL_FUNC void normal( + constant size_t &size, + constant ulong &seed, + constant float &mean, + constant float &stddev, + device T *out, + uint tid [[thread_position_in_grid]] +) { + if (tid >= size) { + return; + } + HybridTaus rng = HybridTaus::init({seed, tid, 1, 1}); + float u1 = rng.rand(); + float u2 = rng.rand(); + + float cosval; + float sinval = sincos(u1 * TWO_PI, cosval); + float mag = stddev * sqrt(-2.0 * log(u1)); + float z0 = mag * cosval + mean; + float z1 = mag * sinval + mean; + + out[tid] = static_cast(z0); + out[size - tid] = static_cast(z1); +} + +#define UNIFORM_OP(NAME, T) \ +kernel void rand_uniform_##NAME( \ + constant size_t &size, \ + constant ulong &seed, \ + constant float &min, \ + constant float &max, \ + device T *out, \ + uint tid [[thread_position_in_grid]] \ +) { \ + rand_uniform(size, seed, min, max, out, tid); \ +} \ + +#define NORMAL_OP(NAME, T) \ +kernel void rand_normal_##NAME( \ + constant size_t &size, \ + constant ulong &seed, \ + constant float &mean, \ + constant float &stddev, \ + device T *out, \ + uint tid [[thread_position_in_grid]] \ +) { \ + normal(size, seed, mean, stddev, out, tid); \ +} \ + #define RANDOM_OPS(NAME, T) \ UNIFORM_OP(NAME, T) \ +NORMAL_OP(NAME, T) \ RANDOM_OPS(f32, float) RANDOM_OPS(f16, half) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index d0ca8330..067dece8 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -806,28 +806,43 @@ fn gemm() { ); } -fn run_random(seed: u64, shape: &[usize], name: &'static str, min: f32, max: f32) -> Vec { +fn run_random(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec { let device = device(); let fence = device.new_fence(); let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; - let length = shape.iter().product::(); let output = device.new_buffer((length * core::mem::size_of::()) as u64, options); - call_random_uniform( - &device, - command_buffer, - &kernels, - name, - seed, - min, - max, - length, - &output, - ) - .unwrap(); + if name.starts_with("rand_uniform") { + call_random_uniform( + &device, + command_buffer, + &kernels, + name, + seed, + a, + b, + length, + &output, + ) + .unwrap(); + } else { + call_random_normal( + &device, + command_buffer, + &kernels, + name, + seed, + a, + b, + length, + &output, + ) + .unwrap(); + } + command_buffer.commit(); command_buffer.wait_until_completed(); @@ -837,24 +852,50 @@ fn run_random(seed: u64, shape: &[usize], name: &'static str, min: f32 #[test] fn random() { - use std::fs::File; - use std::io::prelude::*; - let shape = vec![1024, 4]; - let seed = 299792458; - let min = -30.0; - let max = 30.0; - let results = run_random::(seed, &shape, "rand_uniform_f32", min, max); - for &v in &results { - assert!(v >= min && v <= max); + fn calc_mean(data: &[f32]) -> f32 { + let sum = data.iter().sum::() as f32; + let count = data.len(); + assert!(count > 0); + sum / count as f32 } - // Writing bytes to file for testing with ENT - // https://www.fourmilab.ch/random/ - // TODO: Remove before merge - let (head, body, tail) = unsafe { results.align_to::() }; - assert!(head.is_empty()); - assert!(tail.is_empty()); - let mut file = File::create("test").unwrap(); - file.write_all(body).unwrap(); -} + fn calc_stddev(data: &[f32]) -> f32 { + let mean = calc_mean(data); + let count = data.len(); + assert!(count > 0); + + let variance = data.iter().map(|value| { + let diff = mean - (*value as f32); + diff * diff + }).sum::() / count as f32; + + variance.sqrt() + } + + let shape = vec![1024, 10]; + + let length = shape.iter().product::(); + let seed = 299792458; + + let min = -30.0; + let max = 30.0; + let mean = 100.0; + let stddev = 50.0; + + macro_rules! validate_random { + ($type:ty) => { + let results: Vec = run_random::<$type>(concat!("rand_uniform_", stringify!($type)), seed, length, min, max).into_iter().map(f32::from).collect(); + results.iter().for_each(|v| assert!(*v >= min && *v <= max)); + assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0); + + let results: Vec = run_random::<$type>(concat!("rand_normal_", stringify!($type)), seed, length, mean, stddev).into_iter().map(f32::from).collect(); + assert!((calc_mean(&results) - mean).abs() < mean / 10.0); + assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0); + }; + } + + validate_random!(f32); + validate_random!(f16); + validate_random!(bf16); +} \ No newline at end of file From b4cb982e498fc121992e7c03d00d04755a66001f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sun, 7 Jan 2024 12:04:14 +0100 Subject: [PATCH 03/46] Simplifying our internal cargo dependencies. (#1529) --- Cargo.toml | 8 ++++++++ candle-book/Cargo.toml | 10 +++++----- candle-core/Cargo.toml | 4 ++-- candle-datasets/Cargo.toml | 4 ++-- candle-examples/Cargo.toml | 12 ++++++------ candle-nn/Cargo.toml | 4 ++-- candle-onnx/Cargo.toml | 1 - candle-pyo3/Cargo.toml | 6 +++--- candle-transformers/Cargo.toml | 6 +++--- candle-wasm-examples/bert/Cargo.toml | 6 +++--- candle-wasm-examples/blip/Cargo.toml | 6 +++--- candle-wasm-examples/llama2-c/Cargo.toml | 6 +++--- candle-wasm-examples/phi/Cargo.toml | 6 +++--- candle-wasm-examples/segment-anything/Cargo.toml | 6 +++--- candle-wasm-examples/t5/Cargo.toml | 6 +++--- candle-wasm-examples/whisper/Cargo.toml | 6 +++--- candle-wasm-examples/yolo/Cargo.toml | 4 ++-- candle-wasm-tests/Cargo.toml | 2 +- 18 files changed, 55 insertions(+), 48 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7d61cd74..3d66a02f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,14 @@ license = "MIT OR Apache-2.0" accelerate-src = { version = "0.3.2" } anyhow = { version = "1", features = ["backtrace"] } byteorder = "1.4.3" +candle = { path = "./candle-core", package = "candle-core" } +candle-datasets = { path = "./candle-datasets" } +candle-flash-attn = { path = "./candle-flash-attn" } +candle-kernels = { path = "./candle-kernels" } +candle-metal-kernels = { path = "./candle-metal-kernels" } +candle-nn = { path = "./candle-nn" } +candle-onnx = { path = "./candle-onnx" } +candle-transformers = { path = "./candle-transformers" } clap = { version = "4.2.4", features = ["derive"] } criterion = { version = "0.5.1", default-features=false } cudarc = { version = "0.9.14", features = ["f16"] } diff --git a/candle-book/Cargo.toml b/candle-book/Cargo.toml index e28e6623..5ccda31e 100644 --- a/candle-book/Cargo.toml +++ b/candle-book/Cargo.toml @@ -11,11 +11,11 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-datasets = { path = "../candle-datasets", version = "0.3.3" } -candle-nn = { path = "../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../candle-transformers", version = "0.3.3" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true } +candle = { workspace = true } +candle-datasets = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } +candle-flash-attn = { workspace = true, optional = true } safetensors = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 91655f57..97857a6b 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -12,8 +12,8 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } -candle-kernels = { path = "../candle-kernels", version = "0.3.3", optional = true } -candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.3", optional = true } +candle-kernels = { workspace = true, optional = true } +candle-metal-kernels = { workspace = true, optional = true } metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml index 69438e0e..ccabf7ed 100644 --- a/candle-datasets/Cargo.toml +++ b/candle-datasets/Cargo.toml @@ -11,8 +11,8 @@ readme = "README.md" [dependencies] byteorder = { workspace = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } hf-hub = { workspace = true} intel-mkl-src = { workspace = true, optional = true } memmap2 = { workspace = true } diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 7e081530..439116f8 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -11,12 +11,12 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-datasets = { path = "../candle-datasets", version = "0.3.3" } -candle-nn = { path = "../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../candle-transformers", version = "0.3.3" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true } -candle-onnx = { path = "../candle-onnx", version = "0.3.3", optional = true } +candle = { workspace = true } +candle-datasets = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } +candle-flash-attn = { workspace = true, optional = true } +candle-onnx = { workspace = true, optional = true } csv = "1.3.0" cudarc = { workspace = true, optional = true } diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index 5e0e5c2b..214e8a59 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -11,7 +11,7 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } +candle = { workspace = true } half = { workspace = true } thiserror = { workspace = true } intel-mkl-src = { workspace = true, optional = true } @@ -20,7 +20,7 @@ rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } metal = { workspace = true, optional = true } -candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } +candle-metal-kernels = { workspace = true, optional = true } [dev-dependencies] anyhow = { workspace = true } diff --git a/candle-onnx/Cargo.toml b/candle-onnx/Cargo.toml index ba33b07a..cf7add01 100644 --- a/candle-onnx/Cargo.toml +++ b/candle-onnx/Cargo.toml @@ -20,4 +20,3 @@ prost-build = "0.12.1" [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } clap = { version = "4.2.4", features = ["derive"] } - diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index a03c7559..7c6fbd68 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -15,9 +15,9 @@ crate-type = ["cdylib"] [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.3.3" } -candle-onnx = {path= "../candle-onnx", version = "0.3.3", optional = true} +candle = { workspace = true } +candle-nn = { workspace = true } +candle-onnx = { workspace = true, optional = true } half = { workspace = true } intel-mkl-src = { workspace = true, optional = true } pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] } diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 83bcff62..1a72c36a 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -12,9 +12,9 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true } -candle-nn = { path = "../candle-nn", version = "0.3.3" } +candle = { workspace = true } +candle-flash-attn = { workspace = true, optional = true } +candle-nn = { workspace = true } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } rand = { workspace = true } diff --git a/candle-wasm-examples/bert/Cargo.toml b/candle-wasm-examples/bert/Cargo.toml index 59ce1be3..259a6102 100644 --- a/candle-wasm-examples/bert/Cargo.toml +++ b/candle-wasm-examples/bert/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/blip/Cargo.toml b/candle-wasm-examples/blip/Cargo.toml index 904e90e6..f4de054e 100644 --- a/candle-wasm-examples/blip/Cargo.toml +++ b/candle-wasm-examples/blip/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } num-traits = { workspace = true } diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml index 63f8a9c5..ac89a558 100644 --- a/candle-wasm-examples/llama2-c/Cargo.toml +++ b/candle-wasm-examples/llama2-c/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/phi/Cargo.toml b/candle-wasm-examples/phi/Cargo.toml index c4950df9..e437a937 100644 --- a/candle-wasm-examples/phi/Cargo.toml +++ b/candle-wasm-examples/phi/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } num-traits = { workspace = true } diff --git a/candle-wasm-examples/segment-anything/Cargo.toml b/candle-wasm-examples/segment-anything/Cargo.toml index 4d886bc2..1840bb62 100644 --- a/candle-wasm-examples/segment-anything/Cargo.toml +++ b/candle-wasm-examples/segment-anything/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } num-traits = { workspace = true } # App crates. diff --git a/candle-wasm-examples/t5/Cargo.toml b/candle-wasm-examples/t5/Cargo.toml index 237f9e61..36cd9386 100644 --- a/candle-wasm-examples/t5/Cargo.toml +++ b/candle-wasm-examples/t5/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml index 5d2b2a38..6c6857e4 100644 --- a/candle-wasm-examples/whisper/Cargo.toml +++ b/candle-wasm-examples/whisper/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/yolo/Cargo.toml b/candle-wasm-examples/yolo/Cargo.toml index eb2c320b..0e5a91a8 100644 --- a/candle-wasm-examples/yolo/Cargo.toml +++ b/candle-wasm-examples/yolo/Cargo.toml @@ -9,8 +9,8 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } num-traits = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/candle-wasm-tests/Cargo.toml b/candle-wasm-tests/Cargo.toml index a684f2ce..40c37acd 100644 --- a/candle-wasm-tests/Cargo.toml +++ b/candle-wasm-tests/Cargo.toml @@ -7,7 +7,7 @@ keywords.workspace = true categories.workspace = true [dependencies] -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } +candle = { workspace = true } rand = { workspace = true } getrandom = { version = "0.2", features = ["js"] } From e72d52b1a2118f8773866e87237586bab762a9c6 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 7 Jan 2024 12:26:20 +0100 Subject: [PATCH 04/46] Unpin more of the workplace relative dependencies. (#1535) --- candle-flash-attn/Cargo.toml | 4 ++-- candle-onnx/Cargo.toml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index 64e690e6..0d3af91d 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0" readme = "README.md" [dependencies] -candle = { path = "../candle-core", features = ["cuda"], version = "0.3.3", package = "candle-core" } +candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" } half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] @@ -21,4 +21,4 @@ rayon = "1.7.0" [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } -candle-nn = { path = "../candle-nn", version = "0.3.3", features = ["cuda"] } +candle-nn = { path = "../candle-nn", features = ["cuda"] } diff --git a/candle-onnx/Cargo.toml b/candle-onnx/Cargo.toml index cf7add01..de1e3350 100644 --- a/candle-onnx/Cargo.toml +++ b/candle-onnx/Cargo.toml @@ -10,8 +10,8 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.3.3" } +candle = { path = "../candle-core", package = "candle-core" } +candle-nn = { path = "../candle-nn" } prost = "0.12.1" [build-dependencies] From 30313c308106fff7b20fc8cb2b27eb79800cb818 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sun, 7 Jan 2024 12:29:24 +0100 Subject: [PATCH 05/46] Moving to a proper build crate `bindgen_cuda`. (#1531) * Moving to a proper build crate `bindgen_cuda`. * Fmt. --- candle-flash-attn/Cargo.toml | 4 +- candle-flash-attn/build.rs | 273 +++++------------------------------ candle-kernels/Cargo.toml | 4 +- candle-kernels/build.rs | 243 +------------------------------ 4 files changed, 41 insertions(+), 483 deletions(-) diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index 0d3af91d..d8e8da82 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -15,9 +15,9 @@ candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] +bindgen_cuda = "0.1.1" anyhow = { version = "1", features = ["backtrace"] } -num_cpus = "1.15.0" -rayon = "1.7.0" + [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index fde3aeed..4002770b 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -2,44 +2,32 @@ // The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment // variable in order to cache the compiled artifacts and avoid recompiling too often. use anyhow::{Context, Result}; -use rayon::prelude::*; use std::path::PathBuf; -use std::str::FromStr; const KERNEL_FILES: [&str; 17] = [ - "flash_api.cu", - "flash_fwd_hdim128_fp16_sm80.cu", - "flash_fwd_hdim160_fp16_sm80.cu", - "flash_fwd_hdim192_fp16_sm80.cu", - "flash_fwd_hdim224_fp16_sm80.cu", - "flash_fwd_hdim256_fp16_sm80.cu", - "flash_fwd_hdim32_fp16_sm80.cu", - "flash_fwd_hdim64_fp16_sm80.cu", - "flash_fwd_hdim96_fp16_sm80.cu", - "flash_fwd_hdim128_bf16_sm80.cu", - "flash_fwd_hdim160_bf16_sm80.cu", - "flash_fwd_hdim192_bf16_sm80.cu", - "flash_fwd_hdim224_bf16_sm80.cu", - "flash_fwd_hdim256_bf16_sm80.cu", - "flash_fwd_hdim32_bf16_sm80.cu", - "flash_fwd_hdim64_bf16_sm80.cu", - "flash_fwd_hdim96_bf16_sm80.cu", + "kernels/flash_api.cu", + "kernels/flash_fwd_hdim128_fp16_sm80.cu", + "kernels/flash_fwd_hdim160_fp16_sm80.cu", + "kernels/flash_fwd_hdim192_fp16_sm80.cu", + "kernels/flash_fwd_hdim224_fp16_sm80.cu", + "kernels/flash_fwd_hdim256_fp16_sm80.cu", + "kernels/flash_fwd_hdim32_fp16_sm80.cu", + "kernels/flash_fwd_hdim64_fp16_sm80.cu", + "kernels/flash_fwd_hdim96_fp16_sm80.cu", + "kernels/flash_fwd_hdim128_bf16_sm80.cu", + "kernels/flash_fwd_hdim160_bf16_sm80.cu", + "kernels/flash_fwd_hdim192_bf16_sm80.cu", + "kernels/flash_fwd_hdim224_bf16_sm80.cu", + "kernels/flash_fwd_hdim256_bf16_sm80.cu", + "kernels/flash_fwd_hdim32_bf16_sm80.cu", + "kernels/flash_fwd_hdim64_bf16_sm80.cu", + "kernels/flash_fwd_hdim96_bf16_sm80.cu", ]; fn main() -> Result<()> { - let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else( - |_| num_cpus::get_physical(), - |s| usize::from_str(&s).unwrap(), - ); - - rayon::ThreadPoolBuilder::new() - .num_threads(num_cpus) - .build_global() - .unwrap(); - println!("cargo:rerun-if-changed=build.rs"); for kernel_file in KERNEL_FILES.iter() { - println!("cargo:rerun-if-changed=kernels/{kernel_file}"); + println!("cargo:rerun-if-changed={kernel_file}"); } println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h"); println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h"); @@ -66,223 +54,30 @@ fn main() -> Result<()> { )) } }; - set_cuda_include_dir()?; - let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); - println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); - - let compute_cap = compute_cap()?; + let kernels = KERNEL_FILES.iter().collect(); + let builder = bindgen_cuda::Builder::default() + .kernel_paths(kernels) + .out_dir(build_dir.clone()) + .arg("-std=c++17") + .arg("-O3") + .arg("-U__CUDA_NO_HALF_OPERATORS__") + .arg("-U__CUDA_NO_HALF_CONVERSIONS__") + .arg("-U__CUDA_NO_HALF2_OPERATORS__") + .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + .arg("-Icutlass/include") + .arg("--expt-relaxed-constexpr") + .arg("--expt-extended-lambda") + .arg("--use_fast_math") + .arg("--verbose"); let out_file = build_dir.join("libflashattention.a"); + builder.build_lib(out_file); - let kernel_dir = PathBuf::from("kernels"); - let cu_files: Vec<_> = KERNEL_FILES - .iter() - .map(|f| { - let mut obj_file = out_dir.join(f); - obj_file.set_extension("o"); - (kernel_dir.join(f), obj_file) - }) - .collect(); - let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified()); - let should_compile = if out_file.exists() { - kernel_dir - .read_dir() - .expect("kernels folder should exist") - .any(|entry| { - if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) { - let in_modified = entry.metadata().unwrap().modified().unwrap(); - in_modified.duration_since(*out_modified).is_ok() - } else { - true - } - }) - } else { - true - }; - if should_compile { - cu_files - .par_iter() - .map(|(cu_file, obj_file)| { - let mut command = std::process::Command::new("nvcc"); - command - .arg("-std=c++17") - .arg("-O3") - .arg("-U__CUDA_NO_HALF_OPERATORS__") - .arg("-U__CUDA_NO_HALF_CONVERSIONS__") - .arg("-U__CUDA_NO_HALF2_OPERATORS__") - .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") - .arg(format!("--gpu-architecture=sm_{compute_cap}")) - .arg("-c") - .args(["-o", obj_file.to_str().unwrap()]) - .args(["--default-stream", "per-thread"]) - .arg("-Icutlass/include") - .arg("--expt-relaxed-constexpr") - .arg("--expt-extended-lambda") - .arg("--use_fast_math") - .arg("--verbose"); - if let Ok(ccbin_path) = &ccbin_env { - command - .arg("-allow-unsupported-compiler") - .args(["-ccbin", ccbin_path]); - } - command.arg(cu_file); - let output = command - .spawn() - .context("failed spawning nvcc")? - .wait_with_output()?; - if !output.status.success() { - anyhow::bail!( - "nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}", - &command, - String::from_utf8_lossy(&output.stdout), - String::from_utf8_lossy(&output.stderr) - ) - } - Ok(()) - }) - .collect::>()?; - let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::>(); - let mut command = std::process::Command::new("nvcc"); - command - .arg("--lib") - .args(["-o", out_file.to_str().unwrap()]) - .args(obj_files); - let output = command - .spawn() - .context("failed spawning nvcc")? - .wait_with_output()?; - if !output.status.success() { - anyhow::bail!( - "nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}", - &command, - String::from_utf8_lossy(&output.stdout), - String::from_utf8_lossy(&output.stderr) - ) - } - } println!("cargo:rustc-link-search={}", build_dir.display()); println!("cargo:rustc-link-lib=flashattention"); println!("cargo:rustc-link-lib=dylib=cudart"); println!("cargo:rustc-link-lib=dylib=stdc++"); - /* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never - finishing to run for some reason. Calling nvcc manually worked fine. - cc::Build::new() - .cuda(true) - .include("cutlass/include") - .flag("--expt-relaxed-constexpr") - .flag("--default-stream") - .flag("per-thread") - .flag(&format!("--gpu-architecture=sm_{compute_cap}")) - .file("kernels/flash_fwd_hdim32_fp16_sm80.cu") - .compile("flashattn"); - */ Ok(()) } - -fn set_cuda_include_dir() -> Result<()> { - // NOTE: copied from cudarc build.rs. - let env_vars = [ - "CUDA_PATH", - "CUDA_ROOT", - "CUDA_TOOLKIT_ROOT_DIR", - "CUDNN_LIB", - ]; - let env_vars = env_vars - .into_iter() - .map(std::env::var) - .filter_map(Result::ok) - .map(Into::::into); - - let roots = [ - "/usr", - "/usr/local/cuda", - "/opt/cuda", - "/usr/lib/cuda", - "C:/Program Files/NVIDIA GPU Computing Toolkit", - "C:/CUDA", - ]; - let roots = roots.into_iter().map(Into::::into); - let root = env_vars - .chain(roots) - .find(|path| path.join("include").join("cuda.h").is_file()) - .context("cannot find include/cuda.h")?; - println!( - "cargo:rustc-env=CUDA_INCLUDE_DIR={}", - root.join("include").display() - ); - Ok(()) -} - -#[allow(unused)] -fn compute_cap() -> Result { - println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); - - // Try to parse compute caps from env - let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}"); - compute_cap_str - .parse::() - .context("Could not parse compute cap")? - } else { - // Use nvidia-smi to get the current compute cap - let out = std::process::Command::new("nvidia-smi") - .arg("--query-gpu=compute_cap") - .arg("--format=csv") - .output() - .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; - let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; - let mut lines = out.lines(); - assert_eq!( - lines.next().context("missing line in stdout")?, - "compute_cap" - ); - let cap = lines - .next() - .context("missing line in stdout")? - .replace('.', ""); - let cap = cap - .parse::() - .with_context(|| format!("cannot parse as int {cap}"))?; - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}"); - cap - }; - - // Grab available GPU codes from nvcc and select the highest one - let (supported_nvcc_codes, max_nvcc_code) = { - let out = std::process::Command::new("nvcc") - .arg("--list-gpu-code") - .output() - .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); - let out = std::str::from_utf8(&out.stdout).unwrap(); - - let out = out.lines().collect::>(); - let mut codes = Vec::with_capacity(out.len()); - for code in out { - let code = code.split('_').collect::>(); - if !code.is_empty() && code.contains(&"sm") { - if let Ok(num) = code[1].parse::() { - codes.push(num); - } - } - } - codes.sort(); - let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?; - (codes, max_nvcc_code) - }; - - // Check that nvcc supports the asked compute caps - if !supported_nvcc_codes.contains(&compute_cap) { - anyhow::bail!( - "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}." - ); - } - if compute_cap > max_nvcc_code { - anyhow::bail!( - "CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}" - ); - } - - Ok(compute_cap) -} diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml index e81fe39c..0cd4a14d 100644 --- a/candle-kernels/Cargo.toml +++ b/candle-kernels/Cargo.toml @@ -12,6 +12,4 @@ license = "MIT OR Apache-2.0" [dependencies] [build-dependencies] -anyhow = { version = "1", features = ["backtrace"] } -glob = "0.3.1" -rayon = "1.7.0" +bindgen_cuda = "0.1.1" diff --git a/candle-kernels/build.rs b/candle-kernels/build.rs index 17a0bf9c..63d744ca 100644 --- a/candle-kernels/build.rs +++ b/candle-kernels/build.rs @@ -1,243 +1,8 @@ -use std::io::Write; - fn main() { println!("cargo:rerun-if-changed=build.rs"); - cuda::set_include_dir(); - let (write, kernel_paths) = cuda::build_ptx(); - if write { - let mut file = std::fs::File::create("src/lib.rs").unwrap(); - for kernel_path in kernel_paths { - let name = kernel_path.file_stem().unwrap().to_str().unwrap(); - file.write_all( - format!( - r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#, - name.to_uppercase().replace('.', "_"), - name - ) - .as_bytes(), - ) - .unwrap(); - file.write_all(&[b'\n']).unwrap(); - } - } -} - -mod cuda { - use anyhow::{Context, Result}; - - pub fn set_include_dir() { - use std::path::PathBuf; - // NOTE: copied from cudarc build.rs. - // We can't actually set a env!() value from another crate, - // so we have to do that here. - - // use PathBuf; - - let env_vars = [ - "CUDA_PATH", - "CUDA_ROOT", - "CUDA_TOOLKIT_ROOT_DIR", - "CUDNN_LIB", - ]; - #[allow(unused)] - let env_vars = env_vars - .into_iter() - .map(std::env::var) - .filter_map(Result::ok) - .map(Into::::into); - - let roots = [ - "/usr", - "/usr/local/cuda", - "/opt/cuda", - "/usr/lib/cuda", - "C:/Program Files/NVIDIA GPU Computing Toolkit", - "C:/CUDA", - ]; - #[allow(unused)] - let roots = roots.into_iter().map(Into::::into); - - #[cfg(feature = "ci-check")] - let root: PathBuf = "ci".into(); - - #[cfg(not(feature = "ci-check"))] - let root = env_vars - .chain(roots) - .find(|path| path.join("include").join("cuda.h").is_file()) - .unwrap(); - - println!( - "cargo:rustc-env=CUDA_INCLUDE_DIR={}", - root.join("include").display() - ); - } - - pub fn build_ptx() -> (bool, Vec) { - use rayon::prelude::*; - use std::path::PathBuf; - let out_dir = std::env::var("OUT_DIR").unwrap(); - let kernel_paths: Vec = glob::glob("src/*.cu") - .unwrap() - .map(|p| p.unwrap()) - .collect(); - let mut include_directories: Vec = glob::glob("src/**/*.cuh") - .unwrap() - .map(|p| p.unwrap()) - .collect(); - - println!("cargo:rerun-if-changed=src/"); - // for path in &kernel_paths { - // println!("cargo:rerun-if-changed={}", path.display()); - // } - - for path in &mut include_directories { - // println!("cargo:rerun-if-changed={}", path.display()); - let destination = - std::format!("{out_dir}/{}", path.file_name().unwrap().to_str().unwrap()); - std::fs::copy(path.clone(), destination).unwrap(); - // remove the filename from the path so it's just the directory - path.pop(); - } - - include_directories.sort(); - include_directories.dedup(); - - let compute_cap = compute_cap().expect("Could not get Cuda compute cap"); - - #[allow(unused)] - let include_options: Vec = include_directories - .into_iter() - .map(|s| "-I".to_string() + &s.into_os_string().into_string().unwrap()) - .collect::>(); - - let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); - println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); - let children = kernel_paths - .par_iter() - .flat_map(|p| { - let mut output = p.clone(); - output.set_extension("ptx"); - let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap()); - - let ignore = if output_filename.exists() { - let out_modified = output_filename.metadata().unwrap().modified().unwrap(); - let in_modified = p.metadata().unwrap().modified().unwrap(); - out_modified.duration_since(in_modified).is_ok() - } else { - false - }; - if ignore { - None - } else { - let mut command = std::process::Command::new("nvcc"); - command.arg(format!("--gpu-architecture=sm_{compute_cap}")) - .arg("--ptx") - .args(["--default-stream", "per-thread"]) - .args(["--output-directory", &out_dir]) - // Flash attention only - // .arg("--expt-relaxed-constexpr") - .args(&include_options); - if let Ok(ccbin_path) = &ccbin_env { - command - .arg("-allow-unsupported-compiler") - .args(["-ccbin", ccbin_path]); - } - command.arg(p); - Some((p, command.spawn() - .expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output())) - } - }) - .collect::>(); - - let ptx_paths: Vec = glob::glob(&format!("{out_dir}/**/*.ptx")) - .unwrap() - .map(|p| p.unwrap()) - .collect(); - // We should rewrite `src/lib.rs` only if there are some newly compiled kernels, or removed - // some old ones - let write = !children.is_empty() || kernel_paths.len() < ptx_paths.len(); - for (kernel_path, child) in children { - let output = child.expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); - assert!( - output.status.success(), - "nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", - String::from_utf8_lossy(&output.stdout), - String::from_utf8_lossy(&output.stderr) - ); - } - (write, kernel_paths) - } - - #[allow(unused)] - fn compute_cap() -> Result { - println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); - - // Try to parse compute caps from env - let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}"); - compute_cap_str - .parse::() - .context("Could not parse code")? - } else { - // Use nvidia-smi to get the current compute cap - let out = std::process::Command::new("nvidia-smi") - .arg("--query-gpu=compute_cap") - .arg("--format=csv") - .output() - .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; - let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; - let mut lines = out.lines(); - assert_eq!( - lines.next().context("missing line in stdout")?, - "compute_cap" - ); - let cap = lines - .next() - .context("missing line in stdout")? - .replace('.', ""); - let cap = cap - .parse::() - .with_context(|| format!("cannot parse as int {cap}"))?; - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}"); - cap - }; - - // Grab available GPU codes from nvcc and select the highest one - let (supported_nvcc_codes, max_nvcc_code) = { - let out = std::process::Command::new("nvcc") - .arg("--list-gpu-code") - .output() - .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); - let out = std::str::from_utf8(&out.stdout).unwrap(); - - let out = out.lines().collect::>(); - let mut codes = Vec::with_capacity(out.len()); - for code in out { - let code = code.split('_').collect::>(); - if !code.is_empty() && code.contains(&"sm") { - if let Ok(num) = code[1].parse::() { - codes.push(num); - } - } - } - codes.sort(); - let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?; - (codes, max_nvcc_code) - }; - - // Check that nvcc supports the asked compute caps - if !supported_nvcc_codes.contains(&compute_cap) { - anyhow::bail!( - "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}." - ); - } - if compute_cap > max_nvcc_code { - anyhow::bail!( - "CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}" - ); - } - - Ok(compute_cap) - } + let builder = bindgen_cuda::Builder::default(); + println!("cargo:info={builder:?}"); + let bindings = builder.build_ptx().unwrap(); + bindings.write("src/lib.rs").unwrap(); } From 89b5a068585b73193d2004a7293d5b2fa6c30bfd Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 7 Jan 2024 17:18:46 +0100 Subject: [PATCH 06/46] Use bindgen-cuda for the custom-kernel example. (#1536) * Use bindgen-cuda for the custom-kernel example. * Only depend on the kernels when cuda is enabled. * Skip rustfmt. --- candle-examples/Cargo.toml | 3 +- candle-examples/build.rs | 247 ++---------------- .../examples/custom-ops/cuda_kernels.rs | 3 +- candle-examples/examples/custom-ops/main.rs | 3 +- 4 files changed, 20 insertions(+), 236 deletions(-) diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 439116f8..00340d08 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -49,11 +49,12 @@ tokio = "1.29.1" [build-dependencies] anyhow = { workspace = true } +bindgen_cuda = { version = "0.1.1", optional = true } [features] default = [] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] -cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] +cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"] cudnn = ["candle/cudnn"] flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] diff --git a/candle-examples/build.rs b/candle-examples/build.rs index 0af3a6a4..ba40aeb4 100644 --- a/candle-examples/build.rs +++ b/candle-examples/build.rs @@ -4,251 +4,34 @@ use std::io::Write; use std::path::PathBuf; struct KernelDirectories { - kernel_dir: &'static str, + kernel_glob: &'static str, rust_target: &'static str, include_dirs: &'static [&'static str], } -const DIRS: [KernelDirectories; 1] = [KernelDirectories { - kernel_dir: "examples/custom-ops/kernels/", +const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories { + kernel_glob: "examples/custom-ops/kernels/*.cu", rust_target: "examples/custom-ops/cuda_kernels.rs", include_dirs: &[], }]; -impl KernelDirectories { - fn maybe_build_ptx( - &self, - cu_file: &std::path::Path, - ptx_file: &std::path::Path, - compute_cap: usize, - ) -> Result<()> { - let should_compile = if ptx_file.exists() { - let ptx_modified = ptx_file.metadata()?.modified()?; - let cu_modified = cu_file.metadata()?.modified()?; - cu_modified.duration_since(ptx_modified).is_ok() - } else { - true - }; - if should_compile { - #[cfg(feature = "cuda")] - { - let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); - println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); - let mut command = std::process::Command::new("nvcc"); - let out_dir = ptx_file.parent().context("no parent for ptx file")?; - let include_dirs: Vec = - self.include_dirs.iter().map(|c| format!("-I{c}")).collect(); - command - .arg(format!("--gpu-architecture=sm_{compute_cap}")) - .arg("--ptx") - .args(["--default-stream", "per-thread"]) - .args(["--output-directory", out_dir.to_str().unwrap()]) - .arg(format!("-I/{}", self.kernel_dir)) - .args(include_dirs) - .arg(cu_file); - if let Ok(ccbin_path) = &ccbin_env { - command - .arg("-allow-unsupported-compiler") - .args(["-ccbin", ccbin_path]); - } - let output = command - .spawn() - .context("failed spawning nvcc")? - .wait_with_output()?; - if !output.status.success() { - anyhow::bail!( - "nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", - String::from_utf8_lossy(&output.stdout), - String::from_utf8_lossy(&output.stderr) - ) - } - } - #[cfg(not(feature = "cuda"))] - std::fs::OpenOptions::new() - .create(true) - .write(true) - .open(ptx_file)?; - } - Ok(()) - } - fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> { - println!("cargo:rerun-if-changed={}", self.kernel_dir); - let kernel_dir = PathBuf::from(self.kernel_dir); - let out_dir = out_dir.join(self.kernel_dir); - if !out_dir.exists() { - std::fs::create_dir_all(&out_dir)?; - } - let mut cu_files = vec![]; - let mut cuh_files = vec![]; - for file in std::fs::read_dir(kernel_dir)?.flatten() { - let file = file.path(); - match file.extension().and_then(|v| v.to_str()) { - Some("cu") => cu_files.push(file), - Some("cuh") => cuh_files.push(file), - _ => {} - } - } - - let mut ptx_paths = vec![]; - for cu_file in cu_files.iter() { - let file_stem = cu_file - .file_stem() - .with_context(|| format!("no stem {cu_file:?}"))?; - let file_stem = file_stem.to_string_lossy().into_owned(); - let ptx_file = out_dir.join(&format!("{file_stem}.ptx")); - self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?; - ptx_paths.push(ptx_file); - } - - let regenerate_rs_file = true; - if regenerate_rs_file { - let mut file = std::fs::File::create(self.rust_target)?; - for ptx_path in ptx_paths { - let name = ptx_path - .file_stem() - .context("empty stem")? - .to_string_lossy(); - file.write_all(b"#[rustfmt::skip]\n")?; - let const_definition = format!( - r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#, - name.to_uppercase().replace('.', "_"), - self.kernel_dir, - ); - file.write_all(const_definition.as_bytes())?; - file.write_all(b"\n")?; - } - } - Ok(()) - } -} - fn main() -> Result<()> { println!("cargo:rerun-if-changed=build.rs"); - let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?; - let out_dir = PathBuf::from(out_dir); #[cfg(feature = "cuda")] - set_cuda_include_dir()?; - #[cfg(feature = "cuda")] - let compute_cap = compute_cap()?; + { + for kdir in KERNEL_DIRS.iter() { + let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob); + println!("cargo:info={builder:?}"); + let bindings = builder.build_ptx().unwrap(); + bindings.write(kdir.rust_target).unwrap() + } + } #[cfg(not(feature = "cuda"))] - let compute_cap = 0; - for d in DIRS { - d.process(&out_dir, compute_cap)? + { + for kdir in KERNEL_DIRS.iter() { + let _file = std::fs::File::create(kdir.rust_target)?; + } } Ok(()) } - -fn set_cuda_include_dir() -> Result<()> { - // NOTE: copied from cudarc build.rs. - let env_vars = [ - "CUDA_PATH", - "CUDA_ROOT", - "CUDA_TOOLKIT_ROOT_DIR", - "CUDNN_LIB", - ]; - let env_vars = env_vars - .into_iter() - .map(std::env::var) - .filter_map(Result::ok) - .map(Into::::into); - - let roots = [ - "/usr", - "/usr/local/cuda", - "/opt/cuda", - "/usr/lib/cuda", - "C:/Program Files/NVIDIA GPU Computing Toolkit", - "C:/CUDA", - ]; - let roots = roots.into_iter().map(Into::::into); - let root = env_vars - .chain(roots) - .find(|path| path.join("include").join("cuda.h").is_file()) - .context("cannot find include/cuda.h")?; - println!( - "cargo:rustc-env=CUDA_INCLUDE_DIR={}", - root.join("include").display() - ); - Ok(()) -} - -#[allow(unused)] -fn compute_cap() -> Result { - println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); - - // Try to parse compute cap from env - let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}"); - compute_cap_str - .parse::() - .context("Could not parse code")? - } else { - // Grab compute cap from nvidia-smi - let out = std::process::Command::new("nvidia-smi") - .arg("--query-gpu=compute_cap") - .arg("--format=csv") - .output() - .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; - let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; - let mut lines = out.lines(); - assert_eq!( - lines.next().context("missing line in stdout")?, - "compute_cap" - ); - let cap = lines - .next() - .context("missing line in stdout")? - .replace('.', ""); - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}"); - cap.parse::() - .with_context(|| format!("cannot parse as int {cap}"))? - }; - - // Grab available GPU codes from nvcc and select the highest one - let max_nvcc_code = { - let out = std::process::Command::new("nvcc") - .arg("--list-gpu-code") - .output() - .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); - let out = std::str::from_utf8(&out.stdout).unwrap(); - - let out = out.lines().collect::>(); - let mut codes = Vec::with_capacity(out.len()); - for code in out { - let code = code.split('_').collect::>(); - if !code.is_empty() && code.contains(&"sm") { - if let Ok(num) = code[1].parse::() { - codes.push(num); - } - } - } - codes.sort(); - if !codes.contains(&compute_cap) { - anyhow::bail!( - "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}." - ); - } - *codes.last().unwrap() - }; - - // If nvidia-smi compute_cap is higher than the highest gpu code from nvcc, - // then choose the highest gpu code in nvcc - if compute_cap > max_nvcc_code { - println!( - "cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}." - ); - compute_cap = max_nvcc_code; - } - - println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); - - if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { - compute_cap = compute_cap_str - .parse::() - .with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?; - println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP"); - } - println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}"); - Ok(compute_cap) -} diff --git a/candle-examples/examples/custom-ops/cuda_kernels.rs b/candle-examples/examples/custom-ops/cuda_kernels.rs index 0bee73aa..c00b601b 100644 --- a/candle-examples/examples/custom-ops/cuda_kernels.rs +++ b/candle-examples/examples/custom-ops/cuda_kernels.rs @@ -1,2 +1 @@ -#[rustfmt::skip] -pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx")); +pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx")); diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs index f2f534dc..30e413c1 100644 --- a/candle-examples/examples/custom-ops/main.rs +++ b/candle-examples/examples/custom-ops/main.rs @@ -6,7 +6,8 @@ #[cfg(feature = "mkl")] extern crate intel_mkl_src; -#[allow(unused)] +#[rustfmt::skip] +#[cfg(feature = "cuda")] mod cuda_kernels; use clap::Parser; From 0eb90ed7831d451e2e420ecd158151b44dc5b2ba Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 7 Jan 2024 20:21:49 +0100 Subject: [PATCH 07/46] Simpler repro for the neon optimization issue + bugfix (#1544) * Simpler repro for the neon optimization issue. * Bugfix for q4k. * Improve the fix, share the dot-prod bit. * Clippy fixes. * Fix for q6k. * Also fix for q2k. * Use the new shared dotprod. * Add more testing. --- candle-core/src/quantized/neon.rs | 208 ++++++++------------------- candle-core/tests/quantized_tests.rs | 57 +++++--- 2 files changed, 97 insertions(+), 168 deletions(-) diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index 3cb56229..c4d5d6f4 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -12,6 +12,14 @@ use core::arch::arm::*; #[cfg(target_arch = "aarch64")] use core::arch::aarch64::*; +#[inline(always)] +unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t { + // TODO: dotprod + let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b)); + let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); + vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)) +} + #[inline(always)] pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result { let qk = QK8_0; @@ -43,15 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> let v1_0l = vld1q_s8(y0.qs.as_ptr()); let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16)); - // TODO: Support dotprod when it's available outside of nightly. - let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l)); - let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); - let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h)); - let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); - - let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - + let pl0 = vdotq_s32(v0_0ls, v1_0l); + let ph0 = vdotq_s32(v0_0hs, v1_0h); sumv0 = vmlaq_n_f32( sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), @@ -82,14 +83,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> let y0_0 = vld1q_s8(y0.qs.as_ptr()); let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); - // TODO dotprod once this is the intrinsics are. - let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0)); - let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); - let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1)); - let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); - - let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); - let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); + let p0 = vdotq_s32(x0_0, y0_0); + let p1 = vdotq_s32(x0_1, y0_1); sumv0 = vmlaq_n_f32( sumv0, @@ -118,10 +113,7 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res for i in (0..QK_K).step_by(16) { let xs = vld1q_s8(xs.add(i)); let ys = vld1q_s8(ys.add(i)); - let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys)); - let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys)); - - let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up)); + let xy = vdotq_s32(xs, ys); sum_i = vaddq_s32(sum_i, xy) } sumf += vaddvq_s32(sum_i) as f32 * scale @@ -191,30 +183,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3)); - // TODO: dotprod - - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)), - ); + let p0 = vdotq_s32(q6bytes_0, q8bytes.0); + let p1 = vdotq_s32(q6bytes_1, q8bytes.1); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1; + isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; scale = scale.add(2); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)), - vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)), - vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)), - ); + let p2 = vdotq_s32(q6bytes_2, q8bytes.2); + let p3 = vdotq_s32(q6bytes_3, q8bytes.3); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1; + isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; scale = scale.add(2); let q8bytes = vld1q_s8_x4(q8); @@ -234,29 +212,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3)); - // TODO: dotprod case. - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)), - ); + let p0 = vdotq_s32(q6bytes_0, q8bytes.0); + let p1 = vdotq_s32(q6bytes_1, q8bytes.1); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1; + isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; scale = scale.add(2); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)), - vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)), - vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)), - ); + let p2 = vdotq_s32(q6bytes_2, q8bytes.2); + let p3 = vdotq_s32(q6bytes_3, q8bytes.3); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1; + isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; scale = scale.add(2); } sum += d_all * y.d * ((isum - 32 * isum_mins) as f32); @@ -333,28 +298,14 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2)); let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3)); - // TODO: dotprod - - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)), - ); - sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32; + let p0 = vdotq_s32(q5bytes_0, q8bytes.0); + let p1 = vdotq_s32(q5bytes_1, q8bytes.1); + sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32; scales = scales.add(1); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)), - vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)), - vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)), - ); - sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32; + let p2 = vdotq_s32(q5bytes_2, q8bytes.2); + let p3 = vdotq_s32(q5bytes_3, q8bytes.3); + sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32; scales = scales.add(1); } sumf += d * sumi as f32 - dmin * sumi_mins as f32; @@ -417,22 +368,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res for j in 0..QK_K / 64 { let q4bits = vld1q_u8_x2(q4); q4 = q4.add(32); - // TODO: dotprod let q8bytes = vld1q_s8_x2(q8); q8 = q8.add(32); let q4bytes = int8x16x2_t( vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)), vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)), ); - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)), - ); - sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32; + let p0 = vdotq_s32(q4bytes.0, q8bytes.0); + let p1 = vdotq_s32(q4bytes.1, q8bytes.1); + sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32; let q8bytes = vld1q_s8_x2(q8); q8 = q8.add(32); @@ -440,15 +384,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)), vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)), ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)), - ); - sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32; + let p2 = vdotq_s32(q4bytes.0, q8bytes.0); + let p3 = vdotq_s32(q4bytes.1, q8bytes.1); + sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32; } sumf += d * (sumi1 + sumi2) as f32; } @@ -526,27 +464,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(q3h_3), ); - // TODO: dotprod - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)), - vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)), - vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)), - ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)), - vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)), - vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)), - ); - isum += vaddvq_s16(p0) as i32 * *scale as i32 - + vaddvq_s16(p1) as i32 * *scale.add(1) as i32 - + vaddvq_s16(p2) as i32 * *scale.add(2) as i32 - + vaddvq_s16(p3) as i32 * *scale.add(3) as i32; + let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0); + let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1); + let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2); + let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3); + isum += vaddvq_s32(p0) * *scale as i32 + + vaddvq_s32(p1) * *scale.add(1) as i32 + + vaddvq_s32(p2) * *scale.add(2) as i32 + + vaddvq_s32(p3) * *scale.add(3) as i32; scale = scale.add(4); let q3h_0 = vbicq_u8(m2, qhbits.0); @@ -571,27 +496,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(q3h_3), ); - // TODO: dotprod - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)), - vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)), - vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)), - ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)), - vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)), - vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)), - ); - isum += vaddvq_s16(p0) as i32 * *scale as i32 - + vaddvq_s16(p1) as i32 * *scale.add(1) as i32 - + vaddvq_s16(p2) as i32 * *scale.add(2) as i32 - + vaddvq_s16(p3) as i32 * *scale.add(3) as i32; + let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0); + let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1); + let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2); + let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3); + isum += vaddvq_s32(p0) * *scale as i32 + + vaddvq_s32(p1) * *scale.add(1) as i32 + + vaddvq_s32(p2) * *scale.add(2) as i32 + + vaddvq_s32(p3) * *scale.add(3) as i32; scale = scale.add(4); if j == 0 { @@ -649,7 +561,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res let mut is = 0usize; // TODO: dotprod - for _j in 0..QK_K / 128 { let q2bits = vld1q_u8_x2(q2); q2 = q2.add(32); @@ -696,14 +607,7 @@ unsafe fn multiply_accum_with_scale( q2bytes: int8x16x2_t, q8bytes: int8x16x2_t, ) -> i32 { - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)), - ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)), - ); - vaddvq_s16(p1) as i32 * aux[is + index] as i32 - + vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32 + let p1 = vdotq_s32(q2bytes.0, q8bytes.0); + let p2 = vdotq_s32(q2bytes.1, q8bytes.1); + vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32 } diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 716cca8d..e7a2ea7f 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -1,4 +1,5 @@ use candle_core::{ + bail, quantized::{self, GgmlDType}, test_utils::to_vec2_round, Device, Module, Result, Tensor, @@ -265,7 +266,8 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) { } } -/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30 +/// Creates a vector similar to the ones used in GGML unit tests: +/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30 fn create_ggml_like_vector(offset: f32) -> Vec { (0..GGML_TEST_SIZE) .map(|i| 0.1 + 2.0 * (i as f32 + offset).cos()) @@ -284,14 +286,15 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 { sum / a.len() as f32 } -/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50 +/// Similar to the GGML quantization unit test: +/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50 fn ggml_quantization_error_test(max_error: f32) -> Result<()> { let src = create_ggml_like_vector(0.0); let mut dst = vec![0.0; GGML_TEST_SIZE]; let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; let error = calculate_rmse(src.as_slice(), dst.as_slice()); if error > max_error { - candle_core::bail!( + bail!( "Quantization error {} exceeds max error {}", error, max_error @@ -487,54 +490,66 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result { GgmlDType::Q5K => 0.000740, GgmlDType::Q6K => 0.000952, GgmlDType::Q4_0 => 0.001143, - GgmlDType::Q4_1 => 0.007784, + GgmlDType::Q4_1 => 0.008, GgmlDType::Q5_0 => 0.001353, - GgmlDType::Q5_1 => 0.001363, + GgmlDType::Q5_1 => 0.00149, GgmlDType::Q8_0 => 0.000092, // Not from the ggml repo. GgmlDType::Q8K => 0.00065, - _ => candle_core::bail!("No GGML results for quantization type {dtype:?}",), + _ => bail!("No GGML results for quantization type {dtype:?}",), }; Ok(err) } -/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91 +/// Similar to the GGML matmul unit test: +/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91 fn ggml_matmul_error_test() -> Result<()> { let a = create_ggml_like_vector(0.0); let b = create_ggml_like_vector(1.0); + ggml_matmul_error_test_::(a.as_slice(), b.as_slice(), 1.0)?; + // Another example that is more likely to trigger the overflow reported in #1526 + let a = (0..GGML_TEST_SIZE) + .map(|i| i as f32 / GGML_TEST_SIZE as f32) + .collect::>(); + let b = (0..GGML_TEST_SIZE) + .map(|i| i as f32 / GGML_TEST_SIZE as f32) + .collect::>(); + ggml_matmul_error_test_::(a.as_slice(), b.as_slice(), 2.0)?; + Ok(()) +} + +fn ggml_matmul_error_test_(a: &[f32], b: &[f32], err_m: f32) -> Result<()> { let length = a.len(); let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE]; let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE]; - T::from_float(&a, &mut a_quant)?; - T::VecDotType::from_float(&b, &mut b_quant)?; + T::from_float(a, &mut a_quant)?; + T::VecDotType::from_float(b, &mut b_quant)?; let result = T::vec_dot(length, &a_quant, &b_quant)?; let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?; - let reference_result = vec_dot_reference(&a, &b); + let reference_result = vec_dot_reference(a, b); if (result - result_unopt).abs() / length as f32 > 1e-6 { - candle_core::bail!( + bail!( "the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}" ) } let error = (result - reference_result).abs() / length as f32; - let ggml_error = ggml_reference_matmul_error(T::DTYPE)?; + let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m; if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR { - candle_core::bail!( - "Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}", - ); + bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",); } // We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML // => we use a slightly higher error threshold const ERROR_LENIENCY: f32 = 0.00001; if error - ERROR_LENIENCY > ggml_error { - candle_core::bail!( + bail!( "Dot product error {} exceeds ggml reference error {}", error, ggml_error @@ -543,6 +558,16 @@ fn ggml_matmul_error_test() -> Result<()> { Ok(()) } +#[test] +fn quantized_mm() -> Result<()> { + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + Ok(()) +} + /// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result. fn get_random_tensors( m: usize, From 12b2a337f30f023af157b9ae560b53c3c5bd416c Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 8 Jan 2024 09:20:48 +0100 Subject: [PATCH 08/46] Handle start-offset when loading a tensor from a pickle file. (#1546) --- candle-core/src/pickle.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 25640d1a..276b30e3 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -703,6 +703,7 @@ impl PthTensors { } pub fn get(&self, name: &str) -> Result> { + use std::io::Read; let tensor_info = match self.tensor_infos.get(name) { None => return Ok(None), Some(tensor_info) => tensor_info, @@ -712,14 +713,21 @@ impl PthTensors { let mut zip = zip::ZipArchive::new(zip_reader)?; let mut reader = zip.by_name(&tensor_info.path)?; - // Reading the data is a bit tricky as it can be strided, use an offset, etc. - // For now only support the basic case. - if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() { + // Reading the data is a bit tricky as it can be strided, for now only support the basic + // case. + if !tensor_info.layout.is_contiguous() { crate::bail!( "cannot retrieve non-contiguous tensors {:?}", tensor_info.layout ) } + let start_offset = tensor_info.layout.start_offset(); + if start_offset > 0 { + std::io::copy( + &mut reader.by_ref().take(start_offset as u64), + &mut std::io::sink(), + )?; + } let tensor = Tensor::from_reader( tensor_info.layout.shape().clone(), tensor_info.dtype, From 87efb5d8eb6a6c3f17acf326aadcb11ad6900306 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 9 Jan 2024 19:04:31 +0100 Subject: [PATCH 09/46] Updated feature separated benchmarks --- candle-core/Cargo.toml | 5 --- candle-core/benches/bench_main.rs | 2 +- candle-core/benches/benchmarks/mod.rs | 1 + .../benches/{ => benchmarks}/random.rs | 32 +++++++------------ 4 files changed, 14 insertions(+), 26 deletions(-) rename candle-core/benches/{ => benchmarks}/random.rs (53%) diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 3fae7f07..afdb67cd 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -48,8 +48,3 @@ metal = ["dep:metal", "dep:candle-metal-kernels"] [[bench]] name = "bench_main" harness = false - -[[bench]] -name = "random" -harness = false - diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 4425f2fb..8913df4f 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -1,4 +1,4 @@ mod benchmarks; use criterion::criterion_main; -criterion_main!(benchmarks::matmul::benches); +criterion_main!(benchmarks::matmul::benches, benchmarks::random::benches); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index 1344770d..6bb37a70 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod matmul; +pub(crate) mod random; use candle_core::{Device, Result}; diff --git a/candle-core/benches/random.rs b/candle-core/benches/benchmarks/random.rs similarity index 53% rename from candle-core/benches/random.rs rename to candle-core/benches/benchmarks/random.rs index 781d8b39..e4a4a390 100644 --- a/candle-core/benches/random.rs +++ b/candle-core/benches/benchmarks/random.rs @@ -1,9 +1,10 @@ +use crate::benchmarks::{bench_name, device, BenchDevice}; use candle_core::{DType, Device, Tensor}; -use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; use std::time::Instant; fn rand_uniform(a: &Tensor) { - a.rand_like(0.0, 1.0).unwrap(); + a.rand_like(-1.0, 123.0).unwrap(); } fn rand_normal(a: &Tensor) { @@ -16,14 +17,13 @@ fn criterion_benchmark(c: &mut Criterion) { let rows = 2048; let cols = 2048; - let device = Device::new_metal(0).unwrap(); - let device2 = device.clone(); + let d = device().unwrap(); let dtype = DType::F32; - let tensor = Tensor::zeros((b, rows, cols), dtype, &device).unwrap(); + let tensor = Tensor::zeros((b, rows, cols), dtype, &d).unwrap(); - let flops = b * rows * cols; + let flops = b * rows * cols * dtype.size_in_bytes(); - let mut group = c.benchmark_group("metal_random_uniform"); + let mut group = c.benchmark_group(bench_name("random_uniform")); group.throughput(Throughput::Bytes(flops as u64)); group.bench_function("iter", move |benches| { benches.iter_custom(|iters| { @@ -31,19 +31,16 @@ fn criterion_benchmark(c: &mut Criterion) { for _i in 0..iters { rand_uniform(black_box(&tensor)); } - if let Device::Metal(device) = &device { - device.wait_until_completed().unwrap(); - } else { - panic!("Expected metal device"); - } + d.sync().unwrap(); start.elapsed() }) }); group.finish(); - let tensor = Tensor::zeros((b, rows, cols), dtype, &device2).unwrap(); + let d = device().unwrap(); + let tensor = Tensor::zeros((b, rows, cols), dtype, &d).unwrap(); - let mut group = c.benchmark_group("metal_random_normal"); + let mut group = c.benchmark_group(bench_name("random_normal")); group.throughput(Throughput::Bytes(flops as u64)); group.bench_function("iter", move |benches| { benches.iter_custom(|iters| { @@ -51,11 +48,7 @@ fn criterion_benchmark(c: &mut Criterion) { for _i in 0..iters { rand_normal(black_box(&tensor)); } - if let Device::Metal(device) = &device2 { - device.wait_until_completed().unwrap(); - } else { - panic!("Expected metal device"); - } + d.sync().unwrap(); start.elapsed() }) }); @@ -63,4 +56,3 @@ fn criterion_benchmark(c: &mut Criterion) { } criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); From 53e4755015a07d967801935caaa71298a9ab20d3 Mon Sep 17 00:00:00 2001 From: darker Date: Wed, 10 Jan 2024 14:57:20 +0100 Subject: [PATCH 10/46] feat: add dependabot to the project (#1553) * feat: add dependabot to the project * feat: add let's accept patches/fix from other libs * Revert "feat: add let's accept patches/fix from other libs" This reverts commit d31a956f8108afb1b6ee6f35611feea399d63bdf. --- .github/dependabot.yml | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..08d14cfc --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,8 @@ +version: 2 +updates: + - package-ecosystem: "cargo" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 5 + open-pull-requests-limit-per-dependency: 2 From edf3fcd1c471f054eeaef74ea66eeac5bf04c54e Mon Sep 17 00:00:00 2001 From: darker Date: Wed, 10 Jan 2024 15:12:46 +0100 Subject: [PATCH 11/46] fix: deprecated option field (open-pull-requests-limit-per-dependency) (#1554) --- .github/dependabot.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 08d14cfc..05bcdac6 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -5,4 +5,3 @@ updates: schedule: interval: "weekly" open-pull-requests-limit: 5 - open-pull-requests-limit-per-dependency: 2 From 2cc124799944032cafb4f43d110f3148721c1443 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 10 Jan 2024 16:26:53 +0100 Subject: [PATCH 12/46] Update tokenizers requirement from 0.13.4 to 0.15.0 (#1555) Updates the requirements on [tokenizers](https://github.com/huggingface/tokenizers) to permit the latest version. - [Release notes](https://github.com/huggingface/tokenizers/releases) - [Changelog](https://github.com/huggingface/tokenizers/blob/main/RELEASE.md) - [Commits](https://github.com/huggingface/tokenizers/commits) --- updated-dependencies: - dependency-name: tokenizers dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 3d66a02f..19de9593 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,7 +63,7 @@ serde = { version = "1.0.171", features = ["derive"] } serde_plain = "1.0.2" serde_json = "1.0.99" thiserror = "1" -tokenizers = { version = "0.13.4", default-features = false } +tokenizers = { version = "0.15.0", default-features = false } tracing = "0.1.37" tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" From 6e98cf2a925fd276efb1963df1f62381989d41aa Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 10 Jan 2024 16:27:05 +0100 Subject: [PATCH 13/46] Update cudarc requirement from 0.9.14 to 0.10.0 (#1559) Updates the requirements on [cudarc](https://github.com/coreylowman/cudarc) to permit the latest version. - [Release notes](https://github.com/coreylowman/cudarc/releases) - [Commits](https://github.com/coreylowman/cudarc/compare/v0.9.14...v0.9.15) --- updated-dependencies: - dependency-name: cudarc dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 19de9593..7eda732c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ candle-onnx = { path = "./candle-onnx" } candle-transformers = { path = "./candle-transformers" } clap = { version = "4.2.4", features = ["derive"] } criterion = { version = "0.5.1", default-features=false } -cudarc = { version = "0.9.14", features = ["f16"] } +cudarc = { version = "0.10.0", features = ["f16"] } gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] } hf-hub = "0.3.0" half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } From 1f1179913adc4a0fdba36616bf576cfd2cc00deb Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 10 Jan 2024 16:27:20 +0100 Subject: [PATCH 14/46] Update gloo requirement from 0.8 to 0.11 (#1558) Updates the requirements on [gloo](https://github.com/rustwasm/gloo) to permit the latest version. - [Release notes](https://github.com/rustwasm/gloo/releases) - [Changelog](https://github.com/rustwasm/gloo/blob/master/CHANGELOG.md) - [Commits](https://github.com/rustwasm/gloo/commits) --- updated-dependencies: - dependency-name: gloo dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- candle-wasm-examples/bert/Cargo.toml | 2 +- candle-wasm-examples/llama2-c/Cargo.toml | 2 +- candle-wasm-examples/t5/Cargo.toml | 2 +- candle-wasm-examples/whisper/Cargo.toml | 2 +- candle-wasm-examples/yolo/Cargo.toml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/candle-wasm-examples/bert/Cargo.toml b/candle-wasm-examples/bert/Cargo.toml index 259a6102..51358e45 100644 --- a/candle-wasm-examples/bert/Cargo.toml +++ b/candle-wasm-examples/bert/Cargo.toml @@ -27,7 +27,7 @@ safetensors = { workspace = true } # Wasm specific crates. console_error_panic_hook = "0.1.7" getrandom = { version = "0.2", features = ["js"] } -gloo = "0.8" +gloo = "0.11" js-sys = "0.3.64" wasm-bindgen = "0.2.87" serde-wasm-bindgen = "0.6.0" diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml index ac89a558..d46cdafa 100644 --- a/candle-wasm-examples/llama2-c/Cargo.toml +++ b/candle-wasm-examples/llama2-c/Cargo.toml @@ -26,7 +26,7 @@ serde_json = { workspace = true } # Wasm specific crates. console_error_panic_hook = "0.1.7" getrandom = { version = "0.2", features = ["js"] } -gloo = "0.8" +gloo = "0.11" js-sys = "0.3.64" wasm-bindgen = "0.2.87" wasm-bindgen-futures = "0.4.37" diff --git a/candle-wasm-examples/t5/Cargo.toml b/candle-wasm-examples/t5/Cargo.toml index 36cd9386..5f60d917 100644 --- a/candle-wasm-examples/t5/Cargo.toml +++ b/candle-wasm-examples/t5/Cargo.toml @@ -27,7 +27,7 @@ safetensors = { workspace = true } # Wasm specific crates. console_error_panic_hook = "0.1.7" getrandom = { version = "0.2", features = ["js"] } -gloo = "0.8" +gloo = "0.11" js-sys = "0.3.64" wasm-bindgen = "0.2.87" serde-wasm-bindgen = "0.6.0" diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml index 6c6857e4..92e206b2 100644 --- a/candle-wasm-examples/whisper/Cargo.toml +++ b/candle-wasm-examples/whisper/Cargo.toml @@ -26,7 +26,7 @@ safetensors = { workspace = true } # Wasm specific crates. getrandom = { version = "0.2", features = ["js"] } -gloo = "0.8" +gloo = "0.11" js-sys = "0.3.64" wasm-bindgen = "0.2.87" wasm-bindgen-futures = "0.4.37" diff --git a/candle-wasm-examples/yolo/Cargo.toml b/candle-wasm-examples/yolo/Cargo.toml index 0e5a91a8..ac76f9a7 100644 --- a/candle-wasm-examples/yolo/Cargo.toml +++ b/candle-wasm-examples/yolo/Cargo.toml @@ -26,7 +26,7 @@ safetensors = { workspace = true } # Wasm specific crates. console_error_panic_hook = "0.1.7" getrandom = { version = "0.2", features = ["js"] } -gloo = "0.8" +gloo = "0.11" js-sys = "0.3.64" wasm-bindgen = "0.2.87" wasm-bindgen-futures = "0.4.37" From a897fda74e372ff0e08c86a5468124b51f5941a7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 10 Jan 2024 16:27:59 +0100 Subject: [PATCH 15/46] Update memmap2 requirement from 0.7.1 to 0.9.3 (#1556) Updates the requirements on [memmap2](https://github.com/RazrFalcon/memmap2-rs) to permit the latest version. - [Changelog](https://github.com/RazrFalcon/memmap2-rs/blob/master/CHANGELOG.md) - [Commits](https://github.com/RazrFalcon/memmap2-rs/compare/v0.7.1...v0.7.1) --- updated-dependencies: - dependency-name: memmap2 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 7eda732c..2225c42e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,7 +50,7 @@ imageproc = { version = "0.23.0", default-features = false } intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] } libc = { version = "0.2.147" } log = "0.4" -memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] } +memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] } num_cpus = "1.15.0" num-traits = "0.2.15" parquet = { version = "45.0.0" } From ae06cb74bb132913b4777cd119b915e665c013bb Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Wed, 10 Jan 2024 12:27:17 -0500 Subject: [PATCH 16/46] Add relu kernel for metal (#1488) * Add relu kernel for metal * Copy error messages proposed in #1491 * Revert non relu changes * Fix name changes * Fix the last of us (: * Fix copy and paste mistakes * Fix typo * Revert order changes * Revert order change * Add deleted functions back * Run rustfmt --- candle-core/src/metal_backend.rs | 4 ++++ candle-metal-kernels/src/lib.rs | 4 ++-- candle-metal-kernels/src/unary.metal | 8 ++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index c1c4aa4b..5d72bd68 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -675,6 +675,7 @@ impl BackendStorage for MetalStorage { ("uround", DType::F32) => contiguous::round::FLOAT, ("urecip", DType::F32) => contiguous::recip::FLOAT, ("utanh", DType::F32) => contiguous::tanh::FLOAT, + ("urelu", DType::F32) => contiguous::relu::FLOAT, ("ucos", DType::F16) => contiguous::cos::HALF, ("usin", DType::F16) => contiguous::sin::HALF, ("usqr", DType::F16) => contiguous::sqr::HALF, @@ -691,6 +692,7 @@ impl BackendStorage for MetalStorage { ("uround", DType::F16) => contiguous::round::HALF, ("urecip", DType::F16) => contiguous::recip::HALF, ("utanh", DType::F16) => contiguous::tanh::HALF, + ("urelu", DType::F16) => contiguous::relu::HALF, (name, dtype) => { crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") } @@ -721,6 +723,7 @@ impl BackendStorage for MetalStorage { ("uabs", DType::F32) => strided::abs::FLOAT, ("uceil", DType::F32) => strided::ceil::FLOAT, ("ufloor", DType::F32) => strided::floor::FLOAT, + ("urelu", DType::F32) => strided::relu::FLOAT, ("uround", DType::F32) => strided::round::FLOAT, ("ucos", DType::F16) => strided::cos::HALF, ("usin", DType::F16) => strided::sin::HALF, @@ -735,6 +738,7 @@ impl BackendStorage for MetalStorage { ("uabs", DType::F16) => strided::abs::HALF, ("uceil", DType::F16) => strided::ceil::HALF, ("ufloor", DType::F16) => strided::floor::HALF, + ("urelu", DType::F16) => strided::relu::HALF, ("uround", DType::F16) => strided::round::HALF, (name, dtype) => { crate::bail!("Metal strided unary {name} {dtype:?} not implemented") diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 5d34f61a..c872dc60 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -174,8 +174,8 @@ macro_rules! ops{ pub mod unary { ops!( - cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh, - recip + cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf, + tanh, recip ); } pub mod binary { diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 7fbb613d..f95f6ba9 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -58,6 +58,12 @@ template METAL_FUNC T gelu(T x) { T beta = (static_cast(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); return static_cast(0.5) * x * (static_cast(1.0) + T(tanh(beta))); } +template METAL_FUNC T relu(T in){ + if (in < 0) { + return 0; + } + return in; +} #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ @@ -110,6 +116,7 @@ UNARY_OP(gelu_erf) UNARY_OP(erf) UNARY_OP(tanh) UNARY_OP(recip) +UNARY_OP(relu) UNARY(id, float, copy_f32, copy_f32_strided) UNARY(id, half, copy_f16, copy_f16_strided) @@ -136,6 +143,7 @@ BFLOAT_UNARY_OP(gelu_erf) BFLOAT_UNARY_OP(erf) BFLOAT_UNARY_OP(tanh) BFLOAT_UNARY_OP(recip) +BFLOAT_UNARY_OP(relu) UNARY(id, bfloat, copy_bf16, copy_bf16_strided) #endif From d3bdd788cfdcf49b6ea539b77647b82a0b979db0 Mon Sep 17 00:00:00 2001 From: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Wed, 10 Jan 2024 18:50:30 +0100 Subject: [PATCH 17/46] Use __HAVE_BFLOAT__ to check for bfloat support instead of metal version check (#1540) --- candle-metal-kernels/src/affine.metal | 2 +- candle-metal-kernels/src/binary.metal | 2 +- candle-metal-kernels/src/cast.metal | 2 +- candle-metal-kernels/src/indexing.metal | 2 +- candle-metal-kernels/src/reduce.metal | 2 +- candle-metal-kernels/src/unary.metal | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index 4166d811..3d8e7f0d 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -117,7 +117,7 @@ ELU(elu_f32, float) ELU(elu_f16, half) -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) AFFINE(affine_bf16, bfloat); POWF(powf_bf16, bfloat); ELU(elu_bf16, bfloat); diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index cdc8fef8..eb560f16 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -105,7 +105,7 @@ INT64_BINARY_OP_OUT(ge, x >= y) INT64_BINARY_OP_OUT(gt, x > y) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) BFLOAT_BINARY_OP(x + y, add) BFLOAT_BINARY_OP(x - y, sub) BFLOAT_BINARY_OP(x * y, mul) diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index e9ab17b1..5aacac4a 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -58,7 +58,7 @@ CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) #endif diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 63357428..32f3f410 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -173,7 +173,7 @@ SCATTER_ADD_OP(sa_u32_f32, uint, float) SCATTER_ADD_OP(sa_u32_f16, uint, half) -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 83a56f0a..93dac662 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -295,7 +295,7 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX) ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) REDUCE(x + y, fast_sum_bf16, bfloat, 0) REDUCE(x * y, fast_mul_bf16, bfloat, 1) REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF) diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index f95f6ba9..dcf803d8 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -127,7 +127,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided) UNARY(id, int64_t, copy_i64, copy_i64_strided) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) BFLOAT_UNARY_OP(cos) BFLOAT_UNARY_OP(sin) BFLOAT_UNARY_OP(sqr) From 63944714f267bd3824c548ffcaaaef5e29c4066e Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Wed, 10 Jan 2024 22:36:27 +0200 Subject: [PATCH 18/46] Use candle_nn::embedding instead of local copies in a few models. (#1562) --- candle-transformers/src/models/bert.rs | 7 +------ candle-transformers/src/models/bigcode.rs | 7 +------ candle-transformers/src/models/falcon.rs | 7 +------ candle-transformers/src/models/llama.rs | 9 ++------- candle-transformers/src/models/whisper/model.rs | 7 +------ 5 files changed, 6 insertions(+), 31 deletions(-) diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 51c524f5..810f2803 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -1,6 +1,6 @@ use super::with_tracing::{layer_norm, linear, LayerNorm, Linear}; use candle::{DType, Device, Result, Tensor}; -use candle_nn::{Embedding, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, Module, VarBuilder}; use serde::Deserialize; pub const DTYPE: DType = DType::F32; @@ -112,11 +112,6 @@ impl Config { } } -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - struct Dropout { #[allow(dead_code)] pr: f64, diff --git a/candle-transformers/src/models/bigcode.rs b/candle-transformers/src/models/bigcode.rs index c4a2d1db..e69f08c8 100644 --- a/candle-transformers/src/models/bigcode.rs +++ b/candle-transformers/src/models/bigcode.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder}; fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result { let weight = vb.get((size2, size1), "weight")?; @@ -11,11 +11,6 @@ fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result Result { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { let weight = vb.get(size, "weight")?; let bias = vb.get(size, "bias")?; diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs index 6ede136a..ef5a92fc 100644 --- a/candle-transformers/src/models/falcon.rs +++ b/candle-transformers/src/models/falcon.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, Result, Tensor, D}; -use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder}; const MAX_SEQ_LEN: usize = 5000; @@ -27,11 +27,6 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { Ok(LayerNorm::new(weight, bias, eps)) } -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - // https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py #[derive(Debug)] pub struct Config { diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index 7e8c8920..f003866a 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,6 +1,6 @@ use super::with_tracing::{linear_no_bias as linear, Linear}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, Module, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -136,11 +136,6 @@ impl Cache { } } -fn embedding(cfg: &Config, vb: VarBuilder) -> Result { - let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; - Ok(Embedding::new(embeddings, cfg.hidden_size)) -} - struct RmsNorm { inner: candle_nn::RmsNorm, span: tracing::Span, @@ -409,7 +404,7 @@ impl Llama { } pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { - let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; + let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.num_hidden_layers) diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index 25454ba6..ea2a59b9 100644 --- a/candle-transformers/src/models/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -1,12 +1,7 @@ use super::Config; use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use candle::{Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; - -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} +use candle_nn::{embedding, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; fn conv1d( in_channels: usize, From 2480c5dbddec7cd086746df595be85fdf1407146 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Thu, 11 Jan 2024 08:07:40 +0200 Subject: [PATCH 19/46] Add RepVGG model. (#1561) * Add RepVGG model. * Add RepVGG README * Extract var to top level * Replace hashmap with a match * Add a variant for the model kind + avoid some unnecessary config cloning. --------- Co-authored-by: Laurent --- candle-examples/examples/repvgg/README.md | 20 ++ candle-examples/examples/repvgg/main.rs | 111 ++++++++ candle-transformers/src/models/mod.rs | 1 + candle-transformers/src/models/repvgg.rs | 306 ++++++++++++++++++++++ 4 files changed, 438 insertions(+) create mode 100644 candle-examples/examples/repvgg/README.md create mode 100644 candle-examples/examples/repvgg/main.rs create mode 100644 candle-transformers/src/models/repvgg.rs diff --git a/candle-examples/examples/repvgg/README.md b/candle-examples/examples/repvgg/README.md new file mode 100644 index 00000000..2cb807c1 --- /dev/null +++ b/candle-examples/examples/repvgg/README.md @@ -0,0 +1,20 @@ +# candle-repvgg + +A candle implementation of inference using a pre-trained [repvgg](https://arxiv.org/abs/2101.03697). +This uses a classification head trained on the ImageNet dataset and returns the +probabilities for the top-5 classes. + +## Running an example + +``` +$ cargo run --example repvgg --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg + +loaded image Tensor[dims 3, 224, 224; f32] +model built +mountain bike, all-terrain bike, off-roader: 61.70% +bicycle-built-for-two, tandem bicycle, tandem: 33.14% +unicycle, monocycle : 4.88% +crash helmet : 0.15% +moped : 0.04% + +``` diff --git a/candle-examples/examples/repvgg/main.rs b/candle-examples/examples/repvgg/main.rs new file mode 100644 index 00000000..0864c559 --- /dev/null +++ b/candle-examples/examples/repvgg/main.rs @@ -0,0 +1,111 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::repvgg; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + A0, + A1, + A2, + B0, + B1, + B2, + B3, + B1G4, + B2G4, + B3G4, +} + +impl Which { + fn model_filename(&self) -> String { + let name = match self { + Self::A0 => "a0", + Self::A1 => "a1", + Self::A2 => "a2", + Self::B0 => "b0", + Self::B1 => "b1", + Self::B2 => "b2", + Self::B3 => "b3", + Self::B1G4 => "b1g4", + Self::B2G4 => "b2g4", + Self::B3G4 => "b3g4", + }; + format!("timm/repvgg_{}.rvgg_in1k", name) + } + + fn config(&self) -> repvgg::Config { + match self { + Self::A0 => repvgg::Config::a0(), + Self::A1 => repvgg::Config::a1(), + Self::A2 => repvgg::Config::a2(), + Self::B0 => repvgg::Config::b0(), + Self::B1 => repvgg::Config::b1(), + Self::B2 => repvgg::Config::b2(), + Self::B3 => repvgg::Config::b3(), + Self::B1G4 => repvgg::Config::b1g4(), + Self::B2G4 => repvgg::Config::b2g4(), + Self::B3G4 => repvgg::Config::b3g4(), + } + } +} + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(value_enum, long, default_value_t=Which::A0)] + which: Which, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image224(args.image)?; + println!("loaded image {image:?}"); + + let model_file = match args.model { + None => { + let model_name = args.which.model_filename(); + let api = hf_hub::api::sync::Api::new()?; + let api = api.model(model_name); + api.get("model.safetensors")? + } + Some(model) => model.into(), + }; + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = repvgg::repvgg(&args.which.config(), 1000, vb)?; + println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::()?; + let mut prs = prs.iter().enumerate().collect::>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!( + "{:24}: {:.2}%", + candle_examples::imagenet::CLASSES[category_idx], + 100. * pr + ); + } + Ok(()) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 94a3bd5b..a60b5a06 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -26,6 +26,7 @@ pub mod quantized_mixformer; pub mod quantized_mpt; pub mod quantized_stable_lm; pub mod quantized_t5; +pub mod repvgg; pub mod resnet; pub mod segment_anything; pub mod stable_diffusion; diff --git a/candle-transformers/src/models/repvgg.rs b/candle-transformers/src/models/repvgg.rs new file mode 100644 index 00000000..34016e5b --- /dev/null +++ b/candle-transformers/src/models/repvgg.rs @@ -0,0 +1,306 @@ +//! RepVGG inference implementation +//! +//! See "RepVGG: Making VGG-style ConvNets Great Again" Ding et al. 2021 +//! https://arxiv.org/abs/2101.03697 + +use candle::{Result, Tensor, D}; +use candle_nn::{ + batch_norm, conv2d_no_bias, linear, BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder, +}; + +const CHANNELS_PER_STAGE: [usize; 5] = [64, 64, 128, 256, 512]; + +#[derive(Clone)] +pub struct Config { + a: f32, + b: f32, + groups: usize, + stages: [usize; 4], +} + +impl Config { + pub fn a0() -> Self { + Self { + a: 0.75, + b: 2.5, + groups: 1, + stages: [2, 4, 14, 1], + } + } + + pub fn a1() -> Self { + Self { + a: 1.0, + b: 2.5, + groups: 1, + stages: [2, 4, 14, 1], + } + } + + pub fn a2() -> Self { + Self { + a: 1.5, + b: 2.75, + groups: 1, + stages: [2, 4, 14, 1], + } + } + + pub fn b0() -> Self { + Self { + a: 1.0, + b: 2.5, + groups: 1, + stages: [4, 6, 16, 1], + } + } + + pub fn b1() -> Self { + Self { + a: 2.0, + b: 4.0, + groups: 1, + stages: [4, 6, 16, 1], + } + } + + pub fn b2() -> Self { + Self { + a: 2.5, + b: 5.0, + groups: 1, + stages: [4, 6, 16, 1], + } + } + + pub fn b3() -> Self { + Self { + a: 3.0, + b: 5.0, + groups: 1, + stages: [4, 6, 16, 1], + } + } + + pub fn b1g4() -> Self { + Self { + a: 2.0, + b: 4.0, + groups: 4, + stages: [4, 6, 16, 1], + } + } + + pub fn b2g4() -> Self { + Self { + a: 2.5, + b: 5.0, + groups: 4, + stages: [4, 6, 16, 1], + } + } + + pub fn b3g4() -> Self { + Self { + a: 3.0, + b: 5.0, + groups: 4, + stages: [4, 6, 16, 1], + } + } +} + +// fuses a convolutional kernel and a batchnorm layer into a convolutional layer +// based on the _fuse_bn_tensor method in timm +// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602 +fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> { + let (gamma, beta) = bn.weight_and_bias().unwrap(); + let mu = bn.running_mean(); + let sigma = (bn.running_var() + bn.eps())?.sqrt(); + let gps = (gamma / sigma)?; + let bias = (beta - mu * &gps)?; + let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?; + + Ok((weights, bias)) +} + +// A RepVGG layer has a different training time and inference time architecture. +// The latter is a simple and efficient equivalent transformation of the former +// realized by a structural reparameterization technique, where 3x3 and 1x1 convolutions +// along with identity branches and batchnorm layers are fused into a single 3x3 convolution. +fn repvgg_layer( + has_identity: bool, + dim: usize, + stride: usize, + in_channels: usize, + out_channels: usize, + groups: usize, + vb: VarBuilder, +) -> Result> { + let conv2d_cfg = Conv2dConfig { + stride, + groups, + padding: 1, + ..Default::default() + }; + + // read and reparameterize the 1x1 conv and bn into w1 and b1 + // based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L543 + + let conv1x1_bn = batch_norm(dim, 1e-5, vb.pp("conv_1x1.bn"))?; + let conv1x1 = conv2d_no_bias( + in_channels, + out_channels, + 1, + conv2d_cfg, + vb.pp("conv_1x1.conv"), + )?; + + let (mut w1, b1) = fuse_conv_bn(conv1x1.weight(), conv1x1_bn)?; + + // resize to 3x3 + w1 = w1.pad_with_zeros(D::Minus1, 1, 1)?; + w1 = w1.pad_with_zeros(D::Minus2, 1, 1)?; + + // read and reparameterize the 3x3 conv and bn into w3 and b3 + let convkxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.bn"))?; + let conv3x3 = conv2d_no_bias( + in_channels, + out_channels, + 3, + conv2d_cfg, + vb.pp("conv_kxk.conv"), + )?; + + let (w3, b3) = fuse_conv_bn(conv3x3.weight(), convkxk_bn)?; + + let mut w = (w1 + w3)?; + let mut b = (b1 + b3)?; + + // read and reparameterize the identity bn into wi and bi + if has_identity { + let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?; + + // create a 3x3 convolution equivalent to the identity branch + let mut weights: Vec = vec![0.0; conv3x3.weight().elem_count()]; + + // https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L620 + let in_dim = in_channels / groups; + for i in 0..in_channels { + weights[i * in_dim * 3 * 3 + (i % in_dim) * 3 * 3 + 4] = 1.0; + } + + let weights = &Tensor::from_vec(weights, w.shape(), w.device())?; + let (wi, bi) = fuse_conv_bn(weights, identity_bn)?; + + w = (w + wi)?; + b = (b + bi)?; + } + + // create the 3x3 conv equivalent to the sum of 3x3, 1x1 and identity branches + let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg); + + Ok(Func::new(move |xs| { + let xs = xs.apply(&reparam_conv)?.relu()?; + Ok(xs) + })) +} + +// Get the number of output channels per stage taking into account the multipliers +fn output_channels_per_stage(a: f32, b: f32, stage: usize) -> usize { + let channels = CHANNELS_PER_STAGE[stage] as f32; + + match stage { + 0 => std::cmp::min(64, (channels * a) as usize), + 4 => (channels * b) as usize, + _ => (channels * a) as usize, + } +} + +// Each stage is made of layers. The first layer always downsamples with stride 2. +// All but the first layer have a residual connection. +// The G4 variants have a groupwise convolution instead of a dense one on odd layers +// counted across stage boundaries, so we keep track of which layer we are in the +// full model. +fn repvgg_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result> { + let nlayers = cfg.stages[idx - 1]; + let mut layers = Vec::with_capacity(nlayers); + let prev_layers: usize = cfg.stages[..idx - 1].iter().sum(); + let out_channels_prev = output_channels_per_stage(cfg.a, cfg.b, idx - 1); + let out_channels = output_channels_per_stage(cfg.a, cfg.b, idx); + + for layer_idx in 0..nlayers { + let (has_identity, stride, in_channels) = if layer_idx == 0 { + (false, 2, out_channels_prev) + } else { + (true, 1, out_channels) + }; + + let groups = if (prev_layers + layer_idx) % 2 == 1 { + cfg.groups + } else { + 1 + }; + + layers.push(repvgg_layer( + has_identity, + out_channels, + stride, + in_channels, + out_channels, + groups, + vb.pp(layer_idx), + )?) + } + + Ok(Func::new(move |xs| { + let mut xs = xs.clone(); + for layer in layers.iter() { + xs = xs.apply(layer)? + } + Ok(xs) + })) +} + +// Build a RepVGG model for a given configuration. +fn repvgg_model(config: &Config, nclasses: Option, vb: VarBuilder) -> Result> { + let cls = match nclasses { + None => None, + Some(nclasses) => { + let outputs = output_channels_per_stage(config.a, config.b, 4); + let linear = linear(outputs, nclasses, vb.pp("head.fc"))?; + Some(linear) + } + }; + + let stem_dim = output_channels_per_stage(config.a, config.b, 0); + let stem = repvgg_layer(false, stem_dim, 2, 3, stem_dim, 1, vb.pp("stem"))?; + let vb = vb.pp("stages"); + let stage1 = repvgg_stage(config, 1, vb.pp(0))?; + let stage2 = repvgg_stage(config, 2, vb.pp(1))?; + let stage3 = repvgg_stage(config, 3, vb.pp(2))?; + let stage4 = repvgg_stage(config, 4, vb.pp(3))?; + + Ok(Func::new(move |xs| { + let xs = xs + .apply(&stem)? + .apply(&stage1)? + .apply(&stage2)? + .apply(&stage3)? + .apply(&stage4)? + .mean(D::Minus1)? + .mean(D::Minus1)?; + match &cls { + None => Ok(xs), + Some(cls) => xs.apply(cls), + } + })) +} + +pub fn repvgg(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result> { + repvgg_model(cfg, Some(nclasses), vb) +} + +pub fn repvgg_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result> { + repvgg_model(cfg, None, vb) +} From 0fc95c9f0c426db0f32f7e853035fd3e8415c311 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 11 Jan 2024 11:21:01 +0100 Subject: [PATCH 20/46] Add a dequantize command to tensor-tools. (#1565) * Add a dequantize command to tensor-tools. * Clippy fixes. --- candle-core/examples/tensor-tools.rs | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index d06b30d1..337021aa 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -102,7 +102,7 @@ enum Command { }, Quantize { - /// The input file, in gguf format. + /// The input file(s), in safetensors format. in_file: Vec, /// The output file, in gguf format. @@ -117,6 +117,15 @@ enum Command { #[arg(long, value_enum, default_value_t = QuantizationMode::Llama)] mode: QuantizationMode, }, + + Dequantize { + /// The input file, in gguf format. + in_file: std::path::PathBuf, + + /// The output file, in safetensors format. + #[arg(long)] + out_file: std::path::PathBuf, + }, } #[derive(Parser, Debug, Clone)] @@ -285,6 +294,19 @@ fn run_quantize_safetensors( Ok(()) } +fn run_dequantize(in_file: std::path::PathBuf, out_file: std::path::PathBuf) -> Result<()> { + let mut in_file = std::fs::File::open(in_file)?; + let content = gguf_file::Content::read(&mut in_file)?; + let mut tensors = std::collections::HashMap::new(); + for (tensor_name, _) in content.tensor_infos.iter() { + let tensor = content.tensor(&mut in_file, tensor_name)?; + let tensor = tensor.dequantize(&Device::Cpu)?; + tensors.insert(tensor_name.to_string(), tensor); + } + candle_core::safetensors::save(&tensors, out_file)?; + Ok(()) +} + fn run_quantize( in_files: &[std::path::PathBuf], out_file: std::path::PathBuf, @@ -379,6 +401,7 @@ fn main() -> anyhow::Result<()> { quantization, mode, } => run_quantize(&in_file, out_file, quantization, mode)?, + Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file)?, } Ok(()) } From 9f0c99f0c1020678a682480e5936757510b10cee Mon Sep 17 00:00:00 2001 From: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Thu, 11 Jan 2024 15:35:38 +0100 Subject: [PATCH 21/46] Seperate benchmarks by enabled features (#1538) * Use cfg to seperate benchmark results based on features * Remove allow pragma * Avoid some unnecessary returns. * Improve benchmarks layout * Derive bench_name from actual device * Run CPU benchmarks even when GPU feature is enabled --------- Co-authored-by: Laurent --- candle-core/Cargo.toml | 2 +- candle-core/benches/bench_main.rs | 4 ++ .../benches/{ => benchmarks}/matmul.rs | 26 ++++---- candle-core/benches/benchmarks/mod.rs | 63 +++++++++++++++++++ 4 files changed, 82 insertions(+), 13 deletions(-) create mode 100644 candle-core/benches/bench_main.rs rename candle-core/benches/{ => benchmarks}/matmul.rs (56%) create mode 100644 candle-core/benches/benchmarks/mod.rs diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 97857a6b..d9fc7526 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -46,6 +46,6 @@ accelerate = ["dep:libc", "dep:accelerate-src"] metal = ["dep:metal", "dep:candle-metal-kernels"] [[bench]] -name = "matmul" +name = "bench_main" harness = false diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs new file mode 100644 index 00000000..4425f2fb --- /dev/null +++ b/candle-core/benches/bench_main.rs @@ -0,0 +1,4 @@ +mod benchmarks; + +use criterion::criterion_main; +criterion_main!(benchmarks::matmul::benches); diff --git a/candle-core/benches/matmul.rs b/candle-core/benches/benchmarks/matmul.rs similarity index 56% rename from candle-core/benches/matmul.rs rename to candle-core/benches/benchmarks/matmul.rs index 83679771..9d67e642 100644 --- a/candle-core/benches/matmul.rs +++ b/candle-core/benches/benchmarks/matmul.rs @@ -1,25 +1,25 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; use candle_core::{DType, Device, Tensor}; -use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; use std::time::Instant; fn run(a: &Tensor, b: &Tensor) { a.matmul(&b.t().unwrap()).unwrap(); } -fn criterion_benchmark(c: &mut Criterion) { +fn run_bench(c: &mut Criterion, device: &Device) { let b = 1; let m = 1; let n = 2048; let k = 2048; - let device = Device::new_metal(0).unwrap(); let dtype = DType::F32; - let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap(); - let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap(); + let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap(); + let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap(); let flops = b * m * n * k; - let mut group = c.benchmark_group("matmul_metal"); + let mut group = c.benchmark_group(device.bench_name("matmul")); group.throughput(Throughput::Bytes(flops as u64)); group.bench_function("iter", move |b| { b.iter_custom(|iters| { @@ -27,16 +27,18 @@ fn criterion_benchmark(c: &mut Criterion) { for _i in 0..iters { run(black_box(&lhs), black_box(&rhs)); } - if let Device::Metal(device) = &device { - device.wait_until_completed().unwrap(); - } else { - panic!("Expected metal device"); - } + device.sync().unwrap(); start.elapsed() }) }); group.finish(); } +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_bench(c, &device); + } +} + criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs new file mode 100644 index 00000000..295bbabd --- /dev/null +++ b/candle-core/benches/benchmarks/mod.rs @@ -0,0 +1,63 @@ +pub(crate) mod matmul; + +use candle_core::{Device, Result}; + +pub(crate) trait BenchDevice { + fn sync(&self) -> Result<()>; + + fn bench_name>(&self, name: S) -> String; +} + +impl BenchDevice for Device { + fn sync(&self) -> Result<()> { + match self { + Device::Cpu => Ok(()), + Device::Cuda(device) => { + #[cfg(feature = "cuda")] + return Ok(device.synchronize()?); + #[cfg(not(feature = "cuda"))] + panic!("Cuda device without cuda feature enabled: {:?}", device) + } + Device::Metal(device) => { + #[cfg(feature = "metal")] + return Ok(device.wait_until_completed()?); + #[cfg(not(feature = "metal"))] + panic!("Metal device without metal feature enabled: {:?}", device) + } + } + } + + fn bench_name>(&self, name: S) -> String { + match self { + Device::Cpu => { + let cpu_type = if cfg!(feature = "accelerate") { + "accelerate" + } else if cfg!(feature = "mkl") { + "mkl" + } else { + "cpu" + }; + format!("{}_{}", cpu_type, name.into()) + } + Device::Cuda(_) => format!("cuda_{}", name.into()), + Device::Metal(_) => format!("metal_{}", name.into()), + } + } +} + +struct BenchDeviceHandler { + devices: Vec, +} + +impl BenchDeviceHandler { + pub fn new() -> Result { + let mut devices = Vec::new(); + if cfg!(feature = "metal") { + devices.push(Device::new_metal(0)?); + } else if cfg!(feature = "cuda") { + devices.push(Device::new_cuda(0)?); + } + devices.push(Device::Cpu); + Ok(Self { devices }) + } +} From 402349d120716b49459dafdc79b906b54a5579ea Mon Sep 17 00:00:00 2001 From: Kyle McCarthy Date: Thu, 11 Jan 2024 08:49:13 -0600 Subject: [PATCH 22/46] feat(bf16): add cast support + tests for cast + bin ops (#1524) --- candle-core/Cargo.toml | 1 - candle-core/src/metal_backend.rs | 54 ++++++++- candle-metal-kernels/Cargo.toml | 9 +- candle-metal-kernels/src/cast.metal | 40 +++++- candle-metal-kernels/src/indexing.metal | 3 + candle-metal-kernels/src/tests.rs | 154 ++++++++++++++++++++++-- 6 files changed, 243 insertions(+), 18 deletions(-) diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index d9fc7526..92a04917 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -48,4 +48,3 @@ metal = ["dep:metal", "dep:candle-metal-kernels"] [[bench]] name = "bench_main" harness = false - diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 5d72bd68..aa2898ff 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -590,14 +590,26 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::U8) => "cast_u32_u8", (DType::U32, DType::I64) => "cast_u32_i64", + (DType::U32, DType::BF16) => "cast_u32_bf16", + (DType::U8, DType::U32) => "cast_u8_u32", (DType::U8, DType::F32) => "cast_u8_f32", (DType::U8, DType::I64) => "cast_u8_i64", + (DType::U8, DType::BF16) => "cast_u8_bf16", + (DType::F32, DType::F16) => "cast_f32_f16", - (DType::F16, DType::F32) => "cast_f16_f32", - (DType::I64, DType::F32) => "cast_i64_f32", (DType::F32, DType::BF16) => "cast_f32_bf16", + + (DType::I64, DType::F32) => "cast_i64_f32", + + (DType::F16, DType::BF16) => "cast_f16_bf16", + (DType::F16, DType::F32) => "cast_f16_f32", + + (DType::BF16, DType::U8) => "cast_bf16_u8", + (DType::BF16, DType::U32) => "cast_bf16_u32", + (DType::BF16, DType::F16) => "cast_bf16_f16", (DType::BF16, DType::F32) => "cast_bf16_f32", + (left, right) => { crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented") } @@ -1131,8 +1143,12 @@ impl BackendStorage for MetalStorage { let device = self.device(); let buffer = device.new_buffer(dst_el, dtype, "index_select")?; let name = match (ids.dtype, self.dtype) { + (DType::U8, DType::BF16) => "is_u8_bf16", + (DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F16) => "is_u32_f16", + (DType::U32, DType::BF16) => "is_u32_bf16", + (left, right) => { crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented") } @@ -1322,6 +1338,7 @@ impl MetalStorage { ("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8), ("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8), ("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8), + ("add", DType::F16) => (contiguous::add::HALF, self.dtype), ("sub", DType::F16) => (contiguous::sub::HALF, self.dtype), ("mul", DType::F16) => (contiguous::mul::HALF, self.dtype), @@ -1332,6 +1349,18 @@ impl MetalStorage { ("lt", DType::F16) => (contiguous::lt::HALF, DType::U8), ("ge", DType::F16) => (contiguous::ge::HALF, DType::U8), ("gt", DType::F16) => (contiguous::gt::HALF, DType::U8), + + ("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype), + ("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype), + ("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype), + ("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype), + ("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8), + ("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8), + ("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8), + ("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8), + ("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8), + ("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8), + ("add", DType::I64) => (contiguous::add::I64, self.dtype), ("sub", DType::I64) => (contiguous::sub::I64, self.dtype), ("mul", DType::I64) => (contiguous::mul::I64, self.dtype), @@ -1342,6 +1371,7 @@ impl MetalStorage { ("lt", DType::I64) => (contiguous::lt::I64, DType::U8), ("ge", DType::I64) => (contiguous::ge::I64, DType::U8), ("gt", DType::I64) => (contiguous::gt::I64, DType::U8), + ("add", DType::U32) => (contiguous::add::U32, self.dtype), ("sub", DType::U32) => (contiguous::sub::U32, self.dtype), ("mul", DType::U32) => (contiguous::mul::U32, self.dtype), @@ -1352,6 +1382,7 @@ impl MetalStorage { ("lt", DType::U32) => (contiguous::lt::U32, DType::U8), ("ge", DType::U32) => (contiguous::ge::U32, DType::U8), ("gt", DType::U32) => (contiguous::gt::U32, DType::U8), + ("add", DType::U8) => (contiguous::add::U8, self.dtype), ("sub", DType::U8) => (contiguous::sub::U8, self.dtype), ("mul", DType::U8) => (contiguous::mul::U8, self.dtype), @@ -1362,6 +1393,7 @@ impl MetalStorage { ("lt", DType::U8) => (contiguous::lt::U8, DType::U8), ("ge", DType::U8) => (contiguous::ge::U8, DType::U8), ("gt", DType::U8) => (contiguous::gt::U8, DType::U8), + (name, dtype) => { crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented") } @@ -1395,6 +1427,7 @@ impl MetalStorage { ("lt", DType::F32) => (strided::lt::FLOAT, DType::U8), ("ge", DType::F32) => (strided::ge::FLOAT, DType::U8), ("gt", DType::F32) => (strided::gt::FLOAT, DType::U8), + ("badd", DType::F16) => (strided::add::HALF, self.dtype), ("bsub", DType::F16) => (strided::sub::HALF, self.dtype), ("bmul", DType::F16) => (strided::mul::HALF, self.dtype), @@ -1407,6 +1440,20 @@ impl MetalStorage { ("lt", DType::F16) => (strided::lt::HALF, DType::U8), ("ge", DType::F16) => (strided::ge::HALF, DType::U8), ("gt", DType::F16) => (strided::gt::HALF, DType::U8), + + ("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype), + ("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype), + ("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype), + ("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype), + ("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype), + ("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype), + ("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8), + ("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8), + ("le", DType::BF16) => (strided::le::BFLOAT, DType::U8), + ("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8), + ("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8), + ("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8), + ("badd", DType::I64) => (strided::add::I64, self.dtype), ("bsub", DType::I64) => (strided::sub::I64, self.dtype), ("bmul", DType::I64) => (strided::mul::I64, self.dtype), @@ -1419,6 +1466,7 @@ impl MetalStorage { ("lt", DType::I64) => (strided::lt::I64, DType::U8), ("ge", DType::I64) => (strided::ge::I64, DType::U8), ("gt", DType::I64) => (strided::gt::I64, DType::U8), + ("badd", DType::U32) => (strided::add::U32, self.dtype), ("bsub", DType::U32) => (strided::sub::U32, self.dtype), ("bmul", DType::U32) => (strided::mul::U32, self.dtype), @@ -1431,6 +1479,7 @@ impl MetalStorage { ("lt", DType::U32) => (strided::lt::U32, DType::U8), ("ge", DType::U32) => (strided::ge::U32, DType::U8), ("gt", DType::U32) => (strided::gt::U32, DType::U8), + ("badd", DType::U8) => (strided::add::U8, self.dtype), ("bsub", DType::U8) => (strided::sub::U8, self.dtype), ("bmul", DType::U8) => (strided::mul::U8, self.dtype), @@ -1443,6 +1492,7 @@ impl MetalStorage { ("lt", DType::U8) => (strided::lt::U8, DType::U8), ("ge", DType::U8) => (strided::ge::U8, DType::U8), ("gt", DType::U8) => (strided::gt::U8, DType::U8), + (name, dtype) => { crate::bail!("Metal strided binary {name} {dtype:?} not implemented") } diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 441d2e88..187cb4fd 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -9,12 +9,17 @@ keywords = ["blas", "tensor", "machine-learning"] categories = ["science"] license = "MIT OR Apache-2.0" + [dependencies] -metal = { version = "0.27.0", features = ["mps"]} +metal = { version = "0.27.0", features = ["mps"] } once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" [dev-dependencies] -half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } +half = { version = "2.3.1", features = [ + "num-traits", + "use-intrinsics", + "rand_distr", +] } rand = "0.8.5" diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 5aacac4a..e08931cf 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -28,7 +28,7 @@ kernel void FN_NAME( \ if (tid >= dim) { \ return; \ } \ - output[tid] = RIGHT_TYPENAME(input[tid]); \ + output[tid] = static_cast(input[tid]); \ } \ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -42,7 +42,34 @@ kernel void FN_NAME_STRIDED( \ if (tid >= dim) { \ return; \ } \ - output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \ + output[tid] = static_cast(input[get_strided_index(tid, num_dims, dims, strides)]); \ +} \ + +#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = static_cast(static_cast(input[tid])); \ +} \ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = static_cast(static_cast(input[get_strided_index(tid, num_dims, dims, strides)])); \ } \ CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) @@ -59,6 +86,15 @@ CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) #endif #if defined(__HAVE_BFLOAT__) +#if __METAL_VERSION__ >= 310 +CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) + +CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) +CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat) CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) + +CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) +CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float) +CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) #endif diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 32f3f410..2a57bdbb 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -174,6 +174,9 @@ SCATTER_ADD_OP(sa_u32_f16, uint, half) #if defined(__HAVE_BFLOAT__) +INDEX_OP(is_u32_bf16, uint32_t, bfloat) +INDEX_OP(is_u8_bf16, uint8_t, bfloat) + INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index c955abca..87f8ac45 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,6 +1,6 @@ use super::*; use half::{bf16, f16}; -use metal::{Device, MTLResourceOptions}; +use metal::{Buffer, Device, MTLResourceOptions}; fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { let ptr = buffer.contents() as *const T; @@ -248,6 +248,34 @@ fn binary_add_f32() { assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]); } +#[test] +fn binary_ops_bf16() { + let lhs: Vec = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect(); + let rhs: Vec = [4.2f32, 5.5f32, 6.91f32] + .into_iter() + .map(bf16::from_f32) + .collect(); + + macro_rules! binary_op { + ($opname:ident, $opexpr:expr) => {{ + let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT); + let expected: Vec = lhs + .iter() + .zip(rhs.iter()) + .map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y)) + .collect(); + assert_eq!(results, expected); + }}; + } + + binary_op!(add, |x, y| x + y); + binary_op!(sub, |x, y| x - y); + binary_op!(mul, |x, y| x * y); + binary_op!(div, |x, y| x / y); + binary_op!(min, |x: bf16, y| x.min(y)); + binary_op!(max, |x: bf16, y| x.max(y)); +} + fn cast(v: &[T], name: &'static str) -> Vec { let device = device(); let fence = device.new_fence(); @@ -296,6 +324,89 @@ fn cast_u32_f32() { assert_eq!(results, vec![1.0f32; 10_000]); } +#[test] +fn it_cast_bf16_u32() { + let input: Vec = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec = cast(&input, "cast_bf16_u32"); + let expected: Vec = (1..=3).map(|v| v as u32).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_bf16_f32() { + let input: Vec = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec = cast(&input, "cast_bf16_f32"); + let expected: Vec = (1..=3).map(|v| v as f32).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_u8_bf16() { + let input: Vec = (1..=3).map(|v| v as u8).collect(); + + let output: Vec = cast(&input, "cast_u8_bf16"); + let expected: Vec = input + .iter() + .map(|v| bf16::from_f32(*v as f32)) + .collect::>(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_u32_bf16() { + let input: Vec = (1..=3).map(|v| v as u32).collect(); + + let output: Vec = cast(&input, "cast_u32_bf16"); + let expected: Vec = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_f32_bf16() { + let input: Vec = (1..=3).map(|v| v as f32).collect(); + + let output: Vec = cast(&input, "cast_f32_bf16"); + let expected: Vec = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_bf16_u8() { + let input: Vec = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec = cast(&input, "cast_bf16_u8"); + let expected: Vec = input.iter().map(|v| v.to_f32() as u8).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_bf16_f16() { + let input: Vec = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec = cast(&input, "cast_bf16_f16"); + let expected: Vec = input.iter().map(|v| f16::from_f32(v.to_f32())).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_f16_bf16() { + let input: Vec = (1..=3).map(|v| f16::from_f32(v as f32)).collect(); + + let output: Vec = cast(&input, "cast_f16_bf16"); + let expected: Vec = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect(); + + assert_eq!(output, expected); +} + fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { let device = device(); let fence = device.new_fence(); @@ -396,14 +507,14 @@ fn index_select() { let shape = [5, 2]; let ids = [0u32, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]); let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [2, 5]; let ids = [0u32, 1, 0]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); assert_eq!( result, vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] @@ -419,20 +530,46 @@ fn index_select_f16() { let shape = [5, 2]; let ids = [0u32, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16"); assert_eq!( approx_f16(result, 4), vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] ); } +#[test] +fn index_select_is_u32_bf16() { + let embedding: Vec = (1..=10).map(|x| bf16::from_f32(x as f32)).collect(); + let shape = [5, 2]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16"); + assert_eq!( + approx_bf16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} + +#[test] +fn index_select_is_u8_bf16() { + let embedding: Vec = (1..=10).map(|x| bf16::from_f32(x as f32)).collect(); + let shape = [5, 2]; + let ids = [0u8, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16"); + assert_eq!( + approx_bf16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} + #[test] fn index_select_dim1() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; let ids = [0u32, 1, 0]; let dim = 1; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); assert_eq!( result, vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0] @@ -444,6 +581,7 @@ fn run_index_select( shape: &[usize], ids: &[I], dim: usize, + name: &'static str, ) -> Vec { let device = Device::system_default().expect("no device found"); @@ -457,12 +595,6 @@ fn run_index_select( let dst_el = ids.len() * left_size * right_size; let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); - let name = match core::mem::size_of::() { - 4 => "is_u32_f32", - 2 => "is_u32_f16", - _ => unimplemented!(), - }; - let fence = device.new_fence(); let kernels = Kernels::new(fence); call_index_select( From 1327419776c244867a101b3ff1dc0b9247ed0650 Mon Sep 17 00:00:00 2001 From: Baye Dieng Date: Thu, 11 Jan 2024 17:14:12 +0000 Subject: [PATCH 23/46] close ifdef --- candle-metal-kernels/src/cast.metal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index e08931cf..ac40bc16 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -89,7 +89,6 @@ CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) #if __METAL_VERSION__ >= 310 CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) - CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat) CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) @@ -98,3 +97,4 @@ CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float) CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) #endif +#endif \ No newline at end of file From 85e568027731e58b72fb2798c525a5d8aff65eb8 Mon Sep 17 00:00:00 2001 From: Baye Dieng Date: Thu, 11 Jan 2024 21:02:03 +0000 Subject: [PATCH 24/46] remove metal version check --- candle-metal-kernels/src/cast.metal | 2 -- 1 file changed, 2 deletions(-) diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index ac40bc16..9aead139 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -86,7 +86,6 @@ CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) #endif #if defined(__HAVE_BFLOAT__) -#if __METAL_VERSION__ >= 310 CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) @@ -96,5 +95,4 @@ CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float) CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) -#endif #endif \ No newline at end of file From 41915184bb3e530cc8184fdd8841c66df9285684 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 11 Jan 2024 23:15:11 +0100 Subject: [PATCH 25/46] Bugfix for dequantizing q5k layers. (#1569) --- candle-core/src/quantized/k_quants.rs | 8 ++++---- candle-core/tests/quantized_tests.rs | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index d16289e6..6210ac1e 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1545,13 +1545,13 @@ impl GgmlType for BlockQ5K { let d2 = d * sc as f32; let m2 = min * m as f32; for (ql, qh) in ql.iter().zip(qh) { - let to_add = if qh & u1 != 0 { 16 } else { 1 }; - y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1; + let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 }; + y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1; ys_index += 1; } for (ql, qh) in ql.iter().zip(qh) { - let to_add = if qh & u2 != 0 { 16 } else { 1 }; - y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2; + let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 }; + y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2; ys_index += 1; } is += 2; diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index e7a2ea7f..d31e77a7 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -407,7 +407,7 @@ fn quantize_q5k() -> Result<()> { let dst = round_vector(&dst); assert_eq!( [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], - [-0.499, -0.372, -0.249, 0.001, 0.279, 0.499] + [-0.5, -0.373, -0.25, 0.0, 0.279, 0.499] ); let (src_big, mut dst_big) = get_test_vector(128.0, 1024); From e06e8d0dbea3a052195f4ca27fb5ddcdbf1cd30c Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 12 Jan 2024 07:26:42 +0100 Subject: [PATCH 26/46] fmt --- candle-metal-kernels/src/tests.rs | 38 +++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index b15505f7..775ee0fa 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -975,7 +975,6 @@ fn run_random(name: &'static str, seed: u64, length: usize, a: f32, b: .unwrap(); } - command_buffer.commit(); command_buffer.wait_until_completed(); @@ -984,7 +983,6 @@ fn run_random(name: &'static str, seed: u64, length: usize, a: f32, b: #[test] fn random() { - fn calc_mean(data: &[f32]) -> f32 { let sum = data.iter().sum::() as f32; let count = data.len(); @@ -997,10 +995,14 @@ fn random() { let count = data.len(); assert!(count > 0); - let variance = data.iter().map(|value| { - let diff = mean - (*value as f32); - diff * diff - }).sum::() / count as f32; + let variance = data + .iter() + .map(|value| { + let diff = mean - (*value as f32); + diff * diff + }) + .sum::() + / count as f32; variance.sqrt() } @@ -1017,11 +1019,29 @@ fn random() { macro_rules! validate_random { ($type:ty) => { - let results: Vec = run_random::<$type>(concat!("rand_uniform_", stringify!($type)), seed, length, min, max).into_iter().map(f32::from).collect(); + let results: Vec = run_random::<$type>( + concat!("rand_uniform_", stringify!($type)), + seed, + length, + min, + max, + ) + .into_iter() + .map(f32::from) + .collect(); results.iter().for_each(|v| assert!(*v >= min && *v <= max)); assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0); - let results: Vec = run_random::<$type>(concat!("rand_normal_", stringify!($type)), seed, length, mean, stddev).into_iter().map(f32::from).collect(); + let results: Vec = run_random::<$type>( + concat!("rand_normal_", stringify!($type)), + seed, + length, + mean, + stddev, + ) + .into_iter() + .map(f32::from) + .collect(); assert!((calc_mean(&results) - mean).abs() < mean / 10.0); assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0); }; @@ -1030,4 +1050,4 @@ fn random() { validate_random!(f32); validate_random!(f16); validate_random!(bf16); -} \ No newline at end of file +} From 6242276c0970db6e5805feed4c2ef3b0bf2ba413 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 12 Jan 2024 09:19:30 +0100 Subject: [PATCH 27/46] Pin the revision used for phi-v2 + make it the default. (#1572) * Pin the revision used for phi-v2 + make it the default. * Tweak the custom-ops build. --- candle-examples/build.rs | 6 ------ candle-examples/examples/phi/main.rs | 7 +++---- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/candle-examples/build.rs b/candle-examples/build.rs index ba40aeb4..33497714 100644 --- a/candle-examples/build.rs +++ b/candle-examples/build.rs @@ -27,11 +27,5 @@ fn main() -> Result<()> { bindings.write(kdir.rust_target).unwrap() } } - #[cfg(not(feature = "cuda"))] - { - for kdir in KERNEL_DIRS.iter() { - let _file = std::fs::File::create(kdir.rust_target)?; - } - } Ok(()) } diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index c529867b..c5c7de28 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -169,7 +169,7 @@ struct Args { #[arg(long)] model_id: Option, - #[arg(long, default_value = "1.5")] + #[arg(long, default_value = "2")] model: WhichModel, #[arg(long)] @@ -247,9 +247,8 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "refs/pr/2".to_string(), WhichModel::V1_5 => "refs/pr/18".to_string(), - WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { - "main".to_string() - } + WhichModel::V2 => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(), + WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(), } } } From 8e06bfb4fd33f1229a03abee20cc1c07198408b5 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 12 Jan 2024 09:59:29 +0100 Subject: [PATCH 28/46] Mention VGG in the readme. (#1573) --- README.md | 5 ++++- candle-examples/examples/repvgg/README.md | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 93cbccc4..c4f27548 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,9 @@ We also provide a some command line based examples using state of the art models - [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained using self-supervision (can be used for imagenet classification, depth evaluation, segmentation). +- [VGG](./candle-examples/examples/vgg/), + [RepVGG](./candle-examples/examples/repvgg): computer vision models. +- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to - [BLIP](./candle-examples/examples/blip/): image to text model, can be used to generate captions for an image. - [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation @@ -204,7 +207,7 @@ If you have an addition to this list, please submit a pull request. - Image to text. - BLIP. - Computer Vision Models. - - DINOv2, ConvMixer, EfficientNet, ResNet, ViT. + - DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG. - yolo-v3, yolo-v8. - Segment-Anything Model (SAM). - File formats: load models from safetensors, npz, ggml, or PyTorch files. diff --git a/candle-examples/examples/repvgg/README.md b/candle-examples/examples/repvgg/README.md index 2cb807c1..d24bcd6d 100644 --- a/candle-examples/examples/repvgg/README.md +++ b/candle-examples/examples/repvgg/README.md @@ -1,7 +1,9 @@ # candle-repvgg -A candle implementation of inference using a pre-trained [repvgg](https://arxiv.org/abs/2101.03697). -This uses a classification head trained on the ImageNet dataset and returns the +[RepVGG: Making VGG-style ConvNets Great Again](https://arxiv.org/abs/2101.03697). + +This candle implementation uses a pre-trained RepVGG network for inference. The +classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes. ## Running an example From e90bcdcc7c51dd85037055b59f22568100d801f0 Mon Sep 17 00:00:00 2001 From: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 12 Jan 2024 11:18:11 +0100 Subject: [PATCH 29/46] Metal: f16 and bf16 where_cond + benchmark (#1545) * Use cfg to seperate benchmark results based on features * Add metal where_cond for f16 and bf16. Add benchmark * Remove allow pragma * Avoid some unnecessary returns. * Improve benchmarks layout * Updated feature separated benchmarks --------- Co-authored-by: Laurent --- candle-core/benches/bench_main.rs | 2 +- candle-core/benches/benchmarks/mod.rs | 1 + candle-core/benches/benchmarks/where_cond.rs | 64 ++++++++++++++++++ candle-core/src/metal_backend.rs | 1 + .../examples/custom-ops/cuda_kernels.rs | 1 - candle-metal-kernels/src/ternary.metal | 66 ++++++++++++------- 6 files changed, 110 insertions(+), 25 deletions(-) create mode 100644 candle-core/benches/benchmarks/where_cond.rs diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 4425f2fb..92c33a86 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -1,4 +1,4 @@ mod benchmarks; use criterion::criterion_main; -criterion_main!(benchmarks::matmul::benches); +criterion_main!(benchmarks::matmul::benches, benchmarks::where_cond::benches); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index 295bbabd..4e73ebb6 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod matmul; +pub(crate) mod where_cond; use candle_core::{Device, Result}; diff --git a/candle-core/benches/benchmarks/where_cond.rs b/candle-core/benches/benchmarks/where_cond.rs new file mode 100644 index 00000000..c517dcf5 --- /dev/null +++ b/candle-core/benches/benchmarks/where_cond.rs @@ -0,0 +1,64 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use std::time::Instant; + +fn run(a: &Tensor, b: &Tensor, c: &Tensor) { + a.where_cond(b, c).unwrap(); +} + +const fn create_cond_arr() -> [u8; N] { + let mut arr = [0u8; N]; + let mut i = 0; + while i < N { + arr[i] = (i % 2) as u8; + i += 1; + } + arr +} + +const B: usize = 1; +const M: usize = 1024; +const K: usize = 1024; +const SIZE: usize = B * M * K; + +const DATA: [u8; SIZE] = create_cond_arr::(); + +fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap(); + let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap(); + let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap(); + + let elements = B * M * K; + // E.g. 2 f32 tensors + 1 u8 tensor + let flops = (2 * elements * dtype.size_in_bytes()) + elements; + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run( + black_box(&tensor), + black_box(&on_true), + black_box(&on_false), + ); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let device = BenchDeviceHandler::new().unwrap(); + for d in device.devices { + run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32"); + run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16"); + run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index aa2898ff..38f909c8 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -806,6 +806,7 @@ impl BackendStorage for MetalStorage { } let name = match (self.dtype, t.dtype()) { (DType::U8, DType::F32) => "where_u8_f32", + (DType::U8, DType::BF16) => "where_u8_bf16", (DType::U8, DType::F16) => "where_u8_f16", (DType::U8, DType::I64) => "where_u8_i64", (DType::U8, DType::U32) => "where_u8_u32", diff --git a/candle-examples/examples/custom-ops/cuda_kernels.rs b/candle-examples/examples/custom-ops/cuda_kernels.rs index c00b601b..e69de29b 100644 --- a/candle-examples/examples/custom-ops/cuda_kernels.rs +++ b/candle-examples/examples/custom-ops/cuda_kernels.rs @@ -1 +0,0 @@ -pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx")); diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index 40b4bcf4..7b3b8ca9 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -17,29 +17,45 @@ METAL_FUNC uint get_strided_index( return strided_i; } +template +METAL_FUNC void where_cond( + constant size_t &numel, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides, + constant size_t *strides_t, + constant size_t *strides_f, + device const ID *ids, + device const T *t, + device const T *f, + device T *out, + uint i [[ thread_position_in_grid ]] +) { + if (i >= numel){ + return; + } + uint strided_i = get_strided_index(i, num_dims, dims, strides); + uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); + uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); + out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; +} -#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \ -kernel void FN_NAME( \ - constant size_t &numel, \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant size_t *strides_t, \ - constant size_t *strides_f, \ - device const ID_TYPENAME *ids, \ - device const TYPENAME *t, \ - device const TYPENAME *f, \ - device TYPENAME *out ,\ - uint i [[ thread_position_in_grid ]] \ -) { \ - if (i >= numel){ \ - return; \ - } \ - uint strided_i = get_strided_index(i, num_dims, dims, strides); \ - uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \ - uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \ - out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \ -} \ +#define WHERE_OP(T, ID, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &numel, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t *strides_t, \ + constant size_t *strides_f, \ + device const ID *ids, \ + device const T *t, \ + device const T *f, \ + device T *out, \ + uint i [[ thread_position_in_grid ]] \ +) { \ + where_cond(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \ +} \ // WHERE_OP(float, int64_t, where_i64_f32) // WHERE_OP(double, int64_t, where_i64_f64) @@ -54,10 +70,14 @@ kernel void FN_NAME( \ // WHERE_OP(int64_t, uint32_t, where_u32_i64) WHERE_OP(float, uint8_t, where_u8_f32) -// WHERE_OP(double, uint8_t, where_u8_f64) +WHERE_OP(half, uint8_t, where_u8_f16) WHERE_OP(uint8_t, uint8_t, where_u8_u8) WHERE_OP(uint32_t, uint8_t, where_u8_u32) #if __METAL_VERSION__ >= 220 WHERE_OP(int64_t, uint8_t, where_u8_i64) #endif + +#if defined(__HAVE_BFLOAT__) +WHERE_OP(bfloat, uint8_t, where_u8_bf16) +#endif \ No newline at end of file From a3d92ab226ffc33743f4388a814d7dfe7fbe2809 Mon Sep 17 00:00:00 2001 From: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 12 Jan 2024 11:19:49 +0100 Subject: [PATCH 30/46] Metal: Activate bfloat affine and add benchmark (#1543) * Use cfg to seperate benchmark results based on features * Add bfloat affine and benchmarks * Fix flops calculation * Remove allow pragma * Avoid some unnecessary returns. * Improve benchmarks layout --------- Co-authored-by: Laurent Co-authored-by: Nicolas Patry --- candle-core/benches/bench_main.rs | 2 +- candle-core/benches/benchmarks/affine.rs | 43 ++++++++++++++++++++++++ candle-core/benches/benchmarks/mod.rs | 1 + candle-core/src/metal_backend.rs | 2 ++ candle-metal-kernels/src/affine.metal | 14 ++++---- 5 files changed, 54 insertions(+), 8 deletions(-) create mode 100644 candle-core/benches/benchmarks/affine.rs diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 92c33a86..4e508a39 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -1,4 +1,4 @@ mod benchmarks; use criterion::criterion_main; -criterion_main!(benchmarks::matmul::benches, benchmarks::where_cond::benches); +criterion_main!(benchmarks::matmul::benches, benchmarks::affine::benches, benchmarks::where_cond::benches); \ No newline at end of file diff --git a/candle-core/benches/benchmarks/affine.rs b/candle-core/benches/benchmarks/affine.rs new file mode 100644 index 00000000..eded9f57 --- /dev/null +++ b/candle-core/benches/benchmarks/affine.rs @@ -0,0 +1,43 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use std::time::Instant; + +fn run(a: &Tensor) { + a.affine(12.34, 56.78).unwrap(); +} + +fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let b = 1; + let m = 1024; + let k = 1024; + + let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap(); + + let flops = b * m * k * dtype.size_in_bytes(); + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&tensor)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_affine_benchmark(c, &device, DType::F32, "affine_f32"); + run_affine_benchmark(c, &device, DType::F16, "affine_f16"); + run_affine_benchmark(c, &device, DType::BF16, "affine_bf16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index 4e73ebb6..7dacff5e 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod affine; pub(crate) mod matmul; pub(crate) mod where_cond; diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 38f909c8..5269a899 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -353,6 +353,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "affine_f32", DType::F16 => "affine_f16", + DType::BF16 => "affine_bf16", dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"), }; candle_metal_kernels::call_affine( @@ -371,6 +372,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "affine_f32_strided", DType::F16 => "affine_f16_strided", + DType::BF16 => "affine_bf16_strided", dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"), }; candle_metal_kernels::call_affine_strided( diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index 3d8e7f0d..a4484998 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -17,19 +17,19 @@ METAL_FUNC uint get_strided_index( using namespace metal; -#define AFFINE(FN_NAME, TYPENAME) \ +#define AFFINE(FN_NAME, T) \ kernel void FN_NAME( \ constant size_t &dim, \ constant float &mul, \ constant float &add, \ - device const TYPENAME *input, \ - device TYPENAME *output, \ + device const T *input, \ + device T *output, \ uint id [[ thread_position_in_grid ]] \ ) { \ if (id >= dim) { \ return; \ } \ - output[id] = TYPENAME(float(input[id]) * mul + add); \ + output[id] = T(fma(float(input[id]), mul, add)); \ } \ kernel void FN_NAME##_strided( \ constant size_t &dim, \ @@ -38,14 +38,14 @@ kernel void FN_NAME##_strided( \ constant size_t *strides, \ constant float &mul, \ constant float &add, \ - device const TYPENAME *input, \ - device TYPENAME *output, \ + device const T *input, \ + device T *output, \ uint id [[ thread_position_in_grid ]] \ ) { \ if (id >= dim) { \ return; \ } \ - output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \ + output[id] = T(fma(float(input[get_strided_index(id, num_dims, dims, strides)]), mul, add)); \ } #define POWF(FN_NAME, TYPENAME) \ From bafe95b660048999a3bb000b3509d04fb1bb1789 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 12 Jan 2024 14:23:17 +0100 Subject: [PATCH 31/46] Fix format. (#1576) --- candle-core/benches/bench_main.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 4e508a39..24dca7d4 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -1,4 +1,8 @@ mod benchmarks; use criterion::criterion_main; -criterion_main!(benchmarks::matmul::benches, benchmarks::affine::benches, benchmarks::where_cond::benches); \ No newline at end of file +criterion_main!( + benchmarks::matmul::benches, + benchmarks::affine::benches, + benchmarks::where_cond::benches +); From a46864bd5650c4707753f3d95d7b4ff6b0905995 Mon Sep 17 00:00:00 2001 From: SebastianRueClausen <51479502+SebastianRueClausen@users.noreply.github.com> Date: Fri, 12 Jan 2024 17:47:07 +0100 Subject: [PATCH 32/46] Fix "Minimal Mamba" link in README. (#1577) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c4f27548..14172742 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ We also provide a some command line based examples using state of the art models - [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b. - [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM pre-trained on 1T tokens of English and code datasets. -- [Minimal Mamba](./candle-examples/examples/minimal-mamba/): a minimal +- [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal implementation of the Mamba state space model. - [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with better performance than all publicly available 13b models as of 2023-09-28. From 539ead927a12a485637f7f04f8212cfdabe00fa4 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 13 Jan 2024 17:38:27 +0100 Subject: [PATCH 33/46] Update the Phi model to use the updated architecture. (#1580) * Update the Phi model to use the updated architecture. * Add more of the phi model. * Repeat KV + caching. * Apply the rotary embeddings. * Add support for the new phi model in the phi example. * Fix a couple glitches. * Fix a couple more glitches. --- candle-examples/examples/phi/main.rs | 46 +++- candle-nn/src/activation.rs | 1 + candle-transformers/src/models/mod.rs | 1 + candle-transformers/src/models/phi.rs | 365 ++++++++++++++++++++++++++ 4 files changed, 402 insertions(+), 11 deletions(-) create mode 100644 candle-transformers/src/models/phi.rs diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index c5c7de28..ea99c706 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -8,6 +8,7 @@ use anyhow::{Error as E, Result}; use clap::{Parser, ValueEnum}; use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer}; +use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi}; use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer; use candle::{DType, Device, Tensor}; @@ -18,6 +19,7 @@ use tokenizers::Tokenizer; enum Model { MixFormer(MixFormer), + Phi(Phi), Quantized(QMixFormer), } @@ -84,6 +86,7 @@ impl TextGeneration { let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; let logits = match &mut self.model { Model::MixFormer(m) => m.forward(&input)?, + Model::Phi(m) => m.forward(&input)?, Model::Quantized(m) => m.forward(&input)?, }; let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; @@ -117,7 +120,7 @@ impl TextGeneration { } } -#[derive(Clone, Copy, Debug, ValueEnum)] +#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)] enum WhichModel { #[value(name = "1")] V1, @@ -125,6 +128,9 @@ enum WhichModel { V1_5, #[value(name = "2")] V2, + // TODO: Make this the default once it has been battle tested. + #[value(name = "2-new")] + V2New, PuffinPhiV2, PhiHermes, } @@ -230,7 +236,7 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "microsoft/phi-1".to_string(), WhichModel::V1_5 => "microsoft/phi-1_5".to_string(), - WhichModel::V2 => "microsoft/phi-2".to_string(), + WhichModel::V2 | WhichModel::V2New => "microsoft/phi-2".to_string(), WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { "lmz/candle-quantized-phi".to_string() } @@ -248,7 +254,9 @@ fn main() -> Result<()> { WhichModel::V1 => "refs/pr/2".to_string(), WhichModel::V1_5 => "refs/pr/18".to_string(), WhichModel::V2 => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(), - WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(), + WhichModel::V2New | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { + "main".to_string() + } } } } @@ -257,7 +265,9 @@ fn main() -> Result<()> { let tokenizer_filename = match args.tokenizer { Some(file) => std::path::PathBuf::from(file), None => match args.model { - WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?, + WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2New => { + repo.get("tokenizer.json")? + } WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { repo.get("tokenizer-puffin-phi-v2.json")? } @@ -270,14 +280,14 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?], WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?], - WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?], + WhichModel::V2 | WhichModel::V2New => vec![repo.get("model-v2-q4k.gguf")?], WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?], WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?], } } else { match args.model { WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?], - WhichModel::V2 => candle_examples::hub_load_safetensors( + WhichModel::V2 | WhichModel::V2New => candle_examples::hub_load_safetensors( &repo, "model.safetensors.index.json", )?, @@ -291,25 +301,35 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); - let config = match args.model { + let config = || match args.model { WhichModel::V1 => Config::v1(), WhichModel::V1_5 => Config::v1_5(), - WhichModel::V2 => Config::v2(), + WhichModel::V2 | WhichModel::V2New => Config::v2(), WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(), }; - let (model, device) = if args.quantized { + let (model, device) = if args.model == WhichModel::V2New { + let device = candle_examples::device(args.cpu)?; + let config_filename = repo.get("config.json")?; + let config = std::fs::read_to_string(config_filename)?; + let config: PhiConfig = serde_json::from_str(&config)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; + let phi = Phi::new(&config, vb)?; + (Model::Phi(phi), device) + } else if args.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?; + let config = config(); let model = match args.model { - WhichModel::V2 => QMixFormer::new_v2(&config, vb)?, + WhichModel::V2 | WhichModel::V2New => QMixFormer::new_v2(&config, vb)?, _ => QMixFormer::new(&config, vb)?, }; (Model::Quantized(model), Device::Cpu) } else { let device = candle_examples::device(args.cpu)?; + let config = config(); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; let model = match args.model { - WhichModel::V2 => MixFormer::new_v2(&config, vb)?, + WhichModel::V2 | WhichModel::V2New => MixFormer::new_v2(&config, vb)?, _ => MixFormer::new(&config, vb)?, }; (Model::MixFormer(model), device) @@ -392,6 +412,10 @@ fn mmlu>( m.clear_kv_cache(); m.forward(&input)? } + Model::Phi(m) => { + m.clear_kv_cache(); + m.forward(&input)? + } Model::Quantized(m) => { m.clear_kv_cache(); m.forward(&input)? diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index 80b750ed..e00463f0 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -6,6 +6,7 @@ use serde::Deserialize; pub enum Activation { #[default] Gelu, + #[serde(alias = "gelu_new")] NewGelu, Relu, Relu2, diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index a60b5a06..9af6df69 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -17,6 +17,7 @@ pub mod mixformer; pub mod mixtral; pub mod mpt; pub mod persimmon; +pub mod phi; pub mod quantized_blip; pub mod quantized_blip_text; pub mod quantized_llama; diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs new file mode 100644 index 00000000..a635f3ce --- /dev/null +++ b/candle-transformers/src/models/phi.rs @@ -0,0 +1,365 @@ +use crate::models::with_tracing::{layer_norm, linear, Embedding, LayerNorm, Linear}; +/// Phi model. +/// https://huggingface.co/microsoft/phi-2 +/// There is an alternative implementation of the phi model in mixformers.rs. +/// This corresponds to the model update made with the following commit: +/// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869 +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle_nn::{Activation, VarBuilder}; +use serde::Deserialize; + +// https://huggingface.co/microsoft/phi-2/blob/main/configuration_phi.py +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + pub(crate) vocab_size: usize, + pub(crate) hidden_size: usize, + pub(crate) intermediate_size: usize, + pub(crate) num_hidden_layers: usize, + pub(crate) num_attention_heads: usize, + pub(crate) num_key_value_heads: Option, + pub(crate) hidden_act: Activation, + pub(crate) max_position_embeddings: usize, + pub(crate) layer_norm_eps: f64, + pub(crate) tie_word_embeddings: bool, + pub(crate) rope_theta: f32, + pub(crate) partial_rotary_factor: f64, + pub(crate) qk_layernorm: bool, +} + +impl Config { + fn num_key_value_heads(&self) -> usize { + self.num_key_value_heads.unwrap_or(self.num_attention_heads) + } + + fn head_dim(&self) -> usize { + self.hidden_size / self.num_attention_heads + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + fn new(cfg: &Config, dev: &Device) -> Result { + let dim = (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; + let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)? + .to_dtype(DType::F32)? + .reshape((cfg.max_position_embeddings, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result { + let (_b_size, seqlen, _, _headdim) = xs.dims4()?; + let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?; + let rotary_dim = rotary_dim * 2; + let xs_rot = xs.i((.., .., .., ..rotary_dim))?; + let xs_pass = xs.i((.., .., .., rotary_dim..))?; + let xs12 = xs_rot.chunk(2, D::Minus1)?; + let (xs1, xs2) = (&xs12[0], &xs12[1]); + let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?; + let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?; + let xs_rot = Tensor::cat( + &[ + (xs1.broadcast_mul(&c)? - xs2.broadcast_mul(&s)?)?, + (xs1.broadcast_mul(&s)? + xs2.broadcast_mul(&c)?)?, + ], + D::Minus1, + )?; + Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1) + } +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + fc1: Linear, + fc2: Linear, + act: Activation, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?; + let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?; + Ok(Self { + fc1, + fc2, + act: cfg.hidden_act, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2) + } +} + +#[derive(Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + dense: Linear, + kv_cache: Option<(Tensor, Tensor)>, + q_layernorm: Option, + k_layernorm: Option, + rotary_emb: RotaryEmbedding, + softmax_scale: f64, + num_heads: usize, + num_kv_heads: usize, + head_dim: usize, + span: tracing::Span, +} + +fn get_mask(size: usize, device: &Device) -> Result { + let mask: Vec<_> = (0..size) + .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .collect(); + Tensor::from_slice(&mask, (size, size), device) +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +impl Attention { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads(); + let head_dim = cfg.head_dim(); + let q_proj = linear(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?; + let k_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?; + let v_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?; + let dense = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("dense"))?; + // Alternative rope scalings are not supported. + let rotary_emb = RotaryEmbedding::new(cfg, vb.device())?; + let (q_layernorm, k_layernorm) = if cfg.qk_layernorm { + let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?; + let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?; + (Some(q_layernorm), Some(k_layernorm)) + } else { + (None, None) + }; + let softmax_scale = 1f64 / (head_dim as f64).sqrt(); + Ok(Self { + q_proj, + k_proj, + v_proj, + dense, + kv_cache: None, + q_layernorm, + k_layernorm, + rotary_emb, + softmax_scale, + num_heads, + num_kv_heads, + head_dim, + span: tracing::span!(tracing::Level::TRACE, "attention"), + }) + } + + fn repeat_kv(&self, xs: Tensor) -> Result { + let n_rep = self.num_heads / self.num_kv_heads; + if n_rep == 1 { + Ok(xs) + } else { + let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; + xs.unsqueeze(2)? + .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? + .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) + } + } + + fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { + let _enter = self.span.enter(); + let (b_size, seq_len, _n_embd) = xs.dims3()?; + let query_states = self.q_proj.forward(xs)?; + let key_states = self.k_proj.forward(xs)?; + let value_states = self.v_proj.forward(xs)?; + + let query_states = match &self.q_layernorm { + None => query_states, + Some(ln) => query_states.apply(ln)?, + }; + let key_states = match &self.k_layernorm { + None => key_states, + Some(ln) => key_states.apply(ln)?, + }; + + let query_states = query_states + .reshape((b_size, seq_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let key_states = key_states + .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let value_states = value_states + .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + // Rotary embeddings. + let seqlen_offset = match &self.kv_cache { + None => 0, + Some((prev_k, _)) => prev_k.dim(1)?, + }; + let query_states = self + .rotary_emb + .apply_rotary_emb(&query_states, seqlen_offset)?; + let key_states = self + .rotary_emb + .apply_rotary_emb(&key_states, seqlen_offset)?; + + // KV cache. + let (key_states, value_states) = match &self.kv_cache { + None => (key_states, value_states), + Some((prev_k, prev_v)) => { + let k = Tensor::cat(&[prev_k, &key_states], 2)?; + let v = Tensor::cat(&[prev_v, &value_states], 2)?; + (k, v) + } + }; + self.kv_cache = Some((key_states.clone(), value_states.clone())); + + // Repeat kv. + let key_states = self.repeat_kv(key_states)?.contiguous()?; + let value_states = self.repeat_kv(value_states)?.contiguous()?; + + let attn_weights = (query_states + .to_dtype(DType::F32)? + .contiguous()? + .matmul(&key_states.to_dtype(DType::F32)?.t()?)? + * self.softmax_scale)?; + let attn_weights = match mask { + None => attn_weights, + Some(mask) => masked_fill( + &attn_weights, + &mask.broadcast_left((b_size, self.num_heads))?, + f32::NEG_INFINITY, + )?, + }; + let attn_weights = + candle_nn::ops::softmax_last_dim(&attn_weights)?.to_dtype(value_states.dtype())?; + let attn_output = attn_weights.matmul(&value_states)?; + let attn_output = attn_output + .transpose(1, 2)? + .reshape((b_size, seq_len, ()))?; + attn_output.apply(&self.dense) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } +} + +#[derive(Clone)] +struct DecoderLayer { + self_attn: Attention, + mlp: MLP, + input_layernorm: LayerNorm, + span: tracing::Span, +} + +impl DecoderLayer { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let self_attn = Attention::new(cfg, vb.pp("self_attn"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + let input_layernorm = layer_norm( + cfg.hidden_size, + cfg.layer_norm_eps, + vb.pp("input_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + input_layernorm, + span: tracing::span!(tracing::Level::TRACE, "block"), + }) + } + + fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { + let _enter = self.span.enter(); + let residual = xs; + let xs = xs.apply(&self.input_layernorm)?; + let attn_outputs = self.self_attn.forward(&xs, mask)?; + let feed_forward_hidden_states = self.mlp.forward(&xs)?; + attn_outputs + feed_forward_hidden_states + residual + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache() + } +} + +#[derive(Clone)] +pub struct Model { + embed_tokens: Embedding, + layers: Vec, + final_layernorm: LayerNorm, + lm_head: Linear, + span: tracing::Span, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_m = vb.pp("model"); + let embed_tokens = + Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let final_layernorm = layer_norm( + cfg.hidden_size, + cfg.layer_norm_eps, + vb_m.pp("final_layernorm"), + )?; + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_m = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(cfg, vb_m.pp(layer_idx))?; + layers.push(layer) + } + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + Ok(Self { + embed_tokens, + layers, + final_layernorm, + lm_head, + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward(&mut self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let (_b_size, seq_len) = xs.dims2()?; + let mut xs = xs.apply(&self.embed_tokens)?; + let mask = if seq_len <= 1 { + None + } else { + Some(get_mask(seq_len, xs.device())?) + }; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, mask.as_ref())? + } + xs.apply(&self.final_layernorm)? + .narrow(1, seq_len - 1, 1)? + .apply(&self.lm_head)? + .squeeze(1) + } + + pub fn clear_kv_cache(&mut self) { + self.layers.iter_mut().for_each(|b| b.clear_kv_cache()) + } +} From 88618255cb3c20b511a2f0e6db35d84081ce3c4a Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 13 Jan 2024 19:44:41 +0100 Subject: [PATCH 34/46] Fix the rotary embeddings for the new phi implementation. (#1582) * Fix the rotary embeddings for the new phi implementation. * Match the activation. * KV cache fix. * Use the config activation function. --- candle-transformers/src/models/phi.rs | 34 +++++++++++++-------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs index a635f3ce..8bf357e7 100644 --- a/candle-transformers/src/models/phi.rs +++ b/candle-transformers/src/models/phi.rs @@ -38,6 +38,7 @@ impl Config { #[derive(Debug, Clone)] struct RotaryEmbedding { + dim: usize, sin: Tensor, cos: Tensor, } @@ -55,29 +56,24 @@ impl RotaryEmbedding { .to_dtype(DType::F32)? .reshape((cfg.max_position_embeddings, 1))?; let freqs = t.matmul(&inv_freq)?; + let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { - sin: freqs.sin()?, - cos: freqs.cos()?, + dim, + sin: emb.sin()?, + cos: emb.cos()?, }) } fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result { - let (_b_size, seqlen, _, _headdim) = xs.dims4()?; - let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?; - let rotary_dim = rotary_dim * 2; - let xs_rot = xs.i((.., .., .., ..rotary_dim))?; - let xs_pass = xs.i((.., .., .., rotary_dim..))?; + let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?; + let xs_rot = xs.i((.., .., .., ..self.dim))?; + let xs_pass = xs.i((.., .., .., self.dim..))?; let xs12 = xs_rot.chunk(2, D::Minus1)?; let (xs1, xs2) = (&xs12[0], &xs12[1]); - let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?; - let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?; - let xs_rot = Tensor::cat( - &[ - (xs1.broadcast_mul(&c)? - xs2.broadcast_mul(&s)?)?, - (xs1.broadcast_mul(&s)? + xs2.broadcast_mul(&c)?)?, - ], - D::Minus1, - )?; + let c = self.cos.narrow(0, seqlen_offset, seq_len)?; + let s = self.sin.narrow(0, seqlen_offset, seq_len)?; + let rotate_half = Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)?; + let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?; Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1) } } @@ -97,6 +93,8 @@ impl MLP { Ok(Self { fc1, fc2, + // This does not match the mixformers implementation where Gelu is used rather than + // GeluNew. act: cfg.hidden_act, }) } @@ -216,7 +214,7 @@ impl Attention { // Rotary embeddings. let seqlen_offset = match &self.kv_cache { None => 0, - Some((prev_k, _)) => prev_k.dim(1)?, + Some((prev_k, _)) => prev_k.dim(2)?, }; let query_states = self .rotary_emb @@ -351,7 +349,7 @@ impl Model { Some(get_mask(seq_len, xs.device())?) }; for layer in self.layers.iter_mut() { - xs = layer.forward(&xs, mask.as_ref())? + xs = layer.forward(&xs, mask.as_ref())?; } xs.apply(&self.final_layernorm)? .narrow(1, seq_len - 1, 1)? From e6d86b081980196745e5f0b0eda8ce5334c0ff67 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 13 Jan 2024 20:24:06 +0100 Subject: [PATCH 35/46] Add the pow operator. (#1583) * Add the pow operator. * Support the pow operation in onnx. --- candle-core/src/tensor.rs | 12 +++++++++++- candle-core/tests/tensor_tests.rs | 16 ++++++++++++++-- candle-onnx/src/eval.rs | 6 ++++++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 54f9fa2b..3100c6e8 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2578,11 +2578,21 @@ impl Tensor { } /// Returns log(sum(exp(tensor), dim)). - pub fn logsumexp(&self, sum_dims: D) -> Result { + pub fn log_sum_exp(&self, sum_dims: D) -> Result { let exp = self.exp()?; let sum = exp.sum(sum_dims)?; sum.log() } + + /// Pointwise pow operation. + pub fn pow(&self, rhs: &Tensor) -> Result { + rhs.mul(&self.log()?)?.exp() + } + + /// Broadcasting version of `pow`. + pub fn broadcast_pow(&self, rhs: &Tensor) -> Result { + rhs.broadcast_mul(&self.log()?)?.exp() + } } macro_rules! bin_trait { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index e83fb55b..33bab1b6 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1245,11 +1245,23 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> { } #[test] -fn logsumexp() -> Result<()> { +fn log_sum_exp() -> Result<()> { let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; - let output = input.logsumexp(D::Minus1)?; + let output = input.log_sum_exp(D::Minus1)?; // The expectations obtained from pytorch. let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?; assert_close(&output, &expected, 0.00001)?; Ok(()) } + +#[test] +fn pow() -> Result<()> { + let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + let rhs = (&lhs - 2.)?; + let res = lhs.pow(&rhs)?; + assert_eq!( + test_utils::to_vec2_round(&res, 4)?, + [[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]] + ); + Ok(()) +} diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 684776c2..c0ad8668 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -254,6 +254,12 @@ pub fn simple_eval( let output = input0.broadcast_div(input1)?; values.insert(node.output[0].clone(), output); } + "Pow" => { + let input0 = get(&node.input[0])?; + let input1 = get(&node.input[1])?; + let output = input0.broadcast_pow(input1)?; + values.insert(node.output[0].clone(), output); + } "Equal" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; From bdd8107fda4cd02f6a37330ad6c395f70abbdcbc Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 14 Jan 2024 20:09:49 +0100 Subject: [PATCH 36/46] Expose the ndarray trait. (#1586) --- candle-core/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 6c4fea91..f2aed1b6 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -72,7 +72,7 @@ pub mod utils; mod variable; pub use cpu_backend::CpuStorage; -pub use device::{Device, DeviceLocation}; +pub use device::{Device, DeviceLocation, NdArray}; pub use dtype::{DType, FloatDType, IntDType, WithDType}; pub use error::{Error, Result}; pub use indexer::IndexOp; From 86b7c01b306ba4f5e25172682c0e3034d9aa0cfb Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 15 Jan 2024 09:44:51 +0100 Subject: [PATCH 37/46] Update gemm to the latest version. (#1587) --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 2225c42e..0aef12f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ candle-transformers = { path = "./candle-transformers" } clap = { version = "4.2.4", features = ["derive"] } criterion = { version = "0.5.1", default-features=false } cudarc = { version = "0.10.0", features = ["f16"] } -gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] } +gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } hf-hub = "0.3.0" half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] } From 79478ff5a1eab89f6e638ad7e7abd587b0f5b167 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sun, 14 Jan 2024 18:10:54 +0100 Subject: [PATCH 38/46] Seed should be updated by random kernel result. --- candle-core/src/metal_backend.rs | 35 ++++++++++++++++++++------ candle-metal-kernels/src/lib.rs | 12 ++++++--- candle-metal-kernels/src/random.metal | 36 +++++++++++++++++++-------- candle-metal-kernels/src/tests.rs | 20 ++++++++++----- 4 files changed, 76 insertions(+), 27 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 8a75bd7c..673e6e11 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -4,9 +4,13 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; use candle_metal_kernels::Kernels; +use cudarc::driver::DeviceRepr; use metal; -use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; +use metal::{ + Buffer, CommandBuffer, CommandQueue, MTLPurgeableState, MTLResourceOptions, NSUInteger, +}; use std::collections::HashMap; +use std::ffi::c_void; use std::path::Path; use std::sync::{Arc, Mutex, RwLock, TryLockError}; @@ -107,7 +111,7 @@ pub struct MetalDevice { /// (strong_count = 1). buffers: AllocatedBuffers, /// Seed for random number generation. - seed: Arc>, + seed: Arc>, } impl std::fmt::Debug for MetalDevice { @@ -234,7 +238,7 @@ impl MetalDevice { // The slice might not live long enough for metal // To actually fill the GPU buffer. // Putting this wait forces the GPU buffer to be filled - // with the actual data allowing the CPU storage todo + // with the actual data allowing the CPU storage to do // deallocate properly. self.wait_until_completed()?; Ok(real) @@ -1542,7 +1546,12 @@ impl BackendDevice for MetalDevice { Ok(val) => val.parse()?, _ => 20, }; - let seed = Arc::new(Mutex::new(299792458)); + let s = device.new_buffer_with_data( + 299792458 as *const u32 as *const c_void, + 4, + MTLResourceOptions::StorageModeManaged, + )?; + let seed = Arc::new(Mutex::new(s)); Ok(Self { device, fence, @@ -1624,10 +1633,10 @@ impl BackendDevice for MetalDevice { &command_buffer, &self.kernels, name, - *self.seed.lock().unwrap(), min as f32, max as f32, shape.elem_count(), + &*self.seed.lock().unwrap(), &buffer, ) .map_err(MetalError::from)?; @@ -1655,10 +1664,10 @@ impl BackendDevice for MetalDevice { &command_buffer, &self.kernels, name, - *self.seed.lock().unwrap(), mean as f32, stddev as f32, shape.elem_count(), + &*self.seed.lock().unwrap(), &buffer, ) .map_err(MetalError::from)?; @@ -1667,8 +1676,20 @@ impl BackendDevice for MetalDevice { } fn set_seed(&self, seed: u64) -> Result<()> { + if seed > u32::MAX as u64 { + MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())? + } + let seed = seed as u32; + let mut s = self.seed.try_lock().map_err(MetalError::from)?; - *s = seed; + *s.set_purgeable_state(MTLPurgeableState::Empty); + + *s = self.device.new_buffer_with_data( + &seed as *const u32 as *const c_void, + 8, + MTLResourceOptions::StorageModeManaged, + )?; + Ok(()) } } diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index c427a690..6a10c333 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1587,10 +1587,10 @@ pub fn call_random_uniform( command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, - seed: u64, min: f32, max: f32, length: usize, + seed: &Buffer, buffer: &Buffer, ) -> Result<(), MetalKernelError> { if min >= max { @@ -1607,8 +1607,10 @@ pub fn call_random_uniform( encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, seed, min, max, buffer)); + set_params!(encoder, (length, min, max, seed, buffer)); + encoder.use_resource(seed, metal::MTLResourceUsage::Read); + encoder.use_resource(seed, metal::MTLResourceUsage::Write); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); @@ -1623,10 +1625,10 @@ pub fn call_random_normal( command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, - seed: u64, mean: f32, stddev: f32, length: usize, + seed: &Buffer, buffer: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Random, name)?; @@ -1638,8 +1640,10 @@ pub fn call_random_normal( encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, seed, mean, stddev, buffer)); + set_params!(encoder, (length, mean, stddev, seed, buffer)); + encoder.use_resource(seed, metal::MTLResourceUsage::Read); + encoder.use_resource(seed, metal::MTLResourceUsage::Write); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); diff --git a/candle-metal-kernels/src/random.metal b/candle-metal-kernels/src/random.metal index 5369e8e2..5eae2715 100644 --- a/candle-metal-kernels/src/random.metal +++ b/candle-metal-kernels/src/random.metal @@ -1,4 +1,7 @@ #include +#include +#include + using namespace metal; // Constants @@ -107,72 +110,85 @@ struct HybridTaus { } }; +METAL_FUNC float absdiff(float x, float y) { + return abs(x - y); +} + template METAL_FUNC void rand_uniform( constant size_t &size, - constant ulong &seed, constant float &min, constant float &max, + device atomic_uint *seed, device T *out, uint tid [[thread_position_in_grid]] ) { if (tid >= size) { return; } - float diff = max - min; - HybridTaus rng = HybridTaus::init({seed, tid, 1, 1}); + + float diff = absdiff(min, max); + HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1}); out[tid] = static_cast(rng.rand() * diff + min); out[size - tid] = static_cast(rng.rand() * diff + min); + + if (tid == 0) { + atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); + } } // Create Gaussian normal distribution using Box-Muller transform: // https://en.wikipedia.org/wiki/Box–Muller_transform template METAL_FUNC void normal( constant size_t &size, - constant ulong &seed, constant float &mean, constant float &stddev, + device atomic_uint *seed, device T *out, uint tid [[thread_position_in_grid]] ) { if (tid >= size) { return; } - HybridTaus rng = HybridTaus::init({seed, tid, 1, 1}); + HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1}); float u1 = rng.rand(); float u2 = rng.rand(); float cosval; - float sinval = sincos(u1 * TWO_PI, cosval); + float sinval = sincos(TWO_PI * u2, cosval); float mag = stddev * sqrt(-2.0 * log(u1)); float z0 = mag * cosval + mean; float z1 = mag * sinval + mean; out[tid] = static_cast(z0); out[size - tid] = static_cast(z1); + + if (tid == 0) { + atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); + } } #define UNIFORM_OP(NAME, T) \ kernel void rand_uniform_##NAME( \ constant size_t &size, \ - constant ulong &seed, \ constant float &min, \ constant float &max, \ + device atomic_uint *seed, \ device T *out, \ uint tid [[thread_position_in_grid]] \ ) { \ - rand_uniform(size, seed, min, max, out, tid); \ + rand_uniform(size, min, max, seed, out, tid); \ } \ #define NORMAL_OP(NAME, T) \ kernel void rand_normal_##NAME( \ constant size_t &size, \ - constant ulong &seed, \ constant float &mean, \ constant float &stddev, \ + device atomic_uint *seed, \ device T *out, \ uint tid [[thread_position_in_grid]] \ ) { \ - normal(size, seed, mean, stddev, out, tid); \ + normal(size, mean, stddev, seed, out, tid); \ } \ diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 775ee0fa..2831a386 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -938,14 +938,21 @@ fn gemm() { ); } -fn run_random(name: &'static str, seed: u64, length: usize, a: f32, b: f32) -> Vec { +fn run_random(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec { let device = device(); let fence = device.new_fence(); let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; - let output = device.new_buffer((length * core::mem::size_of::()) as u64, options); + let output = device.new_buffer((length * core::mem::size_of::()) as NSUInteger, options); + + let seed = device.new_buffer_with_data( + &seed as *const u32 as *const core::ffi::c_void, + std::mem::size_of::() as NSUInteger, + options, + ); if name.starts_with("rand_uniform") { call_random_uniform( @@ -953,10 +960,10 @@ fn run_random(name: &'static str, seed: u64, length: usize, a: f32, b: command_buffer, &kernels, name, - seed, a, b, length, + &seed, &output, ) .unwrap(); @@ -966,15 +973,14 @@ fn run_random(name: &'static str, seed: u64, length: usize, a: f32, b: command_buffer, &kernels, name, - seed, a, b, length, + &seed, &output, ) .unwrap(); } - command_buffer.commit(); command_buffer.wait_until_completed(); @@ -1029,7 +1035,9 @@ fn random() { .into_iter() .map(f32::from) .collect(); - results.iter().for_each(|v| assert!(*v >= min && *v <= max)); + results.iter().for_each(|v| { + assert!(*v >= min && *v <= max); + }); assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0); let results: Vec = run_random::<$type>( From ea36f3b11feb7408207413bce1611bdc33d449f4 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 15 Jan 2024 12:30:27 +0100 Subject: [PATCH 39/46] Use the new phi model by default. (#1589) --- candle-examples/examples/phi/main.rs | 55 +++++++++++++++------------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index ea99c706..69eed84f 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -128,9 +128,8 @@ enum WhichModel { V1_5, #[value(name = "2")] V2, - // TODO: Make this the default once it has been battle tested. - #[value(name = "2-new")] - V2New, + #[value(name = "2-old")] + V2Old, PuffinPhiV2, PhiHermes, } @@ -236,7 +235,7 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "microsoft/phi-1".to_string(), WhichModel::V1_5 => "microsoft/phi-1_5".to_string(), - WhichModel::V2 | WhichModel::V2New => "microsoft/phi-2".to_string(), + WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(), WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { "lmz/candle-quantized-phi".to_string() } @@ -251,10 +250,10 @@ fn main() -> Result<()> { "main".to_string() } else { match args.model { - WhichModel::V1 => "refs/pr/2".to_string(), - WhichModel::V1_5 => "refs/pr/18".to_string(), - WhichModel::V2 => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(), - WhichModel::V2New | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { + WhichModel::V1 => "refs/pr/8".to_string(), + WhichModel::V1_5 => "refs/pr/73".to_string(), + WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(), + WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { "main".to_string() } } @@ -265,7 +264,7 @@ fn main() -> Result<()> { let tokenizer_filename = match args.tokenizer { Some(file) => std::path::PathBuf::from(file), None => match args.model { - WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2New => { + WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old => { repo.get("tokenizer.json")? } WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { @@ -280,14 +279,14 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?], WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?], - WhichModel::V2 | WhichModel::V2New => vec![repo.get("model-v2-q4k.gguf")?], + WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?], WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?], WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?], } } else { match args.model { WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?], - WhichModel::V2 | WhichModel::V2New => candle_examples::hub_load_safetensors( + WhichModel::V2 | WhichModel::V2Old => candle_examples::hub_load_safetensors( &repo, "model.safetensors.index.json", )?, @@ -304,35 +303,39 @@ fn main() -> Result<()> { let config = || match args.model { WhichModel::V1 => Config::v1(), WhichModel::V1_5 => Config::v1_5(), - WhichModel::V2 | WhichModel::V2New => Config::v2(), + WhichModel::V2 | WhichModel::V2Old => Config::v2(), WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(), }; - let (model, device) = if args.model == WhichModel::V2New { - let device = candle_examples::device(args.cpu)?; - let config_filename = repo.get("config.json")?; - let config = std::fs::read_to_string(config_filename)?; - let config: PhiConfig = serde_json::from_str(&config)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; - let phi = Phi::new(&config, vb)?; - (Model::Phi(phi), device) - } else if args.quantized { + let (model, device) = if args.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?; let config = config(); let model = match args.model { - WhichModel::V2 | WhichModel::V2New => QMixFormer::new_v2(&config, vb)?, + WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?, _ => QMixFormer::new(&config, vb)?, }; (Model::Quantized(model), Device::Cpu) } else { let device = candle_examples::device(args.cpu)?; - let config = config(); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; let model = match args.model { - WhichModel::V2 | WhichModel::V2New => MixFormer::new_v2(&config, vb)?, - _ => MixFormer::new(&config, vb)?, + WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => { + let config_filename = repo.get("config.json")?; + let config = std::fs::read_to_string(config_filename)?; + let config: PhiConfig = serde_json::from_str(&config)?; + let phi = Phi::new(&config, vb)?; + Model::Phi(phi) + } + WhichModel::V2Old => { + let config = config(); + Model::MixFormer(MixFormer::new_v2(&config, vb)?) + } + WhichModel::PhiHermes | WhichModel::PuffinPhiV2 => { + let config = config(); + Model::MixFormer(MixFormer::new(&config, vb)?) + } }; - (Model::MixFormer(model), device) + (model, device) }; println!("loaded the model in {:?}", start.elapsed()); From 1257fc6719caec595c14db0254b0b22280a37575 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Jan 2024 22:34:40 +0100 Subject: [PATCH 40/46] Update safetensors requirement from 0.3.1 to 0.4.1 (#1591) Updates the requirements on [safetensors](https://github.com/huggingface/safetensors) to permit the latest version. - [Release notes](https://github.com/huggingface/safetensors/releases) - [Changelog](https://github.com/huggingface/safetensors/blob/main/RELEASE.md) - [Commits](https://github.com/huggingface/safetensors/compare/v0.3.1...v0.3.3) --- updated-dependencies: - dependency-name: safetensors dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 0aef12f3..0d68f425 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,7 +58,7 @@ rand = "0.8.5" rand_distr = "0.4.3" rayon = "1.7.0" rusttype = { version = "0.9", default-features = false } -safetensors = "0.3.1" +safetensors = "0.4.1" serde = { version = "1.0.171", features = ["derive"] } serde_plain = "1.0.2" serde_json = "1.0.99" From 7e3349d7c3993b395ca53837dc0fa8fc63a13704 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Jan 2024 22:35:01 +0100 Subject: [PATCH 41/46] Update parquet requirement from 45.0.0 to 50.0.0 (#1592) Updates the requirements on [parquet](https://github.com/apache/arrow-rs) to permit the latest version. - [Changelog](https://github.com/apache/arrow-rs/blob/master/CHANGELOG-old.md) - [Commits](https://github.com/apache/arrow-rs/compare/45.0.0...45.0.0) --- updated-dependencies: - dependency-name: parquet dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 0d68f425..6c73a79c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ log = "0.4" memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] } num_cpus = "1.15.0" num-traits = "0.2.15" -parquet = { version = "45.0.0" } +parquet = { version = "50.0.0" } rand = "0.8.5" rand_distr = "0.4.3" rayon = "1.7.0" From 5270224f407502b82fe90bc2622894ce3871b002 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Tue, 16 Jan 2024 07:34:16 +0200 Subject: [PATCH 42/46] Add MobileOne model. (#1595) * Add MobileOne model. * Clippy fixes * Remove a comment. --------- Co-authored-by: laurent --- candle-examples/examples/mobileone/README.md | 22 ++ candle-examples/examples/mobileone/main.rs | 96 ++++++ candle-transformers/src/models/mobileone.rs | 333 +++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + 4 files changed, 452 insertions(+) create mode 100644 candle-examples/examples/mobileone/README.md create mode 100644 candle-examples/examples/mobileone/main.rs create mode 100644 candle-transformers/src/models/mobileone.rs diff --git a/candle-examples/examples/mobileone/README.md b/candle-examples/examples/mobileone/README.md new file mode 100644 index 00000000..b5e88b6f --- /dev/null +++ b/candle-examples/examples/mobileone/README.md @@ -0,0 +1,22 @@ +# candle-mobileone + +[MobileOne: An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040). + +This candle implementation uses a pre-trained MobileOne network for inference. The +classification head has been trained on the ImageNet dataset and returns the +probabilities for the top-5 classes. + +## Running an example + +``` +$ cargo run --example mobileone --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which s2 + +loaded image Tensor[dims 3, 224, 224; f32] +model built +mountain bike, all-terrain bike, off-roader: 79.33% +bicycle-built-for-two, tandem bicycle, tandem: 15.32% +crash helmet : 2.58% +unicycle, monocycle : 1.70% +alp : 0.21% + +``` diff --git a/candle-examples/examples/mobileone/main.rs b/candle-examples/examples/mobileone/main.rs new file mode 100644 index 00000000..4cd55001 --- /dev/null +++ b/candle-examples/examples/mobileone/main.rs @@ -0,0 +1,96 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::mobileone; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + S0, + S1, + S2, + S3, + S4, +} + +impl Which { + fn model_filename(&self) -> String { + let name = match self { + Self::S0 => "s0", + Self::S1 => "s1", + Self::S2 => "s2", + Self::S3 => "s3", + Self::S4 => "s4", + }; + format!("timm/mobileone_{}.apple_in1k", name) + } + + fn config(&self) -> mobileone::Config { + match self { + Self::S0 => mobileone::Config::s0(), + Self::S1 => mobileone::Config::s1(), + Self::S2 => mobileone::Config::s2(), + Self::S3 => mobileone::Config::s3(), + Self::S4 => mobileone::Config::s4(), + } + } +} + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(value_enum, long, default_value_t=Which::S0)] + which: Which, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image224(args.image)?; + println!("loaded image {image:?}"); + + let model_file = match args.model { + None => { + let model_name = args.which.model_filename(); + let api = hf_hub::api::sync::Api::new()?; + let api = api.model(model_name); + api.get("model.safetensors")? + } + Some(model) => model.into(), + }; + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = mobileone::mobileone(&args.which.config(), 1000, vb)?; + println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::()?; + let mut prs = prs.iter().enumerate().collect::>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!( + "{:24}: {:.2}%", + candle_examples::imagenet::CLASSES[category_idx], + 100. * pr + ); + } + Ok(()) +} diff --git a/candle-transformers/src/models/mobileone.rs b/candle-transformers/src/models/mobileone.rs new file mode 100644 index 00000000..674da40b --- /dev/null +++ b/candle-transformers/src/models/mobileone.rs @@ -0,0 +1,333 @@ +//! MobileOne inference implementation based on timm and candle-repvgg +//! +//! See "MobileOne: An Improved One millisecond Mobile Backbone" +//! https://arxiv.org/abs/2206.04040 + +use candle::{DType, Result, Tensor, D}; +use candle_nn::{ + batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, BatchNorm, Conv2d, Conv2dConfig, + Func, VarBuilder, +}; + +struct StageConfig { + blocks: usize, + channels: usize, +} + +// The architecture in the paper has 6 stages. The timm implementation uses an equivalent form +// by concatenating the 5th stage (starts with stride 1) to the previous one. +const STAGES: [StageConfig; 5] = [ + StageConfig { + blocks: 1, + channels: 64, + }, + StageConfig { + blocks: 2, + channels: 64, + }, + StageConfig { + blocks: 8, + channels: 128, + }, + StageConfig { + blocks: 10, + channels: 256, + }, + StageConfig { + blocks: 1, + channels: 512, + }, +]; + +#[derive(Clone)] +pub struct Config { + /// overparameterization factor + k: usize, + /// per-stage channel number multipliers + alphas: [f32; 5], +} + +impl Config { + pub fn s0() -> Self { + Self { + k: 4, + alphas: [0.75, 0.75, 1.0, 1.0, 2.0], + } + } + pub fn s1() -> Self { + Self { + k: 1, + alphas: [1.5, 1.5, 1.5, 2.0, 2.5], + } + } + pub fn s2() -> Self { + Self { + k: 1, + alphas: [1.5, 1.5, 2.0, 2.5, 4.0], + } + } + pub fn s3() -> Self { + Self { + k: 1, + alphas: [2.0, 2.0, 2.5, 3.0, 4.0], + } + } + pub fn s4() -> Self { + Self { + k: 1, + alphas: [3.0, 3.0, 3.5, 3.5, 4.0], + } + } +} + +// SE blocks are used in the last stages of the s4 variant. +fn squeeze_and_excitation( + in_channels: usize, + squeeze_channels: usize, + vb: VarBuilder, +) -> Result> { + let conv2d_cfg = Conv2dConfig { + ..Default::default() + }; + let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?; + let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?; + + Ok(Func::new(move |xs| { + let residual = xs; + let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?; + let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?; + + residual.broadcast_mul(&xs) + })) +} + +// fuses a convolutional kernel and a batchnorm layer into a convolutional layer +// based on the _fuse_bn_tensor method in timm +// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602 +fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> { + let (gamma, beta) = bn.weight_and_bias().unwrap(); + let mu = bn.running_mean(); + let sigma = (bn.running_var() + bn.eps())?.sqrt(); + let gps = (gamma / sigma)?; + let bias = (beta - mu * &gps)?; + let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?; + + Ok((weights, bias)) +} + +// A mobileone block has a different training time and inference time architecture. +// The latter is a simple and efficient equivalent transformation of the former +// realized by a structural reparameterization technique, where convolutions +// along with identity branches and batchnorm layers are fused into a single convolution. +#[allow(clippy::too_many_arguments)] +fn mobileone_block( + has_identity: bool, + k: usize, + dim: usize, + stride: usize, + padding: usize, + groups: usize, + kernel: usize, + in_channels: usize, + out_channels: usize, + vb: VarBuilder, +) -> Result> { + let conv2d_cfg = Conv2dConfig { + stride, + padding, + groups, + ..Default::default() + }; + + let mut w = Tensor::zeros( + (out_channels, in_channels / groups, kernel, kernel), + DType::F32, + vb.device(), + )?; + let mut b = Tensor::zeros(dim, DType::F32, vb.device())?; + + // k is the training-time overparameterization factor, larger than 1 only in the s0 variant + for i in 0..k { + let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp(format!("conv_kxk.{i}.bn")))?; + let conv_kxk = conv2d_no_bias( + in_channels, + out_channels, + kernel, + conv2d_cfg, + vb.pp(format!("conv_kxk.{i}.conv")), + )?; + let (wk, bk) = fuse_conv_bn(conv_kxk.weight(), conv_kxk_bn)?; + w = (w + wk)?; + b = (b + bk)?; + } + + if kernel > 1 { + let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"))?; + let conv_scale = conv2d_no_bias( + in_channels, + out_channels, + 1, + conv2d_cfg, + vb.pp("conv_scale.conv"), + )?; + + let (mut ws, bs) = fuse_conv_bn(conv_scale.weight(), conv_scale_bn)?; + // resize to 3x3 + ws = ws.pad_with_zeros(D::Minus1, 1, 1)?; + ws = ws.pad_with_zeros(D::Minus2, 1, 1)?; + + w = (w + ws)?; + b = (b + bs)?; + } + + // Use SE blocks if present (last layers of the s4 variant) + let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("attn")); + + // read and reparameterize the identity bn into wi and bi + if has_identity { + let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?; + + let mut weights: Vec = vec![0.0; w.elem_count()]; + + let id = in_channels / groups; + // See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809 + for i in 0..in_channels { + if kernel > 1 { + weights[i * kernel * kernel + 4] = 1.0; + } else { + weights[i * (id + 1)] = 1.0; + } + } + + let weights = &Tensor::from_vec(weights, w.shape(), w.device())?; + let (wi, bi) = fuse_conv_bn(weights, identity_bn)?; + + w = (w + wi)?; + b = (b + bi)?; + } + + let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg); + + Ok(Func::new(move |xs| { + let mut xs = xs.apply(&reparam_conv)?; + if let Ok(f) = &se { + xs = xs.apply(f)?; + } + xs = xs.relu()?; + Ok(xs) + })) +} + +// Get the number of output channels per stage taking into account the multipliers +fn output_channels_per_stage(cfg: &Config, stage: usize) -> usize { + let channels = STAGES[stage].channels as f32; + let alpha = cfg.alphas[stage]; + + match stage { + 0 => std::cmp::min(64, (channels * alpha) as usize), + _ => (channels * alpha) as usize, + } +} + +// Each stage is made of blocks. The first layer always downsamples with stride 2. +// All but the first block have a residual connection. +fn mobileone_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result> { + let nblocks = STAGES[idx].blocks; + let mut blocks = Vec::with_capacity(nblocks); + + let mut in_channels = output_channels_per_stage(cfg, idx - 1); + + for block_idx in 0..nblocks { + let out_channels = output_channels_per_stage(cfg, idx); + let (has_identity, stride) = if block_idx == 0 { + (false, 2) + } else { + (true, 1) + }; + + // depthwise convolution layer + blocks.push(mobileone_block( + has_identity, + cfg.k, + in_channels, + stride, + 1, + in_channels, + 3, + in_channels, + in_channels, + vb.pp(block_idx * 2), + )?); + + // pointwise convolution layer + blocks.push(mobileone_block( + has_identity, + cfg.k, + out_channels, + 1, // stride + 0, // padding + 1, // groups + 1, // kernel + in_channels, + out_channels, + vb.pp(block_idx * 2 + 1), + )?); + + in_channels = out_channels; + } + + Ok(Func::new(move |xs| { + let mut xs = xs.clone(); + for block in blocks.iter() { + xs = xs.apply(block)? + } + Ok(xs) + })) +} + +// Build a mobileone model for a given configuration. +fn mobileone_model( + config: &Config, + nclasses: Option, + vb: VarBuilder, +) -> Result> { + let cls = match nclasses { + None => None, + Some(nclasses) => { + let outputs = output_channels_per_stage(config, 4); + let linear = linear(outputs, nclasses, vb.pp("head.fc"))?; + Some(linear) + } + }; + + let stem_dim = output_channels_per_stage(config, 0); + let stem = mobileone_block(false, 1, stem_dim, 2, 1, 1, 3, 3, stem_dim, vb.pp("stem"))?; + let vb = vb.pp("stages"); + let stage1 = mobileone_stage(config, 1, vb.pp(0))?; + let stage2 = mobileone_stage(config, 2, vb.pp(1))?; + let stage3 = mobileone_stage(config, 3, vb.pp(2))?; + let stage4 = mobileone_stage(config, 4, vb.pp(3))?; + + Ok(Func::new(move |xs| { + let xs = xs + .apply(&stem)? + .apply(&stage1)? + .apply(&stage2)? + .apply(&stage3)? + .apply(&stage4)? + .mean(D::Minus2)? + .mean(D::Minus1)?; + match &cls { + None => Ok(xs), + Some(cls) => xs.apply(cls), + } + })) +} + +pub fn mobileone(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result> { + mobileone_model(cfg, Some(nclasses), vb) +} + +pub fn mobileone_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result> { + mobileone_model(cfg, None, vb) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 9af6df69..a94fd07a 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -15,6 +15,7 @@ pub mod marian; pub mod mistral; pub mod mixformer; pub mod mixtral; +pub mod mobileone; pub mod mpt; pub mod persimmon; pub mod phi; From 86a8e58897f012445de2f35318b19a89ebfaa327 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 16 Jan 2024 19:11:31 +0100 Subject: [PATCH 43/46] Update metal random kernel and set_seed method * set_seed via buffer content pointer copy + did_modify_range * ensure random.metal kernel does not write outside of buffer range when tid==0 --- candle-core/src/metal_backend.rs | 33 +++++++++++---------------- candle-metal-kernels/src/random.metal | 18 ++++++++------- 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 673e6e11..aa97c04a 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -4,11 +4,8 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; use candle_metal_kernels::Kernels; -use cudarc::driver::DeviceRepr; use metal; -use metal::{ - Buffer, CommandBuffer, CommandQueue, MTLPurgeableState, MTLResourceOptions, NSUInteger, -}; +use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; use std::ffi::c_void; use std::path::Path; @@ -1546,12 +1543,11 @@ impl BackendDevice for MetalDevice { Ok(val) => val.parse()?, _ => 20, }; - let s = device.new_buffer_with_data( - 299792458 as *const u32 as *const c_void, + let seed = Arc::new(Mutex::new(device.new_buffer_with_data( + [299792458].as_ptr() as *const c_void, 4, MTLResourceOptions::StorageModeManaged, - )?; - let seed = Arc::new(Mutex::new(s)); + ))); Ok(Self { device, fence, @@ -1676,19 +1672,16 @@ impl BackendDevice for MetalDevice { } fn set_seed(&self, seed: u64) -> Result<()> { - if seed > u32::MAX as u64 { - MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())? + let seed: u32 = seed.try_into().map_err(|_| { + MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string()) + })?; + + let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?; + let contents = seed_buffer.contents(); + unsafe { + std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4); } - let seed = seed as u32; - - let mut s = self.seed.try_lock().map_err(MetalError::from)?; - *s.set_purgeable_state(MTLPurgeableState::Empty); - - *s = self.device.new_buffer_with_data( - &seed as *const u32 as *const c_void, - 8, - MTLResourceOptions::StorageModeManaged, - )?; + seed_buffer.did_modify_range(metal::NSRange::new(0, 4)); Ok(()) } diff --git a/candle-metal-kernels/src/random.metal b/candle-metal-kernels/src/random.metal index 5eae2715..a7e48393 100644 --- a/candle-metal-kernels/src/random.metal +++ b/candle-metal-kernels/src/random.metal @@ -14,6 +14,7 @@ static constexpr constant int3 S1 = {13, 19, 12}; static constexpr constant int3 S2 = {2, 25, 4}; static constexpr constant int3 S3 = {3, 11, 17}; +// Used to prevent bad seeds. static constexpr constant uint64_t PHI[16] = { 0x9E3779B97F4A7C15, 0xF39CC0605CEDC834, @@ -110,10 +111,6 @@ struct HybridTaus { } }; -METAL_FUNC float absdiff(float x, float y) { - return abs(x - y); -} - template METAL_FUNC void rand_uniform( constant size_t &size, constant float &min, @@ -126,14 +123,16 @@ template METAL_FUNC void rand_uniform( return; } - float diff = absdiff(min, max); + float diff = abs(min - max); HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1}); out[tid] = static_cast(rng.rand() * diff + min); - out[size - tid] = static_cast(rng.rand() * diff + min); - if (tid == 0) { atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); + // Return early if tid == 0, otherwise we will write to out[size]. + return; } + // Use symmetry to fill the other half of the array. + out[size - tid] = static_cast(rng.rand() * diff + min); } // Create Gaussian normal distribution using Box-Muller transform: @@ -160,11 +159,14 @@ template METAL_FUNC void normal( float z1 = mag * sinval + mean; out[tid] = static_cast(z0); - out[size - tid] = static_cast(z1); if (tid == 0) { atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); + // Return early if tid == 0, otherwise we will write to out[size]. + return; } + // Use symmetry to fill the other half of the array. + out[size - tid] = static_cast(z1); } #define UNIFORM_OP(NAME, T) \ From 403680f17ddc086295fbaee316cbed22d97a519b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 17 Jan 2024 10:27:58 +0100 Subject: [PATCH 44/46] Quantized GGUF style (#1523) * Metal quantized modifications proposal. - Add a device param, wherever needed. - Create new QMetal storage thing that implements QuantizedType. - Update everywhere needed. Fix Python. Fixing examples. Fix: fmt + clippy + stub. Moving everything around. Only missing the actual implems. Fixing everything + adding dequantized kernels. More work. Fixing matmul. Fmt + Clippy Some clippy fixes. Working state. Q2K Metal -> Bugged (also present in GGML). Q4K CPU -> Bugged (present previously, new test catch it). Q5K CPU -> Bugged (present previously). Q8_1 Both -> Never really implemented it seems Q8K metal -> Never implemented in metal Fixing Q2K bug (present in ggml). * Cleanup. * Fix the rebase. * Removing the fences speeds everything up and *is* correct this time... * Cleanup the fence. * After rebase. * Bad code removal. * Rebase after phi2 merge + fix replit default to CPU. * Making the CI happy. * More happy tests. --------- Co-authored-by: Nicolas Patry --- candle-core/examples/tensor-tools.rs | 122 +- candle-core/src/metal_backend.rs | 112 +- candle-core/src/quantized/ggml_file.rs | 84 +- candle-core/src/quantized/gguf_file.rs | 28 +- candle-core/src/quantized/metal.rs | 153 + candle-core/src/quantized/mod.rs | 302 +- candle-core/tests/quantized_tests.rs | 571 +- candle-examples/examples/blip/main.rs | 4 +- candle-examples/examples/llama2-c/main.rs | 8 +- candle-examples/examples/mistral/main.rs | 7 +- candle-examples/examples/phi/main.rs | 16 +- candle-examples/examples/quantized-t5/main.rs | 3 +- candle-examples/examples/quantized/main.rs | 16 +- candle-examples/examples/replit-code/main.rs | 13 +- candle-examples/examples/stable-lm/main.rs | 5 +- candle-examples/examples/whisper/main.rs | 6 +- candle-metal-kernels/src/lib.rs | 228 +- candle-metal-kernels/src/quantized.metal | 5107 +++++++++++++++++ candle-metal-kernels/src/tests.rs | 33 +- candle-metal-kernels/src/unary.metal | 2 +- candle-nn/examples/cpu_benchmarks.rs | 5 +- candle-pyo3/py_src/candle/utils/__init__.pyi | 8 +- candle-pyo3/src/lib.rs | 51 +- .../src/models/quantized_llama.rs | 41 +- .../src/models/quantized_mixformer.rs | 4 +- .../src/quantized_var_builder.rs | 12 +- candle-wasm-examples/blip/src/bin/m.rs | 2 +- candle-wasm-examples/phi/src/bin/m.rs | 6 +- .../t5/src/bin/m-quantized.rs | 9 +- candle-wasm-examples/whisper/src/worker.rs | 1 + candle-wasm-tests/tests/quantized_tests.rs | 2 +- 31 files changed, 6446 insertions(+), 515 deletions(-) create mode 100644 candle-core/src/quantized/metal.rs create mode 100644 candle-metal-kernels/src/quantized.metal diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index 337021aa..eb6ceb1c 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -1,5 +1,5 @@ -use candle_core::quantized::{gguf_file, k_quants, QTensor}; -use candle_core::{Device, Result, Tensor}; +use candle_core::quantized::{gguf_file, GgmlDType, QTensor}; +use candle_core::{Device, Result}; use clap::{Parser, Subcommand, ValueEnum}; use rayon::prelude::*; @@ -11,12 +11,7 @@ enum QuantizationMode { } impl QuantizationMode { - fn quantize( - &self, - name: &str, - tensor: QTensor, - default: fn(&Tensor) -> Result, - ) -> Result { + fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result { match self { Self::Llama => { // Same behavior as the llama.cpp quantization. @@ -24,9 +19,9 @@ impl QuantizationMode { if should_quantize { let tensor = tensor.dequantize(&Device::Cpu)?; if name == "output.weight" { - QTensor::quantize::(&tensor) + QTensor::quantize(&tensor, GgmlDType::Q6K) } else { - default(&tensor) + QTensor::quantize(&tensor, dtype) } } else { Ok(tensor) @@ -60,6 +55,27 @@ enum Quantization { F32, } +impl Quantization { + fn dtype(&self) -> GgmlDType { + match self { + Quantization::Q4_0 => GgmlDType::Q4_0, + Quantization::Q4_1 => GgmlDType::Q4_1, + Quantization::Q5_0 => GgmlDType::Q5_0, + Quantization::Q5_1 => GgmlDType::Q5_1, + Quantization::Q8_0 => GgmlDType::Q8_0, + Quantization::Q8_1 => GgmlDType::Q8_1, + Quantization::Q2k => GgmlDType::Q2K, + Quantization::Q3k => GgmlDType::Q3K, + Quantization::Q4k => GgmlDType::Q4K, + Quantization::Q5k => GgmlDType::Q5K, + Quantization::Q6k => GgmlDType::Q6K, + Quantization::Q8k => GgmlDType::Q8K, + Quantization::F16 => GgmlDType::F16, + Quantization::F32 => GgmlDType::F32, + } + } +} + #[derive(ValueEnum, Debug, Clone)] enum Format { Safetensors, @@ -134,7 +150,12 @@ struct Args { command: Command, } -fn run_ls(file: &std::path::PathBuf, format: Option, verbose: bool) -> Result<()> { +fn run_ls( + file: &std::path::PathBuf, + format: Option, + verbose: bool, + device: &Device, +) -> Result<()> { let format = match format { Some(format) => format, None => match Format::infer(file) { @@ -200,7 +221,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option, verbose: bool) -> R } Format::Ggml => { let mut file = std::fs::File::open(file)?; - let content = candle_core::quantized::ggml_file::Content::read(&mut file)?; + let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?; let mut tensors = content.tensors.into_iter().collect::>(); tensors.sort_by(|a, b| a.0.cmp(&b.0)); for (name, qtensor) in tensors.iter() { @@ -241,37 +262,8 @@ fn run_quantize_safetensors( } println!("tensors: {}", tensors.len()); - let quantize_fn = match q { - Quantization::Q4_0 => QTensor::quantize::, - Quantization::Q4_1 => QTensor::quantize::, - Quantization::Q5_0 => QTensor::quantize::, - Quantization::Q5_1 => QTensor::quantize::, - Quantization::Q8_0 => QTensor::quantize::, - Quantization::Q8_1 => QTensor::quantize::, - Quantization::Q2k => QTensor::quantize::, - Quantization::Q3k => QTensor::quantize::, - Quantization::Q4k => QTensor::quantize::, - Quantization::Q5k => QTensor::quantize::, - Quantization::Q6k => QTensor::quantize::, - Quantization::Q8k => QTensor::quantize::, - Quantization::F16 => QTensor::quantize::, - Quantization::F32 => QTensor::quantize::, - }; - let block_size = match q { - Quantization::Q4_0 => k_quants::QK4_0, - Quantization::Q4_1 => k_quants::QK4_1, - Quantization::Q5_0 => k_quants::QK5_0, - Quantization::Q5_1 => k_quants::QK5_1, - Quantization::Q8_0 => k_quants::QK8_0, - Quantization::Q8_1 => k_quants::QK8_1, - Quantization::Q2k - | Quantization::Q3k - | Quantization::Q4k - | Quantization::Q5k - | Quantization::Q6k - | Quantization::Q8k => k_quants::QK_K, - Quantization::F16 | Quantization::F32 => 1, - }; + let dtype = q.dtype(); + let block_size = dtype.block_size(); let qtensors = tensors .into_par_iter() @@ -279,9 +271,9 @@ fn run_quantize_safetensors( let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0; println!(" quantizing {name} {tensor:?} {should_quantize}"); let tensor = if should_quantize { - quantize_fn(&tensor)? + QTensor::quantize(&tensor, dtype)? } else { - QTensor::quantize::(&tensor)? + QTensor::quantize(&tensor, GgmlDType::F32)? }; Ok((name, tensor)) }) @@ -294,13 +286,17 @@ fn run_quantize_safetensors( Ok(()) } -fn run_dequantize(in_file: std::path::PathBuf, out_file: std::path::PathBuf) -> Result<()> { +fn run_dequantize( + in_file: std::path::PathBuf, + out_file: std::path::PathBuf, + device: &Device, +) -> Result<()> { let mut in_file = std::fs::File::open(in_file)?; let content = gguf_file::Content::read(&mut in_file)?; let mut tensors = std::collections::HashMap::new(); for (tensor_name, _) in content.tensor_infos.iter() { - let tensor = content.tensor(&mut in_file, tensor_name)?; - let tensor = tensor.dequantize(&Device::Cpu)?; + let tensor = content.tensor(&mut in_file, tensor_name, device)?; + let tensor = tensor.dequantize(device)?; tensors.insert(tensor_name.to_string(), tensor); } candle_core::safetensors::save(&tensors, out_file)?; @@ -312,6 +308,7 @@ fn run_quantize( out_file: std::path::PathBuf, q: Quantization, qmode: QuantizationMode, + device: &Device, ) -> Result<()> { if in_files.is_empty() { candle_core::bail!("no specified input files") @@ -337,31 +334,15 @@ fn run_quantize( let content = gguf_file::Content::read(&mut in_)?; println!("tensors: {}", content.tensor_infos.len()); - let quantize_fn = match q { - Quantization::Q4_0 => QTensor::quantize::, - Quantization::Q4_1 => QTensor::quantize::, - Quantization::Q5_0 => QTensor::quantize::, - Quantization::Q5_1 => QTensor::quantize::, - Quantization::Q8_0 => QTensor::quantize::, - Quantization::Q8_1 => QTensor::quantize::, - Quantization::Q2k => QTensor::quantize::, - Quantization::Q3k => QTensor::quantize::, - Quantization::Q4k => QTensor::quantize::, - Quantization::Q5k => QTensor::quantize::, - Quantization::Q6k => QTensor::quantize::, - Quantization::Q8k => QTensor::quantize::, - Quantization::F16 => QTensor::quantize::, - Quantization::F32 => QTensor::quantize::, - }; - + let dtype = q.dtype(); let qtensors = content .tensor_infos .par_iter() .map(|(name, _)| { println!(" quantizing {name}"); let mut in_file = std::fs::File::open(&in_files[0])?; - let tensor = content.tensor(&mut in_file, name)?; - let tensor = qmode.quantize(name, tensor, quantize_fn)?; + let tensor = content.tensor(&mut in_file, name, device)?; + let tensor = qmode.quantize(name, tensor, dtype)?; Ok((name, tensor)) }) .collect::>>()?; @@ -381,6 +362,7 @@ fn run_quantize( fn main() -> anyhow::Result<()> { let args = Args::parse(); + let device = Device::Cpu; match args.command { Command::Ls { files, @@ -392,7 +374,7 @@ fn main() -> anyhow::Result<()> { if multiple_files { println!("--- {file:?} ---"); } - run_ls(file, format.clone(), verbose)? + run_ls(file, format.clone(), verbose, &device)? } } Command::Quantize { @@ -400,8 +382,8 @@ fn main() -> anyhow::Result<()> { out_file, quantization, mode, - } => run_quantize(&in_file, out_file, quantization, mode)?, - Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file)?, + } => run_quantize(&in_file, out_file, quantization, mode, &device)?, + Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?, } Ok(()) } diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 5269a899..dc790ac9 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -84,13 +84,8 @@ pub struct MetalDevice { command_buffer_index: Arc>, /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) compute_per_buffer: usize, - /// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the - /// execution order to be linear. - /// It could be relaxed in some circumstances, by managing ourselves the dependencies in the - /// compute graph. - fence: metal::Fence, /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. - /// Heavily used by [`candle_metal_kernels`], both fences need to match + /// Heavily used by [`candle_metal_kernels`] kernels: Arc, /// Simple allocator struct. /// The buffers are stored in size buckets since ML tends to use similar shapes over and over. @@ -221,10 +216,8 @@ impl MetalDevice { let command_buffer = self.command_buffer()?; command_buffer.set_label("with_data"); let blit = command_buffer.new_blit_command_encoder(); - blit.wait_for_fence(&self.fence); blit.set_label("with_data_blit"); blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); - blit.update_fence(&self.fence); blit.end_encoding(); // This is necessary, for mmaped safetensors @@ -238,6 +231,27 @@ impl MetalDevice { Ok(real) } + pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result> { + let buffer = self.allocate_buffer( + size_in_bytes as NSUInteger, + MTLResourceOptions::StorageModePrivate, + "allocate_zeros", + )?; + let command_buffer = self.command_buffer()?; + command_buffer.set_label("zeros"); + let blit = command_buffer.new_blit_command_encoder(); + blit.fill_buffer( + &buffer, + metal::NSRange { + location: 0, + length: buffer.length(), + }, + 0, + ); + blit.end_encoding(); + Ok(buffer) + } + /// The critical allocator algorithm fn allocate_buffer( &self, @@ -308,35 +322,14 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result { - let length = self.buffer.length() as usize; - let size = self.dtype.size_in_bytes(); - if length % size != 0 { - crate::bail!( - "The Metal buffer length is not aligned with dtype {:?}", - self.dtype - ); - } - let buffer = self.device.new_buffer_managed(self.buffer.length())?; - { - let command_buffer = self.device.command_buffer()?; - command_buffer.set_label("to_cpu"); - let blit = command_buffer.new_blit_command_encoder(); - blit.set_label("blit_to_cpu"); - blit.wait_for_fence(&self.device.fence); - blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); - blit.update_fence(&self.device.fence); - blit.end_encoding(); - } - self.device.wait_until_completed()?; - match self.dtype { - DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))), - DType::U32 => Ok(CpuStorage::U32(read_to_vec(&buffer, length / size))), - DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))), - DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))), - DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))), - DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))), - DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))), + DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)), + DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)), + DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)), + DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)), + DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)), + DType::F32 => Ok(CpuStorage::F32(self.to_cpu()?)), + DType::F64 => Ok(CpuStorage::F64(self.to_cpu()?)), } } @@ -1264,7 +1257,7 @@ impl BackendStorage for MetalStorage { let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger; let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; - blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length); + blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length); blit.end_encoding(); } else { let src_shape = src_l.shape(); @@ -1521,6 +1514,28 @@ impl MetalStorage { command_buffer.set_label("binary"); Ok(Self::new(buffer, device.clone(), dtype)) } + + pub(crate) fn to_cpu(&self) -> Result> { + let length = self.buffer.length() as usize; + let size = self.dtype.size_in_bytes(); + if length % size != 0 { + crate::bail!( + "The Metal buffer length is not aligned with dtype {:?}", + self.dtype + ); + } + let buffer = self.device.new_buffer_managed(self.buffer.length())?; + { + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("to_cpu"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("blit_to_cpu"); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + } + self.device.wait_until_completed()?; + Ok(read_to_vec(&buffer, length / size)) + } } impl BackendDevice for MetalDevice { @@ -1533,16 +1548,14 @@ impl BackendDevice for MetalDevice { command_buffer.enqueue(); let command_buffer = Arc::new(RwLock::new(command_buffer)); let command_buffer_index = Arc::new(RwLock::new(0)); - let fence = device.new_fence(); - let kernels = Arc::new(Kernels::new(fence.clone())); + let kernels = Arc::new(Kernels::new()); let buffers = Arc::new(RwLock::new(HashMap::new())); let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { Ok(val) => val.parse()?, - _ => 20, + _ => 10, }; Ok(Self { device, - fence, command_queue, command_buffer, command_buffer_index, @@ -1567,21 +1580,8 @@ impl BackendDevice for MetalDevice { } fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { - let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?; - let command_buffer = self.command_buffer()?; - command_buffer.set_label("zeros"); - let blit = command_buffer.new_blit_command_encoder(); - blit.wait_for_fence(&self.fence); - blit.fill_buffer( - &buffer, - metal::NSRange { - location: 0, - length: buffer.length(), - }, - 0, - ); - blit.update_fence(&self.fence); - blit.end_encoding(); + let size = shape.elem_count() * dtype.size_in_bytes(); + let buffer = self.allocate_zeros(size)?; Ok(MetalStorage::new(buffer, self.clone(), dtype)) } diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index 1dd3d9c0..38238580 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -1,7 +1,9 @@ //! Support for the GGML file format. -use super::{k_quants, GgmlDType}; -use crate::Result; +#[cfg(feature = "metal")] +use super::metal::load_quantized_metal; +use super::{k_quants, GgmlDType, QStorage}; +use crate::{Device, Result}; use byteorder::{LittleEndian, ReadBytesExt}; use std::collections::HashMap; @@ -121,11 +123,22 @@ fn from_raw_data( raw_data: &[u8], size_in_bytes: usize, dims: Vec, + device: &Device, ) -> Result { let raw_data_ptr = raw_data.as_ptr(); let n_blocks = size_in_bytes / std::mem::size_of::(); let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; - super::QTensor::new(data.to_vec(), dims) + let data: QStorage = match device { + Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())), + #[cfg(feature = "metal")] + Device::Metal(metal) => load_quantized_metal(metal, data)?, + #[cfg(not(feature = "metal"))] + Device::Metal(_metal) => { + crate::bail!("Metal backend requires `metal` feature") + } + device => unimplemented!("Implement quantized tensor for device {device:?}"), + }; + super::QTensor::new(data, dims) } /// Creates a [Tensor] from a raw GGML tensor. @@ -133,29 +146,50 @@ pub fn qtensor_from_ggml( ggml_dtype: GgmlDType, raw_data: &[u8], dims: Vec, + device: &Device, ) -> Result { let tensor_elems = dims.iter().product::(); - let blck_size = ggml_dtype.blck_size(); - if tensor_elems % blck_size != 0 { + let block_size = ggml_dtype.block_size(); + if tensor_elems % block_size != 0 { crate::bail!( - "the number of elements {tensor_elems} is not divisible by the block size {blck_size}" + "the number of elements {tensor_elems} is not divisible by the block size {block_size}" ) } - let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size(); + let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size(); match ggml_dtype { - GgmlDType::F32 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::F16 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q4_0 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q4_1 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q5_0 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q5_1 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q8_0 => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q2K => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q3K => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q4K => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q5K => from_raw_data::(raw_data, size_in_bytes, dims), - GgmlDType::Q6K => from_raw_data::(raw_data, size_in_bytes, dims), + GgmlDType::F32 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::F16 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::Q4_0 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q4_1 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q5_0 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q5_1 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q8_0 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q2K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q3K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q4K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q5K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q6K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), } } @@ -163,6 +197,7 @@ pub fn qtensor_from_ggml( fn read_one_tensor( reader: &mut R, magic: VersionedMagic, + device: &Device, ) -> Result<(String, super::QTensor)> { let n_dims = reader.read_u32::()?; let name_len = reader.read_u32::()?; @@ -183,11 +218,11 @@ fn read_one_tensor( } let dims = dims.iter().map(|&u| u as usize).collect::>(); let tensor_elems = dims.iter().product::(); - let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size(); + let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size(); // TODO: Mmap version to avoid copying the data around? let mut raw_data = vec![0u8; size_in_bytes]; reader.read_exact(&mut raw_data)?; - match qtensor_from_ggml(ggml_dtype, &raw_data, dims) { + match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) { Ok(tensor) => Ok((name, tensor)), Err(e) => crate::bail!("Error creating tensor {name}: {e}"), } @@ -201,7 +236,10 @@ pub struct Content { } impl Content { - pub fn read(reader: &mut R) -> Result { + pub fn read( + reader: &mut R, + device: &Device, + ) -> Result { // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505 let last_position = reader.seek(std::io::SeekFrom::End(0))?; reader.seek(std::io::SeekFrom::Start(0))?; @@ -211,7 +249,7 @@ impl Content { let mut tensors = HashMap::new(); while reader.stream_position()? != last_position { - let (name, tensor) = read_one_tensor(reader, magic)?; + let (name, tensor) = read_one_tensor(reader, magic, device)?; tensors.insert(name, tensor); } Ok(Self { diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index 587ffc0f..b729d4a0 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -3,7 +3,7 @@ //! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md use super::{GgmlDType, QTensor}; -use crate::Result; +use crate::{Device, Result}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::collections::HashMap; @@ -59,19 +59,25 @@ impl TensorInfo { &self, reader: &mut R, tensor_data_offset: u64, + device: &Device, ) -> Result { let tensor_elems = self.shape.elem_count(); - let blck_size = self.ggml_dtype.blck_size(); - if tensor_elems % blck_size != 0 { + let block_size = self.ggml_dtype.block_size(); + if tensor_elems % block_size != 0 { crate::bail!( - "the number of elements {tensor_elems} is not divisible by the block size {blck_size}" + "the number of elements {tensor_elems} is not divisible by the block size {block_size}" ) } - let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size(); + let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size(); let mut raw_data = vec![0u8; size_in_bytes]; reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?; reader.read_exact(&mut raw_data)?; - super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec()) + super::ggml_file::qtensor_from_ggml( + self.ggml_dtype, + &raw_data, + self.shape.dims().to_vec(), + device, + ) } } @@ -460,12 +466,13 @@ impl Content { &self, reader: &mut R, name: &str, + device: &Device, ) -> Result { let tensor_info = match self.tensor_infos.get(name) { Some(tensor_info) => tensor_info, None => crate::bail!("cannot find tensor info for {name}"), }; - tensor_info.read(reader, self.tensor_data_offset) + tensor_info.read(reader, self.tensor_data_offset, device) } } @@ -517,10 +524,9 @@ pub fn write( "internal error, unexpected current position {tensor_start_pos} {offset} {pos}" ) } - let data_ptr = tensor.as_ptr(); - let size_in_bytes = tensor.storage_size_in_bytes(); - let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; - w.write_all(data)?; + let data = tensor.data()?; + let size_in_bytes = data.len(); + w.write_all(&data)?; let padding = 31 - (31 + size_in_bytes) % 32; w.write_all(&vec![0u8; padding])?; } diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs new file mode 100644 index 00000000..fe57ce14 --- /dev/null +++ b/candle-core/src/quantized/metal.rs @@ -0,0 +1,153 @@ +use super::{GgmlDType, QStorage}; +use crate::{DType, MetalDevice, MetalStorage, Result}; +use metal::Buffer; +use std::sync::Arc; + +pub struct QMetalStorage { + dtype: GgmlDType, + device: MetalDevice, + buffer: Arc, +} + +impl QMetalStorage { + pub fn dtype(&self) -> GgmlDType { + self.dtype + } + + pub fn buffer(&self) -> &Buffer { + &self.buffer + } + + pub fn new(buffer: Arc, device: MetalDevice, dtype: GgmlDType) -> Self { + Self { + device, + buffer, + dtype, + } + } + + pub fn dequantize(&self, elem_count: usize) -> Result { + let buffer = self.device.new_buffer_managed(self.buffer.length())?; + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("to_cpu"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("blit_to_cpu"); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + self.device.wait_until_completed()?; + let mut out = vec![0.0; elem_count]; + match self.dtype { + GgmlDType::F32 => { + let vec: Vec = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + f32::to_float(&vec, &mut out)?; + } + GgmlDType::F16 => { + let vec: Vec = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + half::f16::to_float(&vec, &mut out)?; + } + GgmlDType::Q4_0 => { + let vec: Vec = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q4_1 => { + let vec: Vec = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q5_0 => { + let vec: Vec = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q5_1 => { + let vec: Vec = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q8_0 => { + let vec: Vec = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q8_1 => { + let vec: Vec = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q2K => { + let vec: Vec = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ2K::to_float(&vec, &mut out)?; + } + GgmlDType::Q3K => { + let vec: Vec = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ3K::to_float(&vec, &mut out)?; + } + GgmlDType::Q4K => { + let vec: Vec = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ4K::to_float(&vec, &mut out)?; + } + GgmlDType::Q5K => { + let vec: Vec = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ5K::to_float(&vec, &mut out)?; + } + GgmlDType::Q6K => { + let vec: Vec = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ6K::to_float(&vec, &mut out)?; + } + GgmlDType::Q8K => { + let vec: Vec = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; + } + } + + let buffer = self.device.new_buffer_with_data(&out)?; + Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32)) + } + + pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> { + // Quantization only happens on CPU for now. + let src = src.to_cpu::()?; + let elem_count = src.len(); + let src = crate::Storage::Cpu(crate::CpuStorage::F32(src)); + let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?; + qcpu_storage.quantize(&src)?; + let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?; + self.buffer = buffer; + Ok(()) + } +} + +pub fn load_quantized_metal( + device: &MetalDevice, + data: &[T], +) -> Result { + let buffer = device.new_buffer_with_data(data)?; + let device = device.clone(); + Ok(QStorage::Metal(QMetalStorage { + dtype: T::DTYPE, + device, + buffer, + })) +} + +fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { + let ptr = buffer.contents() as *const T; + assert!(!ptr.is_null()); + let slice = unsafe { std::slice::from_raw_parts(ptr, n) }; + slice.to_vec() +} diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 043733ae..1dc5fe8f 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,23 +1,125 @@ -use crate::{Device, Result, Shape, Tensor}; +#[cfg(feature = "metal")] +use crate::{backend::BackendStorage, DType}; +use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor}; +use k_quants::*; +use std::borrow::Cow; #[cfg(target_feature = "avx")] pub mod avx; pub mod ggml_file; pub mod gguf_file; pub mod k_quants; +#[cfg(feature = "metal")] +pub mod metal; #[cfg(target_feature = "neon")] pub mod neon; #[cfg(target_feature = "simd128")] pub mod simd128; pub mod utils; +use half::f16; pub use k_quants::GgmlType; pub struct QTensor { - data: Box, + storage: QStorage, shape: Shape, } +impl Device { + fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result { + match self { + Device::Cpu => { + let storage = dtype.cpu_zeros(elem_count); + Ok(QStorage::Cpu(storage)) + } + #[cfg(feature = "metal")] + Device::Metal(metal) => { + let size = elem_count * dtype.type_size() / dtype.block_size(); + let buffer = metal.allocate_zeros(size)?; + Ok(QStorage::Metal(metal::QMetalStorage::new( + buffer, + metal.clone(), + dtype, + ))) + } + #[cfg(not(feature = "metal"))] + Device::Metal(_metal) => { + crate::bail!("Metal feature not activated"); + } + Device::Cuda(_cuda) => { + crate::bail!("Cuda ggml quantization not supported"); + } + } + } +} + +pub enum QStorage { + Cpu(Box), + #[cfg(feature = "metal")] + Metal(metal::QMetalStorage), +} + +impl QStorage { + fn block_size(&self) -> usize { + match self { + QStorage::Cpu(storage) => storage.block_size(), + #[cfg(feature = "metal")] + QStorage::Metal(storage) => storage.dtype().block_size(), + } + } + + fn dtype(&self) -> GgmlDType { + match self { + QStorage::Cpu(storage) => storage.dtype(), + #[cfg(feature = "metal")] + QStorage::Metal(storage) => storage.dtype(), + } + } + + fn size_in_bytes(&self) -> usize { + match self { + QStorage::Cpu(storage) => storage.storage_size_in_bytes(), + #[cfg(feature = "metal")] + QStorage::Metal(storage) => storage.buffer().length() as usize, + } + } + + fn quantize(&mut self, src: &Storage) -> Result<()> { + match (self, src) { + (QStorage::Cpu(storage), Storage::Cpu(src)) => { + storage.from_float(src.as_slice::()?)?; + } + #[cfg(feature = "metal")] + (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?, + _ => crate::bail!("Invalid dequantize storage locations do not match"), + } + Ok(()) + } + + fn dequantize(&self, elem_count: usize) -> Result { + match self { + QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)), + #[cfg(feature = "metal")] + QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)), + } + } + + fn data(&self) -> Result> { + match self { + QStorage::Cpu(storage) => { + let data_ptr = storage.as_ptr(); + let size_in_bytes = storage.storage_size_in_bytes(); + let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; + Ok(Cow::from(data)) + } + #[cfg(feature = "metal")] + QStorage::Metal(_storage) => { + crate::bail!("not implemented"); + } + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum GgmlDType { F32, @@ -77,6 +179,25 @@ impl GgmlDType { } } + /// The block dtype + pub fn cpu_zeros(&self, elem_count: usize) -> Box { + match self { + Self::F32 => Box::new(vec![f32::zeros(); elem_count]), + Self::F16 => Box::new(vec![f16::zeros(); elem_count]), + Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]), + Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]), + Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]), + Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]), + Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]), + Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]), + Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]), + Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]), + Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]), + Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]), + Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]), + Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]), + } + } /// The type size for blocks in bytes. pub fn type_size(&self) -> usize { use k_quants::*; @@ -100,7 +221,7 @@ impl GgmlDType { } /// The block size, i.e. the number of elements stored in each block. - pub fn blck_size(&self) -> usize { + pub fn block_size(&self) -> usize { match self { Self::F32 => 1, Self::F16 => 1, @@ -119,9 +240,13 @@ impl GgmlDType { pub trait QuantizedType: Send + Sync { fn dtype(&self) -> GgmlDType; fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>; - fn to_float(&self, ys: &mut [f32]) -> Result<()>; + fn dequantize(&self, elem_count: usize) -> Result; fn storage_size_in_bytes(&self) -> usize; fn as_ptr(&self) -> *const u8; + fn block_size(&self) -> usize; + #[allow(clippy::wrong_self_convention)] + fn from_float(&mut self, xs: &[f32]) -> Result<()>; + fn size(&self) -> usize; } impl QuantizedType for Vec { @@ -129,12 +254,26 @@ impl QuantizedType for Vec { k_quants::matmul(mkn, lhs, self.as_slice(), dst) } + fn size(&self) -> usize { + self.len() * core::mem::size_of::() + } + + fn from_float(&mut self, xs: &[f32]) -> Result<()> { + T::from_float(xs, self) + } + fn dtype(&self) -> GgmlDType { T::DTYPE } - fn to_float(&self, ys: &mut [f32]) -> Result<()> { - T::to_float(self.as_slice(), ys) + fn block_size(&self) -> usize { + T::BLCK_SIZE + } + + fn dequantize(&self, elem_count: usize) -> Result { + let mut ys = vec![0.0f32; elem_count]; + T::to_float(self.as_slice(), &mut ys)?; + Ok(CpuStorage::F32(ys)) } fn storage_size_in_bytes(&self) -> usize { @@ -152,56 +291,49 @@ impl std::fmt::Debug for QTensor { } } -fn check_shape(shape: &Shape) -> Result<()> { +fn check_shape(shape: &Shape, block_size: usize) -> Result<()> { let dims = shape.dims(); if dims.is_empty() { crate::bail!("scalar tensor cannot be quantized {shape:?}") } - if dims[dims.len() - 1] % T::BLCK_SIZE != 0 { + if dims[dims.len() - 1] % block_size != 0 { crate::bail!( "quantized tensor must have their last dim divisible by block size {shape:?} {}", - T::BLCK_SIZE + block_size ) } Ok(()) } impl QTensor { - pub fn new, T: k_quants::GgmlType + Send + Sync + 'static>( - data: Vec, - shape: S, - ) -> Result { + pub fn new>(storage: QStorage, shape: S) -> Result { let shape = shape.into(); - check_shape::(&shape)?; - Ok(Self { - data: Box::new(data), - shape, - }) + check_shape(&shape, storage.block_size())?; + Ok(Self { storage, shape }) } - pub fn quantize(src: &Tensor) -> Result { + pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result { let shape = src.shape(); - check_shape::(shape)?; - let src = src - .to_dtype(crate::DType::F32)? - .flatten_all()? - .to_vec1::()?; - if src.len() % T::BLCK_SIZE != 0 { + let block_size = dtype.block_size(); + check_shape(shape, block_size)?; + let src = src.to_dtype(crate::DType::F32)?.flatten_all()?; + let elem_count = shape.elem_count(); + if elem_count % block_size != 0 { crate::bail!( "tensor size ({shape:?}) is not divisible by block size {}", - T::BLCK_SIZE + block_size ) } - let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE]; - T::from_float(&src, &mut data)?; + let mut storage = src.device().qzeros(elem_count, dtype)?; + storage.quantize(&src.storage())?; Ok(Self { - data: Box::new(data), + storage, shape: shape.clone(), }) } pub fn dtype(&self) -> GgmlDType { - self.data.dtype() + self.storage.dtype() } pub fn rank(&self) -> usize { @@ -213,21 +345,19 @@ impl QTensor { } pub fn dequantize(&self, device: &Device) -> Result { - let mut f32_data = vec![0f32; self.shape.elem_count()]; - self.data.to_float(&mut f32_data)?; - Tensor::from_vec(f32_data, &self.shape, device) - } - - pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> { - self.data.matmul_t(mkn, lhs, dst) + let storage = self.storage.dequantize(self.shape.elem_count())?; + let none = crate::op::BackpropOp::none(); + let is_variable = false; + crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable) + .to_device(device) } pub fn storage_size_in_bytes(&self) -> usize { - self.data.storage_size_in_bytes() + self.storage.size_in_bytes() } - pub fn as_ptr(&self) -> *const u8 { - self.data.as_ptr() + pub fn data(&self) -> Result> { + self.storage.data() } } @@ -294,17 +424,93 @@ impl crate::CustomOp1 for QTensor { } dst_shape.push(n); let dst_shape = Shape::from(dst_shape); - let storage = storage.as_slice::()?; - let storage = - &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; + #[allow(clippy::infallible_destructuring_match)] + let self_storage = match &self.storage { + QStorage::Cpu(storage) => storage, + #[cfg(feature = "metal")] + _ => crate::bail!("Invalid storage"), + }; + let slice = storage.as_slice::()?; + let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; let mut dst_storage = vec![0f32; dst_shape.elem_count()]; - self.matmul_t( - (dst_shape.elem_count() / n, k, n), - storage, - &mut dst_storage, - )?; + self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?; Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + storage: &crate::MetalStorage, + layout: &crate::Layout, + ) -> Result<(crate::MetalStorage, Shape)> { + use crate::MetalError; + + if !layout.is_contiguous() { + crate::bail!("input tensor is not contiguous {layout:?}") + } + let src_shape = layout.shape(); + // self is transposed so n is first then k. + if src_shape.rank() < 2 { + crate::bail!("input tensor has only one dimension {layout:?}") + } + let (n, k) = self.shape.dims2()?; + let mut dst_shape = src_shape.dims().to_vec(); + + let (b, m) = match dst_shape.len() { + 3 => (dst_shape[0], dst_shape[1]), + 2 => (1, dst_shape[0]), + n => crate::bail!("Invalid rank {n} for quantized matmul metal"), + }; + let last_k = dst_shape.pop().unwrap(); + if last_k != k { + crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape) + } + dst_shape.push(n); + let dst_shape = Shape::from(dst_shape); + let device = storage.device().clone(); + let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?; + let (buffer, dtype) = match &self.storage { + QStorage::Metal(metal) => (metal.buffer(), metal.dtype()), + _ => unreachable!("Cannot call metal matmul on non metal QTensor"), + }; + let command_buffer = device.command_buffer()?; + candle_metal_kernels::call_quantized_matmul_t( + device.device(), + &command_buffer, + device.kernels(), + dtype.into(), + (b, m, n, k), + storage.buffer(), + layout.start_offset() * storage.dtype().size_in_bytes(), + buffer, + &dst, + ) + .map_err(MetalError::from)?; + let dst_storage = crate::MetalStorage::new(dst, device, DType::F32); + Ok((dst_storage, dst_shape)) + } +} + +#[cfg(feature = "metal")] +impl From for candle_metal_kernels::GgmlDType { + fn from(value: GgmlDType) -> Self { + match value { + GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0, + GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1, + GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0, + GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1, + GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0, + GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1, + GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K, + GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K, + GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K, + GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K, + GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K, + GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, + GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, + GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, + } + } } impl crate::Module for QMatMul { diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index d31e77a7..a7811ca5 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -1,6 +1,7 @@ use candle_core::{ bail, quantized::{self, GgmlDType}, + test_device, test_utils::to_vec2_round, Device, Module, Result, Tensor, }; @@ -14,16 +15,48 @@ const GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS: f32 = 0.0075; const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040; const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02; -#[test] -fn quantized_matmul() -> Result<()> { - let cpu = &Device::Cpu; +fn test_matmul( + device: &Device, + (b, m, n, k): (usize, usize, usize, usize), + dtype: GgmlDType, +) -> Result<()> { + let lhs = (0..(m * k)) + .map(|v| v as f32 / (m * k) as f32) + .collect::>(); + let rhs = (0..(k * n)) + .map(|v| v as f32 / (n * k) as f32) + .collect::>(); + + let lhs = Tensor::from_slice(&lhs, (m, k), device)?; + let rhs = Tensor::from_slice(&rhs, (k, n), device)?; + let mm = lhs.matmul(&rhs)?; + let qtensor = quantized::QTensor::quantize(&rhs.t()?, dtype)?; + let matmul = quantized::QMatMul::from_qtensor(qtensor)?; + let res = matmul.forward(&lhs)?; + + let error: f32 = ((&mm - &res)?.abs()? / &mm.abs()?)? + .sum_all()? + .to_scalar()?; + let error = error / (b * m * n) as f32; + assert!( + error <= 0.02, + "Error {error} is too big. \nExpected:\n {mm} \nFound:\n {res}\n for {dtype:?}" + ); + + Ok(()) +} + +fn quantized_matmul(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } let (m, k, n) = (3, 64, 4); let lhs = (0..(m * k)).map(|v| v as f32).collect::>(); - let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?; + let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?; let mut dst = vec![42.; 3 * 4]; let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let rhs = (0..(k * n)).map(|v| v as f32).collect::>(); - let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?; k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; assert_eq!( @@ -33,6 +66,7 @@ fn quantized_matmul() -> Result<()> { 341876.0, 994283.0, 1655709.0, 2301518.0 ] ); + let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?; let mm = tensor_lhs.matmul(&tensor_rhs)?; assert_eq!( mm.to_vec2::()?, @@ -43,35 +77,49 @@ fn quantized_matmul() -> Result<()> { ] ); - let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?; + let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let res = matmul.forward(&tensor_lhs)?; - assert_eq!( - to_vec2_round(&res, 0)?, - &[ - [85120.0, 214562.0, 345455.0, 474748.0], - [213475.0, 604465.0, 1000686.0, 1388317.0], - [341876.0, 994283.0, 1655709.0, 2301518.0] - ] - ); + match device { + Device::Metal(_) => assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [84946.0, 214126.0, 344757.0, 473798.0], + [213458.0, 604350.0, 1000469.0, 1387990.0], + [341970.0, 994574.0, 1656181.0, 2302182.0] + ] + ), + _ => assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [85120.0, 214562.0, 345455.0, 474748.0], + [213475.0, 604465.0, 1000686.0, 1388317.0], + [341876.0, 994283.0, 1655709.0, 2301518.0] + ] + ), + } + + test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?; Ok(()) } -#[test] -fn quantized_matmul_neg() -> Result<()> { - let cpu = &Device::Cpu; +fn quantized_matmul_neg(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } let (m, k, n) = (3, 64, 4); let lhs = (0..(m * k)) .map(|v| v as f32 - (m * k) as f32 / 2.0) .collect::>(); - let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?; + let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?; let mut dst = vec![42.; 3 * 4]; let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let rhs = (0..k * n) .map(|v| v as f32 - (k * n) as f32 / 3.0) .collect::>(); - let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?; + let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?; k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; assert_eq!( @@ -91,32 +139,56 @@ fn quantized_matmul_neg() -> Result<()> { ] ); - let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?; + let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let res = matmul.forward(&tensor_lhs)?; - assert_eq!( - to_vec2_round(&res, 0)?, - &[ - [243524.0, -19596.0, -285051.0, -549815.0], - [23777.0, 21651.0, 19398.0, 18367.0], - [-196472.0, 63012.0, 324585.0, 587902.0] - ] - ); + match device { + Device::Metal(_) => assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [243666.0, -19714.0, -285433.0, -550453.0], + [23782.0, 21654.0, 19400.0, 18369.0], + [-196102.0, 63022.0, 324233.0, 587191.0] + ] + ), + _ => assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [243524.0, -19596.0, -285051.0, -549815.0], + [23777.0, 21651.0, 19398.0, 18367.0], + [-196472.0, 63012.0, 324585.0, 587902.0] + ] + ), + } Ok(()) } -#[test] -fn quantize_q4_0() -> Result<()> { - use k_quants::BlockQ4_0; +test_device!( + quantized_matmul, + quantized_matmul_cpu, + quantized_matmul_cuda, + quantized_matmul_metal +); +test_device!( + quantized_matmul_neg, + quantized_matmul_neg_cpu, + quantized_matmul_neg_cuda, + quantized_matmul_neg_metal +); +fn quantize_q4_0(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } let src = (0..32 * 4).map(|v| v as f32).collect::>(); - let mut dst = vec![0f32; 32 * 4]; - let mut quant = vec![BlockQ4_0::zeros(); 4]; - BlockQ4_0::from_float(&src, &mut quant)?; - BlockQ4_0::to_float(&quant, dst.as_mut_slice())?; + + let src = Tensor::from_slice(&src, (32 * 4,), device)?; + let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?; + let dst = quant.dequantize(device)?; assert_eq!( - dst, + dst.to_vec1::()?, &[ -0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625, 11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25, @@ -132,21 +204,21 @@ fn quantize_q4_0() -> Result<()> { 127.0, 127.0 ] ); - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + ggml_quantization_error_test(GgmlDType::Q4_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; Ok(()) } -#[test] -fn quantize_q4_1() -> Result<()> { - use k_quants::BlockQ4_1; - +fn quantize_q4_1(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } let src = (0..32 * 4).map(|v| v as f32).collect::>(); - let mut dst = vec![0f32; 32 * 4]; - let mut quant = vec![BlockQ4_1::zeros(); 4]; - BlockQ4_1::from_float(&src, &mut quant)?; - BlockQ4_1::to_float(&quant, dst.as_mut_slice())?; + let src = Tensor::from_slice(&src, (32 * 4,), device)?; + let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?; + let dst = quant.dequantize(device)?; assert_eq!( - round_vector(&dst), + round_vector(&dst.to_vec1::()?), &[ 0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332, 12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73, @@ -162,21 +234,21 @@ fn quantize_q4_1() -> Result<()> { 118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996 ] ); - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + ggml_quantization_error_test(GgmlDType::Q4_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; Ok(()) } -#[test] -fn quantize_q5_0() -> Result<()> { - use k_quants::BlockQ5_0; - +fn quantize_q5_0(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } let src = (0..32 * 4).map(|v| v as f32).collect::>(); - let mut dst = vec![0f32; 32 * 4]; - let mut quant = vec![BlockQ5_0::zeros(); 4]; - BlockQ5_0::from_float(&src, &mut quant)?; - BlockQ5_0::to_float(&quant, dst.as_mut_slice())?; + let src = Tensor::from_slice(&src, (32 * 4,), device)?; + let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?; + let dst = quant.dequantize(device)?; assert_eq!( - round_vector(&dst), + round_vector(&dst.to_vec1::()?), &[ -0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625, 11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313, @@ -192,21 +264,21 @@ fn quantize_q5_0() -> Result<()> { 119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0 ] ); - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + ggml_quantization_error_test(GgmlDType::Q5_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; Ok(()) } -#[test] -fn quantize_q5_1() -> Result<()> { - use k_quants::BlockQ5_1; - +fn quantize_q5_1(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } let src = (0..32 * 4).map(|v| v as f32).collect::>(); - let mut dst = vec![0f32; 32 * 4]; - let mut quant = vec![BlockQ5_1::zeros(); 4]; - BlockQ5_1::from_float(&src, &mut quant)?; - BlockQ5_1::to_float(&quant, dst.as_mut_slice())?; + let src = Tensor::from_slice(&src, (32 * 4,), device)?; + let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?; + let dst = quant.dequantize(device)?; assert_eq!( - dst, + round_vector(&dst.to_vec1::()?), &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, @@ -220,13 +292,11 @@ fn quantize_q5_1() -> Result<()> { 124.0, 125.0, 126.0, 127.0 ] ); - - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + ggml_quantization_error_test(GgmlDType::Q5_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; Ok(()) } -/// Generates a small test vector ranging from -`bound` to `bound` with `size` steps -fn get_test_vector(bound: f32, size: usize) -> (Vec, Vec) { +fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result { assert!( size % crate::quantized::k_quants::QK_K == 0, "size must be a multiple of {}", @@ -236,10 +306,8 @@ fn get_test_vector(bound: f32, size: usize) -> (Vec, Vec) { let src = (0..size) .map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.)) .collect::>(); - - let dst = vec![0f32; size]; assert_eq!([src[0], src[size / 2]], [-bound, 0.0]); - (src, dst) + Tensor::from_vec(src, (size,), device) } /// Round a vector @@ -288,11 +356,12 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 { /// Similar to the GGML quantization unit test: /// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50 -fn ggml_quantization_error_test(max_error: f32) -> Result<()> { +fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f32) -> Result<()> { let src = create_ggml_like_vector(0.0); - let mut dst = vec![0.0; GGML_TEST_SIZE]; - let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; - let error = calculate_rmse(src.as_slice(), dst.as_slice()); + let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let error = calculate_rmse(&src.to_vec1::()?, &dst.to_vec1::()?); if error > max_error { bail!( "Quantization error {} exceeds max error {}", @@ -303,19 +372,19 @@ fn ggml_quantization_error_test(max_error: f32) -> Result<()> { Ok(()) } -fn quantize_roundtrip(src: &[f32], dst: &mut [f32]) -> Result> { - let mut quant = vec![T::zeros(); src.len() / T::BLCK_SIZE]; - T::from_float(src, &mut quant)?; - T::to_float(&quant, dst)?; - Ok(quant) -} +fn quantize_q2k(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } + let dtype = GgmlDType::Q2K; -#[test] -fn quantize_q2k() -> Result<()> { - use k_quants::BlockQ2K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; - let (src, mut dst) = get_test_vector(0.5, 1024); - let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; compare_with_error(dst.as_slice(), src.as_slice(), 0.1); // Test some specific values @@ -329,20 +398,30 @@ fn quantize_q2k() -> Result<()> { [-0.499, -0.366, -0.249, 0.0, 0.295, 0.492] ); - let (src_big, mut dst_big) = get_test_vector(128.0, 1024); - let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0); - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?; + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?; Ok(()) } -#[test] -fn quantize_q3k() -> Result<()> { - use k_quants::BlockQ3K; +fn quantize_q3k(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } + let dtype = GgmlDType::Q3K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; - let (src, mut dst) = get_test_vector(0.5, 1024); - let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; compare_with_error(dst.as_slice(), src.as_slice(), 0.03); // Test some specific values @@ -356,20 +435,30 @@ fn quantize_q3k() -> Result<()> { [-0.493, -0.37, -0.243, -0.0, 0.292, 0.492] ); - let (src_big, mut dst_big) = get_test_vector(128.0, 1024); - let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5); - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?; + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?; Ok(()) } -#[test] -fn quantize_q4k() -> Result<()> { - use k_quants::BlockQ4K; +fn quantize_q4k(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } + let dtype = GgmlDType::Q4K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; - let (src, mut dst) = get_test_vector(0.5, 1024); - let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; compare_with_error(dst.as_slice(), src.as_slice(), 0.017); // Test some specific values @@ -383,21 +472,31 @@ fn quantize_q4k() -> Result<()> { [-0.5, -0.373, -0.25, 0.0, 0.288, 0.498] ); - let (src_big, mut dst_big) = get_test_vector(128.0, 1024); - let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5); - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; Ok(()) } -#[test] -fn quantize_q5k() -> Result<()> { - use k_quants::BlockQ5K; +fn quantize_q5k(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } + let dtype = GgmlDType::Q5K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; - let (src, mut dst) = get_test_vector(0.5, 1024); - let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; - compare_with_error(dst.as_slice(), src.as_slice(), 0.008); + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.009); // Test some specific values assert_eq!( @@ -410,21 +509,30 @@ fn quantize_q5k() -> Result<()> { [-0.5, -0.373, -0.25, 0.0, 0.279, 0.499] ); - let (src_big, mut dst_big) = get_test_vector(128.0, 1024); - let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5); - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; - + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; Ok(()) } -#[test] -fn quantize_q6k() -> Result<()> { - use k_quants::BlockQ6K; +fn quantize_q6k(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } + let dtype = GgmlDType::Q6K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; - let (src, mut dst) = get_test_vector(0.5, 1024); - let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; compare_with_error(dst.as_slice(), src.as_slice(), 0.008); // Test some specific values @@ -438,22 +546,31 @@ fn quantize_q6k() -> Result<()> { [-0.497, -0.372, -0.25, -0.0, 0.284, 0.5] ); - let (src_big, mut dst_big) = get_test_vector(128.0, 1024); - let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0); - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; - + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; Ok(()) } -#[test] -fn quantize_q8k() -> Result<()> { - use k_quants::BlockQ8K; +fn quantize_q8k(device: &Device) -> Result<()> { + // TODO Enable this later when we enable cuda. + if device.is_cuda() { + return Ok(()); + } + let dtype = GgmlDType::Q8K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; - let (src, mut dst) = get_test_vector(0.5, 1024); - let _quant = quantize_roundtrip::(src.as_slice(), dst.as_mut_slice())?; - compare_with_error(dst.as_slice(), src.as_slice(), 0.003); + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.008); // Test some specific values assert_eq!( @@ -466,15 +583,79 @@ fn quantize_q8k() -> Result<()> { [-0.5, -0.375, -0.25, -0.0, 0.281, 0.499] ); - let (src_big, mut dst_big) = get_test_vector(128.0, 1024); - let _quant_big = quantize_roundtrip::(src_big.as_slice(), dst_big.as_mut_slice())?; + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6); - ggml_quantization_error_test::(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; - + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; Ok(()) } +test_device!( + quantize_q4_0, + quantize_q4_0_cpu, + quantize_q4_0_cuda, + quantize_q4_0_metal +); +test_device!( + quantize_q4_1, + quantize_q4_1_cpu, + quantize_q4_1_cuda, + quantize_q4_1_metal +); +test_device!( + quantize_q5_0, + quantize_q5_0_cpu, + quantize_q5_0_cuda, + quantize_q5_0_metal +); +test_device!( + quantize_q5_1, + quantize_q5_1_cpu, + quantize_q5_1_cuda, + quantize_q5_1_metal +); +test_device!( + quantize_q2k, + quantize_q2k_cpu, + quantize_q2k_cuda, + quantize_q2k_metal +); +test_device!( + quantize_q3k, + quantize_q3k_cpu, + quantize_q3k_cuda, + quantize_q3k_metal +); +test_device!( + quantize_q4k, + quantize_q4k_cpu, + quantize_q4k_cuda, + quantize_q4k_metal +); +test_device!( + quantize_q5k, + quantize_q5k_cpu, + quantize_q5k_cuda, + quantize_q5k_metal +); +test_device!( + quantize_q6k, + quantize_q6k_cpu, + quantize_q6k_cuda, + quantize_q6k_metal +); +test_device!( + quantize_q8k, + quantize_q8k_cpu, + quantize_q8k_cuda, + quantize_q8k_metal +); + /// Very simple dot product implementation fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 { a.iter().zip(b).map(|(a, b)| a * b).sum() @@ -591,6 +772,112 @@ fn get_random_tensors( Ok((lhs, rhs, mm)) } +#[macro_export] +macro_rules! quantized_matmul { + // TODO: Switch to generating the two last arguments automatically once concat_idents is + // stable. https://github.com/rust-lang/rust/issues/29599 + ($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => { + fn $fn_name(device: &Device) -> Result<()> { + if device.is_cuda() { + // TODO Enable Cuda GGML sometime maybe. + return Ok(()); + } + test_matmul(device, (1, 3, 4, 256), $dtype)?; + Ok(()) + } + + test_device!($fn_name, $fn_name_cpu, $fn_name_cuda, $fn_name_metal); + }; +} + +quantized_matmul!( + quantized_matmul_q4_0_bis, + quantized_matmul_q4_0_cpu, + quantized_matmul_q4_0_cuda, + quantized_matmul_q4_0_metal, + GgmlDType::Q4_0 +); +quantized_matmul!( + quantized_matmul_q4_1_bis, + quantized_matmul_q4_1_cpu, + quantized_matmul_q4_1_cuda, + quantized_matmul_q4_1_metal, + GgmlDType::Q4_1 +); +quantized_matmul!( + quantized_matmul_q5_0_bis, + quantized_matmul_q5_0_cpu, + quantized_matmul_q5_0_cuda, + quantized_matmul_q5_0_metal, + GgmlDType::Q5_0 +); +quantized_matmul!( + quantized_matmul_q5_1_bis, + quantized_matmul_q5_1_cpu, + quantized_matmul_q5_1_cuda, + quantized_matmul_q5_1_metal, + GgmlDType::Q5_1 +); +quantized_matmul!( + quantized_matmul_q8_0_bis, + quantized_matmul_q8_0_cpu, + quantized_matmul_q8_0_cuda, + quantized_matmul_q8_0_metal, + GgmlDType::Q8_0 +); +// Not implemented in Ggml +// quantized_matmul!( +// quantized_matmul_q8_1_bis, +// quantized_matmul_q8_1_cpu, +// quantized_matmul_q8_1_cuda, +// quantized_matmul_q8_1_metal, +// GgmlDType::Q8_1 +// ); +// TODO This is bugged (also bugged in GGML +quantized_matmul!( + quantized_matmul_q2k_bis, + quantized_matmul_q2k_cpu, + quantized_matmul_q2k_cuda, + quantized_matmul_q2k_metal, + GgmlDType::Q2K +); +quantized_matmul!( + quantized_matmul_q3k_bis, + quantized_matmul_q3k_cpu, + quantized_matmul_q3k_cuda, + quantized_matmul_q3k_metal, + GgmlDType::Q3K +); +quantized_matmul!( + quantized_matmul_q4k_bis, + quantized_matmul_q4k_cpu, + quantized_matmul_q4k_cuda, + quantized_matmul_q4k_metal, + GgmlDType::Q4K +); +quantized_matmul!( + quantized_matmul_q5k_bis, + quantized_matmul_q5k_cpu, + quantized_matmul_q5k_cuda, + quantized_matmul_q5k_metal, + GgmlDType::Q5K +); +quantized_matmul!( + quantized_matmul_q6k_bis, + quantized_matmul_q6k_cpu, + quantized_matmul_q6k_cuda, + quantized_matmul_q6k_metal, + GgmlDType::Q6K +); +// Not implemented on metal +// quantized_matmul!( +// quantized_matmul_q8k_bis, +// quantized_matmul_q8k_cpu, +// quantized_matmul_q8k_cuda, +// quantized_matmul_q8k_metal, +// GgmlDType::Q8K +// ); + #[test] fn quantized_matmul_q2k() -> Result<()> { use k_quants::BlockQ2K; @@ -603,7 +890,7 @@ fn quantized_matmul_q2k() -> Result<()> { let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); - let rhs = quantized::QTensor::quantize::(&rhs)?; + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; @@ -629,7 +916,7 @@ fn quantized_matmul_q3k() -> Result<()> { let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); - let rhs = quantized::QTensor::quantize::(&rhs)?; + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q3K)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; @@ -655,7 +942,7 @@ fn quantized_matmul_q4k() -> Result<()> { let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); - let rhs = quantized::QTensor::quantize::(&rhs)?; + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q4K)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; @@ -681,7 +968,7 @@ fn quantized_matmul_q5k() -> Result<()> { let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); - let rhs = quantized::QTensor::quantize::(&rhs)?; + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q5K)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; @@ -708,7 +995,7 @@ fn quantized_matmul_q6k() -> Result<()> { let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); - let rhs = quantized::QTensor::quantize::(&rhs)?; + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q6K)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; @@ -733,7 +1020,7 @@ fn quantized_matmul_q8k() -> Result<()> { let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); - let rhs = quantized::QTensor::quantize::(&rhs)?; + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q8K)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; diff --git a/candle-examples/examples/blip/main.rs b/candle-examples/examples/blip/main.rs index a1051a8e..15e36476 100644 --- a/candle-examples/examples/blip/main.rs +++ b/candle-examples/examples/blip/main.rs @@ -106,17 +106,17 @@ pub fn main() -> anyhow::Result<()> { let config = blip::Config::image_captioning_large(); + let device = candle_examples::device(args.cpu)?; let (image_embeds, device, mut model) = if args.quantized { let device = Device::Cpu; let image = load_image(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); - let vb = quantized_blip::VarBuilder::from_gguf(model_file)?; + let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?; let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?; let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?; (image_embeds, device, Model::Q(model)) } else { - let device = candle_examples::device(args.cpu)?; let image = load_image(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 0ceb27af..9d42dcc8 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .extension() .map_or(false, |v| v == "safetensors"); let (model, config) = if is_gguf { - let vb = qmodel::VarBuilder::from_gguf(config_path)?; + let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?; let (_vocab_size, dim) = vb .get_no_shape("model.embed_tokens.weight")? .shape() @@ -279,13 +279,13 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { (config.seq_len, config.head_size() / 2), "rot.freq_cis_real", )? - .dequantize(&candle::Device::Cpu)?; + .dequantize(&device)?; let freq_cis_imag = vb .get( (config.seq_len, config.head_size() / 2), "rot.freq_cis_imag", )? - .dequantize(&candle::Device::Cpu)?; + .dequantize(&device)?; let fake_vb = candle_nn::VarBuilder::from_tensors( [ @@ -295,7 +295,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .into_iter() .collect(), candle::DType::F32, - &candle::Device::Cpu, + &device, ); let cache = model::Cache::new(true, &config, fake_vb)?; let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?); diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index 5ed5e5cb..bad86098 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -244,13 +244,14 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let config = Config::config_7b_v0_1(args.use_flash_attn); + let device = candle_examples::device(args.cpu)?; let (model, device) = if args.quantized { let filename = &filenames[0]; - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?; + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; let model = QMistral::new(&config, vb)?; - (Model::Quantized(model), Device::Cpu) + (Model::Quantized(model), device) } else { - let device = candle_examples::device(args.cpu)?; let dtype = if device.is_cuda() { DType::BF16 } else { diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 69eed84f..39f4fd58 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -307,18 +307,21 @@ fn main() -> Result<()> { WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(), }; - let (model, device) = if args.quantized { - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?; + let device = candle_examples::device(args.cpu)?; + let model = if args.quantized { let config = config(); + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + &filenames[0], + &device, + )?; let model = match args.model { WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?, _ => QMixFormer::new(&config, vb)?, }; - (Model::Quantized(model), Device::Cpu) + Model::Quantized(model) } else { - let device = candle_examples::device(args.cpu)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; - let model = match args.model { + match args.model { WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => { let config_filename = repo.get("config.json")?; let config = std::fs::read_to_string(config_filename)?; @@ -334,8 +337,7 @@ fn main() -> Result<()> { let config = config(); Model::MixFormer(MixFormer::new(&config, vb)?) } - }; - (model, device) + } }; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-examples/examples/quantized-t5/main.rs b/candle-examples/examples/quantized-t5/main.rs index 0ea2e0bd..ed3f1030 100644 --- a/candle-examples/examples/quantized-t5/main.rs +++ b/candle-examples/examples/quantized-t5/main.rs @@ -132,7 +132,8 @@ impl T5ModelBuilder { } pub fn build_model(&self) -> Result { - let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?; + let device = Device::Cpu; + let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &device)?; Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) } diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index bfc6de53..34c44233 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -9,7 +9,7 @@ use std::io::Write; use tokenizers::Tokenizer; use candle::quantized::{ggml_file, gguf_file}; -use candle::{Device, Tensor}; +use candle::Tensor; use candle_transformers::generation::LogitsProcessor; use candle_examples::token_output_stream::TokenOutputStream; @@ -361,6 +361,7 @@ fn main() -> anyhow::Result<()> { let model_path = args.model()?; let mut file = std::fs::File::open(&model_path)?; let start = std::time::Instant::now(); + let device = candle_examples::device(false)?; let mut model = match model_path.extension().and_then(|v| v.to_str()) { Some("gguf") => { @@ -369,7 +370,7 @@ fn main() -> anyhow::Result<()> { for (_, tensor) in model.tensor_infos.iter() { let elem_count = tensor.shape.elem_count(); total_size_in_bytes += - elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size(); + elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); } println!( "loaded {:?} tensors ({}) in {:.2}s", @@ -377,15 +378,16 @@ fn main() -> anyhow::Result<()> { &format_size(total_size_in_bytes), start.elapsed().as_secs_f32(), ); - ModelWeights::from_gguf(model, &mut file)? + ModelWeights::from_gguf(model, &mut file, &device)? } Some("ggml" | "bin") | Some(_) | None => { - let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?; + let model = ggml_file::Content::read(&mut file, &device) + .map_err(|e| e.with_path(model_path))?; let mut total_size_in_bytes = 0; for (_, tensor) in model.tensors.iter() { let elem_count = tensor.shape().elem_count(); total_size_in_bytes += - elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size(); + elem_count * tensor.dtype().type_size() / tensor.dtype().block_size(); } println!( "loaded {:?} tensors ({}) in {:.2}s", @@ -486,7 +488,7 @@ fn main() -> anyhow::Result<()> { let start_prompt_processing = std::time::Instant::now(); let mut next_token = { - let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?; + let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; let logits = model.forward(&input, 0)?; let logits = logits.squeeze(0)?; logits_processor.sample(&logits)? @@ -507,7 +509,7 @@ fn main() -> anyhow::Result<()> { let start_post_prompt = std::time::Instant::now(); let mut sampled = 0; for index in 0..to_sample { - let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?; + let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; let logits = model.forward(&input, prompt_tokens.len() + index)?; let logits = logits.squeeze(0)?; let logits = if args.repeat_penalty == 1. { diff --git a/candle-examples/examples/replit-code/main.rs b/candle-examples/examples/replit-code/main.rs index 0f72b862..b7f767b9 100644 --- a/candle-examples/examples/replit-code/main.rs +++ b/candle-examples/examples/replit-code/main.rs @@ -236,16 +236,15 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; let config = Config::replit_code_v1_5_3b(); - let (model, device) = if args.quantized { - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?; - let model = Model::Q(Q::new(&config, vb.pp("transformer"))?); - (model, Device::Cpu) + let model = if args.quantized { + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?; + Model::Q(Q::new(&config, vb.pp("transformer"))?) } else { - let device = candle_examples::device(args.cpu)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; - let model = Model::M(M::new(&config, vb.pp("transformer"))?); - (model, device) + Model::M(M::new(&config, vb.pp("transformer"))?) }; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs index 0535aa70..ccd924a4 100644 --- a/candle-examples/examples/stable-lm/main.rs +++ b/candle-examples/examples/stable-lm/main.rs @@ -234,13 +234,14 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let config = Config::stablelm_3b_4e1t(args.use_flash_attn); + let device = candle_examples::device(args.cpu)?; let (model, device) = if args.quantized { let filename = &filenames[0]; - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?; + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; let model = QStableLM::new(&config, vb)?; (Model::Quantized(model), Device::Cpu) } else { - let device = candle_examples::device(args.cpu)?; let dtype = if device.is_cuda() { DType::BF16 } else { diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 5be81f2d..6ea34613 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -557,8 +557,10 @@ fn main() -> Result<()> { println!("loaded mel: {:?}", mel.dims()); let mut model = if args.quantized { - let vb = - candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?; + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + &weights_filename, + &device, + )?; Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?) } else { let vb = diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index c872dc60..201af97e 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -15,6 +15,7 @@ const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); const CONV: &str = include_str!("conv.metal"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); +const QUANTIZED: &str = include_str!("quantized.metal"); /// Most kernels apply similarly across the tensors /// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the @@ -62,6 +63,8 @@ macro_rules! primitive { }; } primitive!(usize); +primitive!(i64); +primitive!(i32); primitive!(u32); primitive!(f32); @@ -117,6 +120,7 @@ pub enum Source { Reduce, Mfa, Conv, + Quantized, } macro_rules! ops{ @@ -215,17 +219,15 @@ type Pipelines = HashMap<(&'static str, Option), ComputePipeline pub struct Kernels { libraries: RwLock, pipelines: RwLock, - fence: metal::Fence, } impl Kernels { - pub fn new(fence: metal::Fence) -> Self { + pub fn new() -> Self { let libraries = RwLock::new(Libraries::new()); let pipelines = RwLock::new(Pipelines::new()); Self { libraries, pipelines, - fence, } } @@ -239,6 +241,7 @@ impl Kernels { Source::Cast => CAST, Source::Reduce => REDUCE, Source::Conv => CONV, + Source::Quantized => QUANTIZED, Source::Mfa => panic!("Invalid lib"), } } @@ -345,7 +348,6 @@ pub fn call_unary_contiguous( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, input, output)); @@ -354,7 +356,6 @@ pub fn call_unary_contiguous( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -376,7 +377,6 @@ pub fn call_unary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -398,7 +398,6 @@ pub fn call_unary_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -417,7 +416,6 @@ pub fn call_binary_contiguous( let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, left, right, output)); @@ -428,7 +426,6 @@ pub fn call_binary_contiguous( encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -453,7 +450,6 @@ pub fn call_binary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); let width: usize = shape.iter().product(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -478,7 +474,6 @@ pub fn call_binary_strided( encoder.use_resource(right_input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -497,7 +492,6 @@ pub fn call_cast_contiguous( let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, (input, input_offset), output)); @@ -506,7 +500,6 @@ pub fn call_cast_contiguous( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -526,7 +519,6 @@ pub fn call_cast_strided( let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -548,7 +540,6 @@ pub fn call_cast_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -568,7 +559,6 @@ pub fn call_reduce_contiguous( let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -597,7 +587,6 @@ pub fn call_reduce_contiguous( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -619,7 +608,6 @@ pub fn call_reduce_strided( let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -655,7 +643,6 @@ pub fn call_reduce_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -674,7 +661,6 @@ pub fn call_last_softmax( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -705,7 +691,6 @@ pub fn call_last_softmax( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -725,7 +710,6 @@ pub fn call_affine( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, add, input, output)); @@ -734,7 +718,6 @@ pub fn call_affine( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -757,7 +740,6 @@ pub fn call_affine_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -778,7 +760,6 @@ pub fn call_affine_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -797,7 +778,6 @@ pub fn call_powf( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, input, output)); @@ -806,7 +786,6 @@ pub fn call_powf( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -828,7 +807,6 @@ pub fn call_powf_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -848,7 +826,6 @@ pub fn call_powf_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -867,7 +844,6 @@ pub fn call_elu( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, input, output)); @@ -876,7 +852,6 @@ pub fn call_elu( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -898,7 +873,6 @@ pub fn call_elu_strided( let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -918,7 +892,6 @@ pub fn call_elu_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -940,7 +913,6 @@ pub fn call_where_cond_strided( let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let size: usize = shape.iter().product(); @@ -969,7 +941,6 @@ pub fn call_where_cond_strided( encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -996,7 +967,6 @@ pub fn call_index_select( let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1019,7 +989,6 @@ pub fn call_index_select( encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -1048,7 +1017,6 @@ pub fn call_gather( let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1071,7 +1039,6 @@ pub fn call_gather( encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -1100,7 +1067,6 @@ pub fn call_scatter_add( let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1123,7 +1089,6 @@ pub fn call_scatter_add( encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -1153,7 +1118,6 @@ pub fn call_index_add( let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1177,7 +1141,6 @@ pub fn call_index_add( encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -1381,7 +1344,6 @@ pub fn call_gemm( let block_bytes = block_elements * bytes; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); @@ -1421,12 +1383,10 @@ pub fn call_gemm( height: 1, depth: 1, }; - // println!("grid size {grid_size:?} group size {group_size:?}"); encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_size, group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) @@ -1451,7 +1411,6 @@ pub fn call_im2col1d_strided( let encoder = command_buffer.new_compute_command_encoder(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -1471,7 +1430,6 @@ pub fn call_im2col1d_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) @@ -1501,7 +1459,6 @@ pub fn call_im2col_strided( let encoder = command_buffer.new_compute_command_encoder(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -1523,7 +1480,6 @@ pub fn call_im2col_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) @@ -1549,7 +1505,6 @@ pub fn call_upsample_nearest_2d( let scale_h = shape[3] as f32 / out_h as f32; let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -1567,7 +1522,176 @@ pub fn call_upsample_nearest_2d( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) +} + +#[derive(Debug, Clone, Copy)] +pub enum GgmlDType { + Q4_0, + Q4_1, + Q5_0, + Q5_1, + Q8_0, + Q8_1, + Q2K, + Q3K, + Q4K, + Q5K, + Q6K, + Q8K, + F16, + F32, +} + +pub fn call_quantized_matmul_t( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + dtype: GgmlDType, + (b, m, n, k): (usize, usize, usize, usize), + lhs: &Buffer, + lhs_offset: usize, + rhs: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + // Everything is in reverse + let ne00 = k as i64; + let ne01 = n as i64; + let ne02 = b as i64; + let ne03 = 1 as i64; + + let nb00 = 0i64; + let nb01 = 0 as i64; + let nb02 = 0 as i64; + + let ne10 = k as i64; + let ne11 = m as i64; + let ne12 = b as i64; + let ne13 = 1 as i64; + + let nb10 = 0i64; + let nb11 = 0i64; + let nb12 = 0i64; + + let ne0 = n as i64; + let ne1 = m as i64; + let r2: u32 = (ne12 / ne02) as u32; + let r3: u32 = (ne13 / ne03) as u32; + + let (nth0, nth1, align) = match dtype { + GgmlDType::Q4_0 + | GgmlDType::Q4_1 + | GgmlDType::Q5_0 + | GgmlDType::Q5_1 + | GgmlDType::Q8_0 + | GgmlDType::Q8_1 => { + let nth0 = 8; + let nth1 = 8; + let align = 8; + (nth0, nth1, align) + } + GgmlDType::Q2K => { + // Fixing a bug in Metal for GGML + let nth0 = 4; + let nth1 = 8; + let align = 4; + (nth0, nth1, align) + } + GgmlDType::Q4K => { + let nth0 = 4; + let nth1 = 8; + let align = 4; + (nth0, nth1, align) + } + GgmlDType::Q3K | GgmlDType::Q5K => { + let nth0 = 2; + let nth1 = 32; + let align = 4; + (nth0, nth1, align) + } + GgmlDType::Q6K => { + let nth0 = 2; + let nth1 = 32; + let align = 2; + (nth0, nth1, align) + } + GgmlDType::F16 | GgmlDType::Q8K => { + // Original implem uses rows + let nth0 = 32; + let nth1 = 1; + let align = 8; + (nth0, nth1, align) + } + GgmlDType::F32 => { + let nth0 = 32; + let nth1 = 1; + let align = 8; + (nth0, nth1, align) + } + }; + let thread_groups_count = MTLSize { + width: divide(ne01 as usize, align), + height: ne11 as u64, + depth: (ne12 * ne13) as u64, + }; + let threads_per_threadgroup = MTLSize { + width: nth0, + height: nth1, + depth: 1, + }; + let name = match dtype { + GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32", + GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32", + GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32", + GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32", + GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32", + GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32", + GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32", + GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32", + GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32", + GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32", + GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32", + GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32", + GgmlDType::F16 => "kernel_mul_mv_f16_f32", + GgmlDType::F32 => "kernel_mul_mv_f32_f32", + }; + + let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + rhs, + (lhs, lhs_offset), + output, + ne00, + ne01, + ne02, + nb00, + nb01, + nb02, + ne10, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3 + ) + ); + encoder.set_threadgroup_memory_length(0, 8192); + encoder.use_resource(lhs, metal::MTLResourceUsage::Read); + encoder.use_resource(rhs, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + + encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup); encoder.end_encoding(); Ok(()) diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal new file mode 100644 index 00000000..9aa7b502 --- /dev/null +++ b/candle-metal-kernels/src/quantized.metal @@ -0,0 +1,5107 @@ +#include + +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) +#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } + +#define QK4_0 32 +#define QR4_0 2 +typedef struct { + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; + +#define QK4_1 32 +typedef struct { + half d; // delta + half m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; + +#define QK5_0 32 +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; + +#define QK5_1 32 +typedef struct { + half d; // delta + half m; // min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; + +#define QK8_0 32 +typedef struct { + half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; + +#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 + +enum ggml_sort_order { + GGML_SORT_ASC, + GGML_SORT_DESC, +}; + +// general-purpose kernel for addition, multiplication and division of two tensors +// pros: works for non-contiguous tensors, supports broadcast across all dims +// cons: not very efficient +kernel void kernel_add( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int64_t & offs, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10)); + } +} + +kernel void kernel_mul( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10)); + } +} + +kernel void kernel_div( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] + src1[tpig % nb]; +} + +kernel void kernel_mul_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src1[tpig % nb]; +} + +kernel void kernel_div_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] / src1[tpig % nb]; +} + +kernel void kernel_scale( + device const float * src0, + device float * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_scale_4( + device const float4 * src0, + device float4 * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_relu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = max(0.0f, src0[tpig]); +} + +kernel void kernel_tanh( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = precise::tanh(x); +} + +constant float GELU_COEF_A = 0.044715f; +constant float GELU_QUICK_COEF = -1.702f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +kernel void kernel_gelu( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + // BEWARE !!! + // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! + // This was observed with Falcon 7B and 40B models + // + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_quick( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_silu( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_sqr( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src0[tpig]; +} + +kernel void kernel_sum_rows( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tpig[[thread_position_in_grid]]) { + int64_t i3 = tpig.z; + int64_t i2 = tpig.y; + int64_t i1 = tpig.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); + device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + float row_sum = 0; + + for (int64_t i0 = 0; i0 < ne00; i0++) { + row_sum += src_row[i0]; + } + + dst_row[0] = row_sum; +} + +kernel void kernel_soft_max( + device const float * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; + device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + // parallel max + float lmax = -INFINITY; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)); + } + + // find the max value in the block + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float lsum = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); + lsum += exp_psrc0; + pdst[i00] = exp_psrc0; + } + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + pdst[i00] *= inv_sum; + } +} + +kernel void kernel_soft_max_4( + device const float * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; + device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + // parallel max + float4 lmax4 = -INFINITY; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); + } + + const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); + + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + + const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + pdst4[i00] *= inv_sum; + } +} + +kernel void kernel_diag_mask_inf( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + const int64_t i02 = tpig[2]; + const int64_t i01 = tpig[1]; + const int64_t i00 = tpig[0]; + + if (i00 > n_past + i01) { + dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; + } else { + dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + device const float4 * src0, + device float4 * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + + const int64_t i = 2*tpig[0]; + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int64_t i4 = 4*i; + const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; + const int64_t i01 = i4/(ne00); i4 -= i01*ne00; + const int64_t i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + dst[i+1][k] = -INFINITY; + if (i00 + k > n_past + i01) { + dst[i][k] = -INFINITY; + } + } +} + +kernel void kernel_norm( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, + threadgroup float * sum [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); + // MEAN + // parallel sum + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + sum[tpitg] += x[i00]; + } + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg/2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + const float mean = sum[0] / ne00; + + // recenter and VARIANCE + threadgroup_barrier(mem_flags::mem_threadgroup); + device float * y = dst + tgpig*ne00; + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = x[i00] - mean; + sum[tpitg] += y[i00] * y[i00]; + } + + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg/2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + const float variance = sum[0] / ne00; + + const float scale = 1.0f/sqrt(variance + eps); + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = y[i00] * scale; + } +} + +kernel void kernel_rms_norm( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + + float4 sumf = 0; + float all_sum = 0; + + // parallel sum + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; + all_sum = simd_sum(all_sum); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = all_sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + all_sum = buf[tiisg]; + all_sum = simd_sum(all_sum); + } + + const float mean = all_sum/ne00; + const float scale = 1.0f/sqrt(mean + eps); + + device float4 * y = (device float4 *) (dst + tgpig*ne00); + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + y[i00] = x[i00] * scale; + } +} + +kernel void kernel_group_norm( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int32_t & n_groups, + constant float & eps, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t ne = ne00*ne01*ne02; + const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups); + + int start = tgpig * gs; + int end = start + gs; + + start += tpitg; + + if (end >= ne) { + end = ne; + } + + float tmp = 0.0f; // partial sum for thread in warp + + for (int j = start; j < end; j += ntg) { + tmp += src0[j]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float mean = tmp / gs; + tmp = 0.0f; + + for (int j = start; j < end; j += ntg) { + float xi = src0[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float variance = tmp / gs; + const float scale = 1.0f/sqrt(variance + eps); + for (int j = start; j < end; j += ntg) { + dst[j] *= scale; + } +} + +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (sumy * -8.f + acc[0] + acc[1]); +} + +// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + +// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (sumy * -16.f + acc[0] + acc[1]); +} + +// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_1/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + +// putting them in the kernel cause a significant performance penalty +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +//Note: This is a template, but strictly speaking it only applies to +// quantizations where the block size is 32. It also does not +// guard against the number of rows not being divisible by +// N_DST, so this is another explicit assumption of the implementation. +template +void mul_vec_q_n_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, uint tiisg, uint sgitg) { + const int nb = ne00/QK4_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q_type * x = (device const block_q_type *) src0 + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; // src1 vector cache + float sumf[nr] = {0.f}; + + const int ix = (tiisg/2); + const int il = (tiisg%2)*8; + + device const float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { + float sumy = 0; + for (int i = 0; i < 8; i += 2) { + sumy += yb[i] + yb[i+1]; + yl[i+0] = yb[i+ 0]; + yl[i+1] = yb[i+ 1]/256.f; + + sumy += yb[i+16] + yb[i+17]; + yl[i+8] = yb[i+16]/16.f; + yl[i+9] = yb[i+17]/4096.f; + } + + for (int row = 0; row < nr; row++) { + sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); + } + + yb += QK4_0 * 16; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; + } + } +} + +kernel void kernel_mul_mv_q4_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q4_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q5_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q5_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); +} + + +#define NB_Q8_0 8 + +void kernel_mul_mv_q8_0_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int nr = N_DST; + const int nsg = N_SIMDGROUP; + const int nw = N_SIMDWIDTH; + + const int nb = ne00/QK8_0; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[NB_Q8_0]; + float sumf[nr]={0.f}; + + const int ix = tiisg/4; + const int il = tiisg%4; + + device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; + + // each thread in a SIMD group deals with NB_Q8_0 quants at a time + for (int ib = ix; ib < nb; ib += nw/4) { + for (int i = 0; i < NB_Q8_0; ++i) { + yl[i] = yb[i]; + } + + for (int row = 0; row < nr; row++) { + device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; + float sumq = 0.f; + for (int iq = 0; iq < NB_Q8_0; ++iq) { + sumq += qs[iq] * yl[iq]; + } + sumf[row] += sumq*x[ib+row*nb].d; + } + + yb += NB_Q8_0 * nw; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q8_0_f32")]] +kernel void kernel_mul_mv_q8_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); +} + +#define N_F32_F32 4 + +void kernel_mul_mv_f32_f32_impl( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_F32_F32; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const float * x = (device const float *) (src0 + offset0); + + if (ne00 < 128) { + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const float4 * x4 = (device const float4 *)x; + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float4 * y4 = (device const float4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +[[host_name("kernel_mul_mv_f32_f32")]] +kernel void kernel_mul_mv_f32_f32( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} + +#define N_F16_F16 4 + +kernel void kernel_mul_mv_f16_f16( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_F16_F16; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const half * x = (device const half *) (src0 + offset0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (half) x[i] * (half) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const half4 * x4 = (device const half4 *)x; + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); + device const half4 * y4 = (device const half4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +void kernel_mul_mv_f16_f32_1row_impl( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const half * x = (device const half *) (src0 + offset0); + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + if (ne00 < 128) { + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + device const half4 * x4 = (device const half4 *) x; + device const float4 * y4 = (device const float4 *) y; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_f16_f32_1row")]] +kernel void kernel_mul_mv_f16_f32_1row( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} + +#define N_F16_F32 4 + +void kernel_mul_mv_f16_f32_impl( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_F16_F32; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const half * x = (device const half *) (src0 + offset0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const half4 * x4 = (device const half4 *)x; + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float4 * y4 = (device const float4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +[[host_name("kernel_mul_mv_f16_f32")]] +kernel void kernel_mul_mv_f16_f32( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} + +// Assumes row size (ne00) is a multiple of 4 +kernel void kernel_mul_mv_f16_f32_l4( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int nrows = ne11; + const int64_t r0 = tgpig.x; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const half4 * x4 = (device const half4 *) (src0 + offset0); + + for (int r1 = 0; r1 < nrows; ++r1) { + device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + +kernel void kernel_alibi_f32( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant float & m0, + constant float & m1, + constant int & n_heads_log2_floor, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + const int64_t k = i3*ne3 + i2; + + float m_k; + if (k < n_heads_log2_floor) { + m_k = pow(m0, k + 1); + } else { + m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1); + } + + device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1; + device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01; + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + const float src_v = *(device float *)(src_row + i00*nb00); + device float * dst_v = (device float *)(dst_row + i00*nb0); + *dst_v = i00 * m_k + src_v; + } +} + +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + thread float * cos_theta, thread float * sin_theta +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); + } + *cos_theta = cos(theta) * mscale; + *sin_theta = sin(theta) * mscale; +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { + return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +} + +static void rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { + // start and end correction dims + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); +} + +typedef void (rope_t)( + device const void * src0, + device const int32_t * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant int & n_orig_ctx, + constant float & freq_base, + constant float & freq_scale, + constant float & ext_factor, + constant float & attn_factor, + constant float & beta_fast, + constant float & beta_slow, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]); + +template +kernel void kernel_rope( + device const void * src0, + device const int32_t * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant int & n_orig_ctx, + constant float & freq_base, + constant float & freq_scale, + constant float & ext_factor, + constant float & attn_factor, + constant float & beta_fast, + constant float & beta_slow, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int64_t i3 = tgpig[2]; + const int64_t i2 = tgpig[1]; + const int64_t i1 = tgpig[0]; + + const bool is_neox = mode & 2; + + float corr_dims[2]; + rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + + device const int32_t * pos = src1; + + const int64_t p = pos[i2]; + + const float theta_0 = (float)p; + const float inv_ndims = -1.f/n_dims; + + if (!is_neox) { + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + + const float theta = theta_0 * pow(freq_base, inv_ndims*i0); + float cos_theta, sin_theta; + rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const T x0 = src[0]; + const T x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } + } else { + for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) { + if (ic < n_dims) { + const int64_t ib = 0; + + // simplified from `(ib * n_dims + ic) * inv_ndims` + const float cur_rot = inv_ndims*ic - ib; + + const float theta = theta_0 * pow(freq_base, cur_rot); + float cos_theta, sin_theta; + rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); + + const int64_t i0 = ib*n_dims + ic/2; + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + const int64_t i0 = ic; + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } + } +} + +template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; +template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; + +kernel void kernel_im2col_f16( + device const float * x, + device half * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0; + const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1; + + const int32_t offset_dst = + (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + + (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } +} + +kernel void kernel_upscale_f32( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int32_t & sf, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1/sf; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = src0_ptr[i0/sf]; + } +} + +kernel void kernel_pad_f32( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + + if (i1 < ne01 && i2 < ne02 && i3 < ne03) { + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i0 < ne00) { + dst_ptr[i0] = src0_ptr[i0]; + } else { + dst_ptr[i0] = 0.0f; + } + } + + return; + } + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = 0.0f; + } +} + +// bitonic sort implementation following the CUDA kernels as reference +typedef void (argsort_t)( + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +template +kernel void kernel_argsort_f32_i32( + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + // bitonic sort + int col = tpitg[0]; + int row = tgpig[1]; + + if (col >= ncols) return; + + device const float * x_row = x + row * ncols; + device int32_t * dst_row = dst + row * ncols; + + // initialize indices + if (col < ncols) { + dst_row[col] = col; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 2; k <= ncols; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { + SWAP(dst_row[col], dst_row[ixj]); + } + } else { + if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { + SWAP(dst_row[col], dst_row[ixj]); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } +} + +template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; + +kernel void kernel_leaky_relu_f32( + device const float * src0, + device float * dst, + constant float & slope, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; +} + +kernel void kernel_cpy_f16_f16( + device const half * src0, + device half * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f16_f32( + device const half * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f16( + device const float * src0, + device half * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f32( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_q8_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0; + + device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = src[j]; + amax = MAX(amax, fabs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK8_0].d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = src[j]*id; + + dst_data[i00/QK8_0].qs[j] = round(x0); + } + } +} + +kernel void kernel_cpy_f32_q4_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0; + + device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK4_0].d = d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_0/2 + j]*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); + + dst_data[i00/QK4_0].qs[j] = xi0; + dst_data[i00/QK4_0].qs[j] |= xi1 << 4; + } + } +} + +kernel void kernel_cpy_f32_q4_1( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1; + + device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < QK4_1; j++) { + const float v = src[j]; + if (min > v) min = v; + if (max < v) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK4_1].d = d; + dst_data[i00/QK4_1].m = min; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK4_1/2 + j] - min)*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); + + dst_data[i00/QK4_1].qs[j] = xi0; + dst_data[i00/QK4_1].qs[j] |= xi1 << 4; + } + } +} + +kernel void kernel_concat( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i02 < ne02) { + ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0]; + src0_ptr += ntg.x*nb00; + } else { + ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0]; + src1_ptr += ntg.x*nb10; + } + dst_ptr += ntg.x*nb0; + } +} + +//============================================ k-quants ====================================================== + +#ifndef QK_K +#define QK_K 256 +#else +static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64"); +#endif + +#if QK_K == 256 +#define K_SCALE_SIZE 12 +#else +#define K_SCALE_SIZE 4 +#endif + +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins +} block_q2_K; +// 84 bytes / block + +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits +#if QK_K == 64 + uint8_t scales[2]; +#else + uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits +#endif + half d; // super-block scale +} block_q3_K; + +#if QK_K == 64 +typedef struct { + half d[2]; // super-block scales/mins + uint8_t scales[2]; + uint8_t qs[QK_K/2]; // 4-bit quants +} block_q4_K; +#else +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +#endif + +#if QK_K == 64 +typedef struct { + half d; // super-block scales/mins + int8_t scales[QK_K/16]; // 8-bit block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +#else +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +// 176 bytes / block +#endif + +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +} block_q6_K; +// 210 bytes / block + +//====================================== dot products ========================= + +void kernel_mul_mv_q2_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int step = sizeof(block_q2_K) * nb; + +#if QK_K == 256 + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int iq = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + const int is = (8*ir)/16;// 0 or 1 + + device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; + + for (int ib = ix; ib < nb; ib += 4) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; + } + + device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + float dall = dh[0]; + float dmin = dh[1] * 1.f/16.f; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); + + qs += step/2; + sc += step; + dh += step/2; + } + + y4 += 4 * QK_K; + } +#else + const int ix = tiisg/2; // 0...15 + const int it = tiisg%2; // 0...1 + + device const float * y4 = y + ix * QK_K + 8 * it; + + for (int ib = ix; ib < nb; ib += 16) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+32]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+48]; sumy[3] += yl[i+24]; + } + + device const uint8_t * sc = (device const uint8_t *)x[ib].scales; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4)); + + qs += step/2; + sc += step; + dh += step/2; + } + + y4 += 16 * QK_K; + } +#endif + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_q2_K_f32")]] +kernel void kernel_mul_mv_q2_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +#if QK_K == 256 +void kernel_mul_mv_q3_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + + //const uint16_t kmask1 = 0x3030; + //const uint16_t kmask2 = 0x0f0f; + + const int tid = tiisg/4; + const int ix = tiisg%4; + const int ip = tid/4; // 0 or 1 + const int il = 2*((tid%4)/2); // 0 or 2 + const int ir = tid%2; + const int n = 8; + const int l0 = n*ir; + + // One would think that the Metal compiler would figure out that ip and il can only have + // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it + // with these two tales. + // + // Possible masks for the high bit + const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0 + {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2 + {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0 + {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2 + + // Possible masks for the low 2 bits + const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; + + const ushort4 hm = mm[2*ip + il/2]; + + const int shift = 2*il; + const float v1 = il == 0 ? 4.f : 64.f; + const float v2 = 4.f * v1; + + const uint16_t s_shift1 = 4*ip; + const uint16_t s_shift2 = s_shift1 + il; + + const int q_offset = 32*ip + l0; + const int y_offset = 128*ip + 32*il + l0; + + const int step = sizeof(block_q3_K) * nb / 2; + + device const float * y1 = yy + ix*QK_K + y_offset; + + uint32_t scales32, aux32; + thread uint16_t * scales16 = (thread uint16_t *)&scales32; + thread const int8_t * scales = (thread const int8_t *)&scales32; + + float sumf1[2] = {0.f}; + float sumf2[2] = {0.f}; + for (int i = ix; i < nb; i += 4) { + + for (int l = 0; l < 8; ++l) { + yl[l+ 0] = y1[l+ 0]; + yl[l+ 8] = y1[l+16]; + yl[l+16] = y1[l+32]; + yl[l+24] = y1[l+48]; + } + + device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); + device const uint16_t * a = (device const uint16_t *)(x[i].scales); + device const half * dh = &x[i].d; + + for (int row = 0; row < 2; ++row) { + + const float d_all = (float)dh[0]; + + scales16[0] = a[4]; + scales16[1] = a[5]; + aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; + scales16[0] = a[il+0]; + scales16[1] = a[il+1]; + scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; + + float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; + for (int l = 0; l < n; l += 2) { + const int32_t qs = q[l/2]; + s1 += yl[l+0] * (qs & qm[il/2][0]); + s2 += yl[l+1] * (qs & qm[il/2][1]); + s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); + s4 += yl[l+16] * (qs & qm[il/2][2]); + s5 += yl[l+17] * (qs & qm[il/2][3]); + s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); + } + float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[0] - 32); + sumf2[row] += d2 * (scales[2] - 32); + + s1 = s2 = s3 = s4 = s5 = s6 = 0; + for (int l = 0; l < n; l += 2) { + const int32_t qs = q[l/2+8]; + s1 += yl[l+8] * (qs & qm[il/2][0]); + s2 += yl[l+9] * (qs & qm[il/2][1]); + s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); + s4 += yl[l+24] * (qs & qm[il/2][2]); + s5 += yl[l+25] * (qs & qm[il/2][3]); + s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); + } + d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[1] - 32); + sumf2[row] += d2 * (scales[3] - 32); + + q += step; + h += step; + a += step; + dh += step; + + } + + y1 += 4 * QK_K; + + } + + for (int row = 0; row < 2; ++row) { + const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); + sumf1[row] = simd_sum(sumf); + } + if (tiisg == 0) { + for (int row = 0; row < 2; ++row) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row]; + } + } +} +#else +void kernel_mul_mv_q3_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + const int row = 2 * r0 + sgitg; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/4; + const int il = 4 * (tiisg%4);// 0, 4, 8, 12 + const int iq = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 + + float2 sum = {0.f, 0.f}; + + for (int i = ix; i < nb; i += 8) { + + const float d_all = (float)(x[i].d); + + device const uint16_t * q = (device const uint16_t *)(x[i].qs + il); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in); + device const uint16_t * s = (device const uint16_t *)(x[i].scales); + device const float * y = yy + i * QK_K + il; + + const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8); + const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f; + const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f; + const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f; + + for (int l = 0; l < 4; l += 2) { + const uint16_t hm = h[l/2] >> iq; + sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4)) + + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)) + + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)) + + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256)); + sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024)) + + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096)) + + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384)) + + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536)); + } + + } + const float sumf = sum[0] + sum[1] * 1.f/256.f; + + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + row] = tot; + } + +} +#endif + +[[host_name("kernel_mul_mv_q3_K_f32")]] +kernel void kernel_mul_mv_q3_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +#if QK_K == 256 +void kernel_mul_mv_q4_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int iq = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = r0 * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; + float yh[16]; + float sumf[N_DST]={0.f}, all_sum; + + const int step = sizeof(block_q4_K) * nb / 2; + + device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + + for (int ib = ix; ib < nb; ib += 4) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; + yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; + yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; + yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; + } + + device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq; + device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + device const uint16_t * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); + acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); + acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); + acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); + acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); + acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); + acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); + acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + + (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += step; + sc += step; + dh += step; + } + + y4 += 4 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} +#else +void kernel_mul_mv_q4_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int ix = tiisg/4; // 0...7 + const int it = tiisg%4; // 0...3 + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = r0 * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[8]; + float yh[8]; + float sumf[N_DST]={0.f}, all_sum; + + const int step = sizeof(block_q4_K) * nb / 2; + + device const float * y4 = y + ix * QK_K + 8 * it; + + uint16_t sc16[4]; + + for (int ib = ix; ib < nb; ib += 8) { + + float2 sumy = {0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i] = y4[i+ 0]; sumy[0] += yl[i]; + yh[i] = y4[i+32]; sumy[1] += yh[i]; + } + + device const uint16_t * sc = (device const uint16_t *)x[ib].scales; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; + device const half * dh = x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + sc16[0] = sc[0] & 0x000f; + sc16[1] = sc[0] & 0x0f00; + sc16[2] = sc[0] & 0x00f0; + sc16[3] = sc[0] & 0xf000; + + float2 acc1 = {0.f, 0.f}; + float2 acc2 = {0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+0] * (qs[i/2] & 0x000F); + acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00); + acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0); + acc2[1] += yh[i+1] * (qs[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] + + (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) - + dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f); + + qs += step; + sc += step; + dh += step; + } + + y4 += 8 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} +#endif + +[[host_name("kernel_mul_mv_q4_K_f32")]] +kernel void kernel_mul_mv_q4_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_q5_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float sumf[2]={0.f}; + + const int step = sizeof(block_q5_K) * nb; + +#if QK_K == 256 +# + float yl[16], yh[16]; + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = tiisg/4; + const int ix = tiisg%4; + const int iq = tid/4; + const int ir = tid%4; + const int n = 8; + + const int l0 = n*ir; + const int q_offset = 32*iq + l0; + const int y_offset = 64*iq + l0; + + const uint8_t hm1 = 1u << (2*iq); + const uint8_t hm2 = hm1 << 1; + const uint8_t hm3 = hm1 << 4; + const uint8_t hm4 = hm2 << 4; + + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + + device const float * y1 = yy + ix*QK_K + y_offset; + + for (int i = ix; i < nb; i += 4) { + + device const uint8_t * q1 = x[i].qs + q_offset; + device const uint8_t * qh = x[i].qh + l0; + device const half * dh = &x[i].d; + device const uint16_t * a = (device const uint16_t *)x[i].scales + iq; + + device const float * y2 = y1 + 128; + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < 8; ++l) { + yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; + yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; + yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; + yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; + } + + for (int row = 0; row < 2; ++row) { + + device const uint8_t * q2 = q1 + 64; + + sc16[0] = a[0] & kmask1; + sc16[1] = a[2] & kmask1; + sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); + sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); + + float4 acc1 = {0.f}; + float4 acc2 = {0.f}; + for (int l = 0; l < n; ++l) { + uint8_t h = qh[l]; + acc1[0] += yl[l+0] * (q1[l] & 0x0F); + acc1[1] += yl[l+8] * (q1[l] & 0xF0); + acc1[2] += yh[l+0] * (q2[l] & 0x0F); + acc1[3] += yh[l+8] * (q2[l] & 0xF0); + acc2[0] += h & hm1 ? yl[l+0] : 0.f; + acc2[1] += h & hm2 ? yl[l+8] : 0.f; + acc2[2] += h & hm3 ? yh[l+0] : 0.f; + acc2[3] += h & hm4 ? yh[l+8] : 0.f; + } + const float dall = dh[0]; + const float dmin = dh[1]; + sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + + sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + + sc8[4] * (acc1[2] + 16.f*acc2[2]) + + sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += step; + qh += step; + dh += step/2; + a += step/2; + + } + + y1 += 4 * QK_K; + + } +#else + float yl[8], yh[8]; + + const int il = 4 * (tiisg/8); // 0, 4, 8, 12 + const int ix = tiisg%8; + const int iq = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 + + device const float * y = yy + ix*QK_K + il; + + for (int i = ix; i < nb; i += 8) { + + for (int l = 0; l < 4; ++l) { + yl[l+0] = y[l+ 0]; + yl[l+4] = y[l+16]; + yh[l+0] = y[l+32]; + yh[l+4] = y[l+48]; + } + + device const half * dh = &x[i].d; + device const uint8_t * q = x[i].qs + il; + device const uint8_t * h = x[i].qh + in; + device const int8_t * s = x[i].scales; + + for (int row = 0; row < 2; ++row) { + + const float d = dh[0]; + + float2 acc = {0.f, 0.f}; + for (int l = 0; l < 4; ++l) { + const uint8_t hl = h[l] >> iq; + acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16)) + + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16)); + acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256)) + + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256)); + } + sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]); + + q += step; + h += step; + s += step; + dh += step/2; + + } + + y += 8 * QK_K; + } +#endif + + for (int row = 0; row < 2; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q5_K_f32")]] +kernel void kernel_mul_mv_q5_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_q6_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const uint8_t kmask1 = 0x03; + const uint8_t kmask2 = 0x0C; + const uint8_t kmask3 = 0x30; + const uint8_t kmask4 = 0xC0; + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int im = tgpig.z; + + const int row = 2 * r0 + sgitg; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float sumf = 0; + +#if QK_K == 256 + const int tid = tiisg/2; + const int ix = tiisg%2; + const int ip = tid/8; // 0 or 1 + const int il = tid%8; + const int n = 4; + const int l0 = n*il; + const int is = 8*ip + l0/16; + + const int y_offset = 128*ip + l0; + const int q_offset_l = 64*ip + l0; + const int q_offset_h = 32*ip + l0; + + for (int i = ix; i < nb; i += 2) { + + device const uint8_t * q1 = x[i].ql + q_offset_l; + device const uint8_t * q2 = q1 + 32; + device const uint8_t * qh = x[i].qh + q_offset_h; + device const int8_t * sc = x[i].scales + is; + + device const float * y = yy + i * QK_K + y_offset; + + const float dall = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + + sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + + } + +#else + const int ix = tiisg/4; + const int il = 4*(tiisg%4); + + for (int i = ix; i < nb; i += 8) { + device const float * y = yy + i * QK_K + il; + device const uint8_t * ql = x[i].ql + il; + device const uint8_t * qh = x[i].qh + il; + device const int8_t * s = x[i].scales; + + const float d = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < 4; ++l) { + sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32); + sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); + } + +#endif + + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + row] = tot; + } +} + +[[host_name("kernel_mul_mv_q6_K_f32")]] +kernel void kernel_mul_mv_q6_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +//============================= templates and their specializations ============================= + +// NOTE: this is not dequantizing - we are simply fitting the template +template +void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { + float4x4 temp = *(((device float4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + +template +void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { + half4x4 temp = *(((device half4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + +template +void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; + reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; + } +} + +template +void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 2); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float m = xb->m; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; + reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; + } +} + +template +void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + md; + reg[i/2][2*(i%2)+1] = d * x1 + md; + } +} + +template +void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + m; + reg[i/2][2*(i%2)+1] = d * x1 + m; + } +} + +template +void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const half d = xb->d; + + for (int i = 0; i < 16; i++) { + reg[i/4][i%4] = (qs[i + 16*il] * d); + } +} + +template +void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { + const float d = xb->d; + const float min = xb->dmin; + device const uint8_t * q = (device const uint8_t *)xb->qs; + float dl, ml; + uint8_t sc = xb->scales[il]; + +#if QK_K == 256 + q = q + 32*(il/8) + 16*(il&1); + il = (il/2)%4; +#endif + half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint8_t * q = (device const uint8_t *)xb->qs; + device const uint8_t * h = (device const uint8_t *)xb->hmask; + device const int8_t * scales = (device const int8_t *)xb->scales; + +#if QK_K == 256 + q = q + 32 * (il/8) + 16 * (il&1); + h = h + 16 * (il&1); + uint8_t m = 1 << (il/2); + uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ + ((il/4)>0 ? 12 : 3); + uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; + uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; + int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) + : (scale_2&kmask2) | ((scale_1&kmask1) << 4); + half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h); + const half ml = 4.h * dl; + + il = (il/2) & 3; + const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl *= coef; + + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); + } +#else + float kcoef = il&1 ? 1.f/16.f : 1.f; + uint16_t kmask = il&1 ? 0xF0 : 0x0F; + float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8); + float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + uint8_t m = 1<<(il*2); + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef)); + } +#endif +} + +static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { + return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} + : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; +} + +template +void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { + device const uchar * q = xb->qs; + +#if QK_K == 256 + short is = (il/4) * 2; + q = q + (il/4) * 32 + 16 * (il&1); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const float d = il < 2 ? xb->d : xb->d / 16.h; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; +#else + q = q + 16 * (il&1); + device const uint8_t * s = xb->scales; + device const half2 * dh = (device const half2 *)xb->d; + const float2 d = (float2)dh[0]; + const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h; + const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4); +#endif + const ushort mask = il<2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { + device const uint8_t * q = xb->qs; + device const uint8_t * qh = xb->qh; + +#if QK_K == 256 + short is = (il/4) * 2; + q = q + 32 * (il/4) + 16 * (il&1); + qh = qh + 16 * (il&1); + uint8_t ul = 1 << (il/2); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const float d = il < 2 ? xb->d : xb->d / 16.h; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; + + const ushort mask = il<2 ? 0x0F : 0xF0; + const float qh_val = il<2 ? 16.f : 256.f; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; + } +#else + q = q + 16 * (il&1); + device const int8_t * s = xb->scales; + const float dl = xb->d * s[il]; + uint8_t m = 1<<(il*2); + const float coef = il<2 ? 1.f : 1.f/16.f; + const ushort mask = il<2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef)); + } +#endif +} + +template +void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint8_t * ql = (device const uint8_t *)xb->ql; + device const uint8_t * qh = (device const uint8_t *)xb->qh; + device const int8_t * scales = (device const int8_t *)xb->scales; + +#if QK_K == 256 + ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); + qh = qh + 32*(il/8) + 16*(il&1); + half sc = scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; +#else + ql = ql + 16 * (il&1); + half sc = scales[il]; +#endif + const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; + const half coef = il>1 ? 1.f/16.h : 1.h; + const half ml = d_all * sc * 32.h; + const half dl = d_all * sc * coef; + for (int i = 0; i < 16; ++i) { + const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) + : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); + reg[i/4][i%4] = dl * q - ml; + } +} + +template +kernel void kernel_get_rows( + device const void * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + //const int64_t i = tgpig; + //const int64_t r = ((device int32_t *) src1)[i]; + + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { + float4x4 temp; + dequantize_func( + ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); + *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + } +} + +kernel void kernel_get_rows_f32( + device const void * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_f16( + device const void * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A +#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B +#define BLOCK_SIZE_K 32 +#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A +#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B +#define THREAD_PER_BLOCK 128 +#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers +#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers +#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 +#define SG_MAT_ROW 8 + +// each block_q contains 16*nl weights +template +void kernel_mul_mm_impl(device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = (threadgroup half *)(shared_memory); + threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + + const uint r0 = tgpig.y; + const uint r1 = tgpig.x; + const uint im = tgpig.z; + + // if this block is of 64x32 shape or smaller + short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + + // a thread shouldn't load data outside of the matrix + short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_half8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 c_res[8]; + for (int i = 0; i < 8; i++){ + c_res[i] = make_filled_simdgroup_matrix(0.f); + } + + short il = (tiitg % THREAD_PER_ROW); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); + ushort offset1 = il/nl; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + device const float * y = (device const float *)(src1 + + nb12 * im + + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + // load data and store to threadgroup memory + half4x4 temp_a; + dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(16) + for (int i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ + + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + } + + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // load matrices from threadgroup memory and conduct outer products + threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); + threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + + #pragma unroll(4) + for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) + for (int i = 0; i < 4; i++) { + simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); + } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) + for (int i = 0; i < 2; i++) { + simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); + } + + lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; + lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + + #pragma unroll(8) + for (int i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + } + } + } + + if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { + device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ + + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); + } + } else { + // block is smaller than 64x32, we should avoid writing data outside of the matrix + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; + if (sgitg == 0) { + for (int i = 0; i < n_rows; i++) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); + } + } + } + } +} + +// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids +template +void kernel_mul_mm_id_impl( + device const uchar * src0, + device const uchar * src1, + thread short * src1ids, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + int64_t ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = (threadgroup half *)(shared_memory); + threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + + const uint r0 = tgpig.y; + const uint r1 = tgpig.x; + const uint im = tgpig.z; + + if (r1 * BLOCK_SIZE_N >= ne1) return; + + // if this block is of 64x32 shape or smaller + short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + + // a thread shouldn't load data outside of the matrix + short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_half8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 c_res[8]; + for (int i = 0; i < 8; i++){ + c_res[i] = make_filled_simdgroup_matrix(0.f); + } + + short il = (tiitg % THREAD_PER_ROW); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); + ushort offset1 = il/nl; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + device const float * y = (device const float *)(src1 + + nb12 * im + + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col] + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + // load data and store to threadgroup memory + half4x4 temp_a; + dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ + + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + } + + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // load matrices from threadgroup memory and conduct outer products + threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); + threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + + for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + for (int i = 0; i < 4; i++) { + simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); + } + simdgroup_barrier(mem_flags::mem_none); + for (int i = 0; i < 2; i++) { + simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); + } + + lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; + lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + + for (int i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + } + } + } + + { + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0; + if (sgitg == 0) { + for (int i = 0; i < n_rows; i++) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); + } + } + } + } +} + +template +kernel void kernel_mul_mm(device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mm_impl( + src0, + src1, + dst, + ne00, + ne02, + nb01, + nb02, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + shared_memory, + tgpig, + tiitg, + sgitg); +} + +template +kernel void kernel_mul_mm_id( + device const uchar * ids, + device const uchar * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const uchar * src00, + device const uchar * src01, + device const uchar * src02, + device const uchar * src03, + device const uchar * src04, + device const uchar * src05, + device const uchar * src06, + device const uchar * src07, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + // expert id + const int32_t id = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + // row indices of src1 for expert id + int64_t _ne1 = 0; + short src1ids[512]; + + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (((device int32_t *) (ids + i1*nbi1))[idx] == id) { + src1ids[_ne1++] = i1; + } + } + + kernel_mul_mm_id_impl( + src0s[id], + src1, + src1ids, + dst, + ne00, + ne02, + nb01, + nb02, + ne12, + nb10, + nb11, + nb12, + ne0, + _ne1, + r2, + r3, + shared_memory, + tgpig, + tiitg, + sgitg); +} + +#if QK_K == 256 +#define QK_NL 16 +#else +#define QK_NL 4 +#endif + +// +// get rows +// + +typedef void (get_rows_t)( + device const void * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3, uint, uint3); + +//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; +//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; + +// +// matrix-matrix multiplication +// + +typedef void (mat_mm_t)( + device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar *, + uint3, uint, uint); + +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; + +// +// indirect matrix-matrix multiplication +// + +typedef void (mat_mm_id_t)( + device const uchar * ids, + device const uchar * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const uchar * src00, + device const uchar * src01, + device const uchar * src02, + device const uchar * src03, + device const uchar * src04, + device const uchar * src05, + device const uchar * src06, + device const uchar * src07, + threadgroup uchar *, + uint3, uint, uint); + +template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; + +// +// matrix-vector multiplication +// + +[[host_name("kernel_mul_mv_id_f32_f32")]] +kernel void kernel_mul_mv_id_f32_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_f32_f32_impl( + src0[id], + src1 + bid*nb11, + dst + bid*ne0, + ne00, + ne01, + ne02, + nb00, + nb01, + nb02, + ne10, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg); +} + +[[host_name("kernel_mul_mv_id_f16_f32")]] +kernel void kernel_mul_mv_id_f16_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_f16_f32_impl( + src0[id], + src1 + bid*nb11, + dst + bid*ne0, + ne00, + ne01, + ne02, + nb00, + nb01, + nb02, + ne10, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg); +} + +[[host_name("kernel_mul_mv_id_q8_0_f32")]] +kernel void kernel_mul_mv_id_q8_0_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q8_0_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q4_0_f32")]] +kernel void kernel_mul_mv_id_q4_0_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + mul_vec_q_n_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q4_1_f32")]] +kernel void kernel_mul_mv_id_q4_1_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + mul_vec_q_n_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q5_0_f32")]] +kernel void kernel_mul_mv_id_q5_0_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + mul_vec_q_n_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q5_1_f32")]] +kernel void kernel_mul_mv_id_q5_1_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + mul_vec_q_n_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q2_K_f32")]] +kernel void kernel_mul_mv_id_q2_K_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q2_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q3_K_f32")]] +kernel void kernel_mul_mv_id_q3_K_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q3_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q4_K_f32")]] +kernel void kernel_mul_mv_id_q4_K_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q4_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q5_K_f32")]] +kernel void kernel_mul_mv_id_q5_K_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q5_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q6_K_f32")]] +kernel void kernel_mul_mv_id_q6_K_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q6_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 87f8ac45..787a7d45 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -37,8 +37,7 @@ fn approx_bf16(v: Vec, digits: i32) -> Vec { fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -60,8 +59,7 @@ fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; @@ -96,8 +94,7 @@ fn run_strided( let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); let output = new_buffer(&device, v); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); call_unary_strided( &device, command_buffer, @@ -278,8 +275,7 @@ fn binary_ops_bf16() { fn cast(v: &[T], name: &'static str) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -409,8 +405,7 @@ fn it_cast_f16_bf16() { fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); @@ -445,8 +440,7 @@ fn run_affine_strided( add: f64, ) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); @@ -595,8 +589,7 @@ fn run_index_select( let dst_el = ids.len() * left_size * right_size; let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); call_index_select( &device, &command_buffer, @@ -631,8 +624,7 @@ fn cos_f16() { fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -662,8 +654,7 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec(v: &[T], last_dim: usize, name: &'static str) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); @@ -782,8 +773,7 @@ fn run_where_cond( name: &'static str, ) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; @@ -859,8 +849,7 @@ fn run_gemm( rhs_offset: usize, ) -> Vec { let device = device(); - let fence = device.new_fence(); - let kernels = Kernels::new(fence); + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index dcf803d8..7add58fd 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -117,7 +117,6 @@ UNARY_OP(erf) UNARY_OP(tanh) UNARY_OP(recip) UNARY_OP(relu) - UNARY(id, float, copy_f32, copy_f32_strided) UNARY(id, half, copy_f16, copy_f16_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided) @@ -136,6 +135,7 @@ BFLOAT_UNARY_OP(neg) BFLOAT_UNARY_OP(exp) BFLOAT_UNARY_OP(log) BFLOAT_UNARY_OP(gelu) +BFLOAT_UNARY_OP(abs) BFLOAT_UNARY_OP(ceil) BFLOAT_UNARY_OP(floor) BFLOAT_UNARY_OP(round) diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs index 68d384a6..001be116 100644 --- a/candle-nn/examples/cpu_benchmarks.rs +++ b/candle-nn/examples/cpu_benchmarks.rs @@ -222,7 +222,10 @@ impl Benchmark for QMatMul { type RunResult = Tensor; fn preprocess() -> Result { let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32]; - let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?; + let mm = candle::quantized::QTensor::new( + candle::quantized::QStorage::Cpu(Box::new(zeros)), + (4096, 11008), + )?; let mm = candle::quantized::QMatMul::from_qtensor(mm)?; let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?; Ok((mm, arg)) diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi index 4ee51c29..c9a9f9f3 100644 --- a/candle-pyo3/py_src/candle/utils/__init__.pyi +++ b/candle-pyo3/py_src/candle/utils/__init__.pyi @@ -33,7 +33,9 @@ def has_mkl() -> bool: pass @staticmethod -def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]: +def load_ggml( + path: Union[str, PathLike], device: Optional[Device] = None +) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]: """ Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. @@ -41,7 +43,9 @@ def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, pass @staticmethod -def load_gguf(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any]]: +def load_gguf( + path: Union[str, PathLike], device: Optional[Device] = None +) -> Tuple[Dict[str, QTensor], Dict[str, Any]]: """ Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, and the second maps metadata keys to metadata values. diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 90826b98..ca406876 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1074,20 +1074,20 @@ impl PyTensor { fn quantize(&self, quantized_dtype: &str) -> PyResult { use ::candle::quantized; let res = match quantized_dtype.to_lowercase().as_str() { - "q2k" => quantized::QTensor::quantize::(self), - "q3k" => quantized::QTensor::quantize::(self), - "q4_0" => quantized::QTensor::quantize::(self), - "q4_1" => quantized::QTensor::quantize::(self), - "q4k" => quantized::QTensor::quantize::(self), - "q5_0" => quantized::QTensor::quantize::(self), - "q5_1" => quantized::QTensor::quantize::(self), - "q5k" => quantized::QTensor::quantize::(self), - "q6k" => quantized::QTensor::quantize::(self), - "q8_0" => quantized::QTensor::quantize::(self), - "q8_1" => quantized::QTensor::quantize::(self), - "q8k" => quantized::QTensor::quantize::(self), - "f16" => quantized::QTensor::quantize::(self), - "f32" => quantized::QTensor::quantize::(self), + "q2k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q2K), + "q3k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q3K), + "q4_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_0), + "q4_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_1), + "q4k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4K), + "q5_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_0), + "q5_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_1), + "q5k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5K), + "q6k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q6K), + "q8_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_0), + "q8_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_1), + "q8k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8K), + "f16" => quantized::QTensor::quantize(self, quantized::GgmlDType::F16), + "f32" => quantized::QTensor::quantize(self, quantized::GgmlDType::F32), dt => { return Err(PyErr::new::(format!( "unknown quantized-dtype {dt}" @@ -1278,13 +1278,19 @@ fn save_safetensors( } #[pyfunction] -#[pyo3(text_signature = "(path:Union[str,PathLike])")] +#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")] /// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, /// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. /// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]] -fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> { +fn load_ggml( + path: &str, + device: Option, + py: Python<'_>, +) -> PyResult<(PyObject, PyObject, PyObject)> { let mut file = std::fs::File::open(path)?; - let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?; + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; + let ggml = + ::candle::quantized::ggml_file::Content::read(&mut file, &device).map_err(wrap_err)?; let tensors = ggml .tensors .into_iter() @@ -1313,11 +1319,16 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje } #[pyfunction] -#[pyo3(text_signature = "(path:Union[str,PathLike])")] +#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")] /// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, /// and the second maps metadata keys to metadata values. /// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]] -fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { +fn load_gguf( + path: &str, + device: Option, + py: Python<'_>, +) -> PyResult<(PyObject, PyObject)> { + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; use ::candle::quantized::gguf_file; fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult { let v: PyObject = match v { @@ -1349,7 +1360,7 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { .tensor_infos .keys() .map(|key| { - let qtensor = gguf.tensor(&mut file, key)?; + let qtensor = gguf.tensor(&mut file, key, &device)?; Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))) }) .collect::<::candle::Result>>() diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 1fb2d9e2..8aa06088 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -356,6 +356,7 @@ impl ModelWeights { pub fn from_gguf( ct: gguf_file::Content, reader: &mut R, + device: &Device, ) -> Result { let cpu = &Device::Cpu; let md_get = |s: &str| match ct.metadata.get(s) { @@ -383,21 +384,28 @@ impl ModelWeights { .unwrap_or(10000f32); let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; - let tok_embeddings = ct.tensor(reader, "token_embd.weight")?; + let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(cpu)?; - let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?; - let output = ct.tensor(reader, "output.weight")?; + let norm = RmsNorm::new( + ct.tensor(reader, "output_norm.weight", device)?, + rms_norm_eps, + )?; + let output = ct.tensor(reader, "output.weight", device)?; let mut layers = Vec::with_capacity(block_count); for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}"); - let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?; - let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?; - let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?; - let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?; + let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; + let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; + let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; + let attention_wo = + ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; let mlp_or_moe = if n_expert <= 1 { - let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; - let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; - let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; + let feed_forward_w1 = + ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; + let feed_forward_w2 = + ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; + let feed_forward_w3 = + ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; MlpOrMoe::Mlp(Mlp { feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, @@ -405,15 +413,15 @@ impl ModelWeights { }) } else { let feed_forward_gate_inp = - ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?; + ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?; let mut experts = Vec::with_capacity(n_expert); for i in 0..n_expert { let feed_forward_w1 = - ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?; + ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?; let feed_forward_w2 = - ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?; + ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?; let feed_forward_w3 = - ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?; + ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?; experts.push(Mlp { feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, @@ -426,8 +434,9 @@ impl ModelWeights { experts, } }; - let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?; - let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?; + let attention_norm = + ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?; + let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?; let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs index 1a3cd4ac..882f4cf8 100644 --- a/candle-transformers/src/models/quantized_mixformer.rs +++ b/candle-transformers/src/models/quantized_mixformer.rs @@ -311,7 +311,7 @@ impl MixFormerSequentialForCausalLM { let mut blocks = Vec::new(); for i in 0..cfg.n_layer { let block = ParallelBlock::new(cfg, vb.pp(i + 1))?; - blocks.push(block) + blocks.push(block); } let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?; Ok(Self { @@ -332,7 +332,7 @@ impl MixFormerSequentialForCausalLM { Some(get_mask(seq_len, xs.device())?) }; for block in self.blocks.iter_mut() { - xs = block.forward(&xs, mask.as_ref())? + xs = block.forward(&xs, mask.as_ref())?; } xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1) } diff --git a/candle-transformers/src/quantized_var_builder.rs b/candle-transformers/src/quantized_var_builder.rs index 63101f4c..bfd0629f 100644 --- a/candle-transformers/src/quantized_var_builder.rs +++ b/candle-transformers/src/quantized_var_builder.rs @@ -10,33 +10,33 @@ pub struct VarBuilder { } impl VarBuilder { - pub fn from_gguf>(p: P) -> Result { + pub fn from_gguf>(p: P, device: &Device) -> Result { let mut file = std::fs::File::open(p)?; let content = candle::quantized::gguf_file::Content::read(&mut file)?; let mut data = std::collections::HashMap::new(); for tensor_name in content.tensor_infos.keys() { - let tensor = content.tensor(&mut file, tensor_name)?; + let tensor = content.tensor(&mut file, tensor_name, device)?; data.insert(tensor_name.to_string(), Arc::new(tensor)); } Ok(Self { data: Arc::new(data), path: Vec::new(), - device: Device::Cpu, + device: device.clone(), }) } - pub fn from_gguf_buffer(buffer: &[u8]) -> Result { + pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result { let mut cursor = std::io::Cursor::new(buffer); let content = candle::quantized::gguf_file::Content::read(&mut cursor)?; let mut data = std::collections::HashMap::new(); for tensor_name in content.tensor_infos.keys() { - let tensor = content.tensor(&mut cursor, tensor_name)?; + let tensor = content.tensor(&mut cursor, tensor_name, device)?; data.insert(tensor_name.to_string(), Arc::new(tensor)); } Ok(Self { data: Arc::new(data), path: Vec::new(), - device: Device::Cpu, + device: device.clone(), }) } diff --git a/candle-wasm-examples/blip/src/bin/m.rs b/candle-wasm-examples/blip/src/bin/m.rs index 660bb717..e2ba4fed 100644 --- a/candle-wasm-examples/blip/src/bin/m.rs +++ b/candle-wasm-examples/blip/src/bin/m.rs @@ -61,7 +61,7 @@ impl Model { let start = Date::now(); let model: SelectedModel = if quantized { - let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights)?; + let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights, &device)?; let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?; SelectedModel::Q(model) } else { diff --git a/candle-wasm-examples/phi/src/bin/m.rs b/candle-wasm-examples/phi/src/bin/m.rs index 999f276d..859e58cb 100644 --- a/candle-wasm-examples/phi/src/bin/m.rs +++ b/candle-wasm-examples/phi/src/bin/m.rs @@ -41,6 +41,7 @@ impl Model { ) -> Result { console_error_panic_hook::set_once(); console_log!("loading model"); + let device = Device::Cpu; let name: ModelName = serde_json::from_slice(&config)?; let config: Config = serde_json::from_slice(&config)?; @@ -50,8 +51,9 @@ impl Model { let start = Date::now(); console_log!("weights len: {:?}", weights.len()); let model = if quantized { - let vb = - candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?; + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer( + &weights, &device, + )?; console_log!("weights loaded"); if name._name_or_path == "microsoft/phi-2" { let model = QMixFormer::new_v2(&config, vb)?; diff --git a/candle-wasm-examples/t5/src/bin/m-quantized.rs b/candle-wasm-examples/t5/src/bin/m-quantized.rs index 2f490b84..3b99a275 100644 --- a/candle-wasm-examples/t5/src/bin/m-quantized.rs +++ b/candle-wasm-examples/t5/src/bin/m-quantized.rs @@ -7,6 +7,7 @@ pub use candle_transformers::models::quantized_t5::{ use candle_wasm_example_t5::console_log; use tokenizers::Tokenizer; use wasm_bindgen::prelude::*; +const DEVICE: Device = Device::Cpu; #[wasm_bindgen] pub struct ModelEncoder { @@ -31,7 +32,7 @@ impl ModelConditionalGeneration { ) -> Result { console_error_panic_hook::set_once(); console_log!("loading model"); - let vb = VarBuilder::from_gguf_buffer(&weights)?; + let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?; let mut config: Config = serde_json::from_slice(&config)?; let tokenizer = Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; @@ -46,7 +47,7 @@ impl ModelConditionalGeneration { pub fn decode(&mut self, input: JsValue) -> Result { let input: ConditionalGenerationParams = serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; - let device = &Device::Cpu; + let device = &DEVICE; self.model.clear_kv_cache(); let mut output_token_ids = [self.config.pad_token_id as u32].to_vec(); let prompt = input.prompt; @@ -128,7 +129,7 @@ impl ModelEncoder { ) -> Result { console_error_panic_hook::set_once(); console_log!("loading model"); - let vb = VarBuilder::from_gguf_buffer(&weights)?; + let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?; let mut config: Config = serde_json::from_slice(&config)?; config.use_cache = false; let tokenizer = @@ -138,7 +139,7 @@ impl ModelEncoder { } pub fn decode(&mut self, input: JsValue) -> Result { - let device = &Device::Cpu; + let device = &DEVICE; let input: DecoderParams = serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs index fd91fa8c..898996a7 100644 --- a/candle-wasm-examples/whisper/src/worker.rs +++ b/candle-wasm-examples/whisper/src/worker.rs @@ -315,6 +315,7 @@ impl Decoder { let model = if md.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer( &md.weights, + &device, )?; Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?) } else { diff --git a/candle-wasm-tests/tests/quantized_tests.rs b/candle-wasm-tests/tests/quantized_tests.rs index e5fa7dec..fc107e61 100644 --- a/candle-wasm-tests/tests/quantized_tests.rs +++ b/candle-wasm-tests/tests/quantized_tests.rs @@ -40,7 +40,7 @@ fn quantized_matmul_neg() -> Result<()> { ] ); - let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?; + let qtensor = quantized::QTensor::new(quantized::QStorage::Cpu(Box::new(rhs_t)), (4, 64))?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let res = matmul.forward(&tensor_lhs)?; assert_eq!( From 80b1c689f923473784945e3bcf1ea2286dde7e0d Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Wed, 17 Jan 2024 18:09:28 +0100 Subject: [PATCH 45/46] Revert public EncoderParam --- candle-metal-kernels/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 58569e6b..fe969372 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -47,7 +47,7 @@ fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, /// Helper functions to create the various objects on the compute command encoder /// on a single line. /// Prevents getting wrong some arguments number and mixing length and size in bytes. -pub trait EncoderParam { +trait EncoderParam { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); } macro_rules! primitive { From 17e6e2d7ee111abac7a2cf806cd888ff780da452 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 17 Jan 2024 15:47:08 -0300 Subject: [PATCH 46/46] Fixes metal kernel u8 type --- candle-metal-kernels/src/binary.metal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index eb560f16..ae11286a 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -73,7 +73,7 @@ BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \ BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); #define INT64_BINARY_OP_OUT(NAME, FN) \ -BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided); +BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided); BINARY_OP(x + y, add) BINARY_OP(x - y, sub)