Allow using composed strings as metal kernel names. (#2747)

This commit is contained in:
Laurent Mazare
2025-01-27 22:40:12 +01:00
committed by GitHub
parent 27996a1a9e
commit da02b59516

View File

@ -177,8 +177,54 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
}
}
#[derive(Debug, Clone)]
pub enum KernelName {
Ref(&'static str),
Value(String),
}
impl AsRef<str> for KernelName {
fn as_ref(&self) -> &str {
match self {
Self::Ref(r) => r,
Self::Value(v) => v.as_str(),
}
}
}
impl std::hash::Hash for KernelName {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
Self::Ref(r) => r.hash(state),
Self::Value(v) => v.hash(state),
}
}
}
impl PartialEq for KernelName {
fn eq(&self, other: &Self) -> bool {
let v1: &str = self.as_ref();
let v2: &str = other.as_ref();
v1 == v2
}
}
impl Eq for KernelName {}
impl From<&'static str> for KernelName {
fn from(value: &'static str) -> Self {
Self::Ref(value)
}
}
impl From<String> for KernelName {
fn from(value: String) -> Self {
Self::Value(value)
}
}
type Libraries = HashMap<Source, Library>;
type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>;
type Pipelines = HashMap<(KernelName, Option<ConstantValues>), ComputePipelineState>;
#[derive(Debug)]
pub struct Kernels {
@ -247,7 +293,7 @@ impl Kernels {
&self,
device: &Device,
source: Source,
name: &'static str,
name: &str,
constants: Option<FunctionConstantValues>,
) -> Result<Function, MetalKernelError> {
let func = self
@ -264,11 +310,11 @@ impl Kernels {
&self,
device: &Device,
source: Source,
name: &'static str,
name: impl Into<KernelName>,
constants: Option<ConstantValues>,
) -> Result<ComputePipelineState, MetalKernelError> {
let mut pipelines = self.pipelines.write()?;
let key = (name, constants);
let key = (name.into(), constants);
if let Some(pipeline) = pipelines.get(&key) {
Ok(pipeline.clone())
} else {
@ -276,7 +322,7 @@ impl Kernels {
let func = self.load_function(
device,
source,
name,
name.as_ref(),
constants.as_ref().map(|c| c.function_constant_values()),
)?;
let pipeline = device
@ -295,7 +341,7 @@ impl Kernels {
&self,
device: &Device,
source: Source,
name: &'static str,
name: impl Into<KernelName>,
) -> Result<ComputePipelineState, MetalKernelError> {
self.load_pipeline_with_constants(device, source, name, None)
}