Fix Batcher iterator break when return_last_incomplete_batch and items.is_empty (#2654) (#2655)

This commit is contained in:
hhllhhyyds
2024-12-24 15:41:26 +08:00
committed by GitHub
parent 1be6b090c7
commit 11aa30be10

View File

@ -78,7 +78,7 @@ impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> {
match self.inner.inner.next() { match self.inner.inner.next() {
Some(item) => items.push(item), Some(item) => items.push(item),
None => { None => {
if self.return_last_incomplete_batch { if self.return_last_incomplete_batch && !items.is_empty() {
break; break;
} }
return None; return None;
@ -102,7 +102,7 @@ impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> {
ys.push(y) ys.push(y)
} }
None => { None => {
if self.return_last_incomplete_batch { if self.return_last_incomplete_batch && !xs.is_empty() && !ys.is_empty() {
break; break;
} }
return None; return None;
@ -127,7 +127,7 @@ impl<I: Iterator<Item = Result<Tensor>>> Iterator for Batcher<IterResult1<I>> {
match self.inner.inner.next() { match self.inner.inner.next() {
Some(item) => items.push(item), Some(item) => items.push(item),
None => { None => {
if self.return_last_incomplete_batch { if self.return_last_incomplete_batch && !items.is_empty() {
break; break;
} }
return None; return None;
@ -154,7 +154,7 @@ impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Iterator for Batcher<IterResu
} }
Some(Err(err)) => errs.push(err), Some(Err(err)) => errs.push(err),
None => { None => {
if self.return_last_incomplete_batch { if self.return_last_incomplete_batch && !xs.is_empty() && !ys.is_empty() {
break; break;
} }
return None; return None;