mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Co-authored-by: Yi Xu <xuyi@me.com>
This commit is contained in:
@ -361,7 +361,7 @@ pub struct ModelForCausalLM {
|
|||||||
impl ModelForCausalLM {
|
impl ModelForCausalLM {
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let base_model = Model::new(cfg, vb.clone())?;
|
let base_model = Model::new(cfg, vb.clone())?;
|
||||||
let lm_head = if vb.contains_tensor("lm_head") {
|
let lm_head = if vb.contains_tensor("lm_head.weight") {
|
||||||
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||||
} else {
|
} else {
|
||||||
Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)
|
Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)
|
||||||
|
Reference in New Issue
Block a user