mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Use a tanh activation in the xlm-roberta classification head. (#2968)
This commit is contained in:
@ -482,8 +482,10 @@ impl XLMRobertaClassificationHead {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?;
|
||||
let hidden_states = self.dense.forward(&cls_states)?;
|
||||
let hidden_states = candle_nn::Activation::GeluPytorchTanh.forward(&hidden_states)?;
|
||||
let hidden_states = self.out_proj.forward(&hidden_states)?;
|
||||
// The activation used in the classification head is tanh, as per the original
|
||||
// implementation.
|
||||
// https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py#L1454
|
||||
let hidden_states = self.out_proj.forward(&hidden_states.tanh()?)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user