Add a flag to save the trained weights. (#279)

This commit is contained in:
Laurent Mazare
2023-07-30 15:41:42 +01:00
committed by GitHub
parent ba2254556c
commit 38ff693af0
3 changed files with 42 additions and 2 deletions

View File

@ -52,6 +52,27 @@ impl st::View for Tensor {
}
}
impl st::View for &Tensor {
fn dtype(&self) -> st::Dtype {
(*self).dtype().into()
}
fn shape(&self) -> &[usize] {
self.dims()
}
fn data(&self) -> Cow<[u8]> {
// This copies data from GPU to CPU.
// TODO: Avoid the unwrap here.
Cow::Owned(convert_back(self).unwrap())
}
fn data_len(&self) -> usize {
let n: usize = self.dims().iter().product();
let bytes_per_element = (*self).dtype().size_in_bytes();
n * bytes_per_element
}
}
impl Tensor {
pub fn save_safetensors<P: AsRef<std::path::Path>>(
&self,