mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Bugfixes. (#97)
This commit is contained in:
@ -36,9 +36,8 @@ struct Args {
|
||||
revision: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
use candle_hub::{api::Api, Repo, RepoType};
|
||||
fn main() -> Result<()> {
|
||||
use candle_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
let args = Args::parse();
|
||||
@ -51,13 +50,13 @@ async fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::with_revision(args.model_id, RepoType::Model, args.revision);
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = api.get(&repo, rfilename).await?;
|
||||
let filename = api.get(&repo, rfilename)?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
|
@ -444,7 +444,7 @@ impl FalconAttention {
|
||||
.reshape((b_sz, self.num_heads, q_len, head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, self.num_heads * head_dim))?;
|
||||
let attn_output = self.attn_output.forward(&attn_output)?;
|
||||
let attn_output = self.dense.forward(&attn_output)?;
|
||||
Ok(attn_output)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user