mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Load the weights for llama.
This commit is contained in:
@ -8,7 +8,7 @@ from pathlib import Path
|
||||
def tr(v):
|
||||
return np.ascontiguousarray(np.transpose(v))
|
||||
|
||||
def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float16) -> Dict[str, torch.Tensor]:
|
||||
def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]:
|
||||
print("start conv")
|
||||
|
||||
def get_and_remove(key, transpose=False):
|
||||
@ -53,7 +53,7 @@ def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype =
|
||||
converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight")
|
||||
return converted
|
||||
|
||||
def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float16") -> None:
|
||||
def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float32") -> None:
|
||||
dt = getattr(torch, dtype, None)
|
||||
if not isinstance(dt, torch.dtype):
|
||||
raise ValueError(f"{dtype} is not a valid dtype.")
|
||||
|
Reference in New Issue
Block a user