Test the index op + bugfix. (#148)

This commit is contained in:
Laurent Mazare
2023-07-12 15:42:36 +01:00
committed by GitHub
parent ba35d895e7
commit 8aab787384
2 changed files with 84 additions and 16 deletions

View File

@ -30,7 +30,7 @@ impl Tensor {
let mut current_dim = 0;
for (i, indexer) in indexers.iter().enumerate() {
x = match indexer {
TensorIndexer::Select(n) => x.narrow(i, *n, 1)?.squeeze(i)?,
TensorIndexer::Select(n) => x.narrow(current_dim, *n, 1)?.squeeze(current_dim)?,
TensorIndexer::Narrow(left_bound, right_bound) => {
let start = match left_bound {
Bound::Included(n) => *n,
@ -55,10 +55,10 @@ impl Tensor {
#[derive(Debug, Clone)]
/// Generic structure used to index a slice of the tensor
pub enum TensorIndexer {
/// This selects the elemnts for which an index has some specific value.
Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>),
// IndexSelect(Tensor),
}
impl From<usize> for TensorIndexer {
@ -67,20 +67,6 @@ impl From<usize> for TensorIndexer {
}
}
// impl From<&[usize]> for TensorIndexer {
// fn from(index: &[usize]) -> Self {
// let tensor = index.into();
// TensorIndexer::IndexSelect(tensor)
// }
// }
//
// impl From<Vec<usize>> for TensorIndexer {
// fn from(index: Vec<usize>) -> Self {
// let tensor = Tensor::of_slice(&index);
// TensorIndexer::IndexSelect(tensor)
// }
// }
macro_rules! impl_from_range {
($range_type:ty) => {
impl From<$range_type> for TensorIndexer {