Adding cast + binary kernels.

This commit is contained in:
Nicolas Patry
2023-11-07 23:45:53 +01:00
parent 0c24a885a6
commit 480a3e22e6
7 changed files with 601 additions and 84 deletions

View File

@ -239,15 +239,13 @@ fn main() -> anyhow::Result<()> {
Some(args.temperature)
};
tracing_subscriber::fmt::init();
// let _guard = if args.tracing {
// // let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
// // tracing_subscriber::registry().with(chrome_layer).init();
// tracing_subscriber::fmt::init();
// None
// // Some(guard)
// } else {
// None
// };
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
@ -375,7 +373,8 @@ fn main() -> anyhow::Result<()> {
let logits = logits.squeeze(0)?;
// TODO Remove this once implementation is finished.
let logits = logits.ones_like()?;
logits_processor.sample(&logits)?
// logits_processor.sample(&logits)?
15043
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
@ -399,8 +398,9 @@ fn main() -> anyhow::Result<()> {
)?
};
// TODO Remove this once implementation is finished.
let logits = logits.ones_like()?;
next_token = logits_processor.sample(&logits)?;
// let logits = logits.ones_like()?;
// next_token = logits_processor.sample(&logits)?;
let next_token = 15043;
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
if next_token == eos_token {