mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Only optimize float tensors. (#1069)
This commit is contained in:
@ -67,6 +67,20 @@ impl DType {
|
||||
Self::F64 => 8,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_int(&self) -> bool {
|
||||
match self {
|
||||
Self::U8 | Self::U32 | Self::I64 => true,
|
||||
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_float(&self) -> bool {
|
||||
match self {
|
||||
Self::U8 | Self::U32 | Self::I64 => false,
|
||||
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WithDType:
|
||||
|
@ -41,6 +41,10 @@ impl Optimizer for SGD {
|
||||
type Config = f64;
|
||||
|
||||
fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
|
||||
let vars = vars
|
||||
.into_iter()
|
||||
.filter(|var| var.dtype().is_float())
|
||||
.collect();
|
||||
Ok(Self {
|
||||
vars,
|
||||
learning_rate,
|
||||
@ -116,6 +120,7 @@ impl Optimizer for AdamW {
|
||||
fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
|
||||
let vars = vars
|
||||
.into_iter()
|
||||
.filter(|var| var.dtype().is_float())
|
||||
.map(|var| {
|
||||
let dtype = var.dtype();
|
||||
let shape = var.shape();
|
||||
|
Reference in New Issue
Block a user