mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
UG metal integration. (#2580)
This commit is contained in:
@ -6,7 +6,7 @@ use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::RwLock;
|
||||
|
||||
mod utils;
|
||||
pub mod utils;
|
||||
pub use utils::BufferOffset;
|
||||
use utils::{get_block_dims, linear_split, EncoderProvider};
|
||||
|
||||
|
@ -24,7 +24,7 @@ pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (M
|
||||
}
|
||||
|
||||
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
|
||||
pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
|
||||
pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
|
||||
let mut pows0 = 0u64;
|
||||
let mut pows1 = 0u64;
|
||||
let mut pows2 = 0u64;
|
||||
@ -61,18 +61,14 @@ pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set_param<P: EncoderParam>(
|
||||
encoder: &ComputeCommandEncoderRef,
|
||||
position: u64,
|
||||
data: P,
|
||||
) {
|
||||
pub fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
|
||||
<P as EncoderParam>::set_param(encoder, position, data)
|
||||
}
|
||||
|
||||
/// Helper functions to create the various objects on the compute command encoder
|
||||
/// on a single line.
|
||||
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
||||
pub(crate) trait EncoderParam {
|
||||
pub trait EncoderParam {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
||||
}
|
||||
macro_rules! primitive {
|
||||
|
Reference in New Issue
Block a user