Bugfixes for marian-mt. (#1219)

* Bugfixes for marian-mt.

* Apply the final decoding head.

* More fixes.
This commit is contained in:
Laurent Mazare
2023-10-30 12:44:19 +01:00
committed by GitHub
parent 5fc66bd4ba
commit 969960847a
2 changed files with 21 additions and 13 deletions

View File

@ -36,8 +36,6 @@ struct Args {
text: String,
}
const SEP_TOKEN_ID: u32 = 102;
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
@ -62,7 +60,7 @@ pub fn main() -> anyhow::Result<()> {
model.encoder().forward(&tokens, 0)?
};
let mut token_ids = vec![30522u32];
let mut token_ids = vec![config.decoder_start_token_id];
for index in 0..1000 {
// TODO: Add a kv cache.
let context_size = if index >= 1000 { 1 } else { token_ids.len() };
@ -72,7 +70,8 @@ pub fn main() -> anyhow::Result<()> {
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
if token == SEP_TOKEN_ID {
println!("{token}");
if token == config.eos_token_id || token == config.forced_eos_token_id {
break;
}
token_ids.push(token);