From 79e3bec789bd7bdb0e331d628c8fe9d06c519f33 Mon Sep 17 00:00:00 2001 From: Victor-Mihaila <115141117+Victor-Mihaila@users.noreply.github.com> Date: Sat, 13 Apr 2024 11:06:24 +0200 Subject: [PATCH] Change for the encoder-only ProstT5 model (#2045) * This change avoids crashes when running T5 models with F16 tensors on CPU. * This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN. This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point. --- candle-transformers/src/models/t5.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index f4b5b4b0..4544d828 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -709,8 +709,10 @@ impl T5EncoderModel { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { let shared_vb = if vb.contains_tensor("shared.weight") { vb.pp("shared") - } else { + } else if vb.contains_tensor("decoder.embed_tokens") { vb.pp("decoder").pp("embed_tokens") + } else { + vb.pp("encoder").pp("embed_tokens") }; let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?; let shared = Arc::new(shared);