Use a tanh activation in the xlm-roberta classification head. (#2968)

This commit is contained in:
Laurent Mazare
2025-05-26 08:54:31 +02:00
committed by GitHub
parent 9a62c91643
commit 61ddb9535e

View File

@ -482,8 +482,10 @@ impl XLMRobertaClassificationHead {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?;
let hidden_states = self.dense.forward(&cls_states)?;
let hidden_states = candle_nn::Activation::GeluPytorchTanh.forward(&hidden_states)?;
let hidden_states = self.out_proj.forward(&hidden_states)?;
// The activation used in the classification head is tanh, as per the original
// implementation.
// https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py#L1454
let hidden_states = self.out_proj.forward(&hidden_states.tanh()?)?;
Ok(hidden_states)
}
}