mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Only optimize float tensors. (#1069)
This commit is contained in:
@ -67,6 +67,20 @@ impl DType {
|
|||||||
Self::F64 => 8,
|
Self::F64 => 8,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_int(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::U8 | Self::U32 | Self::I64 => true,
|
||||||
|
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_float(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::U8 | Self::U32 | Self::I64 => false,
|
||||||
|
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait WithDType:
|
pub trait WithDType:
|
||||||
|
@ -41,6 +41,10 @@ impl Optimizer for SGD {
|
|||||||
type Config = f64;
|
type Config = f64;
|
||||||
|
|
||||||
fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
|
fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
|
||||||
|
let vars = vars
|
||||||
|
.into_iter()
|
||||||
|
.filter(|var| var.dtype().is_float())
|
||||||
|
.collect();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
vars,
|
vars,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
@ -116,6 +120,7 @@ impl Optimizer for AdamW {
|
|||||||
fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
|
fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
|
||||||
let vars = vars
|
let vars = vars
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
.filter(|var| var.dtype().is_float())
|
||||||
.map(|var| {
|
.map(|var| {
|
||||||
let dtype = var.dtype();
|
let dtype = var.dtype();
|
||||||
let shape = var.shape();
|
let shape = var.shape();
|
||||||
|
Reference in New Issue
Block a user