mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
MMLU evaluation for Phi. (#1474)
* MMLU evaluation for Phi. * Improve the evaluation.
This commit is contained in:
@ -145,7 +145,10 @@ struct Args {
|
||||
verbose_prompt: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
prompt: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
mmlu_dir: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
@ -314,17 +317,105 @@ fn main() -> Result<()> {
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
args.verbose_prompt,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
match (args.prompt, args.mmlu_dir) {
|
||||
(None, None) | (Some(_), Some(_)) => {
|
||||
anyhow::bail!("exactly one of --prompt and --mmlu-dir must be specified")
|
||||
}
|
||||
(Some(prompt), None) => {
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
args.verbose_prompt,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&prompt, args.sample_len)?;
|
||||
}
|
||||
(None, Some(mmlu_dir)) => mmlu(model, tokenizer, &device, mmlu_dir)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn mmlu<P: AsRef<std::path::Path>>(
|
||||
mut model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
device: &Device,
|
||||
mmlu_dir: P,
|
||||
) -> anyhow::Result<()> {
|
||||
for dir_entry in mmlu_dir.as_ref().read_dir()?.flatten() {
|
||||
let dir_entry = dir_entry.path();
|
||||
let theme = match dir_entry.file_stem().and_then(|v| v.to_str()) {
|
||||
None => "".to_string(),
|
||||
Some(v) => match v.strip_suffix("_test") {
|
||||
None => v.replace('_', " "),
|
||||
Some(v) => v.replace('_', " "),
|
||||
},
|
||||
};
|
||||
if dir_entry.extension().as_ref().and_then(|v| v.to_str()) != Some("csv") {
|
||||
continue;
|
||||
}
|
||||
println!("reading {dir_entry:?}");
|
||||
let dir_entry = std::fs::File::open(dir_entry)?;
|
||||
let mut reader = csv::ReaderBuilder::new()
|
||||
.has_headers(false)
|
||||
.from_reader(dir_entry);
|
||||
let token_a = tokenizer.token_to_id("A").unwrap();
|
||||
let token_b = tokenizer.token_to_id("B").unwrap();
|
||||
let token_c = tokenizer.token_to_id("C").unwrap();
|
||||
let token_d = tokenizer.token_to_id("D").unwrap();
|
||||
for row in reader.records() {
|
||||
let row = match row {
|
||||
Err(_) => continue,
|
||||
Ok(row) => row,
|
||||
};
|
||||
if row.len() < 5 {
|
||||
continue;
|
||||
}
|
||||
let question = row.get(0).unwrap();
|
||||
let answer_a = row.get(1).unwrap();
|
||||
let answer_b = row.get(2).unwrap();
|
||||
let answer_c = row.get(3).unwrap();
|
||||
let answer_d = row.get(4).unwrap();
|
||||
let answer = row.get(5).unwrap();
|
||||
let prompt = format!(
|
||||
"{} {theme}.\n{question}\nA. {answer_a}\nB. {answer_b}\nC. {answer_c}\nD. {answer_d}\nAnswer:\n",
|
||||
"The following are multiple choice questions (with answers) about"
|
||||
);
|
||||
let tokens = tokenizer.encode(prompt.as_str(), true).map_err(E::msg)?;
|
||||
let tokens = tokens.get_ids().to_vec();
|
||||
let input = Tensor::new(tokens, device)?.unsqueeze(0)?;
|
||||
let logits = match &mut model {
|
||||
Model::MixFormer(m) => {
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
}
|
||||
Model::Quantized(m) => {
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
}
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||
let pr_a = logits_v[token_a as usize];
|
||||
let pr_b = logits_v[token_b as usize];
|
||||
let pr_c = logits_v[token_c as usize];
|
||||
let pr_d = logits_v[token_d as usize];
|
||||
let model_answer = if pr_a > pr_b && pr_a > pr_c && pr_a > pr_d {
|
||||
"A"
|
||||
} else if pr_b > pr_c && pr_b > pr_d {
|
||||
"B"
|
||||
} else if pr_c > pr_d {
|
||||
"C"
|
||||
} else {
|
||||
"D"
|
||||
};
|
||||
|
||||
println!("{prompt}\n -> {model_answer} vs {answer}");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user