mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
convert pytorch's tensor in Python API (#1172)
* convert pytorch's tensor * separate tests for convert pytorch tensor
This commit is contained in:
@ -211,6 +211,16 @@ enum Indexer {
|
||||
IndexSelect(Tensor),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct TorchTensor(PyObject);
|
||||
|
||||
impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
|
||||
Ok(TorchTensor(numpy_value))
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyTensor {
|
||||
#[new]
|
||||
@ -246,6 +256,8 @@ impl PyTensor {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
|
||||
return PyTensor::new(py, numpy);
|
||||
} else {
|
||||
let ty = data.as_ref(py).get_type();
|
||||
Err(PyTypeError::new_err(format!(
|
||||
@ -299,6 +311,18 @@ impl PyTensor {
|
||||
M(py).map(self)
|
||||
}
|
||||
|
||||
/// Converts candle's tensor to pytorch's tensor
|
||||
/// &RETURNS&: torch.Tensor
|
||||
fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||
let candle_values = self.values(py)?;
|
||||
let torch_tensor: PyObject = py
|
||||
.import("torch")?
|
||||
.getattr("tensor")?
|
||||
.call1((candle_values,))?
|
||||
.extract()?;
|
||||
Ok(torch_tensor)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Gets the tensor's shape.
|
||||
/// &RETURNS&: Tuple[int]
|
||||
|
Reference in New Issue
Block a user