Relax the requirements on CustomOp. (#486)

* Relax the requirements on CustomOp.

* Simplify the custom-ops when no backward is required.
This commit is contained in:
Laurent Mazare
2023-08-17 11:12:05 +01:00
committed by GitHub
parent d32e8199cd
commit 03be33eea4
8 changed files with 81 additions and 31 deletions

View File

@ -178,7 +178,7 @@ pub fn flash_attn(
softmax_scale,
causal,
};
q.custom_op3(k, v, op)
q.apply_op3(k, v, op)
}
struct FlashAttnVarLen {
@ -402,5 +402,5 @@ pub fn flash_attn_varlen(
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
};
q.custom_op3(k, v, op)
q.apply_op3(k, v, op)
}