mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Only optimize float tensors. (#1069)
This commit is contained in:
@ -41,6 +41,10 @@ impl Optimizer for SGD {
|
||||
type Config = f64;
|
||||
|
||||
fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
|
||||
let vars = vars
|
||||
.into_iter()
|
||||
.filter(|var| var.dtype().is_float())
|
||||
.collect();
|
||||
Ok(Self {
|
||||
vars,
|
||||
learning_rate,
|
||||
@ -116,6 +120,7 @@ impl Optimizer for AdamW {
|
||||
fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
|
||||
let vars = vars
|
||||
.into_iter()
|
||||
.filter(|var| var.dtype().is_float())
|
||||
.map(|var| {
|
||||
let dtype = var.dtype();
|
||||
let shape = var.shape();
|
||||
|
Reference in New Issue
Block a user