Add some tensor creation functions to the pyo3 bindings. (#326)

This commit is contained in:
Laurent Mazare
2023-08-06 06:50:33 +01:00
committed by GitHub
parent b278834267
commit 88bd3b604a
3 changed files with 77 additions and 5 deletions

View File

@ -9,4 +9,9 @@ clean:
test:
cargo test
pyo3-test:
cargo build --profile=release-with-debug --package candle-pyo3
ln -f -s ./target/release-with-debug/libcandle.so candle.so
PYTHONPATH=. python3 candle-pyo3/test.py
all: test

View File

@ -1,3 +1,4 @@
// TODO: Handle negative dimension indexes.
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyTuple;
@ -10,7 +11,23 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
#[derive(Clone)]
#[derive(Clone, Debug)]
struct PyShape(Vec<usize>);
impl<'source> pyo3::FromPyObject<'source> for PyShape {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?;
Ok(PyShape(dims))
}
}
impl From<PyShape> for ::candle::Shape {
fn from(val: PyShape) -> Self {
val.0.into()
}
}
#[derive(Clone, Debug)]
#[pyclass(name = "Tensor")]
struct PyTensor(Tensor);
@ -279,16 +296,15 @@ impl PyTensor {
Ok(Self(tensor))
}
// TODO: Add a PyShape type?
fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> {
fn reshape(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
}
fn broadcast_as(&self, shape: Vec<usize>) -> PyResult<Self> {
fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
}
fn broadcast_left(&self, shape: Vec<usize>) -> PyResult<Self> {
fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
}
@ -381,11 +397,59 @@ fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> {
PyTensor::new(py, vs)
}
#[pyfunction]
#[pyo3(signature = (shape, *, device=None))]
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pyfunction]
#[pyo3(signature = (shape, *, device=None))]
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None))]
fn ones(
_py: Python<'_>,
shape: PyShape,
dtype: Option<PyDType>,
device: Option<PyDevice>,
) -> PyResult<PyTensor> {
let dtype = dtype.map_or(DType::F32, |dt| dt.0);
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None))]
fn zeros(
_py: Python<'_>,
shape: PyShape,
dtype: Option<PyDType>,
device: Option<PyDevice>,
) -> PyResult<PyTensor> {
let dtype = dtype.map_or(DType::F32, |dt| dt.0);
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pymodule]
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyTensor>()?;
m.add_function(wrap_pyfunction!(cat, m)?)?;
m.add_function(wrap_pyfunction!(ones, m)?)?;
m.add_function(wrap_pyfunction!(rand, m)?)?;
m.add_function(wrap_pyfunction!(randn, m)?)?;
m.add_function(wrap_pyfunction!(tensor, m)?)?;
m.add_function(wrap_pyfunction!(stack, m)?)?;
m.add_function(wrap_pyfunction!(zeros, m)?)?;
Ok(())
}

View File

@ -13,3 +13,6 @@ t = t.reshape([2, 4])
print(t.matmul(t.t()))
print(t.to_dtype("u8"))
t = candle.randn((5, 3))
print(t)