mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Metal part 1 - Scaffolding for metal. (#1308)
* Metal part 1 - Scaffolding for metal. * Remove tracing.
This commit is contained in:
@ -71,11 +71,13 @@ impl PyDType {
|
||||
}
|
||||
|
||||
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
|
||||
static METAL_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
enum PyDevice {
|
||||
Cpu,
|
||||
Cuda,
|
||||
Metal,
|
||||
}
|
||||
|
||||
impl PyDevice {
|
||||
@ -83,6 +85,7 @@ impl PyDevice {
|
||||
match device {
|
||||
Device::Cpu => Self::Cpu,
|
||||
Device::Cuda(_) => Self::Cuda,
|
||||
Device::Metal(_) => Self::Metal,
|
||||
}
|
||||
}
|
||||
|
||||
@ -98,6 +101,15 @@ impl PyDevice {
|
||||
*device = Some(d.clone());
|
||||
Ok(d)
|
||||
}
|
||||
Self::Metal => {
|
||||
let mut device = METAL_DEVICE.lock().unwrap();
|
||||
if let Some(device) = device.as_ref() {
|
||||
return Ok(device.clone());
|
||||
};
|
||||
let d = Device::new_metal(0).map_err(wrap_err)?;
|
||||
*device = Some(d.clone());
|
||||
Ok(d)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -119,6 +131,7 @@ impl ToPyObject for PyDevice {
|
||||
let str = match self {
|
||||
PyDevice::Cpu => "cpu",
|
||||
PyDevice::Cuda => "cuda",
|
||||
PyDevice::Metal => "metal",
|
||||
};
|
||||
str.to_object(py)
|
||||
}
|
||||
|
Reference in New Issue
Block a user