Remove some unnecessary calls to contiguous. (#1968)

* Remove some unnecessary calls to contiguous.

* Slightly improved kv cache concatenation.
This commit is contained in:
Laurent Mazare
2024-03-30 13:22:00 +01:00
committed by GitHub
parent efe4a0c84b
commit b190fd8592
2 changed files with 20 additions and 16 deletions

View File

@ -58,20 +58,18 @@ impl Tensor {
}
}
}
if dim == 0 {
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
if all_contiguous {
Self::cat_contiguous(args, dim)
} else if dim == 0 {
Self::cat0(args)
} else {
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
if all_contiguous {
Self::cat_contiguous(args, dim)
} else {
let args: Vec<Tensor> = args
.iter()
.map(|a| a.as_ref().transpose(0, dim))
.collect::<Result<Vec<_>>>()?;
let cat = Self::cat0(&args)?;
cat.transpose(0, dim)
}
let args: Vec<Tensor> = args
.iter()
.map(|a| a.as_ref().transpose(0, dim))
.collect::<Result<Vec<_>>>()?;
let cat = Self::cat0(&args)?;
cat.transpose(0, dim)
}
}