mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add a prompt and support more models in llama2-c. (#285)
* Support more models in llama2-c. * Add a prompt.
This commit is contained in:
@ -193,6 +193,15 @@ struct Args {
|
|||||||
|
|
||||||
#[arg(long, default_value = "karpathy/tinyllamas")]
|
#[arg(long, default_value = "karpathy/tinyllamas")]
|
||||||
model_id: String,
|
model_id: String,
|
||||||
|
|
||||||
|
/// The model to be used when getting it from the hub. Possible
|
||||||
|
/// values are 'stories15M.bin', 'stories42M.bin', see more at:
|
||||||
|
/// https://huggingface.co/karpathy/tinyllamas/tree/main
|
||||||
|
#[arg(long, default_value = "stories15M.bin")]
|
||||||
|
which_model: String,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "")]
|
||||||
|
prompt: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
@ -206,7 +215,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
println!("loading the model weights from {}", args.model_id);
|
println!("loading the model weights from {}", args.model_id);
|
||||||
let api = api.model(args.model_id);
|
let api = api.model(args.model_id);
|
||||||
api.get("stories15M.bin")?
|
api.get(&args.which_model)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let mut file = std::fs::File::open(&config_path)?;
|
let mut file = std::fs::File::open(&config_path)?;
|
||||||
@ -226,15 +235,24 @@ fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
println!("{tokenizer_path:?}");
|
println!("{tokenizer_path:?}");
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
|
||||||
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
|
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
|
||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
let mut tokens = vec![1u32];
|
|
||||||
|
print!("{}", args.prompt);
|
||||||
|
let mut tokens = tokenizer
|
||||||
|
.encode(args.prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..config.seq_len - 10 {
|
for index in 0.. {
|
||||||
|
if tokens.len() >= config.seq_len {
|
||||||
|
break;
|
||||||
|
}
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
let context_size = if cache.use_kv_cache && index > 0 {
|
let context_size = if cache.use_kv_cache && index > 0 {
|
||||||
1
|
1
|
||||||
|
@ -112,8 +112,10 @@ struct CausalSelfAttention {
|
|||||||
impl CausalSelfAttention {
|
impl CausalSelfAttention {
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
|
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
|
||||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
|
||||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
|
||||||
|
let cos = cos.unsqueeze(1)?;
|
||||||
|
let sin = sin.unsqueeze(1)?;
|
||||||
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
||||||
let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
||||||
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
|
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
|
||||||
|
Reference in New Issue
Block a user