mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN. You could write more, like: This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point.
This commit is contained in:
@ -183,7 +183,7 @@ impl Module for T5LayerNorm {
|
|||||||
let xs_f32 = xs.to_dtype(DType::F32)?;
|
let xs_f32 = xs.to_dtype(DType::F32)?;
|
||||||
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||||
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
|
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
|
||||||
let xs = xs_f32.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
|
let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
|
||||||
let xs = xs.to_dtype(dtype)?;
|
let xs = xs.to_dtype(dtype)?;
|
||||||
let xs = xs.broadcast_mul(&self.weight)?;
|
let xs = xs.broadcast_mul(&self.weight)?;
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
@ -709,8 +709,10 @@ impl T5EncoderModel {
|
|||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let shared_vb = if vb.contains_tensor("shared.weight") {
|
let shared_vb = if vb.contains_tensor("shared.weight") {
|
||||||
vb.pp("shared")
|
vb.pp("shared")
|
||||||
} else {
|
} else if vb.contains_tensor("decoder.embed_tokens") {
|
||||||
vb.pp("decoder").pp("embed_tokens")
|
vb.pp("decoder").pp("embed_tokens")
|
||||||
|
} else {
|
||||||
|
vb.pp("encoder").pp("embed_tokens")
|
||||||
};
|
};
|
||||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
|
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
|
||||||
let shared = Arc::new(shared);
|
let shared = Arc::new(shared);
|
||||||
|
Reference in New Issue
Block a user