Marian MT model (#1210)

* Skeleton files for the marian MT model.

* Marian initialization.

* Implement the attention forward method.

* Forward pass for the encoder side.

* Expose the encoder and decoder.

* Start plugging the decoder.

* Forward pass for the decoder layer.

* Set up the marian example.

* Add some missing backtraces.

* Bugfix.
This commit is contained in:
Laurent Mazare
2023-10-29 16:12:22 +01:00
committed by GitHub
parent c3f2676d49
commit 7bbde55c61
5 changed files with 521 additions and 10 deletions

View File

@ -14,6 +14,13 @@ impl Embedding {
Ok(Self { inner, span })
}
pub fn from_weights(weights: Tensor) -> Result<Self> {
let (_in_size, out_size) = weights.dims2()?;
let inner = candle_nn::Embedding::new(weights, out_size);
let span = tracing::span!(tracing::Level::TRACE, "embedding");
Ok(Self { inner, span })
}
pub fn embeddings(&self) -> &Tensor {
self.inner.embeddings()
}