mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
MMLU evaluation for Phi. (#1474)
* MMLU evaluation for Phi. * Improve the evaluation.
This commit is contained in:
@ -28,6 +28,7 @@ safetensors = { workspace = true }
|
|||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
tokenizers = { workspace = true, features = ["onig"] }
|
tokenizers = { workspace = true, features = ["onig"] }
|
||||||
|
csv = "1.3.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
|
@ -145,7 +145,10 @@ struct Args {
|
|||||||
verbose_prompt: bool,
|
verbose_prompt: bool,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: String,
|
prompt: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
mmlu_dir: Option<String>,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -314,17 +317,105 @@ fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
match (args.prompt, args.mmlu_dir) {
|
||||||
model,
|
(None, None) | (Some(_), Some(_)) => {
|
||||||
tokenizer,
|
anyhow::bail!("exactly one of --prompt and --mmlu-dir must be specified")
|
||||||
args.seed,
|
}
|
||||||
args.temperature,
|
(Some(prompt), None) => {
|
||||||
args.top_p,
|
let mut pipeline = TextGeneration::new(
|
||||||
args.repeat_penalty,
|
model,
|
||||||
args.repeat_last_n,
|
tokenizer,
|
||||||
args.verbose_prompt,
|
args.seed,
|
||||||
&device,
|
args.temperature,
|
||||||
);
|
args.top_p,
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
args.verbose_prompt,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&prompt, args.sample_len)?;
|
||||||
|
}
|
||||||
|
(None, Some(mmlu_dir)) => mmlu(model, tokenizer, &device, mmlu_dir)?,
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mmlu<P: AsRef<std::path::Path>>(
|
||||||
|
mut model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
device: &Device,
|
||||||
|
mmlu_dir: P,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
for dir_entry in mmlu_dir.as_ref().read_dir()?.flatten() {
|
||||||
|
let dir_entry = dir_entry.path();
|
||||||
|
let theme = match dir_entry.file_stem().and_then(|v| v.to_str()) {
|
||||||
|
None => "".to_string(),
|
||||||
|
Some(v) => match v.strip_suffix("_test") {
|
||||||
|
None => v.replace('_', " "),
|
||||||
|
Some(v) => v.replace('_', " "),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
if dir_entry.extension().as_ref().and_then(|v| v.to_str()) != Some("csv") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
println!("reading {dir_entry:?}");
|
||||||
|
let dir_entry = std::fs::File::open(dir_entry)?;
|
||||||
|
let mut reader = csv::ReaderBuilder::new()
|
||||||
|
.has_headers(false)
|
||||||
|
.from_reader(dir_entry);
|
||||||
|
let token_a = tokenizer.token_to_id("A").unwrap();
|
||||||
|
let token_b = tokenizer.token_to_id("B").unwrap();
|
||||||
|
let token_c = tokenizer.token_to_id("C").unwrap();
|
||||||
|
let token_d = tokenizer.token_to_id("D").unwrap();
|
||||||
|
for row in reader.records() {
|
||||||
|
let row = match row {
|
||||||
|
Err(_) => continue,
|
||||||
|
Ok(row) => row,
|
||||||
|
};
|
||||||
|
if row.len() < 5 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let question = row.get(0).unwrap();
|
||||||
|
let answer_a = row.get(1).unwrap();
|
||||||
|
let answer_b = row.get(2).unwrap();
|
||||||
|
let answer_c = row.get(3).unwrap();
|
||||||
|
let answer_d = row.get(4).unwrap();
|
||||||
|
let answer = row.get(5).unwrap();
|
||||||
|
let prompt = format!(
|
||||||
|
"{} {theme}.\n{question}\nA. {answer_a}\nB. {answer_b}\nC. {answer_c}\nD. {answer_d}\nAnswer:\n",
|
||||||
|
"The following are multiple choice questions (with answers) about"
|
||||||
|
);
|
||||||
|
let tokens = tokenizer.encode(prompt.as_str(), true).map_err(E::msg)?;
|
||||||
|
let tokens = tokens.get_ids().to_vec();
|
||||||
|
let input = Tensor::new(tokens, device)?.unsqueeze(0)?;
|
||||||
|
let logits = match &mut model {
|
||||||
|
Model::MixFormer(m) => {
|
||||||
|
m.clear_kv_cache();
|
||||||
|
m.forward(&input)?
|
||||||
|
}
|
||||||
|
Model::Quantized(m) => {
|
||||||
|
m.clear_kv_cache();
|
||||||
|
m.forward(&input)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||||
|
let pr_a = logits_v[token_a as usize];
|
||||||
|
let pr_b = logits_v[token_b as usize];
|
||||||
|
let pr_c = logits_v[token_c as usize];
|
||||||
|
let pr_d = logits_v[token_d as usize];
|
||||||
|
let model_answer = if pr_a > pr_b && pr_a > pr_c && pr_a > pr_d {
|
||||||
|
"A"
|
||||||
|
} else if pr_b > pr_c && pr_b > pr_d {
|
||||||
|
"B"
|
||||||
|
} else if pr_c > pr_d {
|
||||||
|
"C"
|
||||||
|
} else {
|
||||||
|
"D"
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("{prompt}\n -> {model_answer} vs {answer}");
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user