mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Fix the revision used in starcoder to use the safetensors PR. (#269)
This commit is contained in:
@ -106,7 +106,7 @@ struct Args {
|
|||||||
#[arg(long, default_value = "bigcode/starcoderbase-1b")]
|
#[arg(long, default_value = "bigcode/starcoderbase-1b")]
|
||||||
model_id: String,
|
model_id: String,
|
||||||
|
|
||||||
#[arg(long, default_value = "main")]
|
#[arg(long, default_value = "refs/pr/1")]
|
||||||
revision: String,
|
revision: String,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -126,13 +126,10 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||||
let filenames = match args.weight_file {
|
let filenames = match args.weight_file {
|
||||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||||
None => {
|
None => ["model.safetensors"]
|
||||||
let repo_filenames: Vec<String> = vec![];
|
.iter()
|
||||||
repo_filenames
|
.map(|f| repo.get(f))
|
||||||
.iter()
|
.collect::<std::result::Result<Vec<_>, _>>()?,
|
||||||
.map(|f| repo.get(f))
|
|
||||||
.collect::<std::result::Result<Vec<_>, _>>()?
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
@ -316,7 +316,7 @@ impl GPTBigCode {
|
|||||||
.map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg))
|
.map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?;
|
let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?;
|
||||||
let lm_head = linear(hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?;
|
let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?;
|
||||||
let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?;
|
let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
wte,
|
wte,
|
||||||
|
Reference in New Issue
Block a user