Refactor the reduce ops in order to introduce argmin/argmax. (#212)

* Refactor the reduce ops in order to introduce argmin/argmax.

* Clippy fixes.

* Use the newly introduced argmax.

* Fix the strided case.

* Handle the non-contiguous case.
This commit is contained in:
Laurent Mazare
2023-07-21 12:41:08 +02:00
committed by GitHub
parent c60831aad4
commit 410654525f
7 changed files with 241 additions and 110 deletions

View File

@ -42,7 +42,7 @@ pub fn main() -> Result<()> {
let bs = Var::zeros(LABELS, DType::F32, &dev)?;
let sgd = candle_nn::SGD::new(&[&ws, &bs], 1.0);
let test_images = m.test_images;
let test_labels = m.test_labels.to_vec1::<u8>()?;
let test_labels = m.test_labels.to_dtype(DType::U32)?;
for epoch in 1..200 {
let logits = train_images.matmul(&ws)?.broadcast_add(&bs)?;
let log_sm = log_softmax(&logits, D::Minus1)?;
@ -52,28 +52,13 @@ pub fn main() -> Result<()> {
sgd.backward_step(&loss)?;
let test_logits = test_images.matmul(&ws)?.broadcast_add(&bs)?;
/* TODO: Add argmax so that the following can be computed within candle.
let test_accuracy = test_logits
.argmax(Some(-1), false)
.eq_tensor(&test_labels)
.to_kind(Kind::Float)
.mean(Kind::Float)
.double_value(&[]);
*/
let test_logits = test_logits.to_vec2::<f32>()?;
let sum_ok = test_logits
.iter()
.zip(test_labels.iter())
.map(|(logits, label)| {
let arg_max = logits
.iter()
.enumerate()
.max_by(|(_, v1), (_, v2)| v1.total_cmp(v2))
.map(|(idx, _)| idx);
f64::from(arg_max == Some(*label as usize))
})
.sum::<f64>();
let test_accuracy = sum_ok / test_labels.len() as f64;
.argmax(D::Minus1)?
.eq(&test_labels)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
let test_accuracy = sum_ok / test_labels.shape().r1()? as f32;
println!(
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
loss.to_scalar::<f32>()?,