mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Preliminary pyo3 support for device.
This commit is contained in:
@ -18,3 +18,7 @@ crate-type = ["cdylib"]
|
|||||||
candle = { path = "../candle-core", default-features=false }
|
candle = { path = "../candle-core", default-features=false }
|
||||||
pyo3 = { version = "0.19.0", features = ["extension-module"] }
|
pyo3 = { version = "0.19.0", features = ["extension-module"] }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["cuda"]
|
||||||
|
cuda = ["candle/cuda"]
|
||||||
|
@ -4,7 +4,7 @@ use pyo3::types::PyTuple;
|
|||||||
|
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
|
||||||
use ::candle::{DType, Device::Cpu, Tensor, WithDType};
|
use ::candle::{DType, Device, Tensor, WithDType};
|
||||||
|
|
||||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||||
@ -30,7 +30,7 @@ impl<'source> FromPyObject<'source> for PyDType {
|
|||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
let dtype: &str = ob.extract()?;
|
let dtype: &str = ob.extract()?;
|
||||||
let dtype = DType::from_str(dtype)
|
let dtype = DType::from_str(dtype)
|
||||||
.map_err(|_| PyTypeError::new_err(format!("invalid dtype {dtype}")))?;
|
.map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
|
||||||
Ok(Self(dtype))
|
Ok(Self(dtype))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -41,6 +41,43 @@ impl ToPyObject for PyDType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
|
enum PyDevice {
|
||||||
|
Cpu,
|
||||||
|
Cuda,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PyDevice {
|
||||||
|
fn from_device(device: Device) -> Self {
|
||||||
|
match device {
|
||||||
|
Device::Cpu => Self::Cpu,
|
||||||
|
Device::Cuda(_) => Self::Cuda,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'source> FromPyObject<'source> for PyDevice {
|
||||||
|
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||||
|
let device: &str = ob.extract()?;
|
||||||
|
let device = match device {
|
||||||
|
"cpu" => PyDevice::Cpu,
|
||||||
|
"cuda" => PyDevice::Cuda,
|
||||||
|
_ => Err(PyTypeError::new_err(format!("invalid device '{device}'")))?,
|
||||||
|
};
|
||||||
|
Ok(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToPyObject for PyDevice {
|
||||||
|
fn to_object(&self, py: Python<'_>) -> PyObject {
|
||||||
|
let str = match self {
|
||||||
|
PyDevice::Cpu => "cpu",
|
||||||
|
PyDevice::Cuda => "cuda",
|
||||||
|
};
|
||||||
|
str.to_object(py)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
trait PyWithDType: WithDType {
|
trait PyWithDType: WithDType {
|
||||||
fn to_py(&self, py: Python<'_>) -> PyObject;
|
fn to_py(&self, py: Python<'_>) -> PyObject;
|
||||||
}
|
}
|
||||||
@ -83,6 +120,7 @@ impl PyTensor {
|
|||||||
#[new]
|
#[new]
|
||||||
// TODO: Handle arbitrary input dtype and shape.
|
// TODO: Handle arbitrary input dtype and shape.
|
||||||
fn new(py: Python<'_>, vs: PyObject) -> PyResult<Self> {
|
fn new(py: Python<'_>, vs: PyObject) -> PyResult<Self> {
|
||||||
|
use Device::Cpu;
|
||||||
let tensor = if let Ok(vs) = vs.extract::<u32>(py) {
|
let tensor = if let Ok(vs) = vs.extract::<u32>(py) {
|
||||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
|
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
|
||||||
@ -155,6 +193,11 @@ impl PyTensor {
|
|||||||
PyDType(self.0.dtype()).to_object(py)
|
PyDType(self.0.dtype()).to_object(py)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn device(&self, py: Python<'_>) -> PyObject {
|
||||||
|
PyDevice::from_device(self.0.device()).to_object(py)
|
||||||
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn rank(&self) -> usize {
|
fn rank(&self) -> usize {
|
||||||
self.0.rank()
|
self.0.rank()
|
||||||
|
@ -2,12 +2,14 @@ import candle
|
|||||||
|
|
||||||
t = candle.Tensor(42.0)
|
t = candle.Tensor(42.0)
|
||||||
print(t)
|
print(t)
|
||||||
print("shape", t.shape, t.rank)
|
print(t.shape, t.rank, t.device)
|
||||||
print(t + t)
|
print(t + t)
|
||||||
|
|
||||||
t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6])
|
t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6])
|
||||||
print(t)
|
print(t)
|
||||||
print(t+t)
|
print(t+t)
|
||||||
|
|
||||||
t = t.reshape([2, 4])
|
t = t.reshape([2, 4])
|
||||||
print(t.matmul(t.t()))
|
print(t.matmul(t.t()))
|
||||||
|
|
||||||
print(t.to_dtype("u8"))
|
print(t.to_dtype("u8"))
|
||||||
|
Reference in New Issue
Block a user