Add a cuda kernel for upsampling. (#441)

* Add a cuda kernel for upsampling.

* Update for the latest tokenizers version.
This commit is contained in:
Laurent Mazare
2023-08-14 13:12:17 +01:00
committed by GitHub
parent a094dc503d
commit c84883ecf2
10 changed files with 119 additions and 26 deletions

View File

@ -111,7 +111,10 @@ fn main() -> Result<()> {
let device = &model.device;
if let Some(prompt) = args.prompt {
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?

View File

@ -65,10 +65,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
let token = self
.tokenizer
.decode(vec![next_token], true)
.map_err(E::msg)?;
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}

View File

@ -72,16 +72,14 @@ impl TextGeneration {
"{} token: {} '{}'",
index + 1,
next_token,
self.tokenizer
.decode(vec![next_token], true)
.map_err(E::msg)?
self.tokenizer.decode(&[next_token], true).map_err(E::msg)?
);
}
let dt = start_gen.elapsed();
println!(
"{sample_len} tokens generated ({} token/s)\n----\n{}\n----",
sample_len as f64 / dt.as_secs_f64(),
self.tokenizer.decode(new_tokens, true).map_err(E::msg)?
self.tokenizer.decode(&new_tokens, true).map_err(E::msg)?
);
Ok(())
}

View File

@ -223,7 +223,7 @@ fn main() -> Result<()> {
"{} token: {} '{}'",
index + 1,
next_token,
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
tokenizer.decode(&[next_token], true).map_err(E::msg)?
);
}
let dt = start_gen.elapsed();
@ -231,7 +231,7 @@ fn main() -> Result<()> {
"{} tokens generated ({} token/s)\n----\n{}\n----",
args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(),
tokenizer.decode(new_tokens, true).map_err(E::msg)?
tokenizer.decode(&new_tokens, true).map_err(E::msg)?
);
Ok(())
}

View File

@ -169,10 +169,7 @@ impl Decoder {
}
sum_logprob += prob.ln();
}
let text = self
.tokenizer
.decode(tokens.clone(), true)
.map_err(E::msg)?;
let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
let avg_logprob = sum_logprob / tokens.len() as f64;
Ok(DecodingResult {