mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Make the nll op closer to the pytorch version + add a test. (#286)
This commit is contained in:
@ -178,10 +178,7 @@ fn training_loop<M: Model>(
|
||||
|
||||
let train_labels = m.train_labels;
|
||||
let train_images = m.train_images.to_device(&dev)?;
|
||||
let train_labels = train_labels
|
||||
.to_dtype(DType::U32)?
|
||||
.unsqueeze(1)?
|
||||
.to_device(&dev)?;
|
||||
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
|
||||
|
||||
let mut vs = VarStore::new(DType::F32, dev.clone());
|
||||
let model = M::new(vs.clone())?;
|
||||
|
Reference in New Issue
Block a user