mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add some tensor creation functions to the pyo3 bindings. (#326)
This commit is contained in:
5
Makefile
5
Makefile
@ -9,4 +9,9 @@ clean:
|
||||
test:
|
||||
cargo test
|
||||
|
||||
pyo3-test:
|
||||
cargo build --profile=release-with-debug --package candle-pyo3
|
||||
ln -f -s ./target/release-with-debug/libcandle.so candle.so
|
||||
PYTHONPATH=. python3 candle-pyo3/test.py
|
||||
|
||||
all: test
|
||||
|
@ -1,3 +1,4 @@
|
||||
// TODO: Handle negative dimension indexes.
|
||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyTuple;
|
||||
@ -10,7 +11,23 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
struct PyShape(Vec<usize>);
|
||||
|
||||
impl<'source> pyo3::FromPyObject<'source> for PyShape {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?;
|
||||
Ok(PyShape(dims))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PyShape> for ::candle::Shape {
|
||||
fn from(val: PyShape) -> Self {
|
||||
val.0.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[pyclass(name = "Tensor")]
|
||||
struct PyTensor(Tensor);
|
||||
|
||||
@ -279,16 +296,15 @@ impl PyTensor {
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
|
||||
// TODO: Add a PyShape type?
|
||||
fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> {
|
||||
fn reshape(&self, shape: PyShape) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn broadcast_as(&self, shape: Vec<usize>) -> PyResult<Self> {
|
||||
fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn broadcast_left(&self, shape: Vec<usize>) -> PyResult<Self> {
|
||||
fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
@ -381,11 +397,59 @@ fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> {
|
||||
PyTensor::new(py, vs)
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, device=None))]
|
||||
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, device=None))]
|
||||
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, dtype=None, device=None))]
|
||||
fn ones(
|
||||
_py: Python<'_>,
|
||||
shape: PyShape,
|
||||
dtype: Option<PyDType>,
|
||||
device: Option<PyDevice>,
|
||||
) -> PyResult<PyTensor> {
|
||||
let dtype = dtype.map_or(DType::F32, |dt| dt.0);
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, dtype=None, device=None))]
|
||||
fn zeros(
|
||||
_py: Python<'_>,
|
||||
shape: PyShape,
|
||||
dtype: Option<PyDType>,
|
||||
device: Option<PyDevice>,
|
||||
) -> PyResult<PyTensor> {
|
||||
let dtype = dtype.map_or(DType::F32, |dt| dt.0);
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PyTensor>()?;
|
||||
m.add_function(wrap_pyfunction!(cat, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(ones, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(rand, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(randn, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(tensor, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(stack, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(zeros, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -13,3 +13,6 @@ t = t.reshape([2, 4])
|
||||
print(t.matmul(t.t()))
|
||||
|
||||
print(t.to_dtype("u8"))
|
||||
|
||||
t = candle.randn((5, 3))
|
||||
print(t)
|
||||
|
Reference in New Issue
Block a user