mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Handle more input types to create tensors.
This commit is contained in:
@ -63,8 +63,19 @@ trait MapDType {
|
|||||||
impl PyTensor {
|
impl PyTensor {
|
||||||
#[new]
|
#[new]
|
||||||
// TODO: Handle arbitrary input dtype and shape.
|
// TODO: Handle arbitrary input dtype and shape.
|
||||||
fn new(f: f32) -> PyResult<Self> {
|
fn new(py: Python<'_>, vs: PyObject) -> PyResult<Self> {
|
||||||
Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?))
|
let tensor = if let Ok(vs) = vs.extract::<u32>(py) {
|
||||||
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
|
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
|
||||||
|
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
|
||||||
|
} else if let Ok(vs) = vs.extract::<f32>(py) {
|
||||||
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
|
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
|
||||||
|
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
|
||||||
|
} else {
|
||||||
|
Err(PyTypeError::new_err("incorrect type for tensor"))?
|
||||||
|
};
|
||||||
|
Ok(Self(tensor))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets the tensor data as a Python value/array/array of array/...
|
/// Gets the tensor data as a Python value/array/array of array/...
|
||||||
@ -167,6 +178,11 @@ impl PyTensor {
|
|||||||
fn __rmul__(&self, rhs: &PyAny) -> PyResult<Self> {
|
fn __rmul__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||||
self.__mul__(rhs)
|
self.__mul__(rhs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Add a PyShape type?
|
||||||
|
fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
|
@ -4,3 +4,8 @@ t = candle.Tensor(42.0)
|
|||||||
print(t)
|
print(t)
|
||||||
print("shape", t.shape, t.rank)
|
print("shape", t.shape, t.rank)
|
||||||
print(t + t)
|
print(t + t)
|
||||||
|
|
||||||
|
t = candle.Tensor([3, 1, 4, 1, 5, 9, 2, 6])
|
||||||
|
print(t)
|
||||||
|
print(t+t)
|
||||||
|
print(t.reshape([2, 4]))
|
||||||
|
Reference in New Issue
Block a user