Adding a bunch of docs !

Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2023-12-15 11:02:41 +01:00
parent cf27868b57
commit 243e83f2b9
2 changed files with 128 additions and 59 deletions

View File

@ -15,6 +15,10 @@ const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
/// Most kernels apply similarly across the tensors
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
/// actual total buffer length).
/// Then kernels can just do their op on their single point in the buffer.
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
@ -36,6 +40,10 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
@ -220,6 +228,9 @@ impl Kernels {
Source::Mfa => panic!("Invalid lib"),
}
}
/// Load the give library from its [`source`].
/// If this has been previously loaded it will just fetch it from cache.
pub fn load_library(
&self,
device: &Device,
@ -262,6 +273,9 @@ impl Kernels {
Ok(func)
}
/// Load the give pipeline
/// loads the library from source, then gets the function [`name`] from
/// that source
fn load_pipeline_with_constants(
&self,
device: &Device,
@ -290,6 +304,9 @@ impl Kernels {
}
}
/// Load the give pipeline
/// loads the library from source, then gets the function [`name`] from
/// that source (without constants)
pub fn load_pipeline(
&self,
device: &Device,