mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Allow using composed strings as metal kernel names. (#2747)
This commit is contained in:
@ -177,8 +177,54 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum KernelName {
|
||||
Ref(&'static str),
|
||||
Value(String),
|
||||
}
|
||||
|
||||
impl AsRef<str> for KernelName {
|
||||
fn as_ref(&self) -> &str {
|
||||
match self {
|
||||
Self::Ref(r) => r,
|
||||
Self::Value(v) => v.as_str(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::hash::Hash for KernelName {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
match self {
|
||||
Self::Ref(r) => r.hash(state),
|
||||
Self::Value(v) => v.hash(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for KernelName {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
let v1: &str = self.as_ref();
|
||||
let v2: &str = other.as_ref();
|
||||
v1 == v2
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for KernelName {}
|
||||
|
||||
impl From<&'static str> for KernelName {
|
||||
fn from(value: &'static str) -> Self {
|
||||
Self::Ref(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for KernelName {
|
||||
fn from(value: String) -> Self {
|
||||
Self::Value(value)
|
||||
}
|
||||
}
|
||||
|
||||
type Libraries = HashMap<Source, Library>;
|
||||
type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>;
|
||||
type Pipelines = HashMap<(KernelName, Option<ConstantValues>), ComputePipelineState>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Kernels {
|
||||
@ -247,7 +293,7 @@ impl Kernels {
|
||||
&self,
|
||||
device: &Device,
|
||||
source: Source,
|
||||
name: &'static str,
|
||||
name: &str,
|
||||
constants: Option<FunctionConstantValues>,
|
||||
) -> Result<Function, MetalKernelError> {
|
||||
let func = self
|
||||
@ -264,11 +310,11 @@ impl Kernels {
|
||||
&self,
|
||||
device: &Device,
|
||||
source: Source,
|
||||
name: &'static str,
|
||||
name: impl Into<KernelName>,
|
||||
constants: Option<ConstantValues>,
|
||||
) -> Result<ComputePipelineState, MetalKernelError> {
|
||||
let mut pipelines = self.pipelines.write()?;
|
||||
let key = (name, constants);
|
||||
let key = (name.into(), constants);
|
||||
if let Some(pipeline) = pipelines.get(&key) {
|
||||
Ok(pipeline.clone())
|
||||
} else {
|
||||
@ -276,7 +322,7 @@ impl Kernels {
|
||||
let func = self.load_function(
|
||||
device,
|
||||
source,
|
||||
name,
|
||||
name.as_ref(),
|
||||
constants.as_ref().map(|c| c.function_constant_values()),
|
||||
)?;
|
||||
let pipeline = device
|
||||
@ -295,7 +341,7 @@ impl Kernels {
|
||||
&self,
|
||||
device: &Device,
|
||||
source: Source,
|
||||
name: &'static str,
|
||||
name: impl Into<KernelName>,
|
||||
) -> Result<ComputePipelineState, MetalKernelError> {
|
||||
self.load_pipeline_with_constants(device, source, name, None)
|
||||
}
|
||||
|
Reference in New Issue
Block a user