mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
PyO3: Add pytorch like .to()
operator to candle.Tensor
(#1100)
* add `.to()` operator * Only allow each value to be provided once via `args` or `kwargs`
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
import candle
|
||||
from candle import Tensor
|
||||
from candle.utils import cuda_is_available
|
||||
import pytest
|
||||
|
||||
|
||||
@ -75,6 +76,70 @@ def test_tensor_can_be_scliced_3d():
|
||||
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
|
||||
|
||||
|
||||
def test_tensor_can_be_cast_via_to():
|
||||
t = Tensor(42.0)
|
||||
assert str(t.dtype) == str(candle.f32)
|
||||
t_new_args = t.to(candle.f64)
|
||||
assert str(t_new_args.dtype) == str(candle.f64)
|
||||
t_new_kwargs = t.to(dtype=candle.f64)
|
||||
assert str(t_new_kwargs.dtype) == str(candle.f64)
|
||||
pytest.raises(TypeError, lambda: t.to("not a dtype"))
|
||||
pytest.raises(TypeError, lambda: t.to(dtype="not a dtype"))
|
||||
pytest.raises(TypeError, lambda: t.to(candle.f64, "not a dtype"))
|
||||
pytest.raises(TypeError, lambda: t.to())
|
||||
pytest.raises(ValueError, lambda: t.to(candle.f16, dtype=candle.f64))
|
||||
pytest.raises(ValueError, lambda: t.to(candle.f16, candle.f16))
|
||||
|
||||
other = Tensor(42.0).to(candle.f64)
|
||||
t_new_other_args = t.to(other)
|
||||
assert str(t_new_other_args.dtype) == str(candle.f64)
|
||||
t_new_other_kwargs = t.to(other=other)
|
||||
assert str(t_new_other_kwargs.dtype) == str(candle.f64)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
|
||||
def test_tensor_can_be_moved_via_to():
|
||||
t = Tensor(42.0)
|
||||
assert t.device == "cpu"
|
||||
t_new_args = t.to("cuda")
|
||||
assert t_new_args.device == "cuda"
|
||||
t_new_kwargs = t.to(device="cuda")
|
||||
assert t_new_kwargs.device == "cuda"
|
||||
pytest.raises(TypeError, lambda: t.to("not a device"))
|
||||
pytest.raises(TypeError, lambda: t.to(device="not a device"))
|
||||
pytest.raises(TypeError, lambda: t.to("cuda", "not a device"))
|
||||
pytest.raises(TypeError, lambda: t.to())
|
||||
pytest.raises(ValueError, lambda: t.to("cuda", device="cpu"))
|
||||
pytest.raises(ValueError, lambda: t.to("cuda", "cuda"))
|
||||
|
||||
other = Tensor(42.0).to("cuda")
|
||||
t_new_other_args = t.to(other)
|
||||
assert t_new_other_args.device == "cuda"
|
||||
t_new_other_kwargs = t.to(other=other)
|
||||
assert t_new_other_kwargs.device == "cuda"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
|
||||
def test_tensor_can_be_moved_and_cast_via_to():
|
||||
t = Tensor(42.0)
|
||||
assert t.device == "cpu"
|
||||
assert str(t.dtype) == str(candle.f32)
|
||||
t_new_args = t.to("cuda", candle.f64)
|
||||
assert t_new_args.device == "cuda"
|
||||
assert str(t_new_args.dtype) == str(candle.f64)
|
||||
t_new_kwargs = t.to(device="cuda", dtype=candle.f64)
|
||||
assert t_new_kwargs.device == "cuda"
|
||||
assert str(t_new_kwargs.dtype) == str(candle.f64)
|
||||
|
||||
other = Tensor(42.0).to("cuda").to(candle.f64)
|
||||
t_new_other_args = t.to(other)
|
||||
assert t_new_other_args.device == "cuda"
|
||||
assert str(t_new_other_args.dtype) == str(candle.f64)
|
||||
t_new_other_kwargs = t.to(other=other)
|
||||
assert t_new_other_kwargs.device == "cuda"
|
||||
assert str(t_new_other_kwargs.dtype) == str(candle.f64)
|
||||
|
||||
|
||||
def test_tensor_can_be_added():
|
||||
t = Tensor(42.0)
|
||||
result = t + t
|
||||
|
Reference in New Issue
Block a user