Metal part 1 - Scaffolding for metal. (#1308)

* Metal part 1 - Scaffolding for metal.

* Remove tracing.
This commit is contained in:
Nicolas Patry
2023-11-10 08:35:48 +01:00
committed by GitHub
parent 18d30005c5
commit 26c4e5bf1d
13 changed files with 473 additions and 16 deletions

View File

@ -71,11 +71,13 @@ impl PyDType {
}
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
static METAL_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum PyDevice {
Cpu,
Cuda,
Metal,
}
impl PyDevice {
@ -83,6 +85,7 @@ impl PyDevice {
match device {
Device::Cpu => Self::Cpu,
Device::Cuda(_) => Self::Cuda,
Device::Metal(_) => Self::Metal,
}
}
@ -98,6 +101,15 @@ impl PyDevice {
*device = Some(d.clone());
Ok(d)
}
Self::Metal => {
let mut device = METAL_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_metal(0).map_err(wrap_err)?;
*device = Some(d.clone());
Ok(d)
}
}
}
}
@ -119,6 +131,7 @@ impl ToPyObject for PyDevice {
let str = match self {
PyDevice::Cpu => "cpu",
PyDevice::Cuda => "cuda",
PyDevice::Metal => "metal",
};
str.to_object(py)
}