mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Get cuda to work on pyo3.
This commit is contained in:
@ -41,6 +41,8 @@ impl ToPyObject for PyDType {
|
||||
}
|
||||
}
|
||||
|
||||
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
enum PyDevice {
|
||||
Cpu,
|
||||
@ -54,6 +56,21 @@ impl PyDevice {
|
||||
Device::Cuda(_) => Self::Cuda,
|
||||
}
|
||||
}
|
||||
|
||||
fn as_device(&self) -> PyResult<Device> {
|
||||
match self {
|
||||
Self::Cpu => Ok(Device::Cpu),
|
||||
Self::Cuda => {
|
||||
let mut device = CUDA_DEVICE.lock().unwrap();
|
||||
if let Some(device) = device.as_ref() {
|
||||
return Ok(device.clone());
|
||||
};
|
||||
let d = Device::new_cuda(0).map_err(wrap_err)?;
|
||||
*device = Some(d.clone());
|
||||
Ok(d)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'source> FromPyObject<'source> for PyDevice {
|
||||
@ -335,6 +352,11 @@ impl PyTensor {
|
||||
fn to_dtype(&self, dtype: PyDType) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn to_device(&self, device: PyDevice) -> PyResult<Self> {
|
||||
let device = device.as_device()?;
|
||||
Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?))
|
||||
}
|
||||
}
|
||||
|
||||
/// Concatenate the tensors across one axis.
|
||||
|
Reference in New Issue
Block a user