mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Fix the revision used in starcoder to use the safetensors PR. (#269)
This commit is contained in:
@ -106,7 +106,7 @@ struct Args {
|
||||
#[arg(long, default_value = "bigcode/starcoderbase-1b")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
#[arg(long, default_value = "refs/pr/1")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
@ -126,13 +126,10 @@ fn main() -> Result<()> {
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let filenames = match args.weight_file {
|
||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||
None => {
|
||||
let repo_filenames: Vec<String> = vec![];
|
||||
repo_filenames
|
||||
None => ["model.safetensors"]
|
||||
.iter()
|
||||
.map(|f| repo.get(f))
|
||||
.collect::<std::result::Result<Vec<_>, _>>()?
|
||||
}
|
||||
.collect::<std::result::Result<Vec<_>, _>>()?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
@ -316,7 +316,7 @@ impl GPTBigCode {
|
||||
.map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?;
|
||||
let lm_head = linear(hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?;
|
||||
let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?;
|
||||
let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?;
|
||||
Ok(Self {
|
||||
wte,
|
||||
|
Reference in New Issue
Block a user