mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
5 Commits
0.8.0
...
metal-gemm
Author | SHA1 | Date | |
---|---|---|---|
9105aa4390 | |||
2a2a349fd4 | |||
c87dd386a9 | |||
f4b1597b5d | |||
ea578478d4 |
@ -2,11 +2,11 @@ mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
criterion_main!(
|
||||
benchmarks::affine::benches,
|
||||
//benchmarks::affine::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::where_cond::benches,
|
||||
benchmarks::conv_transpose2d::benches,
|
||||
benchmarks::qmatmul::benches,
|
||||
benchmarks::unary::benches
|
||||
//benchmarks::random::benches,
|
||||
//benchmarks::where_cond::benches,
|
||||
//benchmarks::conv_transpose2d::benches,
|
||||
//benchmarks::qmatmul::benches,
|
||||
//benchmarks::unary::benches
|
||||
);
|
||||
|
1
candle-metal-kernels/.gitignore
vendored
Normal file
1
candle-metal-kernels/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
src/air
|
141
candle-metal-kernels/build.rs
Normal file
141
candle-metal-kernels/build.rs
Normal file
@ -0,0 +1,141 @@
|
||||
#![allow(clippy::upper_case_acronyms)]
|
||||
|
||||
use std::process::Command;
|
||||
use std::{env, str};
|
||||
|
||||
const COMPILED_KERNELS: [&str; 3] = ["event", "matrix_storage", "gemm"];
|
||||
|
||||
enum Platform {
|
||||
MacOS,
|
||||
IOS,
|
||||
}
|
||||
|
||||
impl Platform {
|
||||
fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Platform::MacOS => "macosx",
|
||||
Platform::IOS => "iphoneos",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_xcode_sdk_path(platform: Platform) -> Result<String, String> {
|
||||
let xcrun_output = Command::new("xcrun")
|
||||
.args(["--sdk", platform.as_str(), "--show-sdk-path"])
|
||||
.output()
|
||||
.expect("xcrun command failed to start");
|
||||
|
||||
Ok(str::from_utf8(&xcrun_output.stdout)
|
||||
.expect("Invalid UTF-8 from xcrun")
|
||||
.replace('\n', ""))
|
||||
}
|
||||
|
||||
fn compile_candle_metallib(sdk_path: String, bfloat_support: bool) -> Result<(), String> {
|
||||
let current_dir = env::current_dir().expect("Failed to get current directory");
|
||||
let out_dir = current_dir.join("src/libraries");
|
||||
let air_dir = current_dir.join("src/air");
|
||||
let working_directory = air_dir.display();
|
||||
let sources = current_dir.join("src/kernels");
|
||||
|
||||
// Compile metal to air
|
||||
let mut compile_air_cmd = Command::new("xcrun");
|
||||
compile_air_cmd
|
||||
.arg("metal")
|
||||
.arg(format!("-working-directory={working_directory}"))
|
||||
.arg("-Wall")
|
||||
.arg("-Wextra")
|
||||
.arg("-O3")
|
||||
.arg("-c")
|
||||
.arg("-w");
|
||||
for metal_file in COMPILED_KERNELS {
|
||||
compile_air_cmd.arg(sources.join(format!("{metal_file}.metal")));
|
||||
}
|
||||
compile_air_cmd.arg(sources.join("utils.metal"));
|
||||
compile_air_cmd.spawn().expect("Failed to compile air");
|
||||
|
||||
let mut child = compile_air_cmd.spawn().expect("Failed to compile air");
|
||||
|
||||
match child.try_wait() {
|
||||
Ok(Some(status)) => {
|
||||
if !status.success() {
|
||||
panic!(
|
||||
"Compiling metal -> air failed. Exit with status: {}",
|
||||
status
|
||||
)
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
let status = child
|
||||
.wait()
|
||||
.expect("Compiling metal -> air failed while waiting for result");
|
||||
if !status.success() {
|
||||
panic!(
|
||||
"Compiling metal -> air failed. Exit with status: {}",
|
||||
status
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(e) => panic!("Compiling metal -> air failed: {:?}", e),
|
||||
}
|
||||
|
||||
// Compile air to metallib
|
||||
let metallib = out_dir.join("candle.metallib");
|
||||
|
||||
let mut compile_metallib_cmd = Command::new("xcrun");
|
||||
compile_metallib_cmd.arg("metal").arg("-o").arg(&metallib);
|
||||
|
||||
for metal_file in COMPILED_KERNELS {
|
||||
compile_metallib_cmd.arg(air_dir.join(format!("{metal_file}.air")));
|
||||
}
|
||||
compile_metallib_cmd.arg(air_dir.join("utils.air"));
|
||||
|
||||
let mut child = compile_metallib_cmd
|
||||
.spawn()
|
||||
.expect("Failed to compile air -> metallib");
|
||||
|
||||
match child.try_wait() {
|
||||
Ok(Some(status)) => {
|
||||
if !status.success() {
|
||||
panic!(
|
||||
"Compiling air -> metallib failed. Exit with status: {}",
|
||||
status
|
||||
)
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
let status = child
|
||||
.wait()
|
||||
.expect("Compiling air -> metallib failed while waiting for result");
|
||||
if !status.success() {
|
||||
panic!(
|
||||
"Compiling air -> metallib failed. Exit with status: {}",
|
||||
status
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(e) => panic!("Compiling air -> metallib failed: {:?}", e),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<(), String> {
|
||||
println!("cargo::rerun-if-changed=build.rs");
|
||||
|
||||
let current_dir = env::current_dir().expect("Failed to get current directory");
|
||||
let sources = current_dir.join("src/kernels");
|
||||
|
||||
for metal_file in COMPILED_KERNELS {
|
||||
println!(
|
||||
"cargo::rerun-if-changed={}",
|
||||
sources.join(format!("{metal_file}.metal")).display()
|
||||
);
|
||||
}
|
||||
|
||||
let macos_sdk = get_xcode_sdk_path(Platform::MacOS).expect("Failed to get MacOS SDK path");
|
||||
let iphoneos_sdk = get_xcode_sdk_path(Platform::IOS).expect("Failed to get IOS SDK path");
|
||||
|
||||
compile_candle_metallib(macos_sdk, false)?;
|
||||
|
||||
Ok(())
|
||||
}
|
65
candle-metal-kernels/src/ffi.rs
Normal file
65
candle-metal-kernels/src/ffi.rs
Normal file
@ -0,0 +1,65 @@
|
||||
#![allow(non_upper_case_globals)]
|
||||
#![allow(non_camel_case_types)]
|
||||
|
||||
use core::ffi::{c_char, c_int, c_uint, c_void};
|
||||
|
||||
pub type CFTypeRef = *const c_void;
|
||||
pub type CFAllocatorRef = *const c_void;
|
||||
pub type CFMutableDictionaryRef = *mut c_void;
|
||||
pub type CFStringRef = *const c_void;
|
||||
pub type CFNumberRef = *const c_void;
|
||||
|
||||
pub type mach_port_t = c_uint;
|
||||
pub type kern_return_t = c_int;
|
||||
pub type io_object_t = mach_port_t;
|
||||
pub type io_iterator_t = io_object_t;
|
||||
pub type io_registry_entry_t = io_object_t;
|
||||
|
||||
pub type IOOptionBits = u32;
|
||||
pub type CFNumberType = u32;
|
||||
|
||||
pub const kIOMainPortDefault: mach_port_t = 0;
|
||||
pub const kIOServicePlane: &str = "IOService\0";
|
||||
pub const kCFNumberSInt64Type: CFNumberType = 4;
|
||||
|
||||
pub const MACH_PORT_NULL: i32 = 0;
|
||||
|
||||
#[link(name = "IOKit", kind = "framework")]
|
||||
extern "C" {
|
||||
pub fn IOServiceGetMatchingServices(
|
||||
mainPort: mach_port_t,
|
||||
matching: CFMutableDictionaryRef,
|
||||
existing: *mut io_iterator_t,
|
||||
) -> kern_return_t;
|
||||
|
||||
pub fn IOServiceMatching(a: *const c_char) -> CFMutableDictionaryRef;
|
||||
|
||||
pub fn IOIteratorNext(iterator: io_iterator_t) -> io_object_t;
|
||||
|
||||
pub fn IOObjectRelease(obj: io_object_t) -> kern_return_t;
|
||||
|
||||
pub fn IORegistryEntrySearchCFProperty(
|
||||
entry: io_registry_entry_t,
|
||||
plane: *const c_char,
|
||||
key: CFStringRef,
|
||||
allocator: CFAllocatorRef,
|
||||
options: IOOptionBits,
|
||||
) -> CFTypeRef;
|
||||
}
|
||||
|
||||
#[link(name = "CoreFoundation", kind = "framework")]
|
||||
extern "C" {
|
||||
pub fn CFNumberGetValue(
|
||||
number: CFNumberRef,
|
||||
theType: CFNumberType,
|
||||
valuePtr: *mut c_void,
|
||||
) -> bool;
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
fn __CFStringMakeConstantString(c_str: *const c_char) -> CFStringRef;
|
||||
}
|
||||
|
||||
pub fn cfstr(val: &str) -> CFStringRef {
|
||||
unsafe { __CFStringMakeConstantString(val.as_ptr().cast()) }
|
||||
}
|
122
candle-metal-kernels/src/gpu.rs
Normal file
122
candle-metal-kernels/src/gpu.rs
Normal file
@ -0,0 +1,122 @@
|
||||
use core::ffi::c_void;
|
||||
use metal::Device;
|
||||
|
||||
use crate::ffi::*;
|
||||
|
||||
const GPU_CORE_COUNT_KEY: &str = "gpu-core-count\0";
|
||||
const AGXACCELERATOR_KEY: &str = "AGXAccelerator\0";
|
||||
|
||||
struct IOIterator(io_iterator_t);
|
||||
|
||||
impl IOIterator {
|
||||
fn new(it: io_iterator_t) -> Self {
|
||||
IOIterator(it)
|
||||
}
|
||||
|
||||
fn next(&self) -> Option<io_object_t> {
|
||||
let result = unsafe { IOIteratorNext(self.0) };
|
||||
if result == MACH_PORT_NULL as u32 {
|
||||
return None;
|
||||
}
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for IOIterator {
|
||||
fn drop(&mut self) {
|
||||
unsafe { IOObjectRelease(self.0 as _) };
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn get_io_service_matching(val: &str) -> Result<CFMutableDictionaryRef, String> {
|
||||
let matching = IOServiceMatching(val.as_ptr().cast());
|
||||
if matching.is_null() {
|
||||
return Err(format!("IOServiceMatching call failed, `{val}` not found"));
|
||||
}
|
||||
Ok(matching)
|
||||
}
|
||||
|
||||
unsafe fn get_matching_services(
|
||||
main_port: mach_port_t,
|
||||
matching: CFMutableDictionaryRef,
|
||||
) -> Result<IOIterator, String> {
|
||||
let mut iterator: io_iterator_t = 0;
|
||||
let result = IOServiceGetMatchingServices(main_port, matching, &mut iterator);
|
||||
if result != 0 {
|
||||
return Err("Error getting matching services".to_string());
|
||||
}
|
||||
Ok(IOIterator::new(iterator))
|
||||
}
|
||||
|
||||
unsafe fn get_gpu_io_service() -> Result<io_object_t, String> {
|
||||
let matching = get_io_service_matching(AGXACCELERATOR_KEY)?;
|
||||
let iterator = get_matching_services(kIOMainPortDefault, matching)?;
|
||||
iterator
|
||||
.next()
|
||||
.ok_or("Error getting GPU IO Service".to_string())
|
||||
}
|
||||
|
||||
unsafe fn get_property_by_key(
|
||||
entry: io_registry_entry_t,
|
||||
plane: &str,
|
||||
key: &str,
|
||||
allocator: CFAllocatorRef,
|
||||
options: IOOptionBits,
|
||||
) -> Result<CFTypeRef, String> {
|
||||
let result = IORegistryEntrySearchCFProperty(
|
||||
entry,
|
||||
plane.as_ptr().cast(),
|
||||
cfstr(key),
|
||||
allocator,
|
||||
options,
|
||||
);
|
||||
if result.is_null() {
|
||||
return Err(format!("Error getting {key} property"));
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
unsafe fn get_int_value(number: CFNumberRef) -> Result<i64, String> {
|
||||
let mut value: i64 = 0;
|
||||
let result = CFNumberGetValue(
|
||||
number,
|
||||
kCFNumberSInt64Type,
|
||||
&mut value as *mut i64 as *mut c_void,
|
||||
);
|
||||
if !result {
|
||||
return Err("Error getting int value".to_string());
|
||||
}
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
unsafe fn find_core_count() -> Result<usize, String> {
|
||||
let gpu_io_service = get_gpu_io_service()?;
|
||||
let gpu_core_count = get_property_by_key(
|
||||
gpu_io_service,
|
||||
kIOServicePlane,
|
||||
GPU_CORE_COUNT_KEY,
|
||||
core::ptr::null(),
|
||||
0,
|
||||
)?;
|
||||
let value = get_int_value(gpu_core_count as CFNumberRef)?;
|
||||
Ok(value as usize)
|
||||
}
|
||||
|
||||
pub(crate) fn get_device_core_count(device: &Device) -> usize {
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
unsafe { find_core_count().expect("Retrieving gpu core count failed") }
|
||||
}
|
||||
#[cfg(target_os = "ios")]
|
||||
{
|
||||
if device.name().starts_with("A") {
|
||||
if device.supports_family(MTLGPUFamily::Apple9) {
|
||||
6
|
||||
} else {
|
||||
5
|
||||
}
|
||||
} else {
|
||||
10
|
||||
}
|
||||
}
|
||||
}
|
226
candle-metal-kernels/src/kernels/event.metal
Normal file
226
candle-metal-kernels/src/kernels/event.metal
Normal file
@ -0,0 +1,226 @@
|
||||
// -*- Metal -*-
|
||||
//===-- metal_simdgroup_event ---------------------------------------------===//
|
||||
// Copyright (c) 2024 Philip Turner. See MIT LICENSE
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef __METAL_SIMDGROUP_EVENT
|
||||
#define __METAL_SIMDGROUP_EVENT
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// %struct._simdgroup_event_t = type opaque
|
||||
//
|
||||
struct _simdgroup_event_t;
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// Bitcode: TBD
|
||||
//
|
||||
thread _simdgroup_event_t*
|
||||
__metal_simdgroup_async_copy_1d(
|
||||
ulong, ulong, threadgroup void *, const device void *, ulong)
|
||||
__asm("air.simdgroup_async_copy_1d.p3i8.p1i8");
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// Bitcode: TBD
|
||||
//
|
||||
thread _simdgroup_event_t*
|
||||
__metal_simdgroup_async_copy_1d(
|
||||
ulong, ulong, device void *, const threadgroup void *, ulong)
|
||||
__asm("air.simdgroup_async_copy_1d.p1i8.p3i8");
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// ; Function Attrs: argmemonly convergent nounwind
|
||||
// declare %struct._simdgroup_event_t*
|
||||
// @air.simdgroup_async_copy_2d.p3i8.p1i8(
|
||||
// i64, i64,
|
||||
// i8 addrspace(3)* nocapture writeonly, i64, i64, <2 x i64>,
|
||||
// i8 addrspace(1)* nocapture readonly, i64, i64, <2 x i64>,
|
||||
// <2 x i64>, i32)
|
||||
// local_unnamed_addr #4
|
||||
//
|
||||
thread _simdgroup_event_t*
|
||||
__metal_simdgroup_async_copy_2d(
|
||||
ulong, ulong,
|
||||
threadgroup void *, ulong, ulong, ulong2,
|
||||
const device void *, ulong, ulong, ulong2,
|
||||
long2, int)
|
||||
__asm("air.simdgroup_async_copy_2d.p3i8.p1i8");
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// ; Function Attrs: argmemonly convergent nounwind
|
||||
// declare %struct._simdgroup_event_t*
|
||||
// @air.simdgroup_async_copy_2d.p1i8.p3i8(
|
||||
// i64, i64,
|
||||
// i8 addrspace(1)* nocapture writeonly, i64, i64, <2 x i64>,
|
||||
// i8 addrspace(3)* nocapture readonly, i64, i64, <2 x i64>,
|
||||
// <2 x i64>, i32)
|
||||
// local_unnamed_addr #4
|
||||
//
|
||||
thread _simdgroup_event_t*
|
||||
__metal_simdgroup_async_copy_2d(
|
||||
ulong, ulong,
|
||||
device void *, ulong, ulong, ulong2,
|
||||
const threadgroup void *, ulong, ulong, ulong2,
|
||||
long2, int)
|
||||
__asm("air.simdgroup_async_copy_2d.p1i8.p3i8");
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// ; Function Attrs: convergent nounwind
|
||||
// declare void
|
||||
// @air.wait_simdgroup_events(i32, %struct._simdgroup_event_t** nocapture)
|
||||
// local_unnamed_addr #3
|
||||
//
|
||||
void __metal_wait_simdgroup_events(
|
||||
int, thread _simdgroup_event_t**)
|
||||
__asm("air.wait_simdgroup_events");
|
||||
|
||||
#pragma METAL internals : enable
|
||||
namespace metal
|
||||
{
|
||||
enum class simdgroup_async_copy_clamp_mode {
|
||||
clamp_to_zero = 0,
|
||||
clamp_to_edge = 1
|
||||
};
|
||||
|
||||
struct simdgroup_event {
|
||||
METAL_FUNC simdgroup_event() thread {}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
threadgroup T *dst,
|
||||
const device T *src,
|
||||
ulong n_elements
|
||||
) thread {
|
||||
event = __metal_simdgroup_async_copy_1d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the arguments.
|
||||
reinterpret_cast<threadgroup void *>(dst),
|
||||
reinterpret_cast<const device void *>(src),
|
||||
n_elements);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
device T *dst,
|
||||
const threadgroup T *src,
|
||||
ulong n_elements
|
||||
) thread {
|
||||
event = __metal_simdgroup_async_copy_1d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the arguments.
|
||||
reinterpret_cast<device void *>(dst),
|
||||
reinterpret_cast<const threadgroup void *>(src),
|
||||
n_elements);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
// Description of the destination.
|
||||
threadgroup T *dst,
|
||||
ushort dst_elements_per_row,
|
||||
ushort2 dst_tile_dimensions,
|
||||
|
||||
// Description of the source.
|
||||
const device T *src,
|
||||
uint src_elements_per_row,
|
||||
ushort2 src_tile_dimensions,
|
||||
|
||||
// Other arguments.
|
||||
bool transpose_matrix = false,
|
||||
simdgroup_async_copy_clamp_mode clamp_mode =
|
||||
simdgroup_async_copy_clamp_mode::clamp_to_zero
|
||||
) thread {
|
||||
if (transpose_matrix) {
|
||||
src_tile_dimensions = src_tile_dimensions.yx;
|
||||
dst_tile_dimensions = dst_tile_dimensions.yx;
|
||||
}
|
||||
event = __metal_simdgroup_async_copy_2d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the destination.
|
||||
reinterpret_cast<threadgroup void *>(dst),
|
||||
ushort(dst_elements_per_row),
|
||||
1,
|
||||
ulong2(dst_tile_dimensions),
|
||||
|
||||
// Description of the source.
|
||||
reinterpret_cast<const device void *>(src),
|
||||
uint(src_elements_per_row),
|
||||
1,
|
||||
ulong2(src_tile_dimensions),
|
||||
|
||||
// Other arguments.
|
||||
long2(0),
|
||||
static_cast<int>(clamp_mode));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
// Description of the destination.
|
||||
device T *dst,
|
||||
uint dst_elements_per_row,
|
||||
ushort2 dst_tile_dimensions,
|
||||
|
||||
// Description of the source.
|
||||
const threadgroup T *src,
|
||||
ushort src_elements_per_row,
|
||||
ushort2 src_tile_dimensions,
|
||||
|
||||
// Other arguments.
|
||||
bool transpose_matrix = false
|
||||
) thread {
|
||||
if (transpose_matrix) {
|
||||
src_tile_dimensions = src_tile_dimensions.yx;
|
||||
dst_tile_dimensions = dst_tile_dimensions.yx;
|
||||
}
|
||||
event = __metal_simdgroup_async_copy_2d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the destination.
|
||||
reinterpret_cast<device void *>(dst),
|
||||
uint(dst_elements_per_row),
|
||||
1,
|
||||
ulong2(dst_tile_dimensions),
|
||||
|
||||
// Description of the source.
|
||||
reinterpret_cast<const threadgroup void *>(src),
|
||||
ushort(src_elements_per_row),
|
||||
1,
|
||||
ulong2(src_tile_dimensions),
|
||||
|
||||
// Other arguments.
|
||||
long2(0),
|
||||
0);
|
||||
}
|
||||
|
||||
METAL_FUNC static void wait(int count, thread simdgroup_event *events) {
|
||||
__metal_wait_simdgroup_events(
|
||||
count, reinterpret_cast<thread _simdgroup_event_t**>(events));
|
||||
}
|
||||
|
||||
private:
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// %"struct.metal::simdgroup_event" = type { %struct._simdgroup_event_t* }
|
||||
//
|
||||
thread _simdgroup_event_t* event;
|
||||
};
|
||||
} // namespace metal
|
||||
#pragma METAL internals : disable
|
||||
|
||||
#endif
|
538
candle-metal-kernels/src/kernels/gemm.metal
Normal file
538
candle-metal-kernels/src/kernels/gemm.metal
Normal file
@ -0,0 +1,538 @@
|
||||
// Heavily inspired by the GEMM kernels by Philip Turner:
|
||||
// https://github.com/philipturner/metal-flash-attention
|
||||
// This implementation uses generics and specialization to generate kernels for different data types instead of code generation.
|
||||
#include <metal_stdlib>
|
||||
#include "event.metal"
|
||||
#include "matrix_storage.metal"
|
||||
using namespace metal;
|
||||
|
||||
// Dimensions of each matrix.
|
||||
// - Limitations to matrix size:
|
||||
// - 2^32 in each dimension (M/N/K).
|
||||
// - TODO: Test whether the maximum dimension with correct execution is
|
||||
// actually 2^16. This will require a testing setup with non-square
|
||||
// matrices, as 65536^3 is uncomputable.
|
||||
// - Extending to 2^64 may require changing 'uint' to 'ulong'. There is a
|
||||
// good chance this will significantly degrade performance, and require
|
||||
// changing the data type of several variables that process addresses. The
|
||||
// client is responsible for ensuring correctness and performance with
|
||||
// matrices spanning several billion elements in one direction.
|
||||
// - The matrix dimensions must be known at compile time, via function
|
||||
// constants. Dynamic matrix shapes are beyond the scope of this reference
|
||||
// implementation. Dynamic shapes cause a non-negligible regression to
|
||||
// shader execution speed. However, they could minimize a compilation
|
||||
// latency bottleneck in some use cases.
|
||||
// - Limitations to batch size:
|
||||
// - Dictated by how the client modifies the code to implement batching.
|
||||
// - Dynamic batch shapes would likely not harm performance much. For example,
|
||||
// someone could enter an array of pointers/memory offsets to different
|
||||
// matrices in the batch. Each slice of a 3D thread grid could read a
|
||||
// different pointer from memory, and use that pointer as the A/B/C matrix.
|
||||
// Another approach is to restrict the input format, so all matrices are
|
||||
// stored contiguously in memory. Then, the memory offset could be computed
|
||||
// analytically from matrix size and the Z dimension in a 3D thread grid.
|
||||
//
|
||||
// Another note:
|
||||
// - The rows of the matrix must be contiguous in memory. Supporting strides
|
||||
// that differ from the actual matrix dimensions should not be difficult, but
|
||||
// it is out of scope for this reference kernel.
|
||||
constant uint M [[function_constant(0)]];
|
||||
constant uint N [[function_constant(1)]];
|
||||
constant uint K [[function_constant(2)]];
|
||||
|
||||
// Whether each matrix is transposed.
|
||||
constant bool A_trans [[function_constant(10)]];
|
||||
constant bool B_trans [[function_constant(11)]];
|
||||
|
||||
constant bool prefer_async_copy [[function_constant(206)]];
|
||||
constant bool ideal_grouping [[function_constant(207)]];
|
||||
|
||||
constant bool batched [[function_constant(100)]];
|
||||
|
||||
constant ushort A_leading_dim = A_trans ? M : K;
|
||||
constant ushort B_leading_dim = B_trans ? K : N;
|
||||
|
||||
// The layout of threads within a SIMD matrix.
|
||||
//
|
||||
// 0 0 1 1 8 8 9 9
|
||||
// 2 2 3 3 10 10 11 11
|
||||
// 4 4 5 5 12 12 13 13
|
||||
// 6 6 7 7 14 14 15 15
|
||||
// 16 16 17 17 24 24 25 25
|
||||
// 18 18 19 19 26 26 27 27
|
||||
// 20 20 21 21 28 28 29 29
|
||||
// 22 22 23 23 30 30 31 31
|
||||
//
|
||||
// This is Morton order, a method for coalescing data accesses. It is used
|
||||
// in a variety of contexts, from ray tracing acceleration structures, to
|
||||
// nodal-point Laplacians, to sorting large lattices of atoms.
|
||||
//
|
||||
// Source: https://patents.google.com/patent/US11256518B2
|
||||
METAL_FUNC ushort2 morton_order(ushort thread_index_in_simdgroup) {
|
||||
ushort lane_id = thread_index_in_simdgroup;
|
||||
ushort quad_id = lane_id / 4;
|
||||
|
||||
constexpr ushort QUADRANT_SPAN_M = 4;
|
||||
constexpr ushort THREADS_PER_QUADRANT = 8;
|
||||
ushort M_floor_of_quadrant = (quad_id / 4) * QUADRANT_SPAN_M;
|
||||
ushort M_in_quadrant = (lane_id / 2) % (THREADS_PER_QUADRANT / 2);
|
||||
ushort M_in_simd = M_floor_of_quadrant + M_in_quadrant;
|
||||
|
||||
ushort N_floor_of_quadrant = (quad_id & 2) * 2; // 0 or 4
|
||||
ushort N_in_quadrant = (lane_id % 2) * 2; // 0 or 2
|
||||
ushort N_in_simd = N_floor_of_quadrant + N_in_quadrant;
|
||||
|
||||
return ushort2(N_in_simd, M_in_simd);
|
||||
}
|
||||
|
||||
// Indexes into an array of registers.
|
||||
//
|
||||
// Calls to this function are expected to be evaluated at compile time. The
|
||||
// array indices transform into register offsets, which are embedded into the
|
||||
// assembly code.
|
||||
template <typename T>
|
||||
METAL_FUNC thread simdgroup_matrix_storage<T>* get_sram(
|
||||
thread simdgroup_matrix_storage<T> *sram,
|
||||
ushort sram_leading_dim,
|
||||
ushort2 matrix_origin
|
||||
) {
|
||||
return sram + (matrix_origin.y / 8) * (sram_leading_dim / 8) + (matrix_origin.x / 8);
|
||||
}
|
||||
|
||||
// One multiply-accumulate loop iteration, or 8 dot products.
|
||||
template<
|
||||
typename T,
|
||||
typename U = T,
|
||||
ushort M_register,
|
||||
ushort N_register
|
||||
>
|
||||
METAL_FUNC void multiply_accumulate(
|
||||
const device T *A_src,
|
||||
const device U *B_src,
|
||||
thread simdgroup_matrix_storage<T> *A_sram,
|
||||
thread simdgroup_matrix_storage<U> *B_sram,
|
||||
thread simdgroup_matrix_storage<U> *C_sram,
|
||||
ushort k
|
||||
) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_register; m += 8) {
|
||||
ushort2 origin(0, m);
|
||||
auto A = get_sram(A_sram, 8, origin);
|
||||
A->load(A_src, A_leading_dim, ushort2(k, m), A_trans);
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_register; n += 8) {
|
||||
ushort2 origin(n, 0);
|
||||
auto B = get_sram(B_sram, N_register, origin);
|
||||
B->load(B_src, B_leading_dim, ushort2(n, k), B_trans);
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_register; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_register; n += 8) {
|
||||
auto A = get_sram(A_sram, 8, ushort2(0, m));
|
||||
auto B = get_sram(B_sram, N_register, ushort2(n, 0));
|
||||
auto C = get_sram(C_sram, N_register, ushort2(n, m));
|
||||
C->multiply(*A, *B);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// One multiply-accumulate loop iteration, or 8 dot products.
|
||||
template<
|
||||
typename T,
|
||||
typename U = T,
|
||||
ushort M_register,
|
||||
ushort N_register
|
||||
>
|
||||
METAL_FUNC void multiply_accumulate(
|
||||
const threadgroup T *A_src,
|
||||
const threadgroup U *B_src,
|
||||
thread simdgroup_matrix_storage<T> *A_sram,
|
||||
thread simdgroup_matrix_storage<U> *B_sram,
|
||||
thread simdgroup_matrix_storage<U> *C_sram,
|
||||
ushort k
|
||||
) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_register; m += 8) {
|
||||
ushort2 origin(0, m);
|
||||
auto A = get_sram(A_sram, 8, origin);
|
||||
A->load(A_src, A_leading_dim, ushort2(k, m), A_trans);
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_register; n += 8) {
|
||||
ushort2 origin(n, 0);
|
||||
auto B = get_sram(B_sram, N_register, origin);
|
||||
B->load(B_src, B_leading_dim, ushort2(n, k), B_trans);
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_register; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_register; n += 8) {
|
||||
auto A = get_sram(A_sram, 8, ushort2(0, m));
|
||||
auto B = get_sram(B_sram, N_register, ushort2(n, 0));
|
||||
auto C = get_sram(C_sram, N_register, ushort2(n, m));
|
||||
C->multiply(*A, *B);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Metal function arguments.
|
||||
//
|
||||
// A: the left-hand side matrix
|
||||
// - dimensions: M x K
|
||||
// K x M (transposed)
|
||||
// - memory precision: T
|
||||
// - register precision: T
|
||||
//
|
||||
// B: the right-hand side matrix
|
||||
// - dimensions: K x N
|
||||
// N x K (transposed)
|
||||
// - memory precision: U
|
||||
// - register precision: U
|
||||
//
|
||||
// C: the output matrix, alternatively the dot product accumulator
|
||||
// - dimensions: M x N
|
||||
// - memory precision: V
|
||||
// - register precision: V
|
||||
//
|
||||
// threadgroup_block: the chunk of threadgroup memory allocated at runtime
|
||||
// - ideally 10 KB or less
|
||||
// - precision: void/8-bit integer to make the pointer arithmetic more legible
|
||||
template <
|
||||
typename T,
|
||||
typename U = T,
|
||||
typename V = U,
|
||||
ushort M_group,
|
||||
ushort N_group,
|
||||
ushort K_group,
|
||||
ushort M_splits,
|
||||
ushort N_splits,
|
||||
ushort M_register = M_group / M_splits,
|
||||
ushort N_register = N_group / N_splits
|
||||
>
|
||||
void gemm_impl(
|
||||
device T *A [[buffer(0)]],
|
||||
device U *B [[buffer(1)]],
|
||||
device V *C [[buffer(2)]],
|
||||
|
||||
threadgroup uchar *threadgroup_block [[threadgroup(0)]],
|
||||
constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]]
|
||||
) {
|
||||
const ushort A_leading_block_dim = A_trans ? M_group : K_group;
|
||||
const ushort B_leading_block_dim = B_trans ? K_group : N_group;
|
||||
|
||||
// Thresholds that mark the matrix edge.
|
||||
const uint M_edge = M - (M % M_group);
|
||||
const uint N_edge = N - (N % N_group);
|
||||
|
||||
const ushort async_iter_start = prefer_async_copy ? 0 : (K - (K % K_group));
|
||||
|
||||
// Find the number of elements in the final block. If the matrix
|
||||
// dimensions are perfectly divisibly by block dimensions, we don't want
|
||||
// this value to be zero. The final block is a full block.
|
||||
const uint M_remainder = (M % M_register == 0)
|
||||
? M_register : M % M_register;
|
||||
const ushort N_remainder = (N % N_register == 0)
|
||||
? N_register : N % N_register;
|
||||
const ushort K_remainder = (K % K_group == 0)
|
||||
? K_group : K % K_group;
|
||||
const ushort K_remainder_padded = (K_remainder + 7) / 8 * 8;
|
||||
|
||||
// Shift the final block, so it doesn't access out-of-bounds memory.
|
||||
const ushort M_shift = (M < M_group) ? 0 : M_register - M_remainder;
|
||||
const ushort N_shift = (N < N_group) ? 0 : N_register - N_remainder;
|
||||
|
||||
if (batched) {
|
||||
ulong3 offsets = matrix_offsets[0].xyz * gid.z;
|
||||
A = (device T*)((device uchar*)A + offsets[0]);
|
||||
B = (device U*)((device uchar*)B + offsets[1]);
|
||||
C = (device V*)((device uchar*)C + offsets[2]);
|
||||
}
|
||||
|
||||
auto A_block = (threadgroup T*)(threadgroup_block);
|
||||
auto B_block = (threadgroup U*)(threadgroup_block + (M * K));
|
||||
ushort2 sid(sidx % N_splits, sidx / N_splits);
|
||||
ushort2 morton_offset = morton_order(lane_id);
|
||||
|
||||
// Return early if the SIMD is out of bounds.
|
||||
//
|
||||
// There could be some threadgroups where the matrix edge cuts straight
|
||||
// through the middle of the block. SIMDs on the right or bottom of the
|
||||
// dividing line must be stopped from causing out-of-bounds accesses. This is
|
||||
// the reason for the early exit.
|
||||
uint M_offset = gid.y * M_group;
|
||||
uint N_offset = gid.x * N_group;
|
||||
if (M_offset + sid.y * M_register >= M ||
|
||||
N_offset + sid.x * N_register >= N) {
|
||||
return;
|
||||
}
|
||||
ushort2 offset_in_group(sid.x * N_register + morton_offset.x,
|
||||
sid.y * M_register + morton_offset.y);
|
||||
|
||||
// Shift the matrix block within bounds, if possible.
|
||||
if ((M_shift != 0) && (gid.y * M_group >= M_edge)) {
|
||||
M_offset -= M_shift;
|
||||
}
|
||||
if ((N_shift != 0) && (gid.x * N_group >= N_edge)) {
|
||||
N_offset -= N_shift;
|
||||
}
|
||||
|
||||
simdgroup_matrix_storage<V> C_sram[(M_register / 8) * (N_register / 8)];
|
||||
|
||||
// Initialize the accumulator.
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_register; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_register; n += 8) {
|
||||
ushort2 origin(m, n);
|
||||
auto C = get_sram(C_sram, N_register, origin);
|
||||
*C = simdgroup_matrix_storage<V>(0);
|
||||
}
|
||||
}
|
||||
// Perform the iterations where async copy is avoided.
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint k = 0; k < async_iter_start; k += 8) {
|
||||
uint2 A_offset(k, M_offset);
|
||||
uint2 B_offset(N_offset, k);
|
||||
A_offset += uint2(morton_offset.x, offset_in_group.y);
|
||||
B_offset += uint2(offset_in_group.x, morton_offset.y);
|
||||
|
||||
auto A_src = simdgroup_matrix_storage<T>::apply_offset(A, A_leading_dim, A_offset, A_trans);
|
||||
auto B_src = simdgroup_matrix_storage<U>::apply_offset(B, B_leading_dim, B_offset, B_trans);
|
||||
|
||||
simdgroup_matrix_storage<T> A_sram[M_register / 8];
|
||||
simdgroup_matrix_storage<U> B_sram[N_register / 8];
|
||||
multiply_accumulate<T, U, M_register, N_register>(A_src, B_src, A_sram, B_sram, C_sram, 0);
|
||||
}
|
||||
if (!prefer_async_copy) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint k = 0; k < K; k += K_group) {
|
||||
uint2 A_offset(k, M_offset);
|
||||
uint2 B_offset(N_offset, k);
|
||||
A_offset += uint2(morton_offset.x, offset_in_group.y);
|
||||
B_offset += uint2(offset_in_group.x, morton_offset.y);
|
||||
|
||||
auto A_src = simdgroup_matrix_storage<T>::apply_offset(A, A_leading_dim, A_offset, A_trans);
|
||||
auto B_src = simdgroup_matrix_storage<U>::apply_offset(B, B_leading_dim, B_offset, B_trans);
|
||||
|
||||
simdgroup_matrix_storage<T> A_sram[M_register / 8];
|
||||
simdgroup_matrix_storage<U> B_sram[N_register / 8];
|
||||
multiply_accumulate<T, U, M_register, N_register>(A_src, B_src, A_sram, B_sram, C_sram, 0);
|
||||
}
|
||||
} else {
|
||||
// Perform the iterations where async copy is used.
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint k = async_iter_start; k < K; k += K_group) {
|
||||
// Launch an async copy from device to threadgroup memory.
|
||||
if (sidx == 0) {
|
||||
uint2 A_offset(k, M_offset);
|
||||
uint2 B_offset(N_offset, k);
|
||||
auto A_src = simdgroup_matrix_storage<T>::apply_offset(A, A_leading_dim, A_offset, A_trans);
|
||||
auto B_src = simdgroup_matrix_storage<U>::apply_offset(B, B_leading_dim, B_offset, B_trans);
|
||||
|
||||
ushort M_tile_dimension = min(uint(M_group), M - M_offset);
|
||||
ushort N_tile_dimension = min(uint(N_group), N - N_offset);
|
||||
ushort K_tile_dimension = min(uint(K_group), K - k);
|
||||
ushort K_tile_padded = min(uint(K_group), (K + K_remainder_padded - K_remainder) - k);
|
||||
|
||||
ushort2 A_tile_src(K_tile_dimension, M_tile_dimension);
|
||||
ushort2 B_tile_src(N_tile_dimension, K_tile_dimension);
|
||||
ushort2 A_tile_dst(K_tile_padded, M_tile_dimension);
|
||||
ushort2 B_tile_dst(N_tile_dimension, K_tile_padded);
|
||||
|
||||
simdgroup_event events[2];
|
||||
events[0].async_copy(A_block, A_leading_block_dim, A_tile_dst, A_src, A_leading_dim, A_tile_src, A_trans);
|
||||
events[1].async_copy(B_block, B_leading_block_dim, B_tile_dst, B_src, B_leading_dim, B_tile_src, B_trans);
|
||||
simdgroup_event::wait(2, events);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
ushort2 A_block_offset(morton_offset.x, offset_in_group.y);
|
||||
ushort2 B_block_offset(offset_in_group.x, morton_offset.y);
|
||||
auto A_block_src = simdgroup_matrix_storage<T>::apply_offset(A_block, A_leading_block_dim, A_block_offset, A_trans);
|
||||
auto B_block_src = simdgroup_matrix_storage<U>::apply_offset(B_block, B_leading_block_dim, B_block_offset, B_trans);
|
||||
|
||||
simdgroup_matrix_storage<T> A_sram[(M_register / 8) * (K_group / 8)];
|
||||
simdgroup_matrix_storage<U> B_sram[(K_group / 8) * (N_register / 8)];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort k = 0; k < K_remainder_padded; k += 8) {
|
||||
multiply_accumulate<T, U, M_register, N_register>(A_block_src, B_block_src, A_sram, B_sram, C_sram, k);
|
||||
}
|
||||
|
||||
// Will there be any iterations after this one?
|
||||
if (k + K_group < K) {
|
||||
// If so, we haven't reached the edge of either input matrix yet.
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort k = K_remainder_padded; k < K_group; k += 8) {
|
||||
multiply_accumulate<T, U, M_register, N_register>(A_block_src, B_block_src, A_sram, B_sram, C_sram, k);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!prefer_async_copy && (M >= M_group) && (N >= N_group)) {
|
||||
// Fast path for matrices that qualify.
|
||||
uint2 C_offset(N_offset + offset_in_group.x,
|
||||
M_offset + offset_in_group.y);
|
||||
auto C_dst = simdgroup_matrix_storage<U>::apply_offset(
|
||||
C, N, C_offset);
|
||||
|
||||
// Write the accumulator to device memory.
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_register; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_register; n += 8) {
|
||||
ushort2 origin(n, m);
|
||||
auto C = get_sram(C_sram, N_register, origin);
|
||||
C->store(C_dst, N, origin);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Slow path for when memory must be handled more carefully.
|
||||
auto C_block = (threadgroup V*)(threadgroup_block);
|
||||
auto C_block_dst = simdgroup_matrix_storage<V>::apply_offset(C_block, N_group, offset_in_group);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write the accumulator to threadgroup memory.
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_register; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_register; n += 8) {
|
||||
ushort2 origin(n, m);
|
||||
auto C = get_sram(C_sram, N_register, origin);
|
||||
C->store(C_block_dst, N_group, origin);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Launch the async copy from threadgroup to device memory.
|
||||
if (sidx == 0) {
|
||||
uint2 C_offset(gid.x * N_group, gid.y * M_group);
|
||||
ushort2 C_tile(min(uint(N_group), N - C_offset.x),
|
||||
min(uint(M_group), M - C_offset.y));
|
||||
auto C_dst = simdgroup_matrix_storage<V>::apply_offset(C, N, C_offset);
|
||||
|
||||
// If we shift successfully, the garbage zone moves from the bottom right
|
||||
// to the top left.
|
||||
if ((M_shift != 0) || (N_shift != 0)) {
|
||||
ushort2 C_block_shift(0, 0);
|
||||
if ((M_shift != 0) && (C_offset.y >= M_edge)) {
|
||||
C_block_shift.y = M_shift;
|
||||
}
|
||||
if ((N_shift != 0) && (C_offset.x >= N_edge)) {
|
||||
C_block_shift.x = N_shift;
|
||||
}
|
||||
C_block = simdgroup_matrix_storage<V>::apply_offset(C_block, N_group, C_block_shift);
|
||||
}
|
||||
|
||||
simdgroup_event event;
|
||||
event.async_copy(C_dst, N, C_tile, C_block, N_group, C_tile);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void hgemm(
|
||||
device half *A [[buffer(0)]],
|
||||
device half *B [[buffer(1)]],
|
||||
device half *C [[buffer(2)]],
|
||||
|
||||
threadgroup uchar *threadgroup_block [[threadgroup(0)]],
|
||||
constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]]
|
||||
) {
|
||||
if (ideal_grouping) {
|
||||
gemm_impl<half, half, half, 32, 32, 32, 1, 1>(
|
||||
A, B, C, threadgroup_block, matrix_offsets, gid, sidx, lane_id
|
||||
);
|
||||
} else {
|
||||
gemm_impl<half, half, half, 48, 48, 32, 1, 1>(
|
||||
A, B, C, threadgroup_block, matrix_offsets, gid, sidx, lane_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void sgemm(
|
||||
device float *A [[buffer(0)]],
|
||||
device float *B [[buffer(1)]],
|
||||
device float *C [[buffer(2)]],
|
||||
|
||||
threadgroup uchar *threadgroup_block [[threadgroup(0)]],
|
||||
constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]]
|
||||
) {
|
||||
gemm_impl<float, float, float, 32, 32, 32, 2, 2>(
|
||||
A, B, C, threadgroup_block, matrix_offsets, gid, sidx, lane_id
|
||||
);
|
||||
/*
|
||||
if (prefer_async_copy) {
|
||||
constexpr ushort M_split = 1;
|
||||
constexpr ushort N_split = 1;
|
||||
if (ideal_grouping) {
|
||||
gemm_impl<
|
||||
float,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
M_split,
|
||||
N_split
|
||||
>(
|
||||
A, B, C, threadgroup_block, gid, sidx, lane_id
|
||||
);
|
||||
} else {
|
||||
gemm_impl<
|
||||
float,
|
||||
float,
|
||||
48,
|
||||
48,
|
||||
24,
|
||||
M_split,
|
||||
N_split
|
||||
>(
|
||||
A, B, C, threadgroup_block, gid, sidx, lane_id
|
||||
);
|
||||
}
|
||||
} else {
|
||||
constexpr ushort M_split = 2;
|
||||
constexpr ushort N_split = 2;
|
||||
if (ideal_grouping) {
|
||||
gemm_impl<
|
||||
float,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
M_split,
|
||||
N_split
|
||||
>(
|
||||
A, B, C, threadgroup_block, gid, sidx, lane_id
|
||||
);
|
||||
} else {
|
||||
gemm_impl<
|
||||
float,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
100,
|
||||
M_split,
|
||||
N_split
|
||||
>(
|
||||
A, B, C, threadgroup_block, gid, sidx, lane_id
|
||||
);
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
243
candle-metal-kernels/src/kernels/matrix_storage.metal
Normal file
243
candle-metal-kernels/src/kernels/matrix_storage.metal
Normal file
@ -0,0 +1,243 @@
|
||||
// -*- Metal -*-
|
||||
//===-- metal_simdgroup_matrix_storage ------------------------------------===//
|
||||
// Copyright (c) 2024 Philip Turner. See MIT LICENSE
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef __METAL_SIMDGROUP_MATRIX_STORAGE
|
||||
#define __METAL_SIMDGROUP_MATRIX_STORAGE
|
||||
|
||||
#pragma METAL internals : enable
|
||||
namespace metal
|
||||
{
|
||||
template <typename T>
|
||||
struct simdgroup_matrix_storage {
|
||||
typedef vec<T, 64> storage_type;
|
||||
|
||||
storage_type t;
|
||||
|
||||
METAL_FUNC thread vec<T, 2>* thread_elements() thread {
|
||||
return reinterpret_cast<thread vec<T, 2>*>(&t);
|
||||
}
|
||||
|
||||
METAL_FUNC simdgroup_matrix_storage() thread = default;
|
||||
|
||||
METAL_FUNC simdgroup_matrix_storage(vec<T, 2> thread_elements) thread {
|
||||
*(this->thread_elements()) = thread_elements;
|
||||
}
|
||||
|
||||
METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y;
|
||||
} else {
|
||||
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x;
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
return src + matrix_origin.x * elements_per_row + matrix_origin.y;
|
||||
} else {
|
||||
return src + matrix_origin.y * elements_per_row + matrix_origin.x;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
METAL_FUNC void load(const device U *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y);
|
||||
uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y);
|
||||
U memoryForm0 = src[address0];
|
||||
U memoryForm1 = src[address1];
|
||||
((thread T*)thread_elements())[0] = T(memoryForm0);
|
||||
((thread T*)thread_elements())[1] = T(memoryForm1);
|
||||
} else if (elements_per_row % 2 != 0) {
|
||||
uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1);
|
||||
U memoryForm0 = src[address0];
|
||||
U memoryForm1 = src[address1];
|
||||
((thread T*)thread_elements())[0] = T(memoryForm0);
|
||||
((thread T*)thread_elements())[1] = T(memoryForm1);
|
||||
} else {
|
||||
auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
vec<U, 2> memoryForm = *(const device vec<U, 2>*)(src + combinedAddress);
|
||||
*(thread_elements()) = vec<T, 2>(memoryForm);
|
||||
}
|
||||
}
|
||||
|
||||
// WARNING: 'T' must be 'float'.
|
||||
METAL_FUNC void load_bfloat(const device bfloat *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y);
|
||||
uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y);
|
||||
bfloat memoryForm0 = src[address0];
|
||||
bfloat memoryForm1 = src[address1];
|
||||
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
registerForm[1] = memoryForm0;
|
||||
registerForm[3] = memoryForm1;
|
||||
((thread bfloat4*)thread_elements())[0] = registerForm;
|
||||
} else {
|
||||
auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
bfloat2 memoryForm = *(const device packed_bfloat2*)(src + combinedAddress);
|
||||
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
((thread float*)®isterForm)[1] = *(thread float*)(&memoryForm);
|
||||
((thread bfloat*)®isterForm)[1] = memoryForm[0];
|
||||
((thread bfloat4*)thread_elements())[0] = registerForm;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
METAL_FUNC void load(const threadgroup U *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y);
|
||||
ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y);
|
||||
U memoryForm0 = src[address0];
|
||||
U memoryForm1 = src[address1];
|
||||
((thread T*)thread_elements())[0] = T(memoryForm0);
|
||||
((thread T*)thread_elements())[1] = T(memoryForm1);
|
||||
} else if (elements_per_row % 2 != 0) {
|
||||
ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1);
|
||||
U memoryForm0 = src[address0];
|
||||
U memoryForm1 = src[address1];
|
||||
((thread T*)thread_elements())[0] = T(memoryForm0);
|
||||
((thread T*)thread_elements())[1] = T(memoryForm1);
|
||||
} else {
|
||||
auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
vec<U, 2> memoryForm = *(const threadgroup vec<U, 2>*)(src + combinedAddress);
|
||||
*(thread_elements()) = vec<T, 2>(memoryForm);
|
||||
}
|
||||
}
|
||||
|
||||
// WARNING: 'T' must be 'float'.
|
||||
METAL_FUNC void load_bfloat(const threadgroup bfloat *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y);
|
||||
ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y);
|
||||
bfloat memoryForm0 = src[address0];
|
||||
bfloat memoryForm1 = src[address1];
|
||||
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
registerForm[1] = memoryForm0;
|
||||
registerForm[3] = memoryForm1;
|
||||
((thread bfloat4*)thread_elements())[0] = registerForm;
|
||||
} else {
|
||||
auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
bfloat2 memoryForm = *(const threadgroup packed_bfloat2*)(src + combinedAddress);
|
||||
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
((thread float*)®isterForm)[1] = *(thread float*)(&memoryForm);
|
||||
((thread bfloat*)®isterForm)[1] = memoryForm[0];
|
||||
((thread bfloat4*)thread_elements())[0] = registerForm;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
METAL_FUNC void store(device U *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y);
|
||||
uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y);
|
||||
T registerForm0 = ((thread T*)thread_elements())[0];
|
||||
T registerForm1 = ((thread T*)thread_elements())[1];
|
||||
dst[address0] = U(registerForm0);
|
||||
dst[address1] = U(registerForm1);
|
||||
} else if (elements_per_row % 2 != 0) {
|
||||
uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1);
|
||||
T registerForm0 = ((thread T*)thread_elements())[0];
|
||||
T registerForm1 = ((thread T*)thread_elements())[1];
|
||||
dst[address0] = U(registerForm0);
|
||||
dst[address1] = U(registerForm1);
|
||||
} else {
|
||||
auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
vec<T, 2> registerForm = *(thread_elements());
|
||||
*(device vec<U, 2>*)(dst + combinedAddress) = vec<U, 2>(registerForm);
|
||||
}
|
||||
}
|
||||
|
||||
// WARNING: 'T' must be 'float'.
|
||||
METAL_FUNC void store_bfloat(device bfloat *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y);
|
||||
uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y);
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
registerForm[2] = registerForm[1];
|
||||
dst[address0] = registerForm[2];
|
||||
dst[address1] = registerForm[3];
|
||||
} else if (elements_per_row % 2 != 0) {
|
||||
uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1);
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
registerForm[2] = registerForm[1];
|
||||
dst[address0] = registerForm[2];
|
||||
dst[address1] = registerForm[3];
|
||||
} else {
|
||||
auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
registerForm[2] = registerForm[1];
|
||||
float memoryForm = ((thread float*)®isterForm)[1];
|
||||
*(device float*)(dst + combinedAddress) = memoryForm;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
METAL_FUNC void store(threadgroup U *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y);
|
||||
ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y);
|
||||
T registerForm0 = ((thread T*)thread_elements())[0];
|
||||
T registerForm1 = ((thread T*)thread_elements())[1];
|
||||
dst[address0] = U(registerForm0);
|
||||
dst[address1] = U(registerForm1);
|
||||
} else if (elements_per_row % 2 != 0) {
|
||||
ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1);
|
||||
T registerForm0 = ((thread T*)thread_elements())[0];
|
||||
T registerForm1 = ((thread T*)thread_elements())[1];
|
||||
dst[address0] = U(registerForm0);
|
||||
dst[address1] = U(registerForm1);
|
||||
} else {
|
||||
auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
vec<T, 2> registerForm = *(thread_elements());
|
||||
*(threadgroup vec<U, 2>*)(dst + combinedAddress) = vec<U, 2>(registerForm);
|
||||
}
|
||||
}
|
||||
|
||||
// WARNING: 'T' must be 'float'.
|
||||
METAL_FUNC void store_bfloat(threadgroup bfloat *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y);
|
||||
ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y);
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
registerForm[2] = registerForm[1];
|
||||
dst[address0] = registerForm[2];
|
||||
dst[address1] = registerForm[3];
|
||||
} else if (elements_per_row % 2 != 0) {
|
||||
ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1);
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
registerForm[2] = registerForm[1];
|
||||
dst[address0] = registerForm[2];
|
||||
dst[address1] = registerForm[3];
|
||||
} else {
|
||||
auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
registerForm[2] = registerForm[1];
|
||||
float memoryForm = ((thread float*)®isterForm)[1];
|
||||
*(threadgroup float*)(dst + combinedAddress) = memoryForm;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, typename V>
|
||||
METAL_FUNC void multiply(simdgroup_matrix_storage<U> a, simdgroup_matrix_storage<V> b, bool accumulate = true) {
|
||||
if (!accumulate) {
|
||||
*(thread_elements()) = vec<T, 2>(0);
|
||||
}
|
||||
t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage<T>::storage_type());
|
||||
}
|
||||
};
|
||||
} // namespace metal
|
||||
#pragma METAL internals : disable
|
||||
|
||||
#endif
|
@ -104,7 +104,7 @@ METAL_FUNC void argmax(
|
||||
threadgroup T * shared_memory,
|
||||
threadgroup uint * shared_indices
|
||||
) {
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||
@ -173,7 +173,7 @@ METAL_FUNC void reduce(
|
||||
threadgroup T * shared_memory,
|
||||
T (*fn)(T, T)
|
||||
) {
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
47
candle-metal-kernels/src/kernels/utils.metal
Normal file
47
candle-metal-kernels/src/kernels/utils.metal
Normal file
@ -0,0 +1,47 @@
|
||||
#pragma once
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint nonzero(uint n) {
|
||||
return n == 0 ? 1 : n;
|
||||
}
|
||||
|
||||
template<uint N>
|
||||
constexpr uint nonzero() {
|
||||
return N == 0 ? 1 : N;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr ushort granularity() {
|
||||
return nonzero<vec_elements<T>::value>();
|
||||
}
|
||||
|
||||
METAL_FUNC uint next_p2(uint x) {
|
||||
return 1 << (32 - clz(x - 1));
|
||||
}
|
||||
|
||||
METAL_FUNC uint prev_p2(uint x) {
|
||||
return 1 << (31 - clz(x));
|
||||
}
|
||||
|
||||
constant uint MAX_SHARED_MEM = 32767;
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC uint max_shared_mem(uint n) {
|
||||
return min(n, prev_p2(MAX_SHARED_MEM / sizeof(T)));
|
||||
}
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant const uint &num_dims,
|
||||
constant const size_t *dims,
|
||||
constant const size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
@ -1,30 +1,37 @@
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputePipelineState, Device, Function,
|
||||
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||
FunctionConstantValues, Library, MTLDataType, MTLGPUFamily, MTLSize, NSUInteger,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::RwLock;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
mod ffi;
|
||||
mod gpu;
|
||||
use gpu::get_device_core_count;
|
||||
|
||||
mod utils;
|
||||
pub use utils::BufferOffset;
|
||||
use utils::{get_block_dims, linear_split};
|
||||
|
||||
const AFFINE: &str = include_str!("affine.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
const SORT: &str = include_str!("sort.metal");
|
||||
const AFFINE: &str = include_str!("kernels/affine.metal");
|
||||
const INDEXING: &str = include_str!("kernels/indexing.metal");
|
||||
const UNARY: &str = include_str!("kernels/unary.metal");
|
||||
const BINARY: &str = include_str!("kernels/binary.metal");
|
||||
const TERNARY: &str = include_str!("kernels/ternary.metal");
|
||||
const CAST: &str = include_str!("kernels/cast.metal");
|
||||
const CONV: &str = include_str!("kernels/conv.metal");
|
||||
const REDUCE: &str = include_str!("kernels/reduce.metal");
|
||||
const RANDOM: &str = include_str!("kernels/random.metal");
|
||||
const QUANTIZED: &str = include_str!("kernels/quantized.metal");
|
||||
const SORT: &str = include_str!("kernels/sort.metal");
|
||||
const MFA: &[u8] = include_bytes!("libraries/libMetalFlashAttention.metallib");
|
||||
const CANDLE: &[u8] = include_bytes!("libraries/candle.metallib");
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
Candle,
|
||||
Affine,
|
||||
Indexing,
|
||||
Unary,
|
||||
@ -200,7 +207,7 @@ impl Kernels {
|
||||
Source::Random => RANDOM,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Sort => SORT,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
_ => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -216,14 +223,16 @@ impl Kernels {
|
||||
Ok(lib.clone())
|
||||
} else {
|
||||
let lib = match source {
|
||||
Source::Mfa => {
|
||||
let source_data = MFA;
|
||||
device.new_library_with_data(source_data).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||
))
|
||||
})?
|
||||
}
|
||||
Source::Candle => device.new_library_with_data(CANDLE).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load candle: {e}"
|
||||
))
|
||||
})?,
|
||||
Source::Mfa => device.new_library_with_data(MFA).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||
))
|
||||
})?,
|
||||
source => {
|
||||
let source_content = self.get_library_source(source);
|
||||
device
|
||||
@ -1465,6 +1474,29 @@ pub fn call_gemm(
|
||||
rhs_buffer: &Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let prefer_async_copy = !device.supports_family(MTLGPUFamily::Apple9);
|
||||
|
||||
let mut actual_groups: usize = 1;
|
||||
actual_groups *= divide(m, 48) as usize;
|
||||
actual_groups *= divide(n, 48) as usize;
|
||||
actual_groups *= b;
|
||||
|
||||
let core_count = get_device_core_count(device);
|
||||
let ideal_grouping = if name == "sgemm" {
|
||||
actual_groups <= core_count * 6
|
||||
} else {
|
||||
actual_groups <= core_count * 9
|
||||
};
|
||||
|
||||
let mut blockdim = (32, 32, 32);
|
||||
if !ideal_grouping {
|
||||
if name == "sgemm" {
|
||||
blockdim = (48, 48, 24);
|
||||
} else {
|
||||
blockdim = (48, 48, 32);
|
||||
}
|
||||
}
|
||||
|
||||
assert!(rhs_stride.len() >= 2);
|
||||
assert!(lhs_stride.len() >= 2);
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
@ -1501,50 +1533,45 @@ pub fn call_gemm(
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
let batched = b > 1;
|
||||
println!("batched: {batched}");
|
||||
let fused_activation = false;
|
||||
let fused_bias = false;
|
||||
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
||||
let m_simd = 8;
|
||||
let n_simd = 8;
|
||||
let k_simd = 64;
|
||||
let m_splits = 1;
|
||||
let n_splits = 1;
|
||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||
} else {
|
||||
let m_simd = 40;
|
||||
let n_simd = 40;
|
||||
let k_simd = 32;
|
||||
let m_splits = 1;
|
||||
let n_splits = 1;
|
||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||
};
|
||||
|
||||
let constants = Some(ConstantValues::new(vec![
|
||||
(0, Value::USize(m)),
|
||||
(1, Value::USize(n)),
|
||||
(2, Value::USize(k)),
|
||||
(10, Value::Bool(a_trans)),
|
||||
(11, Value::Bool(b_trans)),
|
||||
(13, Value::Bool(d_trans)),
|
||||
(20, Value::F32(alpha)),
|
||||
(21, Value::F32(beta)),
|
||||
//(13, Value::Bool(d_trans)),
|
||||
//(20, Value::F32(alpha)),
|
||||
//(21, Value::F32(beta)),
|
||||
(100, Value::Bool(batched)),
|
||||
(101, Value::Bool(fused_activation)),
|
||||
//(101, Value::Bool(fused_activation)),
|
||||
// Garbage
|
||||
(102, Value::Bool(false)),
|
||||
(103, Value::Bool(false)),
|
||||
(113, Value::Bool(false)),
|
||||
(50_000, Value::Bool(false)),
|
||||
// End garbage
|
||||
(200, Value::U16(m_simd)),
|
||||
(201, Value::U16(n_simd)),
|
||||
(202, Value::U16(k_simd)),
|
||||
(210, Value::U16(m_splits)),
|
||||
(211, Value::U16(n_splits)),
|
||||
(50_001, Value::Bool(fused_bias)),
|
||||
//(200, Value::U16(blockdim.0)),
|
||||
//(201, Value::U16(blockdim.1)),
|
||||
//(202, Value::U16(blockdim.2)),
|
||||
(206, Value::Bool(prefer_async_copy)),
|
||||
(207, Value::Bool(ideal_grouping)),
|
||||
//(210, Value::U16(m_splits)),
|
||||
//(211, Value::U16(n_splits)),
|
||||
//(50_001, Value::Bool(fused_bias)),
|
||||
]));
|
||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
|
||||
let m_group = m_simd * m_splits;
|
||||
let n_group = n_simd * n_splits;
|
||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Candle, name, constants)?;
|
||||
|
||||
let m_group: u16 = 32;
|
||||
let n_group: u16 = 32;
|
||||
let m_splits: u16 = 2;
|
||||
let n_splits: u16 = 2;
|
||||
let k_simd: u16 = 32;
|
||||
let m_simd = m_group / m_splits;
|
||||
let n_simd = n_group / n_splits;
|
||||
|
||||
let a_block_length = m_group * k_simd;
|
||||
let b_block_length = k_simd * n_group;
|
||||
@ -1554,6 +1581,7 @@ pub fn call_gemm(
|
||||
let c_block_length = m_group * n_group;
|
||||
block_elements = std::cmp::max(c_block_length, block_elements)
|
||||
}
|
||||
/*
|
||||
if fused_bias {
|
||||
if d_trans {
|
||||
block_elements = std::cmp::max(block_elements, m_group);
|
||||
@ -1561,6 +1589,7 @@ pub fn call_gemm(
|
||||
block_elements = std::cmp::max(block_elements, n_group);
|
||||
}
|
||||
}
|
||||
*/
|
||||
let bytes = match name {
|
||||
"sgemm" => 4,
|
||||
"hgemm" => 2,
|
||||
@ -1574,7 +1603,7 @@ pub fn call_gemm(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||
encoder.set_threadgroup_memory_length(0, block_bytes as NSUInteger);
|
||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||
encoder.set_buffer(2, Some(output), 0);
|
||||
@ -1588,7 +1617,7 @@ pub fn call_gemm(
|
||||
// TODO byte_stride_d
|
||||
let byte_stride_d = 0;
|
||||
|
||||
let buffer: Vec<u64> = vec![
|
||||
let buffer: [u64; 4] = [
|
||||
byte_stride_a as _,
|
||||
byte_stride_b as _,
|
||||
byte_stride_c as _,
|
||||
|
BIN
candle-metal-kernels/src/libraries/candle.metallib
Normal file
BIN
candle-metal-kernels/src/libraries/candle.metallib
Normal file
Binary file not shown.
@ -1100,6 +1100,11 @@ fn gemm() {
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
println!("lhs: {lhs:?}");
|
||||
println!("lhs_stride: {lhs_stride:?}");
|
||||
println!("rhs: {rhs:?}");
|
||||
println!("rhs_stride: {rhs_stride:?}");
|
||||
|
||||
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
@ -1111,6 +1116,11 @@ fn gemm() {
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
println!("lhs: {lhs:?}");
|
||||
println!("lhs_stride: {lhs_stride:?}");
|
||||
println!("rhs: {rhs:?}");
|
||||
println!("rhs_stride: {rhs_stride:?}");
|
||||
|
||||
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
|
Reference in New Issue
Block a user