mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Get cuda to work on pyo3.
This commit is contained in:
@ -41,6 +41,8 @@ impl ToPyObject for PyDType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
enum PyDevice {
|
enum PyDevice {
|
||||||
Cpu,
|
Cpu,
|
||||||
@ -54,6 +56,21 @@ impl PyDevice {
|
|||||||
Device::Cuda(_) => Self::Cuda,
|
Device::Cuda(_) => Self::Cuda,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn as_device(&self) -> PyResult<Device> {
|
||||||
|
match self {
|
||||||
|
Self::Cpu => Ok(Device::Cpu),
|
||||||
|
Self::Cuda => {
|
||||||
|
let mut device = CUDA_DEVICE.lock().unwrap();
|
||||||
|
if let Some(device) = device.as_ref() {
|
||||||
|
return Ok(device.clone());
|
||||||
|
};
|
||||||
|
let d = Device::new_cuda(0).map_err(wrap_err)?;
|
||||||
|
*device = Some(d.clone());
|
||||||
|
Ok(d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'source> FromPyObject<'source> for PyDevice {
|
impl<'source> FromPyObject<'source> for PyDevice {
|
||||||
@ -335,6 +352,11 @@ impl PyTensor {
|
|||||||
fn to_dtype(&self, dtype: PyDType) -> PyResult<Self> {
|
fn to_dtype(&self, dtype: PyDType) -> PyResult<Self> {
|
||||||
Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
|
Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn to_device(&self, device: PyDevice) -> PyResult<Self> {
|
||||||
|
let device = device.as_device()?;
|
||||||
|
Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Concatenate the tensors across one axis.
|
/// Concatenate the tensors across one axis.
|
||||||
|
Reference in New Issue
Block a user