Improve the reshape error messages. (#1096)

* Improve the reshape error messages.

* Add the verbose-prompt flag to the phi example.
This commit is contained in:
Laurent Mazare
2023-10-15 10:43:10 +01:00
committed by GitHub
parent 8f310cc666
commit b73c35cc57
2 changed files with 49 additions and 75 deletions

View File

@ -511,154 +511,119 @@ impl ShapeWithOneHole for ((),) {
} }
} }
fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
if prod_d == 0 {
crate::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
}
if el_count % prod_d != 0 {
crate::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
}
Ok(el_count / prod_d)
}
impl ShapeWithOneHole for ((), usize) { impl ShapeWithOneHole for ((), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1) = self; let ((), d1) = self;
if el_count % d1 != 0 { Ok((hole_size(el_count, d1, &self)?, d1).into())
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
}
Ok((el_count / d1, d1).into())
} }
} }
impl ShapeWithOneHole for (usize, ()) { impl ShapeWithOneHole for (usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, ()) = self; let (d1, ()) = self;
if el_count % d1 != 0 { Ok((d1, hole_size(el_count, d1, &self)?).into())
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
}
Ok((d1, el_count / d1).into())
} }
} }
impl ShapeWithOneHole for ((), usize, usize) { impl ShapeWithOneHole for ((), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2) = self; let ((), d1, d2) = self;
let d = d1 * d2; Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into())
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2).into())
} }
} }
impl ShapeWithOneHole for (usize, (), usize) { impl ShapeWithOneHole for (usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2) = self; let (d1, (), d2) = self;
let d = d1 * d2; Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into())
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2).into())
} }
} }
impl ShapeWithOneHole for (usize, usize, ()) { impl ShapeWithOneHole for (usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, ()) = self; let (d1, d2, ()) = self;
let d = d1 * d2; Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into())
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d).into())
} }
} }
impl ShapeWithOneHole for ((), usize, usize, usize) { impl ShapeWithOneHole for ((), usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2, d3) = self; let ((), d1, d2, d3) = self;
let d = d1 * d2 * d3; let d = hole_size(el_count, d1 * d2 * d3, &self)?;
if el_count % d != 0 { Ok((d, d1, d2, d3).into())
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2, d3).into())
} }
} }
impl ShapeWithOneHole for (usize, (), usize, usize) { impl ShapeWithOneHole for (usize, (), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2, d3) = self; let (d1, (), d2, d3) = self;
let d = d1 * d2 * d3; let d = hole_size(el_count, d1 * d2 * d3, &self)?;
if el_count % d != 0 { Ok((d1, d, d2, d3).into())
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2, d3).into())
} }
} }
impl ShapeWithOneHole for (usize, usize, (), usize) { impl ShapeWithOneHole for (usize, usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, (), d3) = self; let (d1, d2, (), d3) = self;
let d = d1 * d2 * d3; let d = hole_size(el_count, d1 * d2 * d3, &self)?;
if el_count % d != 0 { Ok((d1, d2, d, d3).into())
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d, d3).into())
} }
} }
impl ShapeWithOneHole for (usize, usize, usize, ()) { impl ShapeWithOneHole for (usize, usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, ()) = self; let (d1, d2, d3, ()) = self;
let d = d1 * d2 * d3; let d = hole_size(el_count, d1 * d2 * d3, &self)?;
if el_count % d != 0 { Ok((d1, d2, d3, d).into())
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, el_count / d).into())
} }
} }
impl ShapeWithOneHole for ((), usize, usize, usize, usize) { impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2, d3, d4) = self; let ((), d1, d2, d3, d4) = self;
let d = d1 * d2 * d3 * d4; let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
if el_count % d != 0 { Ok((d, d1, d2, d3, d4).into())
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2, d3, d4).into())
} }
} }
impl ShapeWithOneHole for (usize, (), usize, usize, usize) { impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2, d3, d4) = self; let (d1, (), d2, d3, d4) = self;
let d = d1 * d2 * d3 * d4; let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
if el_count % d != 0 { Ok((d1, d, d2, d3, d4).into())
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2, d3, d4).into())
} }
} }
impl ShapeWithOneHole for (usize, usize, (), usize, usize) { impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, (), d3, d4) = self; let (d1, d2, (), d3, d4) = self;
let d = d1 * d2 * d3 * d4; let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
if el_count % d != 0 { Ok((d1, d2, d, d3, d4).into())
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d, d3, d4).into())
} }
} }
impl ShapeWithOneHole for (usize, usize, usize, (), usize) { impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, (), d4) = self; let (d1, d2, d3, (), d4) = self;
let d = d1 * d2 * d3 * d4; let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
if el_count % d != 0 { Ok((d1, d2, d3, d, d4).into())
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, el_count / d, d4).into())
} }
} }
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) { impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> { fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, d4, ()) = self; let (d1, d2, d3, d4, ()) = self;
let d = d1 * d2 * d3 * d4; let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
if el_count % d != 0 { Ok((d1, d2, d3, d4, d).into())
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, d4, el_count / d).into())
} }
} }

View File

@ -28,6 +28,7 @@ struct TextGeneration {
logits_processor: LogitsProcessor, logits_processor: LogitsProcessor,
repeat_penalty: f32, repeat_penalty: f32,
repeat_last_n: usize, repeat_last_n: usize,
verbose_prompt: bool,
} }
impl TextGeneration { impl TextGeneration {
@ -40,6 +41,7 @@ impl TextGeneration {
top_p: Option<f64>, top_p: Option<f64>,
repeat_penalty: f32, repeat_penalty: f32,
repeat_last_n: usize, repeat_last_n: usize,
verbose_prompt: bool,
device: &Device, device: &Device,
) -> Self { ) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p); let logits_processor = LogitsProcessor::new(seed, temp, top_p);
@ -49,6 +51,7 @@ impl TextGeneration {
logits_processor, logits_processor,
repeat_penalty, repeat_penalty,
repeat_last_n, repeat_last_n,
verbose_prompt,
device: device.clone(), device: device.clone(),
} }
} }
@ -58,13 +61,14 @@ impl TextGeneration {
println!("starting the inference loop"); println!("starting the inference loop");
print!("{prompt}"); print!("{prompt}");
std::io::stdout().flush()?; std::io::stdout().flush()?;
let mut tokens = self let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
.tokenizer if self.verbose_prompt {
.encode(prompt, true) for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
.map_err(E::msg)? let token = token.replace('▁', " ").replace("<0x0A>", "\n");
.get_ids() println!("{id:7} -> '{token}'");
.to_vec(); }
}
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize; let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token, Some(token) => *token,
@ -129,6 +133,10 @@ struct Args {
#[arg(long)] #[arg(long)]
tracing: bool, tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long)] #[arg(long)]
prompt: String, prompt: String,
@ -266,6 +274,7 @@ fn main() -> Result<()> {
args.top_p, args.top_p,
args.repeat_penalty, args.repeat_penalty,
args.repeat_last_n, args.repeat_last_n,
args.verbose_prompt,
&device, &device,
); );
pipeline.run(&args.prompt, args.sample_len)?; pipeline.run(&args.prompt, args.sample_len)?;