Support safetensors weights in llama2.c inference. (#317)

This commit is contained in:
Laurent Mazare
2023-08-03 11:10:58 +01:00
committed by GitHub
parent 74845a4dcd
commit a79286885c
2 changed files with 18 additions and 7 deletions

View File

@ -104,7 +104,7 @@ impl TransformerWeights {
})
}
pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder> {
pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'static>> {
let mut ws = std::collections::HashMap::new();
let mut insert = |name: &str, t: Tensor| {
ws.insert(name.to_string(), t);