Add the get method.

This commit is contained in:
laurent
2023-06-27 17:39:58 +01:00
parent c44e5346f4
commit 684f66326d

View File

@ -677,6 +677,15 @@ impl Tensor {
self.flatten(None, None)
}
pub fn get(&self, i: usize) -> Result<Tensor> {
let dims = self.dims();
if dims.is_empty() {
Ok(self.clone())
} else {
self.narrow(0, i, 1)?.reshape(&dims[1..])
}
}
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
/// input are swapped.
pub fn t(&self) -> Result<Tensor> {