From 9fea56d28e5f99529da8ed8df1eb508b0f163cc3 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 10 Oct 2023 10:05:41 +0200 Subject: [PATCH] Only optimize float tensors. (#1069) --- candle-core/src/dtype.rs | 14 ++++++++++++++ candle-nn/src/optim.rs | 5 +++++ 2 files changed, 19 insertions(+) diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index c7a1567f..94ca57d8 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -67,6 +67,20 @@ impl DType { Self::F64 => 8, } } + + pub fn is_int(&self) -> bool { + match self { + Self::U8 | Self::U32 | Self::I64 => true, + Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false, + } + } + + pub fn is_float(&self) -> bool { + match self { + Self::U8 | Self::U32 | Self::I64 => false, + Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true, + } + } } pub trait WithDType: diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs index 4294d75e..7704bb48 100644 --- a/candle-nn/src/optim.rs +++ b/candle-nn/src/optim.rs @@ -41,6 +41,10 @@ impl Optimizer for SGD { type Config = f64; fn new(vars: Vec, learning_rate: f64) -> Result { + let vars = vars + .into_iter() + .filter(|var| var.dtype().is_float()) + .collect(); Ok(Self { vars, learning_rate, @@ -116,6 +120,7 @@ impl Optimizer for AdamW { fn new(vars: Vec, params: ParamsAdamW) -> Result { let vars = vars .into_iter() + .filter(|var| var.dtype().is_float()) .map(|var| { let dtype = var.dtype(); let shape = var.shape();