mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Test the index op + bugfix. (#148)
This commit is contained in:
@ -30,7 +30,7 @@ impl Tensor {
|
||||
let mut current_dim = 0;
|
||||
for (i, indexer) in indexers.iter().enumerate() {
|
||||
x = match indexer {
|
||||
TensorIndexer::Select(n) => x.narrow(i, *n, 1)?.squeeze(i)?,
|
||||
TensorIndexer::Select(n) => x.narrow(current_dim, *n, 1)?.squeeze(current_dim)?,
|
||||
TensorIndexer::Narrow(left_bound, right_bound) => {
|
||||
let start = match left_bound {
|
||||
Bound::Included(n) => *n,
|
||||
@ -55,10 +55,10 @@ impl Tensor {
|
||||
#[derive(Debug, Clone)]
|
||||
/// Generic structure used to index a slice of the tensor
|
||||
pub enum TensorIndexer {
|
||||
/// This selects the elemnts for which an index has some specific value.
|
||||
Select(usize),
|
||||
/// This is a regular slice, purely indexing a chunk of the tensor
|
||||
Narrow(Bound<usize>, Bound<usize>),
|
||||
// IndexSelect(Tensor),
|
||||
}
|
||||
|
||||
impl From<usize> for TensorIndexer {
|
||||
@ -67,20 +67,6 @@ impl From<usize> for TensorIndexer {
|
||||
}
|
||||
}
|
||||
|
||||
// impl From<&[usize]> for TensorIndexer {
|
||||
// fn from(index: &[usize]) -> Self {
|
||||
// let tensor = index.into();
|
||||
// TensorIndexer::IndexSelect(tensor)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// impl From<Vec<usize>> for TensorIndexer {
|
||||
// fn from(index: Vec<usize>) -> Self {
|
||||
// let tensor = Tensor::of_slice(&index);
|
||||
// TensorIndexer::IndexSelect(tensor)
|
||||
// }
|
||||
// }
|
||||
|
||||
macro_rules! impl_from_range {
|
||||
($range_type:ty) => {
|
||||
impl From<$range_type> for TensorIndexer {
|
||||
|
Reference in New Issue
Block a user