mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Refactor the reduce ops in order to introduce argmin/argmax. (#212)
* Refactor the reduce ops in order to introduce argmin/argmax. * Clippy fixes. * Use the newly introduced argmax. * Fix the strided case. * Handle the non-contiguous case.
This commit is contained in:
@ -42,7 +42,7 @@ pub fn main() -> Result<()> {
|
||||
let bs = Var::zeros(LABELS, DType::F32, &dev)?;
|
||||
let sgd = candle_nn::SGD::new(&[&ws, &bs], 1.0);
|
||||
let test_images = m.test_images;
|
||||
let test_labels = m.test_labels.to_vec1::<u8>()?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?;
|
||||
for epoch in 1..200 {
|
||||
let logits = train_images.matmul(&ws)?.broadcast_add(&bs)?;
|
||||
let log_sm = log_softmax(&logits, D::Minus1)?;
|
||||
@ -52,28 +52,13 @@ pub fn main() -> Result<()> {
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
let test_logits = test_images.matmul(&ws)?.broadcast_add(&bs)?;
|
||||
/* TODO: Add argmax so that the following can be computed within candle.
|
||||
let test_accuracy = test_logits
|
||||
.argmax(Some(-1), false)
|
||||
.eq_tensor(&test_labels)
|
||||
.to_kind(Kind::Float)
|
||||
.mean(Kind::Float)
|
||||
.double_value(&[]);
|
||||
*/
|
||||
let test_logits = test_logits.to_vec2::<f32>()?;
|
||||
let sum_ok = test_logits
|
||||
.iter()
|
||||
.zip(test_labels.iter())
|
||||
.map(|(logits, label)| {
|
||||
let arg_max = logits
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, v1), (_, v2)| v1.total_cmp(v2))
|
||||
.map(|(idx, _)| idx);
|
||||
f64::from(arg_max == Some(*label as usize))
|
||||
})
|
||||
.sum::<f64>();
|
||||
let test_accuracy = sum_ok / test_labels.len() as f64;
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.shape().r1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
|
Reference in New Issue
Block a user