mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add the jina-bert embeddings model. (#1187)
* Add the jina-bert model. * Use alibi. * Remove the unused pragma. * Recompute the alibi embeddings. * Generate the token type ids. * Use the module trait. * Add the jina-bert example. * DType fix. * Get the inference to work.
This commit is contained in:
@ -44,8 +44,10 @@ impl Linear {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
Self { weight, bias, span }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
impl Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let w = match x.dims() {
|
||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
@ -77,8 +79,10 @@ impl LayerNorm {
|
||||
span,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
impl Module for LayerNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
@ -195,7 +199,9 @@ impl Dropout {
|
||||
fn new(pr: f64) -> Self {
|
||||
Self { pr }
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Dropout {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
// TODO
|
||||
Ok(x.clone())
|
||||
@ -316,7 +322,9 @@ impl BertSelfAttention {
|
||||
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
||||
xs.contiguous()
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertSelfAttention {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let query_layer = self.query.forward(hidden_states)?;
|
||||
@ -391,7 +399,9 @@ impl BertAttention {
|
||||
span: tracing::span!(tracing::Level::TRACE, "attn"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertAttention {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let self_outputs = self.self_attention.forward(hidden_states)?;
|
||||
@ -416,7 +426,9 @@ impl BertIntermediate {
|
||||
span: tracing::span!(tracing::Level::TRACE, "inter"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertIntermediate {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
@ -478,7 +490,9 @@ impl BertLayer {
|
||||
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertLayer {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let attention_output = self.attention.forward(hidden_states)?;
|
||||
@ -507,7 +521,9 @@ impl BertEncoder {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||
Ok(BertEncoder { layers, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertEncoder {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut hidden_states = hidden_states.clone();
|
||||
|
Reference in New Issue
Block a user