Bugfixes. (#97)

This commit is contained in:
Laurent Mazare
2023-07-06 23:26:11 +01:00
committed by GitHub
parent a3f3b93d16
commit 2b8e8c9f14
2 changed files with 5 additions and 6 deletions

View File

@ -36,9 +36,8 @@ struct Args {
revision: String,
}
#[tokio::main]
async fn main() -> Result<()> {
use candle_hub::{api::Api, Repo, RepoType};
fn main() -> Result<()> {
use candle_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
let args = Args::parse();
@ -51,13 +50,13 @@ async fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = Repo::with_revision(args.model_id, RepoType::Model, args.revision);
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
let filename = api.get(&repo, rfilename).await?;
let filename = api.get(&repo, rfilename)?;
filenames.push(filename);
}
println!("retrieved the files in {:?}", start.elapsed());

View File

@ -444,7 +444,7 @@ impl FalconAttention {
.reshape((b_sz, self.num_heads, q_len, head_dim))?
.transpose(1, 2)?
.reshape((b_sz, q_len, self.num_heads * head_dim))?;
let attn_output = self.attn_output.forward(&attn_output)?;
let attn_output = self.dense.forward(&attn_output)?;
Ok(attn_output)
}
}