mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Self-contained safetensor wrappers (#946)
* Self-contained safetensor wrappers. * Use the new safetensor container in varbuilders.
This commit is contained in:
@ -321,6 +321,18 @@ impl MmapedSafetensors {
|
||||
}
|
||||
|
||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||
self.get(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
let mut tensors = vec![];
|
||||
for safetensors in self.safetensors.iter() {
|
||||
tensors.push(safetensors.get().0.tensors())
|
||||
}
|
||||
tensors.into_iter().flatten().collect()
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||
let index = match &self.routing {
|
||||
None => 0,
|
||||
Some(routing) => {
|
||||
@ -333,15 +345,7 @@ impl MmapedSafetensors {
|
||||
*index
|
||||
}
|
||||
};
|
||||
self.safetensors[index].get().0.tensor(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
let mut tensors = vec![];
|
||||
for safetensors in self.safetensors.iter() {
|
||||
tensors.push(safetensors.get().0.tensors())
|
||||
}
|
||||
tensors.into_iter().flatten().collect()
|
||||
Ok(self.safetensors[index].get().0.tensor(name)?)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user