mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Adding a bunch of docs !
Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
@ -15,6 +15,10 @@ const CAST: &str = include_str!("cast.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||
/// actual total buffer length).
|
||||
/// Then kernels can just do their op on their single point in the buffer.
|
||||
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
||||
let size = length as u64;
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
|
||||
@ -36,6 +40,10 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL
|
||||
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
|
||||
<P as EncoderParam>::set_param(encoder, position, data)
|
||||
}
|
||||
|
||||
/// Helper functions to create the various objects on the compute command encoder
|
||||
/// on a single line.
|
||||
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
||||
trait EncoderParam {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
||||
}
|
||||
@ -220,6 +228,9 @@ impl Kernels {
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the give library from its [`source`].
|
||||
/// If this has been previously loaded it will just fetch it from cache.
|
||||
pub fn load_library(
|
||||
&self,
|
||||
device: &Device,
|
||||
@ -262,6 +273,9 @@ impl Kernels {
|
||||
Ok(func)
|
||||
}
|
||||
|
||||
/// Load the give pipeline
|
||||
/// loads the library from source, then gets the function [`name`] from
|
||||
/// that source
|
||||
fn load_pipeline_with_constants(
|
||||
&self,
|
||||
device: &Device,
|
||||
@ -290,6 +304,9 @@ impl Kernels {
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the give pipeline
|
||||
/// loads the library from source, then gets the function [`name`] from
|
||||
/// that source (without constants)
|
||||
pub fn load_pipeline(
|
||||
&self,
|
||||
device: &Device,
|
||||
|
Reference in New Issue
Block a user