mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Not implementing quantized.
This commit is contained in:
@ -15,7 +15,6 @@ const CAST: &str = include_str!("cast.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||
@ -241,7 +240,6 @@ impl Kernels {
|
||||
Source::Cast => CAST,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Conv => CONV,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
@ -1556,145 +1554,7 @@ pub fn call_quantized_matmul_t(
|
||||
rhs: &Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// Everything is in reverse
|
||||
let ne00 = k as i64;
|
||||
let ne01 = n as i64;
|
||||
let ne02 = b as i64;
|
||||
let ne03 = 1 as i64;
|
||||
|
||||
let nb00 = 0i64;
|
||||
let nb01 = 0 as i64;
|
||||
let nb02 = 0 as i64;
|
||||
|
||||
let ne10 = k as i64;
|
||||
let ne11 = m as i64;
|
||||
let ne12 = b as i64;
|
||||
let ne13 = 1 as i64;
|
||||
|
||||
let nb10 = 0i64;
|
||||
let nb11 = 0i64;
|
||||
let nb12 = 0i64;
|
||||
|
||||
let ne0 = n as i64;
|
||||
let ne1 = m as i64;
|
||||
let r2: u32 = (ne12 / ne02) as u32;
|
||||
let r3: u32 = (ne13 / ne03) as u32;
|
||||
|
||||
let (nth0, nth1, align) = match dtype {
|
||||
GgmlDType::Q4_0
|
||||
| GgmlDType::Q4_1
|
||||
| GgmlDType::Q5_0
|
||||
| GgmlDType::Q5_1
|
||||
| GgmlDType::Q8_0
|
||||
| GgmlDType::Q8_1 => {
|
||||
let nth0 = 8;
|
||||
let nth1 = 8;
|
||||
let align = 8;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
// Fixing a bug in Metal for GGML
|
||||
let nth0 = 4;
|
||||
let nth1 = 8;
|
||||
let align = 4;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let nth0 = 4;
|
||||
let nth1 = 8;
|
||||
let align = 4;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::Q3K | GgmlDType::Q5K => {
|
||||
let nth0 = 2;
|
||||
let nth1 = 32;
|
||||
let align = 4;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let nth0 = 2;
|
||||
let nth1 = 32;
|
||||
let align = 2;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::F16 | GgmlDType::Q8K => {
|
||||
// Original implem uses rows
|
||||
let nth0 = 32;
|
||||
let nth1 = 1;
|
||||
let align = 8;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::F32 => {
|
||||
let nth0 = 32;
|
||||
let nth1 = 1;
|
||||
let align = 8;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
};
|
||||
let thread_groups_count = MTLSize {
|
||||
width: divide(ne01 as usize, align),
|
||||
height: ne11 as u64,
|
||||
depth: (ne12 * ne13) as u64,
|
||||
};
|
||||
let threads_per_threadgroup = MTLSize {
|
||||
width: nth0,
|
||||
height: nth1,
|
||||
depth: 1,
|
||||
};
|
||||
let name = match dtype {
|
||||
GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32",
|
||||
GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32",
|
||||
GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32",
|
||||
GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32",
|
||||
GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32",
|
||||
GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32",
|
||||
GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32",
|
||||
GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32",
|
||||
GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32",
|
||||
GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32",
|
||||
GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32",
|
||||
GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32",
|
||||
GgmlDType::F16 => "kernel_mul_mv_f16_f32",
|
||||
GgmlDType::F32 => "kernel_mul_mv_f32_f32",
|
||||
};
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
rhs,
|
||||
(lhs, lhs_offset),
|
||||
output,
|
||||
ne00,
|
||||
ne01,
|
||||
ne02,
|
||||
nb00,
|
||||
nb01,
|
||||
nb02,
|
||||
ne10,
|
||||
ne11,
|
||||
ne12,
|
||||
nb10,
|
||||
nb11,
|
||||
nb12,
|
||||
ne0,
|
||||
ne1,
|
||||
r2,
|
||||
r3
|
||||
)
|
||||
);
|
||||
encoder.set_threadgroup_memory_length(0, 8192);
|
||||
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
todo!("Not implemented yet");
|
||||
}
|
||||
|
||||
fn divide(m: usize, b: usize) -> NSUInteger {
|
||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user