mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add a cuda kernel for upsampling. (#441)
* Add a cuda kernel for upsampling. * Update for the latest tokenizers version.
This commit is contained in:
@ -111,7 +111,10 @@ fn main() -> Result<()> {
|
||||
let device = &model.device;
|
||||
|
||||
if let Some(prompt) = args.prompt {
|
||||
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
|
@ -65,10 +65,7 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
let token = self
|
||||
.tokenizer
|
||||
.decode(vec![next_token], true)
|
||||
.map_err(E::msg)?;
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
@ -72,16 +72,14 @@ impl TextGeneration {
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
self.tokenizer
|
||||
.decode(vec![next_token], true)
|
||||
.map_err(E::msg)?
|
||||
self.tokenizer.decode(&[next_token], true).map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"{sample_len} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
sample_len as f64 / dt.as_secs_f64(),
|
||||
self.tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
self.tokenizer.decode(&new_tokens, true).map_err(E::msg)?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -223,7 +223,7 @@ fn main() -> Result<()> {
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
|
||||
tokenizer.decode(&[next_token], true).map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
@ -231,7 +231,7 @@ fn main() -> Result<()> {
|
||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
args.sample_len,
|
||||
args.sample_len as f64 / dt.as_secs_f64(),
|
||||
tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
tokenizer.decode(&new_tokens, true).map_err(E::msg)?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -169,10 +169,7 @@ impl Decoder {
|
||||
}
|
||||
sum_logprob += prob.ln();
|
||||
}
|
||||
let text = self
|
||||
.tokenizer
|
||||
.decode(tokens.clone(), true)
|
||||
.map_err(E::msg)?;
|
||||
let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
|
||||
let avg_logprob = sum_logprob / tokens.len() as f64;
|
||||
|
||||
Ok(DecodingResult {
|
||||
|
Reference in New Issue
Block a user