mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
PyO3: Add None
and Tensor
indexing to candle.Tensor
(#1098)
* Add proper `None` and `tensor` indexing * Allow indexing via lists + allow tensor/list indexing outside of first dimension
This commit is contained in:
@ -201,6 +201,8 @@ enum Indexer {
|
|||||||
Index(usize),
|
Index(usize),
|
||||||
Slice(usize, usize),
|
Slice(usize, usize),
|
||||||
Elipsis,
|
Elipsis,
|
||||||
|
Expand,
|
||||||
|
IndexSelect(Tensor),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
@ -450,7 +452,7 @@ impl PyTensor {
|
|||||||
let mut indexers: Vec<Indexer> = vec![];
|
let mut indexers: Vec<Indexer> = vec![];
|
||||||
let dims = self.0.shape().dims();
|
let dims = self.0.shape().dims();
|
||||||
|
|
||||||
let to_absolute_index = |index: isize, current_dim: usize| {
|
fn to_absolute_index(index: isize, current_dim: usize, dims: &[usize]) -> PyResult<usize> {
|
||||||
// Convert a relative index to an absolute index e.g. tensor[-1] -> tensor[0]
|
// Convert a relative index to an absolute index e.g. tensor[-1] -> tensor[0]
|
||||||
let actual_index = if index < 0 {
|
let actual_index = if index < 0 {
|
||||||
dims[current_dim] as isize + index
|
dims[current_dim] as isize + index
|
||||||
@ -460,48 +462,92 @@ impl PyTensor {
|
|||||||
|
|
||||||
// Check that the index is in range
|
// Check that the index is in range
|
||||||
if actual_index < 0 || actual_index >= dims[current_dim] as isize {
|
if actual_index < 0 || actual_index >= dims[current_dim] as isize {
|
||||||
return Err(PyTypeError::new_err(format!(
|
return Err(PyValueError::new_err(format!(
|
||||||
"index out of range for dimension '{i}' with indexer '{value}'",
|
"index out of range for dimension '{i}' with indexer '{value}'",
|
||||||
i = current_dim,
|
i = current_dim,
|
||||||
value = index
|
value = index
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
Ok(actual_index as usize)
|
Ok(actual_index as usize)
|
||||||
};
|
}
|
||||||
if let Ok(index) = idx.extract(py) {
|
|
||||||
|
fn extract_indexer(
|
||||||
|
py_indexer: &PyAny,
|
||||||
|
current_dim: usize,
|
||||||
|
dims: &[usize],
|
||||||
|
index_argument_count: usize,
|
||||||
|
) -> PyResult<(Indexer, usize)> {
|
||||||
|
if let Ok(index) = py_indexer.extract() {
|
||||||
// Handle a single index e.g. tensor[0] or tensor[-1]
|
// Handle a single index e.g. tensor[0] or tensor[-1]
|
||||||
indexers.push(Indexer::Index(to_absolute_index(index, 0)?));
|
Ok((
|
||||||
} else if let Ok(slice) = idx.downcast::<pyo3::types::PySlice>(py) {
|
Indexer::Index(to_absolute_index(index, current_dim, dims)?),
|
||||||
|
current_dim + 1,
|
||||||
|
))
|
||||||
|
} else if let Ok(slice) = py_indexer.downcast::<pyo3::types::PySlice>() {
|
||||||
// Handle a single slice e.g. tensor[0:1] or tensor[0:-1]
|
// Handle a single slice e.g. tensor[0:1] or tensor[0:-1]
|
||||||
let index = slice.indices(dims[0] as c_long)?;
|
let index = slice.indices(dims[current_dim] as c_long)?;
|
||||||
indexers.push(Indexer::Slice(index.start as usize, index.stop as usize));
|
Ok((
|
||||||
} else if let Ok(tuple) = idx.downcast::<pyo3::types::PyTuple>(py) {
|
Indexer::Slice(index.start as usize, index.stop as usize),
|
||||||
// Handle multiple indices e.g. tensor[0,0] or tensor[0:1,0:1]
|
current_dim + 1,
|
||||||
|
))
|
||||||
if tuple.len() > dims.len() {
|
} else if let Ok(tensor) = py_indexer.extract::<PyTensor>() {
|
||||||
return Err(PyTypeError::new_err("provided too many indices"));
|
// Handle a tensor as indices e.g. tensor[tensor([0,1])]
|
||||||
|
let t = tensor.0;
|
||||||
|
if t.rank() != 1 {
|
||||||
|
return Err(PyTypeError::new_err(
|
||||||
|
"multi-dimensional tensor indexing is not supported",
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
Ok((Indexer::IndexSelect(t), current_dim + 1))
|
||||||
for (i, item) in tuple.iter().enumerate() {
|
} else if let Ok(list) = py_indexer.downcast::<pyo3::types::PyList>() {
|
||||||
if item.is_ellipsis() {
|
// Handle a list of indices e.g. tensor[[0,1]]
|
||||||
|
let mut indexes = vec![];
|
||||||
|
for item in list.iter() {
|
||||||
|
let index = item.extract::<i64>()?;
|
||||||
|
indexes.push(index);
|
||||||
|
}
|
||||||
|
Ok((
|
||||||
|
Indexer::IndexSelect(
|
||||||
|
Tensor::from_vec(indexes, list.len(), &Device::Cpu).map_err(wrap_err)?,
|
||||||
|
),
|
||||||
|
current_dim + 1,
|
||||||
|
))
|
||||||
|
} else if py_indexer.is_ellipsis() {
|
||||||
// Handle '...' e.g. tensor[..., 0]
|
// Handle '...' e.g. tensor[..., 0]
|
||||||
|
if current_dim > 0 {
|
||||||
|
return Err(PyTypeError::new_err(
|
||||||
|
"Ellipsis ('...') can only be used at the start of an indexing operation",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
Ok((Indexer::Elipsis, dims.len() - (index_argument_count - 1)))
|
||||||
|
} else if py_indexer.is_none() {
|
||||||
|
// Handle None e.g. tensor[None, 0]
|
||||||
|
Ok((Indexer::Expand, current_dim))
|
||||||
|
} else {
|
||||||
|
Err(PyTypeError::new_err(format!(
|
||||||
|
"unsupported indexer {}",
|
||||||
|
py_indexer
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if i > 0 {
|
if let Ok(tuple) = idx.downcast::<pyo3::types::PyTuple>(py) {
|
||||||
return Err(PyTypeError::new_err("Ellipsis ('...') can only be used at the start of an indexing operation"));
|
let not_none_count: usize = tuple.iter().filter(|x| !x.is_none()).count();
|
||||||
}
|
|
||||||
indexers.push(Indexer::Elipsis);
|
if not_none_count > dims.len() {
|
||||||
} else if let Ok(slice) = item.downcast::<pyo3::types::PySlice>() {
|
return Err(PyValueError::new_err("provided too many indices"));
|
||||||
// Handle slice
|
|
||||||
let index = slice.indices(dims[i] as c_long)?;
|
|
||||||
indexers.push(Indexer::Slice(index.start as usize, index.stop as usize));
|
|
||||||
} else if let Ok(index) = item.extract::<isize>() {
|
|
||||||
indexers.push(Indexer::Index(to_absolute_index(index, i)?));
|
|
||||||
} else {
|
|
||||||
return Err(PyTypeError::new_err("unsupported index"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut current_dim = 0;
|
||||||
|
for item in tuple.iter() {
|
||||||
|
let (indexer, new_current_dim) =
|
||||||
|
extract_indexer(item, current_dim, dims, not_none_count)?;
|
||||||
|
current_dim = new_current_dim;
|
||||||
|
indexers.push(indexer);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return Err(PyTypeError::new_err("unsupported index"));
|
let (indexer, _) = extract_indexer(idx.downcast::<PyAny>(py)?, 0, dims, 1)?;
|
||||||
|
indexers.push(indexer);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut x = self.0.clone();
|
let mut x = self.0.clone();
|
||||||
@ -526,6 +572,22 @@ impl PyTensor {
|
|||||||
current_dim += dims.len() - (indexers.len() - 1);
|
current_dim += dims.len() - (indexers.len() - 1);
|
||||||
x
|
x
|
||||||
}
|
}
|
||||||
|
Indexer::Expand => {
|
||||||
|
// Expand is a special case, it means that a new dimension should be added => unsqueeze and advance the current_dim
|
||||||
|
let out = x.unsqueeze(current_dim).map_err(wrap_err)?;
|
||||||
|
current_dim += 1;
|
||||||
|
out
|
||||||
|
}
|
||||||
|
Indexer::IndexSelect(indexes) => {
|
||||||
|
let out = x
|
||||||
|
.index_select(
|
||||||
|
&indexes.to_device(x.device()).map_err(wrap_err)?,
|
||||||
|
current_dim,
|
||||||
|
)
|
||||||
|
.map_err(wrap_err)?;
|
||||||
|
current_dim += 1;
|
||||||
|
out
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,6 +55,7 @@ def test_tensor_can_be_sliced():
|
|||||||
assert t[-4:].values() == [5.0, 9.0, 2.0, 6.0]
|
assert t[-4:].values() == [5.0, 9.0, 2.0, 6.0]
|
||||||
assert t[:-4].values() == [3.0, 1.0, 4.0, 10.0]
|
assert t[:-4].values() == [3.0, 1.0, 4.0, 10.0]
|
||||||
assert t[-4:-2].values() == [5.0, 9.0]
|
assert t[-4:-2].values() == [5.0, 9.0]
|
||||||
|
assert t[...].values() == t.values()
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_can_be_sliced_2d():
|
def test_tensor_can_be_sliced_2d():
|
||||||
@ -76,6 +77,43 @@ def test_tensor_can_be_scliced_3d():
|
|||||||
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
|
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_can_be_expanded_with_none():
|
||||||
|
t = candle.rand((12, 12))
|
||||||
|
|
||||||
|
b = t[None]
|
||||||
|
assert b.shape == (1, 12, 12)
|
||||||
|
c = t[:, None, None, :]
|
||||||
|
assert c.shape == (12, 1, 1, 12)
|
||||||
|
d = t[None, :, None, :]
|
||||||
|
assert d.shape == (1, 12, 1, 12)
|
||||||
|
e = t[None, None, :, :]
|
||||||
|
assert e.shape == (1, 1, 12, 12)
|
||||||
|
f = t[:, :, None]
|
||||||
|
assert f.shape == (12, 12, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_can_be_index_via_tensor():
|
||||||
|
t = candle.Tensor([[1, 2, 1, 2], [3, 4, 3, 4], [5, 6, 5, 6]])
|
||||||
|
indexed = t[candle.Tensor([0, 2])]
|
||||||
|
assert indexed.shape == (2, 4)
|
||||||
|
assert indexed.values() == [[1, 2, 1, 2], [5, 6, 5, 6]]
|
||||||
|
|
||||||
|
indexed = t[:, candle.Tensor([0, 2])]
|
||||||
|
assert indexed.shape == (3, 2)
|
||||||
|
assert indexed.values() == [[1, 1], [3, 3], [5, 5]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_can_be_index_via_list():
|
||||||
|
t = candle.Tensor([[1, 2, 1, 2], [3, 4, 3, 4], [5, 6, 5, 6]])
|
||||||
|
indexed = t[[0, 2]]
|
||||||
|
assert indexed.shape == (2, 4)
|
||||||
|
assert indexed.values() == [[1, 2, 1, 2], [5, 6, 5, 6]]
|
||||||
|
|
||||||
|
indexed = t[:, [0, 2]]
|
||||||
|
assert indexed.shape == (3, 2)
|
||||||
|
assert indexed.values() == [[1, 1], [3, 3], [5, 5]]
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_can_be_cast_via_to():
|
def test_tensor_can_be_cast_via_to():
|
||||||
t = Tensor(42.0)
|
t = Tensor(42.0)
|
||||||
assert str(t.dtype) == str(candle.f32)
|
assert str(t.dtype) == str(candle.f32)
|
||||||
|
Reference in New Issue
Block a user