mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add some tensor creation functions to the pyo3 bindings. (#326)
This commit is contained in:
5
Makefile
5
Makefile
@ -9,4 +9,9 @@ clean:
|
|||||||
test:
|
test:
|
||||||
cargo test
|
cargo test
|
||||||
|
|
||||||
|
pyo3-test:
|
||||||
|
cargo build --profile=release-with-debug --package candle-pyo3
|
||||||
|
ln -f -s ./target/release-with-debug/libcandle.so candle.so
|
||||||
|
PYTHONPATH=. python3 candle-pyo3/test.py
|
||||||
|
|
||||||
all: test
|
all: test
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
// TODO: Handle negative dimension indexes.
|
||||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::PyTuple;
|
use pyo3::types::PyTuple;
|
||||||
@ -10,7 +11,23 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
|||||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone, Debug)]
|
||||||
|
struct PyShape(Vec<usize>);
|
||||||
|
|
||||||
|
impl<'source> pyo3::FromPyObject<'source> for PyShape {
|
||||||
|
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||||
|
let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?;
|
||||||
|
Ok(PyShape(dims))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<PyShape> for ::candle::Shape {
|
||||||
|
fn from(val: PyShape) -> Self {
|
||||||
|
val.0.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
#[pyclass(name = "Tensor")]
|
#[pyclass(name = "Tensor")]
|
||||||
struct PyTensor(Tensor);
|
struct PyTensor(Tensor);
|
||||||
|
|
||||||
@ -279,16 +296,15 @@ impl PyTensor {
|
|||||||
Ok(Self(tensor))
|
Ok(Self(tensor))
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Add a PyShape type?
|
fn reshape(&self, shape: PyShape) -> PyResult<Self> {
|
||||||
fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> {
|
|
||||||
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
|
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn broadcast_as(&self, shape: Vec<usize>) -> PyResult<Self> {
|
fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> {
|
||||||
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
|
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn broadcast_left(&self, shape: Vec<usize>) -> PyResult<Self> {
|
fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> {
|
||||||
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
|
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -381,11 +397,59 @@ fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> {
|
|||||||
PyTensor::new(py, vs)
|
PyTensor::new(py, vs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
#[pyo3(signature = (shape, *, device=None))]
|
||||||
|
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
||||||
|
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||||
|
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
||||||
|
Ok(PyTensor(tensor))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
#[pyo3(signature = (shape, *, device=None))]
|
||||||
|
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
||||||
|
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||||
|
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
||||||
|
Ok(PyTensor(tensor))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
#[pyo3(signature = (shape, *, dtype=None, device=None))]
|
||||||
|
fn ones(
|
||||||
|
_py: Python<'_>,
|
||||||
|
shape: PyShape,
|
||||||
|
dtype: Option<PyDType>,
|
||||||
|
device: Option<PyDevice>,
|
||||||
|
) -> PyResult<PyTensor> {
|
||||||
|
let dtype = dtype.map_or(DType::F32, |dt| dt.0);
|
||||||
|
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||||
|
let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?;
|
||||||
|
Ok(PyTensor(tensor))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyfunction]
|
||||||
|
#[pyo3(signature = (shape, *, dtype=None, device=None))]
|
||||||
|
fn zeros(
|
||||||
|
_py: Python<'_>,
|
||||||
|
shape: PyShape,
|
||||||
|
dtype: Option<PyDType>,
|
||||||
|
device: Option<PyDevice>,
|
||||||
|
) -> PyResult<PyTensor> {
|
||||||
|
let dtype = dtype.map_or(DType::F32, |dt| dt.0);
|
||||||
|
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||||
|
let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?;
|
||||||
|
Ok(PyTensor(tensor))
|
||||||
|
}
|
||||||
|
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||||
m.add_class::<PyTensor>()?;
|
m.add_class::<PyTensor>()?;
|
||||||
m.add_function(wrap_pyfunction!(cat, m)?)?;
|
m.add_function(wrap_pyfunction!(cat, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(ones, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(rand, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(randn, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(tensor, m)?)?;
|
m.add_function(wrap_pyfunction!(tensor, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(stack, m)?)?;
|
m.add_function(wrap_pyfunction!(stack, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(zeros, m)?)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -13,3 +13,6 @@ t = t.reshape([2, 4])
|
|||||||
print(t.matmul(t.t()))
|
print(t.matmul(t.t()))
|
||||||
|
|
||||||
print(t.to_dtype("u8"))
|
print(t.to_dtype("u8"))
|
||||||
|
|
||||||
|
t = candle.randn((5, 3))
|
||||||
|
print(t)
|
||||||
|
Reference in New Issue
Block a user