Add training for the llama2.c example (#296)

* Rework the commands and run inference by default.

* Add the training module and load the training dataset.

* Random dataset iterator.

* Proper valid-loss computation.

* Compute the evaluation loss.

* Add more substance to the training loop.
This commit is contained in:
Laurent Mazare
2023-08-01 17:23:07 +01:00
committed by GitHub
parent babee9f011
commit a27239f3d9
6 changed files with 227 additions and 9 deletions

View File

@ -228,3 +228,11 @@ macro_rules! bail {
return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt())
};
}
pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
match (r1, r2) {
(Ok(r1), Ok(r2)) => Ok((r1, r2)),
(Err(e), _) => Err(e),
(_, Err(e)) => Err(e),
}
}

View File

@ -44,7 +44,7 @@ mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
mod error;
pub mod error;
mod indexer;
pub mod layout;
#[cfg(feature = "mkl")]