mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
convert pytorch's tensor in Python API (#1172)
* convert pytorch's tensor * separate tests for convert pytorch tensor
This commit is contained in:
@ -396,6 +396,11 @@ class Tensor:
|
|||||||
Convert the tensor to a new dtype.
|
Convert the tensor to a new dtype.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
def to_torch(self) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Converts candle's tensor to pytorch's tensor
|
||||||
|
"""
|
||||||
|
pass
|
||||||
def transpose(self, dim1: int, dim2: int) -> Tensor:
|
def transpose(self, dim1: int, dim2: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
|
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
|
||||||
|
@ -211,6 +211,16 @@ enum Indexer {
|
|||||||
IndexSelect(Tensor),
|
IndexSelect(Tensor),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct TorchTensor(PyObject);
|
||||||
|
|
||||||
|
impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
|
||||||
|
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||||
|
let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
|
||||||
|
Ok(TorchTensor(numpy_value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PyTensor {
|
impl PyTensor {
|
||||||
#[new]
|
#[new]
|
||||||
@ -246,6 +256,8 @@ impl PyTensor {
|
|||||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
|
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
|
||||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
|
} else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
|
||||||
|
return PyTensor::new(py, numpy);
|
||||||
} else {
|
} else {
|
||||||
let ty = data.as_ref(py).get_type();
|
let ty = data.as_ref(py).get_type();
|
||||||
Err(PyTypeError::new_err(format!(
|
Err(PyTypeError::new_err(format!(
|
||||||
@ -299,6 +311,18 @@ impl PyTensor {
|
|||||||
M(py).map(self)
|
M(py).map(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Converts candle's tensor to pytorch's tensor
|
||||||
|
/// &RETURNS&: torch.Tensor
|
||||||
|
fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||||
|
let candle_values = self.values(py)?;
|
||||||
|
let torch_tensor: PyObject = py
|
||||||
|
.import("torch")?
|
||||||
|
.getattr("tensor")?
|
||||||
|
.call1((candle_values,))?
|
||||||
|
.extract()?;
|
||||||
|
Ok(torch_tensor)
|
||||||
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
/// Gets the tensor's shape.
|
/// Gets the tensor's shape.
|
||||||
/// &RETURNS&: Tuple[int]
|
/// &RETURNS&: Tuple[int]
|
||||||
|
14
candle-pyo3/test_pytorch.py
Normal file
14
candle-pyo3/test_pytorch.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import candle
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# convert from candle tensor to torch tensor
|
||||||
|
t = candle.randn((3, 512, 512))
|
||||||
|
torch_tensor = t.to_torch()
|
||||||
|
print(torch_tensor)
|
||||||
|
print(type(torch_tensor))
|
||||||
|
|
||||||
|
# convert from torch tensor to candle tensor
|
||||||
|
t = torch.randn((3, 512, 512))
|
||||||
|
candle_tensor = candle.Tensor(t)
|
||||||
|
print(candle_tensor)
|
||||||
|
print(type(candle_tensor))
|
Reference in New Issue
Block a user