mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add training for the llama2.c example (#296)
* Rework the commands and run inference by default. * Add the training module and load the training dataset. * Random dataset iterator. * Proper valid-loss computation. * Compute the evaluation loss. * Add more substance to the training loop.
This commit is contained in:
@ -228,3 +228,11 @@ macro_rules! bail {
|
||||
return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt())
|
||||
};
|
||||
}
|
||||
|
||||
pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
|
||||
match (r1, r2) {
|
||||
(Ok(r1), Ok(r2)) => Ok((r1, r2)),
|
||||
(Err(e), _) => Err(e),
|
||||
(_, Err(e)) => Err(e),
|
||||
}
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ mod device;
|
||||
pub mod display;
|
||||
mod dtype;
|
||||
mod dummy_cuda_backend;
|
||||
mod error;
|
||||
pub mod error;
|
||||
mod indexer;
|
||||
pub mod layout;
|
||||
#[cfg(feature = "mkl")]
|
||||
|
Reference in New Issue
Block a user