Support for the new Qwen2 models. (#2257)

* Support for the new Qwen2 models.

* Add more models.
This commit is contained in:
Laurent Mazare
2024-06-07 10:51:50 +01:00
committed by GitHub
parent b9fac7ec00
commit 54ff971e35
2 changed files with 32 additions and 12 deletions

View File

@ -360,8 +360,12 @@ pub struct ModelForCausalLM {
impl ModelForCausalLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let base_model = Model::new(cfg, vb)?;
let base_model = Model::new(cfg, vb.clone())?;
let lm_head = if vb.contains_tensor("lm_head") {
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
} else {
Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)
};
Ok(Self {
base_model,
lm_head,