Infer the config for llama2-c. (#1208)

This commit is contained in:
Laurent Mazare
2023-10-28 20:00:39 +02:00
committed by GitHub
parent 95a857cf57
commit 012ae0090e
4 changed files with 63 additions and 4 deletions

View File

@ -17,7 +17,20 @@ pub struct Config {
}
impl Config {
pub fn tiny() -> Self {
pub fn tiny_260k() -> Self {
Self {
dim: 64,
hidden_dim: 768,
n_layers: 5,
n_heads: 8,
n_kv_heads: 4,
vocab_size: 32000,
seq_len: 512,
norm_eps: 1e-5,
}
}
pub fn tiny_15m() -> Self {
Self {
dim: 288,
hidden_dim: 768,
@ -29,6 +42,32 @@ impl Config {
norm_eps: 1e-5,
}
}
pub fn tiny_42m() -> Self {
Self {
dim: 512,
hidden_dim: 768,
n_layers: 8,
n_heads: 8,
n_kv_heads: 8,
vocab_size: 32000,
seq_len: 1024,
norm_eps: 1e-5,
}
}
pub fn tiny_110m() -> Self {
Self {
dim: 768,
hidden_dim: 768,
n_layers: 12,
n_heads: 12,
n_kv_heads: 12,
vocab_size: 32000,
seq_len: 1024,
norm_eps: 1e-5,
}
}
}
#[derive(Clone)]

View File

@ -77,6 +77,16 @@ impl VarBuilder {
}
}
pub fn get_no_shape(&self, name: &str) -> Result<Arc<QTensor>> {
let path = self.path(name);
match self.data.get(&path) {
None => {
candle::bail!("cannot find tensor {name}")
}
Some(qtensor) => Ok(qtensor.clone()),
}
}
pub fn device(&self) -> &Device {
&self.device
}