Add some display tests + bugfixes.

This commit is contained in:
laurent
2023-06-27 21:37:28 +01:00
parent 8c81a70170
commit b0f5f2d22d
3 changed files with 102 additions and 14 deletions

View File

@ -652,18 +652,22 @@ impl Tensor {
}
pub fn flatten(&self, start_dim: Option<usize>, end_dim: Option<usize>) -> Result<Tensor> {
let start_dim = start_dim.unwrap_or(0);
let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1);
if start_dim < end_dim {
let dims = self.dims();
let mut dst_dims = dims[..start_dim].to_vec();
dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
if end_dim + 1 < dims.len() {
dst_dims.extend(&dims[end_dim + 1..]);
}
self.reshape(dst_dims)
if self.rank() == 0 {
self.reshape(1)
} else {
Ok(self.clone())
let start_dim = start_dim.unwrap_or(0);
let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1);
if start_dim < end_dim {
let dims = self.dims();
let mut dst_dims = dims[..start_dim].to_vec();
dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
if end_dim + 1 < dims.len() {
dst_dims.extend(&dims[end_dim + 1..]);
}
self.reshape(dst_dims)
} else {
Ok(self.clone())
}
}
}