mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Sketch the Falcon model. (#93)
* Sketch the Falcon model. * Add more substance to the falcon example. * Falcon (wip). * Falcon (wip again). * Falcon inference. * Get the weights from the api and properly generate the model. * Use the proper model. * Fix the file/revision names. * Fix bias handling. * Recompute the rot embeddings. * Fix the input shape. * Add the release-with-debug profile. * Silly bugfix. * More bugfixes. * Stricter shape checking in matmul.
This commit is contained in:
88
candle-examples/examples/falcon/main.rs
Normal file
88
candle-examples/examples/falcon/main.rs
Normal file
@ -0,0 +1,88 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use clap::Parser;
|
||||
|
||||
mod model;
|
||||
use model::{Config, Falcon, VarBuilder};
|
||||
|
||||
const DTYPE: DType = DType::F16;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
#[arg(long, default_value = "tiiuae/falcon-7b")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "refs/pr/43")]
|
||||
revision: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
use candle_hub::{api::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
let args = Args::parse();
|
||||
let device = if args.cpu {
|
||||
Device::Cpu
|
||||
} else {
|
||||
Device::new_cuda(0)?
|
||||
};
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::with_revision(args.model_id, RepoType::Model, args.revision);
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = api.get(&repo, rfilename).await?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let weights = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|f| Ok(f.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let vb = VarBuilder::from_safetensors(weights, DTYPE, &device);
|
||||
let config = Config::falcon7b();
|
||||
config.validate()?;
|
||||
let mut model = Falcon::load(&vb, config)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&tokens)?;
|
||||
println!("{}", logits);
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user