Add some skeleton code for GPU support.

This commit is contained in:
laurent
2023-06-21 09:13:57 +01:00
parent f319583530
commit 8cde0c5478
3 changed files with 28 additions and 0 deletions

View File

@ -6,6 +6,7 @@ use crate::{
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum Device {
Cpu,
Cuda { gpu_id: usize },
}
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
@ -72,6 +73,9 @@ impl Device {
};
Storage::Cpu(storage)
}
Device::Cuda { gpu_id: _ } => {
todo!()
}
}
}
@ -91,12 +95,18 @@ impl Device {
};
Storage::Cpu(storage)
}
Device::Cuda { gpu_id: _ } => {
todo!()
}
}
}
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Storage {
match self {
Device::Cpu => Storage::Cpu(array.to_cpu_storage()),
Device::Cuda { gpu_id: _ } => {
todo!()
}
}
}
}