convert pytorch's tensor in Python API (#1172)

* convert pytorch's tensor

* separate tests for convert pytorch tensor
This commit is contained in:
andrew
2023-10-26 01:39:14 +07:00
committed by GitHub
parent 0acd16751d
commit 6a446d9d73
3 changed files with 43 additions and 0 deletions

View File

@ -211,6 +211,16 @@ enum Indexer {
IndexSelect(Tensor),
}
#[derive(Clone, Debug)]
struct TorchTensor(PyObject);
impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
Ok(TorchTensor(numpy_value))
}
}
#[pymethods]
impl PyTensor {
#[new]
@ -246,6 +256,8 @@ impl PyTensor {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
return PyTensor::new(py, numpy);
} else {
let ty = data.as_ref(py).get_type();
Err(PyTypeError::new_err(format!(
@ -299,6 +311,18 @@ impl PyTensor {
M(py).map(self)
}
/// Converts candle's tensor to pytorch's tensor
/// &RETURNS&: torch.Tensor
fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
let candle_values = self.values(py)?;
let torch_tensor: PyObject = py
.import("torch")?
.getattr("tensor")?
.call1((candle_values,))?
.extract()?;
Ok(torch_tensor)
}
#[getter]
/// Gets the tensor's shape.
/// &RETURNS&: Tuple[int]