Merge pull request #19 from LaurentMazare/llama_safetensors

Llama safetensors
This commit is contained in:
Nicolas Patry
2023-06-29 12:49:26 +02:00
committed by GitHub
6 changed files with 344 additions and 100 deletions

View File

@ -27,6 +27,9 @@ anyhow = "1"
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
rand = "0.8.5" rand = "0.8.5"
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] } tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] }
candle-hub = { path = "../candle-hub" }
memmap2 = "0.7.1"
[features] [features]
default = ["cuda"] default = ["cuda"]

View File

@ -15,11 +15,12 @@ use anyhow::{Error as E, Result};
use clap::Parser; use clap::Parser;
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
use candle_hub::{api::Api, Repo, RepoType};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
mod var_store; mod var_store;
use var_store::VarBuilder; mod weights;
const CONTEXT_SIZE: usize = 512; const CONTEXT_SIZE: usize = 512;
const START_PROMPT: &str = r" const START_PROMPT: &str = r"
@ -131,9 +132,8 @@ struct Embedding {
} }
impl Embedding { impl Embedding {
fn new(mut vb: VarBuilder, vocab_size: usize, n_embd: usize) -> Result<Self> { fn new(embeddings: Tensor) -> Self {
let embeddings = vb.var("weight", (vocab_size, n_embd))?; Self { embeddings }
Ok(Self { embeddings })
} }
fn forward(&self, indexes: &Tensor) -> Result<Tensor> { fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
@ -145,42 +145,27 @@ impl Embedding {
} }
struct Linear { struct Linear {
ws: Tensor, weight: Tensor,
bs: Option<Tensor>,
} }
impl Linear { impl Linear {
#[allow(dead_code)] fn new(weight: Tensor) -> Self {
fn new(mut vb: VarBuilder, in_size: usize, out_size: usize) -> Result<Self> { Self { weight }
let ws = vb.var("weight", (in_size, out_size))?;
let bs = vb.var("bias", out_size)?;
Ok(Self { ws, bs: Some(bs) })
}
fn new_no_bias(mut vb: VarBuilder, in_size: usize, out_size: usize) -> Result<Self> {
let ws = vb.var("weight", (in_size, out_size))?;
Ok(Self { ws, bs: None })
} }
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.matmul(&self.ws.to_dtype(DType::F32)?)?; let x = x.matmul(&self.weight.to_dtype(DType::F32)?.t()?)?;
let y = match &self.bs { Ok(x)
None => x,
Some(bs) => x.broadcast_add(&bs.to_dtype(DType::F32)?)?,
};
Ok(y)
} }
} }
struct RmsNorm { struct RmsNorm {
scale: Tensor, scale: Tensor,
size: usize,
} }
impl RmsNorm { impl RmsNorm {
fn new(mut vb: VarBuilder, size: usize) -> Result<Self> { fn new(scale: Tensor) -> Self {
let scale = vb.var("scale", &[size])?; Self { scale }
Ok(Self { scale, size })
} }
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
@ -188,10 +173,11 @@ impl RmsNorm {
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?; let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let size = self.scale.shape().r1()?;
let scale = self let scale = self
.scale .scale
.to_dtype(DType::F32)? .to_dtype(DType::F32)?
.broadcast_as((seq_len, self.size))?; .broadcast_as((seq_len, size))?;
Ok((scale * x_normed)?) Ok((scale * x_normed)?)
} }
} }
@ -207,17 +193,12 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
} }
impl Mlp { impl Mlp {
fn new(vb: VarBuilder, n_embd: usize) -> Result<Self> { fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
let n_hidden = 8 * n_embd / 3; Self {
let n_hidden = (n_hidden - 1) / 256 * 256 + 256;
let c_fc1 = Linear::new_no_bias(&vb / "c_fc1", n_embd, n_hidden)?;
let c_fc2 = Linear::new_no_bias(&vb / "c_fc2", n_embd, n_hidden)?;
let c_proj = Linear::new_no_bias(&vb / "c_proj", n_hidden, n_embd)?;
Ok(Self {
c_fc1, c_fc1,
c_fc2, c_fc2,
c_proj, c_proj,
}) }
} }
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
@ -256,10 +237,6 @@ impl Cache {
let mask: Vec<_> = (0..t) let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u32::from(j > i))) .flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
.collect(); .collect();
// Once lower_triangle is available, use the following:
//let mask = Tensor::new(1u32, &device)?
// .broadcast_as(&[t, t])?
// .lower_triangle()?
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
masks.insert(t, mask.clone()); masks.insert(t, mask.clone());
Ok(mask) Ok(mask)
@ -271,21 +248,18 @@ struct CausalSelfAttention {
c_attn: Linear, c_attn: Linear,
c_proj: Linear, c_proj: Linear,
n_head: usize, n_head: usize,
n_embd: usize, // n_embd: usize,
cache: Cache, cache: Cache,
} }
impl CausalSelfAttention { impl CausalSelfAttention {
fn new(vb: VarBuilder, n_head: usize, n_embd: usize, cache: &Cache) -> Result<Self> { fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self {
let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?; Self {
let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?;
Ok(Self {
c_attn, c_attn,
c_proj, c_proj,
n_head, n_head,
n_embd,
cache: cache.clone(), cache: cache.clone(),
}) }
} }
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
@ -313,7 +287,7 @@ impl CausalSelfAttention {
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let (t, c) = x.shape().r2()?; let (t, c) = x.shape().r2()?;
let qkv = self.c_attn.forward(x)?; let qkv = self.c_attn.forward(x)?;
let n_embd = self.n_embd; let n_embd = c;
let q = qkv.narrow(1, 0, n_embd)?; let q = qkv.narrow(1, 0, n_embd)?;
let k = qkv.narrow(1, n_embd, n_embd)?; let k = qkv.narrow(1, n_embd, n_embd)?;
let v = qkv.narrow(1, 2 * n_embd, n_embd)?; let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
@ -344,17 +318,13 @@ struct Block {
} }
impl Block { impl Block {
fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> { fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
let rms_1 = RmsNorm::new(&vb / "rms_1", config.n_embd)?; Self {
let attn = CausalSelfAttention::new(&vb / "attn", config.n_head, config.n_embd, cache)?;
let rms_2 = RmsNorm::new(&vb / "rms_2", config.n_embd)?;
let mlp = Mlp::new(&vb / "mlp", config.n_embd)?;
Ok(Self {
rms_1, rms_1,
attn, attn,
rms_2, rms_2,
mlp, mlp,
}) }
} }
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
@ -372,23 +342,13 @@ struct Llama {
} }
impl Llama { impl Llama {
fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> { fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
let lm_head = Linear::new_no_bias(&vb / "lm_head", config.n_embd, config.vocab_size)?; Self {
let wte = Embedding::new(
&vb / "transformer" / "wte",
config.vocab_size,
config.n_embd,
)?;
let blocks = (0..config.n_layer)
.map(|i| Block::new(&vb / "transformer" / "h" / i, cache, config))
.collect::<Result<Vec<_>>>()?;
let ln_f = RmsNorm::new(&vb / "transformer" / "ln_f", config.n_embd)?;
Ok(Self {
wte, wte,
blocks, blocks,
ln_f, ln_f,
lm_head, lm_head,
}) }
} }
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
@ -434,6 +394,10 @@ struct Args {
#[arg(long)] #[arg(long)]
cpu: bool, cpu: bool,
/// Use npy instead of safetensors
#[arg(long)]
npy: bool,
/// The temperature used to generate samples. /// The temperature used to generate samples.
#[arg(long, default_value_t = 1.0)] #[arg(long, default_value_t = 1.0)]
temperature: f64, temperature: f64,
@ -443,8 +407,9 @@ struct Args {
sample_len: usize, sample_len: usize,
} }
fn main() -> Result<()> { #[tokio::main]
use rand::prelude::*; async fn main() -> Result<()> {
//use rand::prelude::*;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
let args = Args::parse(); let args = Args::parse();
@ -453,38 +418,44 @@ fn main() -> Result<()> {
} else { } else {
Device::new_cuda(0)? Device::new_cuda(0)?
}; };
println!("loading tokenizer config"); let api = Api::new()?;
let tokenizer = Tokenizer::from_file("llama-tokenizer.json").map_err(E::msg)?; let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
println!("Filename {tokenizer_filename:?}");
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let mut tokens = tokenizer let mut tokens = tokenizer
.encode(START_PROMPT, true) .encode(START_PROMPT, true)
.map_err(E::msg)? .map_err(E::msg)?
.get_ids() .get_ids()
.to_vec(); .to_vec();
let weight_path = std::path::Path::new("llama.npz"); let mut filenames = vec![];
let weights = if weight_path.exists() { for rfilename in [
println!("loading weights from {weight_path:?}"); "model-00001-of-00002.safetensors",
let start_load = std::time::Instant::now(); "model-00002-of-00002.safetensors",
let tensors = Tensor::read_npz(weight_path)?; ] {
println!("loaded weights in {:?}", start_load.elapsed()); let filename = api.get(&repo, rfilename).await?;
let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect(); filenames.push(filename);
Some(tensors) }
} else {
println!("cannot find {weight_path:?}, using zero weights");
None
};
let vb = VarBuilder::new::<f32>(&device, weights);
println!("building the model"); println!("building the model");
let config = Config::config_7b(); let config = Config::config_7b();
let cache = Cache::new(&device); let cache = Cache::new(&device);
let llama = Llama::new(vb, &cache, &config)?; let start = std::time::Instant::now();
let llama = if args.npy {
println!("building the model (NPY)");
Llama::load_npy(&device, &filenames, &cache, &config)?
} else {
println!("building the model (SF)");
Llama::load(&device, &filenames, &cache, &config)?
};
println!("Loaded in {:?}", start.elapsed());
println!("pre-computing the positional embeddings"); println!("pre-computing the positional embeddings");
let freqs_cis = precompute_freqs_cis(&config, &device)?; let freqs_cis = precompute_freqs_cis(&config, &device)?;
println!("starting the inference loop"); println!("starting the inference loop");
let mut new_tokens = vec![]; let mut new_tokens = vec![];
let mut rng = thread_rng(); //let mut rng = thread_rng();
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
for index in 0..args.sample_len { for index in 0..args.sample_len {
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
@ -493,8 +464,20 @@ fn main() -> Result<()> {
let logits = llama.forward(&input, &freqs_cis)?; let logits = llama.forward(&input, &freqs_cis)?;
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?; let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
let logits_v: Vec<f32> = prs.to_vec1()?; let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?; let next_token = logits_v
let next_token = distr.sample(&mut rng) as u32; .iter()
.enumerate()
.fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| {
if &val_max > val {
(idx_max, val_max)
} else {
(idx, *val)
}
})
.0 as u32;
// let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
// let next_token = distr.sample(&mut rng) as u32;
tokens.push(next_token); tokens.push(next_token);
new_tokens.push(next_token); new_tokens.push(next_token);
println!("> {:?}", start_gen.elapsed()); println!("> {:?}", start_gen.elapsed());

View File

@ -1,5 +1,7 @@
use super::*;
use candle::{DType, Device, Result, Shape, Tensor, WithDType}; use candle::{DType, Device, Result, Shape, Tensor, WithDType};
use std::collections::HashMap; use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
#[allow(dead_code)] #[allow(dead_code)]
@ -40,22 +42,15 @@ impl VarBuilder {
self.vars.borrow().len() self.vars.borrow().len()
} }
pub fn var<S: Into<Shape>>(&mut self, s: &str, shape: S) -> Result<Tensor> { pub fn var(&self, s: &str) -> Result<Tensor> {
let shape = shape.into();
let path = format!("{}.{s}", self.path.join(".")); let path = format!("{}.{s}", self.path.join("."));
let mut vars = self.vars.borrow_mut();
let parameter = match self.tensors.as_ref() { let parameter = match self.tensors.as_ref() {
None => Tensor::zeros(&shape, self.default_dtype, &self.default_device)?, None => panic!("Cannot find tensors"),
Some(tensors) => match tensors.get(&path) { Some(tensors) => match tensors.get(&path) {
Some(tensor) => tensor.to_device(&self.default_device)?, Some(tensor) => tensor.to_device(&self.default_device)?,
None => panic!("cannot find tensor for {path}"), None => panic!("cannot find tensor for {path}"),
}, },
}; };
vars.push(NamedVar {
path,
dtype: self.default_dtype,
shape,
});
Ok(parameter) Ok(parameter)
} }
@ -90,3 +85,88 @@ impl<S: ToString> std::ops::Div<S> for VarBuilder {
&self / rhs &self / rhs
} }
} }
impl Embedding {
fn load_npy(vb: VarBuilder) -> Result<Self> {
let embeddings = vb.var("weight")?;
Ok(Self { embeddings })
}
}
impl Linear {
fn load_npy(vb: VarBuilder) -> Result<Self> {
let weight = vb.var("weight")?.t()?;
Ok(Self { weight })
}
}
impl RmsNorm {
fn load_npy(vb: VarBuilder) -> Result<Self> {
let scale = vb.var("scale")?;
Ok(Self::new(scale))
}
}
impl CausalSelfAttention {
fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
let c_attn = Linear::load_npy(&vb / "c_attn")?;
let c_proj = Linear::load_npy(&vb / "c_proj")?;
Ok(Self::new(c_attn, c_proj, config.n_head, cache))
}
}
impl Mlp {
fn load_npy(vb: VarBuilder) -> Result<Self> {
let c_fc1 = Linear::load_npy(&vb / "c_fc1")?;
let c_fc2 = Linear::load_npy(&vb / "c_fc2")?;
let c_proj = Linear::load_npy(&vb / "c_proj")?;
Ok(Self::new(c_fc1, c_fc2, c_proj))
}
}
impl Block {
fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load_npy(&vb / "attn", cache, config)?;
let mlp = Mlp::load_npy(&vb / "mlp")?;
let input_layernorm = RmsNorm::load_npy(&vb / "rms_1")?;
let post_attention_layernorm = RmsNorm::load_npy(&vb / "rms_2")?;
Ok(Self::new(
input_layernorm,
attn,
post_attention_layernorm,
mlp,
))
}
}
impl Llama {
pub fn load_npy(
device: &Device,
_filenames: &[PathBuf],
cache: &Cache,
config: &Config,
) -> Result<Self> {
let weight_path = std::path::Path::new("/data/llama.npz");
let weights = if weight_path.exists() {
println!("loading weights from {weight_path:?}");
let start_load = std::time::Instant::now();
let tensors = Tensor::read_npz(weight_path)?;
println!("loaded weights in {:?}", start_load.elapsed());
let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
Some(tensors)
} else {
println!("cannot find {weight_path:?}, using zero weights");
None
};
let vb = VarBuilder::new::<f32>(device, weights);
let wte = Embedding::load_npy(&vb / "transformer" / "wte")?;
let lm_head = Linear::load_npy(&vb / "lm_head")?;
let norm = RmsNorm::load_npy(&vb / "transformer" / "ln_f")?;
let blocks: Vec<_> = (0..config.n_layer)
.map(|i| Block::load_npy(&vb / "transformer" / "h" / i, cache, config).unwrap())
.collect();
Ok(Self::new(wte, blocks, norm, lm_head))
}
}

View File

@ -0,0 +1,164 @@
use super::*;
use candle::{Device, Result, Tensor};
use half::f16;
use memmap2::MmapOptions;
use safetensors::{
tensor::{Dtype, TensorView},
SafeTensors,
};
use std::fs::File;
use std::path::PathBuf;
fn convert(view: TensorView<'_>, device: &Device) -> Result<Tensor> {
match view.dtype() {
Dtype::F16 => {
let v = view.data();
if (v.as_ptr() as usize) % 2 == 0 {
// SAFETY This is safe because we just checked that this
// was correctly aligned.
let data: &[f16] =
unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
Tensor::from_slice(data, view.shape(), device)
} else {
let mut c = Vec::with_capacity(v.len() / 2);
let mut i = 0;
while i < v.len() {
c.push(f16::from_le_bytes([v[i], v[i + 1]]));
i += 2;
}
Tensor::from_slice(&c, view.shape(), device)
}
}
dt => todo!("Unhandled dtype {dt:?}"),
}
}
pub struct VarBuilder<'a> {
routing: HashMap<String, usize>,
safetensors: Vec<SafeTensors<'a>>,
device: Device,
}
impl<'a> VarBuilder<'a> {
pub fn new(safetensors: Vec<SafeTensors<'a>>, device: Device) -> Self {
let mut routing = HashMap::new();
for (index, sf) in safetensors.iter().enumerate() {
for k in sf.names() {
routing.insert(k.to_string(), index);
}
}
Self {
safetensors,
device,
routing,
}
}
pub fn get(&self, tensor_name: &str) -> Result<Tensor> {
// Unwrap or 0 just to let the proper error flow.
let index = self.routing.get(tensor_name).unwrap_or(&0);
let view = self.safetensors[*index].tensor(tensor_name).unwrap();
let tensor = convert(view, &self.device)?;
Ok(tensor)
}
}
impl Linear {
fn load(prefix: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get(&format!("{prefix}.weight"))?;
Ok(Self::new(weight))
}
fn load_multi(prefixes: &[&str], vb: &VarBuilder) -> Result<Self> {
let weights: Vec<_> = prefixes
.iter()
.map(|p| vb.get(&format!("{p}.weight")).unwrap())
.collect();
let weight = Tensor::cat(&weights, 0)?;
Ok(Self::new(weight))
}
}
impl RmsNorm {
fn load(prefix: &str, vb: &VarBuilder) -> Result<Self> {
let scale = vb.get(&format!("{prefix}.weight"))?;
Ok(Self::new(scale))
}
}
impl CausalSelfAttention {
fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
let c_attn = Linear::load_multi(
&[
&format!("{prefix}.q_proj"),
&format!("{prefix}.k_proj"),
&format!("{prefix}.v_proj"),
],
vb,
)?;
let o_proj = Linear::load(&format!("{prefix}.o_proj"), vb)?;
Ok(Self::new(c_attn, o_proj, config.n_head, cache))
}
}
impl Mlp {
fn load(prefix: &str, vb: &VarBuilder) -> Result<Self> {
let c_fc1 = Linear::load(&format!("{prefix}.gate_proj"), vb)?;
let c_fc2 = Linear::load(&format!("{prefix}.up_proj"), vb)?;
let c_proj = Linear::load(&format!("{prefix}.down_proj"), vb)?;
Ok(Self::new(c_fc1, c_fc2, c_proj))
}
}
impl Block {
fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(&format!("{prefix}.self_attn"), vb, cache, config)?;
let mlp = Mlp::load(&format!("{prefix}.mlp"), vb)?;
let input_layernorm = RmsNorm::load(&format!("{prefix}.input_layernorm"), vb)?;
let post_attention_layernorm =
RmsNorm::load(&format!("{prefix}.post_attention_layernorm"), vb)?;
Ok(Self::new(
input_layernorm,
attn,
post_attention_layernorm,
mlp,
))
}
}
impl Llama {
pub fn load(
device: &Device,
filenames: &[PathBuf],
cache: &Cache,
config: &Config,
) -> Result<Self> {
let handles: Vec<_> = filenames
.iter()
.map(|f| {
let file = File::open(f).unwrap();
unsafe { MmapOptions::new().map(&file).unwrap() }
})
.collect();
let tensors: Vec<_> = handles
.iter()
.map(|h| {
let tensors = SafeTensors::deserialize(h).unwrap();
tensors
})
.collect();
let vb = VarBuilder::new(tensors, device.clone());
let embedding = vb.get("model.embed_tokens.weight")?;
let wte = Embedding::new(embedding);
let lm_head = Linear::load("lm_head", &vb)?;
let norm = RmsNorm::load("model.norm", &vb)?;
let blocks: Vec<_> = (0..config.n_layer)
.map(|i| Block::load(&format!("model.layers.{i}"), &vb, cache, config).unwrap())
.collect();
Ok(Self::new(wte, blocks, norm, lm_head))
}
}

