Self-contained safetensor wrappers (#946)

* Self-contained safetensor wrappers.

* Use the new safetensor container in varbuilders.
This commit is contained in:
Laurent Mazare
2023-09-23 20:39:52 +01:00
committed by GitHub
parent 5dbe46b389
commit 890d069092
3 changed files with 61 additions and 30 deletions

View File

@ -321,6 +321,18 @@ impl MmapedSafetensors {
}
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
self.get(name)?.load(dev)
}
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
let mut tensors = vec![];
for safetensors in self.safetensors.iter() {
tensors.push(safetensors.get().0.tensors())
}
tensors.into_iter().flatten().collect()
}
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
let index = match &self.routing {
None => 0,
Some(routing) => {
@ -333,15 +345,7 @@ impl MmapedSafetensors {
*index
}
};
self.safetensors[index].get().0.tensor(name)?.load(dev)
}
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
let mut tensors = vec![];
for safetensors in self.safetensors.iter() {
tensors.push(safetensors.get().0.tensors())
}
tensors.into_iter().flatten().collect()
Ok(self.safetensors[index].get().0.tensor(name)?)
}
}