mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add a flag to save the trained weights. (#279)
This commit is contained in:
@ -52,6 +52,27 @@ impl st::View for Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
impl st::View for &Tensor {
|
||||
fn dtype(&self) -> st::Dtype {
|
||||
(*self).dtype().into()
|
||||
}
|
||||
fn shape(&self) -> &[usize] {
|
||||
self.dims()
|
||||
}
|
||||
|
||||
fn data(&self) -> Cow<[u8]> {
|
||||
// This copies data from GPU to CPU.
|
||||
// TODO: Avoid the unwrap here.
|
||||
Cow::Owned(convert_back(self).unwrap())
|
||||
}
|
||||
|
||||
fn data_len(&self) -> usize {
|
||||
let n: usize = self.dims().iter().product();
|
||||
let bytes_per_element = (*self).dtype().size_in_bytes();
|
||||
n * bytes_per_element
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
pub fn save_safetensors<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
|
Reference in New Issue
Block a user