mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Not implementing quantized.
This commit is contained in:
@ -15,7 +15,6 @@ const CAST: &str = include_str!("cast.metal");
|
|||||||
const REDUCE: &str = include_str!("reduce.metal");
|
const REDUCE: &str = include_str!("reduce.metal");
|
||||||
const CONV: &str = include_str!("conv.metal");
|
const CONV: &str = include_str!("conv.metal");
|
||||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
|
||||||
|
|
||||||
/// Most kernels apply similarly across the tensors
|
/// Most kernels apply similarly across the tensors
|
||||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||||
@ -241,7 +240,6 @@ impl Kernels {
|
|||||||
Source::Cast => CAST,
|
Source::Cast => CAST,
|
||||||
Source::Reduce => REDUCE,
|
Source::Reduce => REDUCE,
|
||||||
Source::Conv => CONV,
|
Source::Conv => CONV,
|
||||||
Source::Quantized => QUANTIZED,
|
|
||||||
Source::Mfa => panic!("Invalid lib"),
|
Source::Mfa => panic!("Invalid lib"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1556,145 +1554,7 @@ pub fn call_quantized_matmul_t(
|
|||||||
rhs: &Buffer,
|
rhs: &Buffer,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
// Everything is in reverse
|
todo!("Not implemented yet");
|
||||||
let ne00 = k as i64;
|
|
||||||
let ne01 = n as i64;
|
|
||||||
let ne02 = b as i64;
|
|
||||||
let ne03 = 1 as i64;
|
|
||||||
|
|
||||||
let nb00 = 0i64;
|
|
||||||
let nb01 = 0 as i64;
|
|
||||||
let nb02 = 0 as i64;
|
|
||||||
|
|
||||||
let ne10 = k as i64;
|
|
||||||
let ne11 = m as i64;
|
|
||||||
let ne12 = b as i64;
|
|
||||||
let ne13 = 1 as i64;
|
|
||||||
|
|
||||||
let nb10 = 0i64;
|
|
||||||
let nb11 = 0i64;
|
|
||||||
let nb12 = 0i64;
|
|
||||||
|
|
||||||
let ne0 = n as i64;
|
|
||||||
let ne1 = m as i64;
|
|
||||||
let r2: u32 = (ne12 / ne02) as u32;
|
|
||||||
let r3: u32 = (ne13 / ne03) as u32;
|
|
||||||
|
|
||||||
let (nth0, nth1, align) = match dtype {
|
|
||||||
GgmlDType::Q4_0
|
|
||||||
| GgmlDType::Q4_1
|
|
||||||
| GgmlDType::Q5_0
|
|
||||||
| GgmlDType::Q5_1
|
|
||||||
| GgmlDType::Q8_0
|
|
||||||
| GgmlDType::Q8_1 => {
|
|
||||||
let nth0 = 8;
|
|
||||||
let nth1 = 8;
|
|
||||||
let align = 8;
|
|
||||||
(nth0, nth1, align)
|
|
||||||
}
|
|
||||||
GgmlDType::Q2K => {
|
|
||||||
// Fixing a bug in Metal for GGML
|
|
||||||
let nth0 = 4;
|
|
||||||
let nth1 = 8;
|
|
||||||
let align = 4;
|
|
||||||
(nth0, nth1, align)
|
|
||||||
}
|
|
||||||
GgmlDType::Q4K => {
|
|
||||||
let nth0 = 4;
|
|
||||||
let nth1 = 8;
|
|
||||||
let align = 4;
|
|
||||||
(nth0, nth1, align)
|
|
||||||
}
|
|
||||||
GgmlDType::Q3K | GgmlDType::Q5K => {
|
|
||||||
let nth0 = 2;
|
|
||||||
let nth1 = 32;
|
|
||||||
let align = 4;
|
|
||||||
(nth0, nth1, align)
|
|
||||||
}
|
|
||||||
GgmlDType::Q6K => {
|
|
||||||
let nth0 = 2;
|
|
||||||
let nth1 = 32;
|
|
||||||
let align = 2;
|
|
||||||
(nth0, nth1, align)
|
|
||||||
}
|
|
||||||
GgmlDType::F16 | GgmlDType::Q8K => {
|
|
||||||
// Original implem uses rows
|
|
||||||
let nth0 = 32;
|
|
||||||
let nth1 = 1;
|
|
||||||
let align = 8;
|
|
||||||
(nth0, nth1, align)
|
|
||||||
}
|
|
||||||
GgmlDType::F32 => {
|
|
||||||
let nth0 = 32;
|
|
||||||
let nth1 = 1;
|
|
||||||
let align = 8;
|
|
||||||
(nth0, nth1, align)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let thread_groups_count = MTLSize {
|
|
||||||
width: divide(ne01 as usize, align),
|
|
||||||
height: ne11 as u64,
|
|
||||||
depth: (ne12 * ne13) as u64,
|
|
||||||
};
|
|
||||||
let threads_per_threadgroup = MTLSize {
|
|
||||||
width: nth0,
|
|
||||||
height: nth1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
let name = match dtype {
|
|
||||||
GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32",
|
|
||||||
GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32",
|
|
||||||
GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32",
|
|
||||||
GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32",
|
|
||||||
GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32",
|
|
||||||
GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32",
|
|
||||||
GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32",
|
|
||||||
GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32",
|
|
||||||
GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32",
|
|
||||||
GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32",
|
|
||||||
GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32",
|
|
||||||
GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32",
|
|
||||||
GgmlDType::F16 => "kernel_mul_mv_f16_f32",
|
|
||||||
GgmlDType::F32 => "kernel_mul_mv_f32_f32",
|
|
||||||
};
|
|
||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
|
||||||
|
|
||||||
set_params!(
|
|
||||||
encoder,
|
|
||||||
(
|
|
||||||
rhs,
|
|
||||||
(lhs, lhs_offset),
|
|
||||||
output,
|
|
||||||
ne00,
|
|
||||||
ne01,
|
|
||||||
ne02,
|
|
||||||
nb00,
|
|
||||||
nb01,
|
|
||||||
nb02,
|
|
||||||
ne10,
|
|
||||||
ne11,
|
|
||||||
ne12,
|
|
||||||
nb10,
|
|
||||||
nb11,
|
|
||||||
nb12,
|
|
||||||
ne0,
|
|
||||||
ne1,
|
|
||||||
r2,
|
|
||||||
r3
|
|
||||||
)
|
|
||||||
);
|
|
||||||
encoder.set_threadgroup_memory_length(0, 8192);
|
|
||||||
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
|
|
||||||
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
|
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
|
|
||||||
encoder.end_encoding();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn divide(m: usize, b: usize) -> NSUInteger {
|
fn divide(m: usize, b: usize) -> NSUInteger {
|
||||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user