Load the weights for llama.

This commit is contained in:
laurent
2023-06-26 07:23:59 +01:00
parent 7a3101f15f
commit d867155ef2
3 changed files with 29 additions and 9 deletions

View File

@ -8,7 +8,7 @@ from pathlib import Path
def tr(v):
return np.ascontiguousarray(np.transpose(v))
def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float16) -> Dict[str, torch.Tensor]:
def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]:
print("start conv")
def get_and_remove(key, transpose=False):
@ -53,7 +53,7 @@ def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype =
converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight")
return converted
def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float16") -> None:
def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float32") -> None:
dt = getattr(torch, dtype, None)
if not isinstance(dt, torch.dtype):
raise ValueError(f"{dtype} is not a valid dtype.")