mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Adding tons of profiling and removing the metal allocation (still slow).
This commit is contained in:
@ -232,7 +232,7 @@ fn main() -> anyhow::Result<()> {
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(false)?;
|
||||
let mut device = candle_examples::device(false)?;
|
||||
let temperature = if args.temperature == 0. {
|
||||
None
|
||||
} else {
|
||||
@ -384,17 +384,20 @@ fn main() -> anyhow::Result<()> {
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
if let candle::Device::Metal(device) = &mut device{
|
||||
device.flush()
|
||||
}
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&all_tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
// let logits = if args.repeat_penalty == 1. {
|
||||
// logits
|
||||
// } else {
|
||||
// let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||
// candle_transformers::utils::apply_repeat_penalty(
|
||||
// &logits,
|
||||
// args.repeat_penalty,
|
||||
// &all_tokens[start_at..],
|
||||
// )?
|
||||
// };
|
||||
// TODO Remove this once implementation is finished.
|
||||
let logits = logits.ones_like()?;
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
|
Reference in New Issue
Block a user