diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 62eb21e8..3fc5fffa 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -41,6 +41,8 @@ impl ToPyObject for PyDType { } } +static CUDA_DEVICE: std::sync::Mutex> = std::sync::Mutex::new(None); + #[derive(Clone, Copy, Debug, PartialEq, Eq)] enum PyDevice { Cpu, @@ -54,6 +56,21 @@ impl PyDevice { Device::Cuda(_) => Self::Cuda, } } + + fn as_device(&self) -> PyResult { + match self { + Self::Cpu => Ok(Device::Cpu), + Self::Cuda => { + let mut device = CUDA_DEVICE.lock().unwrap(); + if let Some(device) = device.as_ref() { + return Ok(device.clone()); + }; + let d = Device::new_cuda(0).map_err(wrap_err)?; + *device = Some(d.clone()); + Ok(d) + } + } + } } impl<'source> FromPyObject<'source> for PyDevice { @@ -335,6 +352,11 @@ impl PyTensor { fn to_dtype(&self, dtype: PyDType) -> PyResult { Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?)) } + + fn to_device(&self, device: PyDevice) -> PyResult { + let device = device.as_device()?; + Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?)) + } } /// Concatenate the tensors across one axis.