Move the data between the host and the device.

This commit is contained in:
laurent
2023-06-21 19:43:25 +01:00
parent c654ecdb16
commit 71735c7a02
5 changed files with 100 additions and 20 deletions

View File

@ -11,7 +11,7 @@ pub enum DeviceLocation {
#[derive(Debug, Clone)]
pub enum Device {
Cpu,
Cuda(std::sync::Arc<cudarc::driver::CudaDevice>),
Cuda(crate::CudaDevice),
}
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
@ -63,8 +63,7 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N];
impl Device {
pub fn new_cuda(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal)?;
Ok(Self::Cuda(device))
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
pub fn location(&self) -> DeviceLocation {
@ -85,7 +84,8 @@ impl Device {
Device::Cuda(device) => {
// TODO: Instead of allocating memory on the host and transfering it,
// allocate some zeros on the device and use a shader to set them to 1.
let storage = device.htod_copy(vec![1f32; shape.elem_count()])?;
let storage = CpuStorage::ones_impl(shape, dtype);
let storage = device.cuda_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
}
@ -98,7 +98,7 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.alloc_zeros::<f32>(shape.elem_count())?;
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
}
@ -108,16 +108,9 @@ impl Device {
match self {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
Device::Cuda(device) => {
// TODO: Avoid making a copy through the cpu.
match array.to_cpu_storage() {
CpuStorage::F64(_) => {
todo!()
}
CpuStorage::F32(data) => {
let storage = device.htod_copy(data)?;
Ok(Storage::Cuda(storage))
}
}
let storage = array.to_cpu_storage();
let storage = device.cuda_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
}
}