mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
PyO3: Add pytorch like .to()
operator to candle.Tensor
(#1100)
* add `.to()` operator * Only allow each value to be provided once via `args` or `kwargs`
This commit is contained in:
@ -381,6 +381,11 @@ class Tensor:
|
||||
Transposes the tensor.
|
||||
"""
|
||||
pass
|
||||
def to(self, *args, **kwargs) -> Tensor:
|
||||
"""
|
||||
Performs Tensor dtype and/or device conversion.
|
||||
"""
|
||||
pass
|
||||
def to_device(self, device: Union[str, Device]) -> Tensor:
|
||||
"""
|
||||
Move the tensor to a new device.
|
||||
|
@ -772,6 +772,112 @@ impl PyTensor {
|
||||
Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
#[pyo3(signature = (*args, **kwargs), text_signature = "(self, *args, **kwargs)")]
|
||||
/// Performs Tensor dtype and/or device conversion.
|
||||
/// &RETURNS&: Tensor
|
||||
fn to(&self, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
||||
let mut device: Option<PyDevice> = None;
|
||||
let mut dtype: Option<PyDType> = None;
|
||||
let mut other: Option<PyTensor> = None;
|
||||
|
||||
fn handle_duplicates<T>(
|
||||
opt: &mut Option<T>,
|
||||
extraction_result: PyResult<T>,
|
||||
err_msg: &'static str,
|
||||
) -> PyResult<()> {
|
||||
if let Ok(sucessfull_extraction) = extraction_result {
|
||||
if opt.is_some() {
|
||||
return Err(PyValueError::new_err(err_msg));
|
||||
}
|
||||
*opt = Some(sucessfull_extraction);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
//handle args
|
||||
for arg in args.iter() {
|
||||
if arg.extract::<PyDevice>().is_ok() {
|
||||
handle_duplicates(
|
||||
&mut device,
|
||||
arg.extract::<PyDevice>(),
|
||||
"cannot specify multiple devices",
|
||||
)?;
|
||||
} else if arg.extract::<PyDType>().is_ok() {
|
||||
handle_duplicates(
|
||||
&mut dtype,
|
||||
arg.extract::<PyDType>(),
|
||||
"cannot specify multiple dtypes",
|
||||
)?;
|
||||
} else if arg.extract::<PyTensor>().is_ok() {
|
||||
handle_duplicates(
|
||||
&mut other,
|
||||
arg.extract::<PyTensor>(),
|
||||
"cannot specify multiple output tensors",
|
||||
)?;
|
||||
} else {
|
||||
return Err(PyTypeError::new_err(format!(
|
||||
"unsupported argument type `{:#?}`",
|
||||
arg.get_type().name()
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(kwargs) = kwargs {
|
||||
if let Some(any) = kwargs.get_item("dtype") {
|
||||
handle_duplicates(
|
||||
&mut dtype,
|
||||
any.extract::<PyDType>(),
|
||||
"cannot specify multiple dtypes",
|
||||
)?;
|
||||
}
|
||||
if let Some(any) = kwargs.get_item("device") {
|
||||
handle_duplicates(
|
||||
&mut device,
|
||||
any.extract::<PyDevice>(),
|
||||
"cannot specify multiple devices",
|
||||
)?;
|
||||
}
|
||||
if let Some(any) = kwargs.get_item("other") {
|
||||
handle_duplicates(
|
||||
&mut other,
|
||||
any.extract::<PyTensor>(),
|
||||
"cannot specify multiple output tensors",
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(other) = other {
|
||||
if device.is_some() {
|
||||
return Err(PyValueError::new_err(
|
||||
"cannot specify both an output tensor and a device",
|
||||
));
|
||||
}
|
||||
if dtype.is_some() {
|
||||
return Err(PyValueError::new_err(
|
||||
"cannot specify both an output tensor and a dtype",
|
||||
));
|
||||
}
|
||||
dtype = Some(other.dtype());
|
||||
device = Some(PyDevice::from_device(other.0.device()));
|
||||
}
|
||||
|
||||
let result = match (device, dtype) {
|
||||
(Some(device), Some(dtype)) => self
|
||||
.0
|
||||
.to_device(&device.as_device()?)
|
||||
.map_err(wrap_err)?
|
||||
.to_dtype(dtype.0)
|
||||
.map_err(wrap_err)?,
|
||||
(Some(device), None) => self.0.to_device(&device.as_device()?).map_err(wrap_err)?,
|
||||
(None, Some(dtype)) => self.0.to_dtype(dtype.0).map_err(wrap_err)?,
|
||||
(None, None) => {
|
||||
return Err(PyTypeError::new_err("No valide dtype or device specified"))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(PyTensor(result))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, dtype:Union[str,DType])")]
|
||||
/// Convert the tensor to a new dtype.
|
||||
/// &RETURNS&: Tensor
|
||||
|
@ -1,5 +1,6 @@
|
||||
import candle
|
||||
from candle import Tensor
|
||||
from candle.utils import cuda_is_available
|
||||
import pytest
|
||||
|
||||
|
||||
@ -75,6 +76,70 @@ def test_tensor_can_be_scliced_3d():
|
||||
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
|
||||
|
||||
|
||||
def test_tensor_can_be_cast_via_to():
|
||||
t = Tensor(42.0)
|
||||
assert str(t.dtype) == str(candle.f32)
|
||||
t_new_args = t.to(candle.f64)
|
||||
assert str(t_new_args.dtype) == str(candle.f64)
|
||||
t_new_kwargs = t.to(dtype=candle.f64)
|
||||
assert str(t_new_kwargs.dtype) == str(candle.f64)
|
||||
pytest.raises(TypeError, lambda: t.to("not a dtype"))
|
||||
pytest.raises(TypeError, lambda: t.to(dtype="not a dtype"))
|
||||
pytest.raises(TypeError, lambda: t.to(candle.f64, "not a dtype"))
|
||||
pytest.raises(TypeError, lambda: t.to())
|
||||
pytest.raises(ValueError, lambda: t.to(candle.f16, dtype=candle.f64))
|
||||
pytest.raises(ValueError, lambda: t.to(candle.f16, candle.f16))
|
||||
|
||||
other = Tensor(42.0).to(candle.f64)
|
||||
t_new_other_args = t.to(other)
|
||||
assert str(t_new_other_args.dtype) == str(candle.f64)
|
||||
t_new_other_kwargs = t.to(other=other)
|
||||
assert str(t_new_other_kwargs.dtype) == str(candle.f64)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
|
||||
def test_tensor_can_be_moved_via_to():
|
||||
t = Tensor(42.0)
|
||||
assert t.device == "cpu"
|
||||
t_new_args = t.to("cuda")
|
||||
assert t_new_args.device == "cuda"
|
||||
t_new_kwargs = t.to(device="cuda")
|
||||
assert t_new_kwargs.device == "cuda"
|
||||
pytest.raises(TypeError, lambda: t.to("not a device"))
|
||||
pytest.raises(TypeError, lambda: t.to(device="not a device"))
|
||||
pytest.raises(TypeError, lambda: t.to("cuda", "not a device"))
|
||||
pytest.raises(TypeError, lambda: t.to())
|
||||
pytest.raises(ValueError, lambda: t.to("cuda", device="cpu"))
|
||||
pytest.raises(ValueError, lambda: t.to("cuda", "cuda"))
|
||||
|
||||
other = Tensor(42.0).to("cuda")
|
||||
t_new_other_args = t.to(other)
|
||||
assert t_new_other_args.device == "cuda"
|
||||
t_new_other_kwargs = t.to(other=other)
|
||||
assert t_new_other_kwargs.device == "cuda"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
|
||||
def test_tensor_can_be_moved_and_cast_via_to():
|
||||
t = Tensor(42.0)
|
||||
assert t.device == "cpu"
|
||||
assert str(t.dtype) == str(candle.f32)
|
||||
t_new_args = t.to("cuda", candle.f64)
|
||||
assert t_new_args.device == "cuda"
|
||||
assert str(t_new_args.dtype) == str(candle.f64)
|
||||
t_new_kwargs = t.to(device="cuda", dtype=candle.f64)
|
||||
assert t_new_kwargs.device == "cuda"
|
||||
assert str(t_new_kwargs.dtype) == str(candle.f64)
|
||||
|
||||
other = Tensor(42.0).to("cuda").to(candle.f64)
|
||||
t_new_other_args = t.to(other)
|
||||
assert t_new_other_args.device == "cuda"
|
||||
assert str(t_new_other_args.dtype) == str(candle.f64)
|
||||
t_new_other_kwargs = t.to(other=other)
|
||||
assert t_new_other_kwargs.device == "cuda"
|
||||
assert str(t_new_other_kwargs.dtype) == str(candle.f64)
|
||||
|
||||
|
||||
def test_tensor_can_be_added():
|
||||
t = Tensor(42.0)
|
||||
result = t + t
|
||||
|
Reference in New Issue
Block a user