Adding tons of profiling and removing the metal allocation (still slow).

This commit is contained in:
Nicolas Patry
2023-11-02 17:48:07 +01:00
parent 7161002a34
commit 9a27f11c3f
4 changed files with 100 additions and 61 deletions

View File

@ -232,7 +232,7 @@ fn main() -> anyhow::Result<()> {
use tracing_subscriber::prelude::*;
let args = Args::parse();
let device = candle_examples::device(false)?;
let mut device = candle_examples::device(false)?;
let temperature = if args.temperature == 0. {
None
} else {
@ -384,17 +384,20 @@ fn main() -> anyhow::Result<()> {
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
if let candle::Device::Metal(device) = &mut device{
device.flush()
}
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
// let logits = if args.repeat_penalty == 1. {
// logits
// } else {
// let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
// candle_transformers::utils::apply_repeat_penalty(
// &logits,
// args.repeat_penalty,
// &all_tokens[start_at..],
// )?
// };
// TODO Remove this once implementation is finished.
let logits = logits.ones_like()?;
next_token = logits_processor.sample(&logits)?;