Llama batch (#144)

* Add a batch dimension to llama.

* Bugfixes.
This commit is contained in:
Laurent Mazare
2023-07-12 11:38:19 +01:00
committed by GitHub
parent bcf96e3cf3
commit b3b39cca92
3 changed files with 32 additions and 52 deletions

View File

@ -30,7 +30,7 @@ impl Tensor {
let mut current_dim = 0;
for (i, indexer) in indexers.iter().enumerate() {
x = match indexer {
TensorIndexer::Select(n) => x.get(*n)?,
TensorIndexer::Select(n) => x.narrow(i, *n, 1)?.squeeze(i)?,
TensorIndexer::Narrow(left_bound, right_bound) => {
let start = match left_bound {
Bound::Included(n) => *n,