mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Move the llama2-c model in transformers. (#1205)
This commit is contained in:
@ -6,10 +6,10 @@ extern crate accelerate_src;
|
|||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
mod model;
|
use candle_transformers::models::llama2_c as model;
|
||||||
mod qmodel;
|
use candle_transformers::models::llama2_c_weights as weights;
|
||||||
|
use candle_transformers::models::quantized_llama2_c as qmodel;
|
||||||
mod training;
|
mod training;
|
||||||
mod weights;
|
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
|
@ -11,6 +11,7 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
|
byteorder = { workspace = true }
|
||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
use candle::{DType, Device, IndexOp, Shape, Tensor};
|
use candle::{DType, Device, IndexOp, Result, Shape, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
use crate::model::Config;
|
use super::llama2_c::Config;
|
||||||
|
|
||||||
pub struct TransformerWeights {
|
pub struct TransformerWeights {
|
||||||
// token embedding table
|
// token embedding table
|
@ -8,6 +8,8 @@ pub mod efficientnet;
|
|||||||
pub mod falcon;
|
pub mod falcon;
|
||||||
pub mod jina_bert;
|
pub mod jina_bert;
|
||||||
pub mod llama;
|
pub mod llama;
|
||||||
|
pub mod llama2_c;
|
||||||
|
pub mod llama2_c_weights;
|
||||||
pub mod mistral;
|
pub mod mistral;
|
||||||
pub mod mixformer;
|
pub mod mixformer;
|
||||||
pub mod mpt;
|
pub mod mpt;
|
||||||
@ -15,6 +17,7 @@ pub mod persimmon;
|
|||||||
pub mod quantized_blip;
|
pub mod quantized_blip;
|
||||||
pub mod quantized_blip_text;
|
pub mod quantized_blip_text;
|
||||||
pub mod quantized_llama;
|
pub mod quantized_llama;
|
||||||
|
pub mod quantized_llama2_c;
|
||||||
pub mod quantized_mistral;
|
pub mod quantized_mistral;
|
||||||
pub mod quantized_mixformer;
|
pub mod quantized_mixformer;
|
||||||
pub mod quantized_mpt;
|
pub mod quantized_mpt;
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use super::model::{Cache, Config};
|
use super::llama2_c::{Cache, Config};
|
||||||
|
use crate::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm};
|
||||||
|
pub use crate::quantized_var_builder::VarBuilder;
|
||||||
use candle::{DType, IndexOp, Module, Result, Tensor, D};
|
use candle::{DType, IndexOp, Module, Result, Tensor, D};
|
||||||
use candle_transformers::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm};
|
|
||||||
pub use candle_transformers::quantized_var_builder::VarBuilder;
|
|
||||||
|
|
||||||
fn silu(xs: &Tensor) -> Result<Tensor> {
|
fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||||
xs / (xs.neg()?.exp()? + 1.0)?
|
xs / (xs.neg()?.exp()? + 1.0)?
|
Reference in New Issue
Block a user