mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
PyO3: Add None
and Tensor
indexing to candle.Tensor
(#1098)
* Add proper `None` and `tensor` indexing * Allow indexing via lists + allow tensor/list indexing outside of first dimension
This commit is contained in:
@ -55,6 +55,7 @@ def test_tensor_can_be_sliced():
|
||||
assert t[-4:].values() == [5.0, 9.0, 2.0, 6.0]
|
||||
assert t[:-4].values() == [3.0, 1.0, 4.0, 10.0]
|
||||
assert t[-4:-2].values() == [5.0, 9.0]
|
||||
assert t[...].values() == t.values()
|
||||
|
||||
|
||||
def test_tensor_can_be_sliced_2d():
|
||||
@ -76,6 +77,43 @@ def test_tensor_can_be_scliced_3d():
|
||||
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
|
||||
|
||||
|
||||
def test_tensor_can_be_expanded_with_none():
|
||||
t = candle.rand((12, 12))
|
||||
|
||||
b = t[None]
|
||||
assert b.shape == (1, 12, 12)
|
||||
c = t[:, None, None, :]
|
||||
assert c.shape == (12, 1, 1, 12)
|
||||
d = t[None, :, None, :]
|
||||
assert d.shape == (1, 12, 1, 12)
|
||||
e = t[None, None, :, :]
|
||||
assert e.shape == (1, 1, 12, 12)
|
||||
f = t[:, :, None]
|
||||
assert f.shape == (12, 12, 1)
|
||||
|
||||
|
||||
def test_tensor_can_be_index_via_tensor():
|
||||
t = candle.Tensor([[1, 2, 1, 2], [3, 4, 3, 4], [5, 6, 5, 6]])
|
||||
indexed = t[candle.Tensor([0, 2])]
|
||||
assert indexed.shape == (2, 4)
|
||||
assert indexed.values() == [[1, 2, 1, 2], [5, 6, 5, 6]]
|
||||
|
||||
indexed = t[:, candle.Tensor([0, 2])]
|
||||
assert indexed.shape == (3, 2)
|
||||
assert indexed.values() == [[1, 1], [3, 3], [5, 5]]
|
||||
|
||||
|
||||
def test_tensor_can_be_index_via_list():
|
||||
t = candle.Tensor([[1, 2, 1, 2], [3, 4, 3, 4], [5, 6, 5, 6]])
|
||||
indexed = t[[0, 2]]
|
||||
assert indexed.shape == (2, 4)
|
||||
assert indexed.values() == [[1, 2, 1, 2], [5, 6, 5, 6]]
|
||||
|
||||
indexed = t[:, [0, 2]]
|
||||
assert indexed.shape == (3, 2)
|
||||
assert indexed.values() == [[1, 1], [3, 3], [5, 5]]
|
||||
|
||||
|
||||
def test_tensor_can_be_cast_via_to():
|
||||
t = Tensor(42.0)
|
||||
assert str(t.dtype) == str(candle.f32)
|
||||
|
Reference in New Issue
Block a user