mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Handle more input types to create tensors.
This commit is contained in:
@ -63,8 +63,19 @@ trait MapDType {
|
||||
impl PyTensor {
|
||||
#[new]
|
||||
// TODO: Handle arbitrary input dtype and shape.
|
||||
fn new(f: f32) -> PyResult<Self> {
|
||||
Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?))
|
||||
fn new(py: Python<'_>, vs: PyObject) -> PyResult<Self> {
|
||||
let tensor = if let Ok(vs) = vs.extract::<u32>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
|
||||
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<f32>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
|
||||
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
|
||||
} else {
|
||||
Err(PyTypeError::new_err("incorrect type for tensor"))?
|
||||
};
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
|
||||
/// Gets the tensor data as a Python value/array/array of array/...
|
||||
@ -167,6 +178,11 @@ impl PyTensor {
|
||||
fn __rmul__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
self.__mul__(rhs)
|
||||
}
|
||||
|
||||
// TODO: Add a PyShape type?
|
||||
fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
|
@ -4,3 +4,8 @@ t = candle.Tensor(42.0)
|
||||
print(t)
|
||||
print("shape", t.shape, t.rank)
|
||||
print(t + t)
|
||||
|
||||
t = candle.Tensor([3, 1, 4, 1, 5, 9, 2, 6])
|
||||
print(t)
|
||||
print(t+t)
|
||||
print(t.reshape([2, 4]))
|
||||
|
Reference in New Issue
Block a user