Add the jina-bert embeddings model. (#1187)

* Add the jina-bert model.

* Use alibi.

* Remove the unused pragma.

* Recompute the alibi embeddings.

* Generate the token type ids.

* Use the module trait.

* Add the jina-bert example.

* DType fix.

* Get the inference to work.
This commit is contained in:
Laurent Mazare
2023-10-26 16:54:36 +01:00
committed by GitHub
parent e37b487767
commit 5f20697918
4 changed files with 550 additions and 2 deletions

View File

@ -44,8 +44,10 @@ impl Linear {
let span = tracing::span!(tracing::Level::TRACE, "linear");
Self { weight, bias, span }
}
}
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
impl Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let _enter = self.span.enter();
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
@ -77,8 +79,10 @@ impl LayerNorm {
span,
}
}
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
impl Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
@ -195,7 +199,9 @@ impl Dropout {
fn new(pr: f64) -> Self {
Self { pr }
}
}
impl Module for Dropout {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
// TODO
Ok(x.clone())
@ -316,7 +322,9 @@ impl BertSelfAttention {
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
xs.contiguous()
}
}
impl Module for BertSelfAttention {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let query_layer = self.query.forward(hidden_states)?;
@ -391,7 +399,9 @@ impl BertAttention {
span: tracing::span!(tracing::Level::TRACE, "attn"),
})
}
}
impl Module for BertAttention {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let self_outputs = self.self_attention.forward(hidden_states)?;
@ -416,7 +426,9 @@ impl BertIntermediate {
span: tracing::span!(tracing::Level::TRACE, "inter"),
})
}
}
impl Module for BertIntermediate {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states = self.dense.forward(hidden_states)?;
@ -478,7 +490,9 @@ impl BertLayer {
span: tracing::span!(tracing::Level::TRACE, "layer"),
})
}
}
impl Module for BertLayer {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let attention_output = self.attention.forward(hidden_states)?;
@ -507,7 +521,9 @@ impl BertEncoder {
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(BertEncoder { layers, span })
}
}
impl Module for BertEncoder {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone();