Merge pull request #55 from LaurentMazare/pyo3-device

Cuda support for the pyo3 bindings
This commit is contained in:
Laurent Mazare
2023-07-02 21:04:58 +01:00
committed by GitHub
3 changed files with 74 additions and 3 deletions

View File

@ -18,3 +18,7 @@ crate-type = ["cdylib"]
candle = { path = "../candle-core", default-features=false }
pyo3 = { version = "0.19.0", features = ["extension-module"] }
half = { version = "2.3.1", features = ["num-traits"] }
[features]
default = ["cuda"]
cuda = ["candle/cuda"]

View File

@ -4,7 +4,7 @@ use pyo3::types::PyTuple;
use half::{bf16, f16};
use ::candle::{DType, Device::Cpu, Tensor, WithDType};
use ::candle::{DType, Device, Tensor, WithDType};
pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
@ -30,7 +30,7 @@ impl<'source> FromPyObject<'source> for PyDType {
use std::str::FromStr;
let dtype: &str = ob.extract()?;
let dtype = DType::from_str(dtype)
.map_err(|_| PyTypeError::new_err(format!("invalid dtype {dtype}")))?;
.map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
Ok(Self(dtype))
}
}
@ -41,6 +41,60 @@ impl ToPyObject for PyDType {
}
}
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum PyDevice {
Cpu,
Cuda,
}
impl PyDevice {
fn from_device(device: Device) -> Self {
match device {
Device::Cpu => Self::Cpu,
Device::Cuda(_) => Self::Cuda,
}
}
fn as_device(&self) -> PyResult<Device> {
match self {
Self::Cpu => Ok(Device::Cpu),
Self::Cuda => {
let mut device = CUDA_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_cuda(0).map_err(wrap_err)?;
*device = Some(d.clone());
Ok(d)
}
}
}
}
impl<'source> FromPyObject<'source> for PyDevice {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let device: &str = ob.extract()?;
let device = match device {
"cpu" => PyDevice::Cpu,
"cuda" => PyDevice::Cuda,
_ => Err(PyTypeError::new_err(format!("invalid device '{device}'")))?,
};
Ok(device)
}
}
impl ToPyObject for PyDevice {
fn to_object(&self, py: Python<'_>) -> PyObject {
let str = match self {
PyDevice::Cpu => "cpu",
PyDevice::Cuda => "cuda",
};
str.to_object(py)
}
}
trait PyWithDType: WithDType {
fn to_py(&self, py: Python<'_>) -> PyObject;
}
@ -83,6 +137,7 @@ impl PyTensor {
#[new]
// TODO: Handle arbitrary input dtype and shape.
fn new(py: Python<'_>, vs: PyObject) -> PyResult<Self> {
use Device::Cpu;
let tensor = if let Ok(vs) = vs.extract::<u32>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
@ -155,6 +210,11 @@ impl PyTensor {
PyDType(self.0.dtype()).to_object(py)
}
#[getter]
fn device(&self, py: Python<'_>) -> PyObject {
PyDevice::from_device(self.0.device()).to_object(py)
}
#[getter]
fn rank(&self) -> usize {
self.0.rank()
@ -292,6 +352,11 @@ impl PyTensor {
fn to_dtype(&self, dtype: PyDType) -> PyResult<Self> {
Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
}
fn to_device(&self, device: PyDevice) -> PyResult<Self> {
let device = device.as_device()?;
Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?))
}
}
/// Concatenate the tensors across one axis.

View File

@ -2,12 +2,14 @@ import candle
t = candle.Tensor(42.0)
print(t)
print("shape", t.shape, t.rank)
print(t.shape, t.rank, t.device)
print(t + t)
t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6])
print(t)
print(t+t)
t = t.reshape([2, 4])
print(t.matmul(t.t()))
print(t.to_dtype("u8"))