mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Pyo3 dtype (#327)
* Better handling of dtypes in pyo3. * More pyo3 dtype.
This commit is contained in:
3
Makefile
3
Makefile
@ -11,7 +11,6 @@ test:
|
|||||||
|
|
||||||
pyo3-test:
|
pyo3-test:
|
||||||
cargo build --profile=release-with-debug --package candle-pyo3
|
cargo build --profile=release-with-debug --package candle-pyo3
|
||||||
ln -f -s ./target/release-with-debug/libcandle.so candle.so
|
python3 candle-pyo3/test.py
|
||||||
PYTHONPATH=. python3 candle-pyo3/test.py
|
|
||||||
|
|
||||||
all: test
|
all: test
|
||||||
|
@ -16,8 +16,8 @@ doc = false
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
|
||||||
pyo3 = { version = "0.19.0", features = ["extension-module"] }
|
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
|
pyo3 = { version = "0.19.0", features = ["extension-module"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
pyo3-build-config = "0.19"
|
pyo3-build-config = "0.19"
|
||||||
|
@ -40,21 +40,30 @@ impl std::ops::Deref for PyTensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
|
#[pyclass(name = "DType")]
|
||||||
struct PyDType(DType);
|
struct PyDType(DType);
|
||||||
|
|
||||||
impl<'source> FromPyObject<'source> for PyDType {
|
#[pymethods]
|
||||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
impl PyDType {
|
||||||
use std::str::FromStr;
|
fn __repr__(&self) -> String {
|
||||||
let dtype: &str = ob.extract()?;
|
format!("{:?}", self.0)
|
||||||
let dtype = DType::from_str(dtype)
|
}
|
||||||
.map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
|
|
||||||
Ok(Self(dtype))
|
fn __str__(&self) -> String {
|
||||||
|
self.__repr__()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToPyObject for PyDType {
|
impl PyDType {
|
||||||
fn to_object(&self, py: Python<'_>) -> PyObject {
|
fn from_pyobject(ob: PyObject, py: Python<'_>) -> PyResult<Self> {
|
||||||
self.0.as_str().to_object(py)
|
use std::str::FromStr;
|
||||||
|
if let Ok(dtype) = ob.extract::<&str>(py) {
|
||||||
|
let dtype = DType::from_str(dtype)
|
||||||
|
.map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
|
||||||
|
Ok(Self(dtype))
|
||||||
|
} else {
|
||||||
|
ob.extract(py)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -223,8 +232,8 @@ impl PyTensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn dtype(&self, py: Python<'_>) -> PyObject {
|
fn dtype(&self) -> PyDType {
|
||||||
PyDType(self.0.dtype()).to_object(py)
|
PyDType(self.0.dtype())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
@ -367,7 +376,8 @@ impl PyTensor {
|
|||||||
Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
|
Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_dtype(&self, dtype: PyDType) -> PyResult<Self> {
|
fn to_dtype(&self, dtype: PyObject, py: Python<'_>) -> PyResult<Self> {
|
||||||
|
let dtype = PyDType::from_pyobject(dtype, py)?;
|
||||||
Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
|
Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -416,12 +426,15 @@ fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<
|
|||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
#[pyo3(signature = (shape, *, dtype=None, device=None))]
|
#[pyo3(signature = (shape, *, dtype=None, device=None))]
|
||||||
fn ones(
|
fn ones(
|
||||||
_py: Python<'_>,
|
py: Python<'_>,
|
||||||
shape: PyShape,
|
shape: PyShape,
|
||||||
dtype: Option<PyDType>,
|
dtype: Option<PyObject>,
|
||||||
device: Option<PyDevice>,
|
device: Option<PyDevice>,
|
||||||
) -> PyResult<PyTensor> {
|
) -> PyResult<PyTensor> {
|
||||||
let dtype = dtype.map_or(DType::F32, |dt| dt.0);
|
let dtype = match dtype {
|
||||||
|
None => DType::F32,
|
||||||
|
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
|
||||||
|
};
|
||||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||||
let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?;
|
let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?;
|
||||||
Ok(PyTensor(tensor))
|
Ok(PyTensor(tensor))
|
||||||
@ -430,12 +443,15 @@ fn ones(
|
|||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
#[pyo3(signature = (shape, *, dtype=None, device=None))]
|
#[pyo3(signature = (shape, *, dtype=None, device=None))]
|
||||||
fn zeros(
|
fn zeros(
|
||||||
_py: Python<'_>,
|
py: Python<'_>,
|
||||||
shape: PyShape,
|
shape: PyShape,
|
||||||
dtype: Option<PyDType>,
|
dtype: Option<PyObject>,
|
||||||
device: Option<PyDevice>,
|
device: Option<PyDevice>,
|
||||||
) -> PyResult<PyTensor> {
|
) -> PyResult<PyTensor> {
|
||||||
let dtype = dtype.map_or(DType::F32, |dt| dt.0);
|
let dtype = match dtype {
|
||||||
|
None => DType::F32,
|
||||||
|
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
|
||||||
|
};
|
||||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||||
let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?;
|
let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?;
|
||||||
Ok(PyTensor(tensor))
|
Ok(PyTensor(tensor))
|
||||||
@ -444,6 +460,13 @@ fn zeros(
|
|||||||
#[pymodule]
|
#[pymodule]
|
||||||
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||||
m.add_class::<PyTensor>()?;
|
m.add_class::<PyTensor>()?;
|
||||||
|
m.add_class::<PyDType>()?;
|
||||||
|
m.add("u8", PyDType(DType::U8))?;
|
||||||
|
m.add("u32", PyDType(DType::U32))?;
|
||||||
|
m.add("bf16", PyDType(DType::BF16))?;
|
||||||
|
m.add("f16", PyDType(DType::F16))?;
|
||||||
|
m.add("f32", PyDType(DType::F32))?;
|
||||||
|
m.add("f64", PyDType(DType::F64))?;
|
||||||
m.add_function(wrap_pyfunction!(cat, m)?)?;
|
m.add_function(wrap_pyfunction!(cat, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(ones, m)?)?;
|
m.add_function(wrap_pyfunction!(ones, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(rand, m)?)?;
|
m.add_function(wrap_pyfunction!(rand, m)?)?;
|
||||||
|
@ -1,3 +1,18 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# The "import candle" statement below works if there is a "candle.so" file in sys.path.
|
||||||
|
# Here we check for shared libraries that can be used in the build directory.
|
||||||
|
BUILD_DIR = "./target/release-with-debug"
|
||||||
|
so_file = BUILD_DIR + "/candle.so"
|
||||||
|
if os.path.islink(so_file): os.remove(so_file)
|
||||||
|
for lib_file in ["libcandle.dylib", "libcandle.so"]:
|
||||||
|
lib_file_ = BUILD_DIR + "/" + lib_file
|
||||||
|
if os.path.isfile(lib_file_):
|
||||||
|
os.symlink(lib_file, so_file)
|
||||||
|
sys.path.insert(0, BUILD_DIR)
|
||||||
|
break
|
||||||
|
|
||||||
import candle
|
import candle
|
||||||
|
|
||||||
t = candle.Tensor(42.0)
|
t = candle.Tensor(42.0)
|
||||||
@ -12,7 +27,9 @@ print(t+t)
|
|||||||
t = t.reshape([2, 4])
|
t = t.reshape([2, 4])
|
||||||
print(t.matmul(t.t()))
|
print(t.matmul(t.t()))
|
||||||
|
|
||||||
|
print(t.to_dtype(candle.u8))
|
||||||
print(t.to_dtype("u8"))
|
print(t.to_dtype("u8"))
|
||||||
|
|
||||||
t = candle.randn((5, 3))
|
t = candle.randn((5, 3))
|
||||||
print(t)
|
print(t)
|
||||||
|
print(t.dtype)
|
||||||
|
Reference in New Issue
Block a user