mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Bugfixes for marian-mt. (#1219)
* Bugfixes for marian-mt. * Apply the final decoding head. * More fixes.
This commit is contained in:
@ -36,8 +36,6 @@ struct Args {
|
||||
text: String,
|
||||
}
|
||||
|
||||
const SEP_TOKEN_ID: u32 = 102;
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
@ -62,7 +60,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
model.encoder().forward(&tokens, 0)?
|
||||
};
|
||||
|
||||
let mut token_ids = vec![30522u32];
|
||||
let mut token_ids = vec![config.decoder_start_token_id];
|
||||
for index in 0..1000 {
|
||||
// TODO: Add a kv cache.
|
||||
let context_size = if index >= 1000 { 1 } else { token_ids.len() };
|
||||
@ -72,7 +70,8 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
if token == SEP_TOKEN_ID {
|
||||
println!("{token}");
|
||||
if token == config.eos_token_id || token == config.forced_eos_token_id {
|
||||
break;
|
||||
}
|
||||
token_ids.push(token);
|
||||
|
Reference in New Issue
Block a user