From fb805b8ca2c9413ad9227800328145434a08eaca Mon Sep 17 00:00:00 2001 From: Victor-Mihaila <115141117+Victor-Mihaila@users.noreply.github.com> Date: Sat, 13 Apr 2024 11:07:28 +0200 Subject: [PATCH] Avoid crashes when running T5 models with F16 tensors on CPU (#2047) * This change avoids crashes when running T5 models with F16 tensors on CPU. * This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN. You could write more, like: This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point. * Revert "This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN. You could write more, like: This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point." This reverts commit d886d3ce5e3f1504934f4f6f7cf86108b7efd191. --- candle-transformers/src/models/t5.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 4544d828..8a7a8955 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -183,7 +183,7 @@ impl Module for T5LayerNorm { let xs_f32 = xs.to_dtype(DType::F32)?; // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?; - let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; + let xs = xs_f32.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; let xs = xs.to_dtype(dtype)?; let xs = xs.broadcast_mul(&self.weight)?; Ok(xs)