Handle more input types to create tensors.

This commit is contained in:
laurent
2023-07-02 07:19:46 +01:00
parent 4a28dcf828
commit dfe197f791
2 changed files with 23 additions and 2 deletions

View File

@ -63,8 +63,19 @@ trait MapDType {
impl PyTensor {
#[new]
// TODO: Handle arbitrary input dtype and shape.
fn new(f: f32) -> PyResult<Self> {
Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?))
fn new(py: Python<'_>, vs: PyObject) -> PyResult<Self> {
let tensor = if let Ok(vs) = vs.extract::<u32>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<f32>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
} else {
Err(PyTypeError::new_err("incorrect type for tensor"))?
};
Ok(Self(tensor))
}
/// Gets the tensor data as a Python value/array/array of array/...
@ -167,6 +178,11 @@ impl PyTensor {
fn __rmul__(&self, rhs: &PyAny) -> PyResult<Self> {
self.__mul__(rhs)
}
// TODO: Add a PyShape type?
fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> {
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
}
}
#[pyfunction]

View File

@ -4,3 +4,8 @@ t = candle.Tensor(42.0)
print(t)
print("shape", t.shape, t.rank)
print(t + t)
t = candle.Tensor([3, 1, 4, 1, 5, 9, 2, 6])
print(t)
print(t+t)
print(t.reshape([2, 4]))