mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Use a tanh activation in the xlm-roberta classification head. (#2968)
This commit is contained in:
@ -482,8 +482,10 @@ impl XLMRobertaClassificationHead {
|
|||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?;
|
let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?;
|
||||||
let hidden_states = self.dense.forward(&cls_states)?;
|
let hidden_states = self.dense.forward(&cls_states)?;
|
||||||
let hidden_states = candle_nn::Activation::GeluPytorchTanh.forward(&hidden_states)?;
|
// The activation used in the classification head is tanh, as per the original
|
||||||
let hidden_states = self.out_proj.forward(&hidden_states)?;
|
// implementation.
|
||||||
|
// https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py#L1454
|
||||||
|
let hidden_states = self.out_proj.forward(&hidden_states.tanh()?)?;
|
||||||
Ok(hidden_states)
|
Ok(hidden_states)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user