mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add some display tests + bugfixes.
This commit is contained in:
@ -652,18 +652,22 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn flatten(&self, start_dim: Option<usize>, end_dim: Option<usize>) -> Result<Tensor> {
|
||||
let start_dim = start_dim.unwrap_or(0);
|
||||
let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1);
|
||||
if start_dim < end_dim {
|
||||
let dims = self.dims();
|
||||
let mut dst_dims = dims[..start_dim].to_vec();
|
||||
dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
|
||||
if end_dim + 1 < dims.len() {
|
||||
dst_dims.extend(&dims[end_dim + 1..]);
|
||||
}
|
||||
self.reshape(dst_dims)
|
||||
if self.rank() == 0 {
|
||||
self.reshape(1)
|
||||
} else {
|
||||
Ok(self.clone())
|
||||
let start_dim = start_dim.unwrap_or(0);
|
||||
let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1);
|
||||
if start_dim < end_dim {
|
||||
let dims = self.dims();
|
||||
let mut dst_dims = dims[..start_dim].to_vec();
|
||||
dst_dims.push(dims[start_dim..end_dim + 1].iter().product::<usize>());
|
||||
if end_dim + 1 < dims.len() {
|
||||
dst_dims.extend(&dims[end_dim + 1..]);
|
||||
}
|
||||
self.reshape(dst_dims)
|
||||
} else {
|
||||
Ok(self.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user