UG metal integration. (#2580)

This commit is contained in:
Laurent Mazare
2024-10-27 15:20:37 +01:00
committed by GitHub
parent 594d984f9c
commit 0e2c8c17fb
8 changed files with 92 additions and 18 deletions

View File

@ -6,7 +6,7 @@ use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
mod utils;
pub mod utils;
pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split, EncoderProvider};

View File

@ -24,7 +24,7 @@ pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (M
}
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
let mut pows0 = 0u64;
let mut pows1 = 0u64;
let mut pows2 = 0u64;
@ -61,18 +61,14 @@ pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
}
}
pub(crate) fn set_param<P: EncoderParam>(
encoder: &ComputeCommandEncoderRef,
position: u64,
data: P,
) {
pub fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
pub(crate) trait EncoderParam {
pub trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {