Fix the revision used in starcoder to use the safetensors PR. (#269)

This commit is contained in:
Laurent Mazare
2023-07-28 14:02:31 +01:00
committed by GitHub
parent fb84ead8f7
commit a0e47aba98
2 changed files with 6 additions and 9 deletions

View File

@ -106,7 +106,7 @@ struct Args {
#[arg(long, default_value = "bigcode/starcoderbase-1b")] #[arg(long, default_value = "bigcode/starcoderbase-1b")]
model_id: String, model_id: String,
#[arg(long, default_value = "main")] #[arg(long, default_value = "refs/pr/1")]
revision: String, revision: String,
#[arg(long)] #[arg(long)]
@ -126,13 +126,10 @@ fn main() -> Result<()> {
let tokenizer_filename = repo.get("tokenizer.json")?; let tokenizer_filename = repo.get("tokenizer.json")?;
let filenames = match args.weight_file { let filenames = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => { None => ["model.safetensors"]
let repo_filenames: Vec<String> = vec![]; .iter()
repo_filenames .map(|f| repo.get(f))
.iter() .collect::<std::result::Result<Vec<_>, _>>()?,
.map(|f| repo.get(f))
.collect::<std::result::Result<Vec<_>, _>>()?
}
}; };
println!("retrieved the files in {:?}", start.elapsed()); println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

View File

@ -316,7 +316,7 @@ impl GPTBigCode {
.map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg)) .map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?; let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?;
let lm_head = linear(hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?; let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?;
let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?; let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?;
Ok(Self { Ok(Self {
wte, wte,