From da02b595165227765b1e068b747159580f1ab0b3 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 27 Jan 2025 22:40:12 +0100 Subject: [PATCH] Allow using composed strings as metal kernel names. (#2747) --- candle-metal-kernels/src/lib.rs | 58 +++++++++++++++++++++++++++++---- 1 file changed, 52 insertions(+), 6 deletions(-) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 2e001a0f..eeb9a975 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -177,8 +177,54 @@ impl From> for MetalKernelError { } } +#[derive(Debug, Clone)] +pub enum KernelName { + Ref(&'static str), + Value(String), +} + +impl AsRef for KernelName { + fn as_ref(&self) -> &str { + match self { + Self::Ref(r) => r, + Self::Value(v) => v.as_str(), + } + } +} + +impl std::hash::Hash for KernelName { + fn hash(&self, state: &mut H) { + match self { + Self::Ref(r) => r.hash(state), + Self::Value(v) => v.hash(state), + } + } +} + +impl PartialEq for KernelName { + fn eq(&self, other: &Self) -> bool { + let v1: &str = self.as_ref(); + let v2: &str = other.as_ref(); + v1 == v2 + } +} + +impl Eq for KernelName {} + +impl From<&'static str> for KernelName { + fn from(value: &'static str) -> Self { + Self::Ref(value) + } +} + +impl From for KernelName { + fn from(value: String) -> Self { + Self::Value(value) + } +} + type Libraries = HashMap; -type Pipelines = HashMap<(&'static str, Option), ComputePipelineState>; +type Pipelines = HashMap<(KernelName, Option), ComputePipelineState>; #[derive(Debug)] pub struct Kernels { @@ -247,7 +293,7 @@ impl Kernels { &self, device: &Device, source: Source, - name: &'static str, + name: &str, constants: Option, ) -> Result { let func = self @@ -264,11 +310,11 @@ impl Kernels { &self, device: &Device, source: Source, - name: &'static str, + name: impl Into, constants: Option, ) -> Result { let mut pipelines = self.pipelines.write()?; - let key = (name, constants); + let key = (name.into(), constants); if let Some(pipeline) = pipelines.get(&key) { Ok(pipeline.clone()) } else { @@ -276,7 +322,7 @@ impl Kernels { let func = self.load_function( device, source, - name, + name.as_ref(), constants.as_ref().map(|c| c.function_constant_values()), )?; let pipeline = device @@ -295,7 +341,7 @@ impl Kernels { &self, device: &Device, source: Source, - name: &'static str, + name: impl Into, ) -> Result { self.load_pipeline_with_constants(device, source, name, None) }