Metal part 1 - Scaffolding for metal. (#1308)

* Metal part 1 - Scaffolding for metal.

* Remove tracing.
This commit is contained in:
Nicolas Patry
2023-11-10 08:35:48 +01:00
committed by GitHub
parent 18d30005c5
commit 26c4e5bf1d
13 changed files with 473 additions and 16 deletions

View File

@ -8,12 +8,14 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
pub enum DeviceLocation {
Cpu,
Cuda { gpu_id: usize },
Metal,
}
#[derive(Debug, Clone)]
pub enum Device {
Cpu,
Cuda(crate::CudaDevice),
Metal(crate::MetalDevice),
}
pub trait NdArray {
@ -128,10 +130,15 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
pub fn new_metal(ordinal: usize) -> Result<Self> {
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
}
pub fn set_seed(&self, seed: u64) -> Result<()> {
match self {
Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed),
Self::Cpu => CpuDevice.set_seed(seed),
Self::Cuda(c) => c.set_seed(seed),
Self::Metal(m) => m.set_seed(seed),
}
}
@ -147,21 +154,20 @@ impl Device {
match self {
Self::Cpu => DeviceLocation::Cpu,
Self::Cuda(device) => device.location(),
Device::Metal(device) => device.location(),
}
}
pub fn is_cpu(&self) -> bool {
match self {
Self::Cpu => true,
Self::Cuda(_) => false,
}
matches!(self, Self::Cpu)
}
pub fn is_cuda(&self) -> bool {
match self {
Self::Cpu => false,
Self::Cuda(_) => true,
}
matches!(self, Self::Cuda(_))
}
pub fn is_metal(&self) -> bool {
matches!(self, Self::Metal(_))
}
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
@ -194,6 +200,11 @@ impl Device {
Ok(Storage::Cuda(storage))
}
}
Device::Metal(_device) => {
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
// Ok(Storage::Metal(storage))
crate::bail!("Metal rand_uniform not implemented")
}
}
}
@ -228,6 +239,10 @@ impl Device {
Ok(Storage::Cuda(storage))
}
}
Device::Metal(device) => {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Metal(storage))
}
}
}
@ -250,6 +265,10 @@ impl Device {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
@ -263,6 +282,10 @@ impl Device {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
@ -274,6 +297,11 @@ impl Device {
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage))
}
}
}
@ -285,6 +313,11 @@ impl Device {
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage))
}
}
}
}