View File

@ -378,6 +378,12 @@ impl Api {
let parallel_failures_semaphore = Arc::new(Semaphore::new(self.parallel_failures)); let parallel_failures_semaphore = Arc::new(Semaphore::new(self.parallel_failures));
let filename = temp_filename(); let filename = temp_filename();
// Create the file and set everything properly
tokio::fs::File::create(&filename)
.await?
.set_len(length as u64)
.await?;
let chunk_size = self.chunk_size; let chunk_size = self.chunk_size;
for start in (0..length).step_by(chunk_size) { for start in (0..length).step_by(chunk_size) {
let url = url.to_string(); let url = url.to_string();
@ -440,7 +446,6 @@ impl Api {
let range = format!("bytes={start}-{stop}"); let range = format!("bytes={start}-{stop}");
let mut file = tokio::fs::OpenOptions::new() let mut file = tokio::fs::OpenOptions::new()
.write(true) .write(true)
.create(true)
.open(filename) .open(filename)
.await?; .await?;
file.seek(SeekFrom::Start(start as u64)).await?; file.seek(SeekFrom::Start(start as u64)).await?;

View File

@ -53,7 +53,11 @@ impl Cache {
let commit_hash = std::fs::read_to_string(commit_path).ok()?; let commit_hash = std::fs::read_to_string(commit_path).ok()?;
let mut pointer_path = self.pointer_path(repo, &commit_hash); let mut pointer_path = self.pointer_path(repo, &commit_hash);
pointer_path.push(filename); pointer_path.push(filename);
Some(pointer_path) if pointer_path.exists() {
Some(pointer_path)
} else {
None
}
} }
/// Creates a reference in the cache directory that points branches to the correct /// Creates a reference in the cache directory that points branches to the correct
@ -146,7 +150,12 @@ impl Repo {
/// The normalized folder nameof the repo within the cache directory /// The normalized folder nameof the repo within the cache directory
pub fn folder_name(&self) -> String { pub fn folder_name(&self) -> String {
self.repo_id.replace('/', "--") let prefix = match self.repo_type {
RepoType::Model => "models",
RepoType::Dataset => "datasets",
RepoType::Space => "spaces",
};
format!("{prefix}--{}", self.repo_id).replace('/', "--")
} }
/// The revision /// The revision