mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Infer the config for llama2-c. (#1208)
This commit is contained in:
@ -17,7 +17,20 @@ pub struct Config {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn tiny() -> Self {
|
||||
pub fn tiny_260k() -> Self {
|
||||
Self {
|
||||
dim: 64,
|
||||
hidden_dim: 768,
|
||||
n_layers: 5,
|
||||
n_heads: 8,
|
||||
n_kv_heads: 4,
|
||||
vocab_size: 32000,
|
||||
seq_len: 512,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_15m() -> Self {
|
||||
Self {
|
||||
dim: 288,
|
||||
hidden_dim: 768,
|
||||
@ -29,6 +42,32 @@ impl Config {
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_42m() -> Self {
|
||||
Self {
|
||||
dim: 512,
|
||||
hidden_dim: 768,
|
||||
n_layers: 8,
|
||||
n_heads: 8,
|
||||
n_kv_heads: 8,
|
||||
vocab_size: 32000,
|
||||
seq_len: 1024,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_110m() -> Self {
|
||||
Self {
|
||||
dim: 768,
|
||||
hidden_dim: 768,
|
||||
n_layers: 12,
|
||||
n_heads: 12,
|
||||
n_kv_heads: 12,
|
||||
vocab_size: 32000,
|
||||
seq_len: 1024,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
@ -77,6 +77,16 @@ impl VarBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_no_shape(&self, name: &str) -> Result<Arc<QTensor>> {
|
||||
let path = self.path(name);
|
||||
match self.data.get(&path) {
|
||||
None => {
|
||||
candle::bail!("cannot find tensor {name}")
|
||||
}
|
||||
Some(qtensor) => Ok(qtensor.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
|
Reference in New Issue
Block a user