Make the nll op closer to the pytorch version + add a test. (#286)

This commit is contained in:
Laurent Mazare
2023-07-31 14:14:01 +01:00
committed by GitHub
parent b3ea96b62b
commit ffeafbfc43
3 changed files with 54 additions and 6 deletions

View File

@ -178,10 +178,7 @@ fn training_loop<M: Model>(
let train_labels = m.train_labels;
let train_images = m.train_images.to_device(&dev)?;
let train_labels = train_labels
.to_dtype(DType::U32)?
.unsqueeze(1)?
.to_device(&dev)?;
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
let mut vs = VarStore::new(DType::F32, dev.clone());
let model = M::new(vs.clone())?;