Allow using composed strings as metal kernel names. (#2747)

This commit is contained in:
Laurent Mazare
2025-01-27 22:40:12 +01:00
committed by GitHub
parent 27996a1a9e
commit da02b59516

View File

@ -177,8 +177,54 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
} }
} }
#[derive(Debug, Clone)]
pub enum KernelName {
Ref(&'static str),
Value(String),
}
impl AsRef<str> for KernelName {
fn as_ref(&self) -> &str {
match self {
Self::Ref(r) => r,
Self::Value(v) => v.as_str(),
}
}
}
impl std::hash::Hash for KernelName {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
Self::Ref(r) => r.hash(state),
Self::Value(v) => v.hash(state),
}
}
}
impl PartialEq for KernelName {
fn eq(&self, other: &Self) -> bool {
let v1: &str = self.as_ref();
let v2: &str = other.as_ref();
v1 == v2
}
}
impl Eq for KernelName {}
impl From<&'static str> for KernelName {
fn from(value: &'static str) -> Self {
Self::Ref(value)
}
}
impl From<String> for KernelName {
fn from(value: String) -> Self {
Self::Value(value)
}
}
type Libraries = HashMap<Source, Library>; type Libraries = HashMap<Source, Library>;
type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>; type Pipelines = HashMap<(KernelName, Option<ConstantValues>), ComputePipelineState>;
#[derive(Debug)] #[derive(Debug)]
pub struct Kernels { pub struct Kernels {
@ -247,7 +293,7 @@ impl Kernels {
&self, &self,
device: &Device, device: &Device,
source: Source, source: Source,
name: &'static str, name: &str,
constants: Option<FunctionConstantValues>, constants: Option<FunctionConstantValues>,
) -> Result<Function, MetalKernelError> { ) -> Result<Function, MetalKernelError> {
let func = self let func = self
@ -264,11 +310,11 @@ impl Kernels {
&self, &self,
device: &Device, device: &Device,
source: Source, source: Source,
name: &'static str, name: impl Into<KernelName>,
constants: Option<ConstantValues>, constants: Option<ConstantValues>,
) -> Result<ComputePipelineState, MetalKernelError> { ) -> Result<ComputePipelineState, MetalKernelError> {
let mut pipelines = self.pipelines.write()?; let mut pipelines = self.pipelines.write()?;
let key = (name, constants); let key = (name.into(), constants);
if let Some(pipeline) = pipelines.get(&key) { if let Some(pipeline) = pipelines.get(&key) {
Ok(pipeline.clone()) Ok(pipeline.clone())
} else { } else {
@ -276,7 +322,7 @@ impl Kernels {
let func = self.load_function( let func = self.load_function(
device, device,
source, source,
name, name.as_ref(),
constants.as_ref().map(|c| c.function_constant_values()), constants.as_ref().map(|c| c.function_constant_values()),
)?; )?;
let pipeline = device let pipeline = device
@ -295,7 +341,7 @@ impl Kernels {
&self, &self,
device: &Device, device: &Device,
source: Source, source: Source,
name: &'static str, name: impl Into<KernelName>,
) -> Result<ComputePipelineState, MetalKernelError> { ) -> Result<ComputePipelineState, MetalKernelError> {
self.load_pipeline_with_constants(device, source, name, None) self.load_pipeline_with_constants(device, source, name, None)
} }