Avoid crashes when running T5 models with F16 tensors on CPU (#2047)

* This change avoids crashes when running T5 models with F16 tensors on CPU.

* This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN. You could write more, like: This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point.

* Revert "This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN. You could write more, like: This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point."

This reverts commit d886d3ce5e.
This commit is contained in:
Victor-Mihaila
2024-04-13 11:07:28 +02:00
committed by GitHub
parent 79e3bec789
commit fb805b8ca2

View File

@ -183,7 +183,7 @@ impl Module for T5LayerNorm {
let xs_f32 = xs.to_dtype(DType::F32)?; let xs_f32 = xs.to_dtype(DType::F32)?;
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?; let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; let xs = xs_f32.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
let xs = xs.to_dtype(dtype)?; let xs = xs.to_dtype(dtype)?;
let xs = xs.broadcast_mul(&self.weight)?; let xs = xs.broadcast_mul(&self.weight)?;
Ok(xs) Ok(xs)