Get cuda to work on pyo3.

This commit is contained in:
laurent
2023-07-02 21:04:11 +01:00
parent fbfe74caab
commit 5b0ee2e0ba

View File

@ -41,6 +41,8 @@ impl ToPyObject for PyDType {
}
}
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum PyDevice {
Cpu,
@ -54,6 +56,21 @@ impl PyDevice {
Device::Cuda(_) => Self::Cuda,
}
}
fn as_device(&self) -> PyResult<Device> {
match self {
Self::Cpu => Ok(Device::Cpu),
Self::Cuda => {
let mut device = CUDA_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_cuda(0).map_err(wrap_err)?;
*device = Some(d.clone());
Ok(d)
}
}
}
}
impl<'source> FromPyObject<'source> for PyDevice {
@ -335,6 +352,11 @@ impl PyTensor {
fn to_dtype(&self, dtype: PyDType) -> PyResult<Self> {
Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
}
fn to_device(&self, device: PyDevice) -> PyResult<Self> {
let device = device.as_device()?;
Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?))
}
}
/// Concatenate the tensors across one axis.