File: tensor-info.py

package info (click to toggle)
llama.cpp 8064%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 76,488 kB
  • sloc: cpp: 353,828; ansic: 51,268; python: 30,090; lisp: 11,788; sh: 6,290; objc: 1,395; javascript: 924; xml: 384; makefile: 233
file content (159 lines) | stat: -rwxr-xr-x 4,668 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/usr/bin/env python3

import argparse
import json
import os
import re
import sys
from pathlib import Path
from typing import Optional
from safetensors import safe_open


MODEL_SAFETENSORS_FILE = "model.safetensors"
MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json"


def get_weight_map(model_path: Path) -> Optional[dict[str, str]]:
    index_file = model_path / MODEL_SAFETENSORS_INDEX

    if index_file.exists():
        with open(index_file, 'r') as f:
            index = json.load(f)
            return index.get("weight_map", {})

    return None


def get_all_tensor_names(model_path: Path) -> list[str]:
    weight_map = get_weight_map(model_path)

    if weight_map is not None:
        return list(weight_map.keys())

    single_file = model_path / MODEL_SAFETENSORS_FILE
    if single_file.exists():
        try:
            with safe_open(single_file, framework="pt", device="cpu") as f:
                return list(f.keys())
        except Exception as e:
            print(f"Error reading {single_file}: {e}")
            sys.exit(1)

    print(f"Error: No safetensors files found in {model_path}")
    sys.exit(1)


def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]:
    weight_map = get_weight_map(model_path)

    if weight_map is not None:
        return weight_map.get(tensor_name)

    single_file = model_path / MODEL_SAFETENSORS_FILE
    if single_file.exists():
        return single_file.name

    return None


def normalize_tensor_name(tensor_name: str) -> str:
    normalized = re.sub(r'\.\d+\.', '.#.', tensor_name)
    normalized = re.sub(r'\.\d+$', '.#', normalized)
    return normalized


def list_all_tensors(model_path: Path, unique: bool = False):
    tensor_names = get_all_tensor_names(model_path)

    if unique:
        seen = set()
        for tensor_name in sorted(tensor_names):
            normalized = normalize_tensor_name(tensor_name)
            if normalized not in seen:
                seen.add(normalized)
                print(normalized)
    else:
        for tensor_name in sorted(tensor_names):
            print(tensor_name)


def print_tensor_info(model_path: Path, tensor_name: str):
    tensor_file = find_tensor_file(model_path, tensor_name)

    if tensor_file is None:
        print(f"Error: Could not find tensor '{tensor_name}' in model index")
        print(f"Model path: {model_path}")
        sys.exit(1)

    file_path = model_path / tensor_file

    try:
        with safe_open(file_path, framework="pt", device="cpu") as f:
            if tensor_name in f.keys():
                tensor_slice = f.get_slice(tensor_name)
                shape = tensor_slice.get_shape()
                print(f"Tensor: {tensor_name}")
                print(f"File:   {tensor_file}")
                print(f"Shape:  {shape}")
            else:
                print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
                sys.exit(1)

    except FileNotFoundError:
        print(f"Error: The file '{file_path}' was not found.")
        sys.exit(1)
    except Exception as e:
        print(f"An error occurred: {e}")
        sys.exit(1)


def main():
    parser = argparse.ArgumentParser(
        description="Print tensor information from a safetensors model"
    )
    parser.add_argument(
        "tensor_name",
        nargs="?",  # optional (if --list is used for example)
        help="Name of the tensor to inspect"
    )
    parser.add_argument(
        "-m", "--model-path",
        type=Path,
        help="Path to the model directory (default: MODEL_PATH environment variable)"
    )
    parser.add_argument(
        "-l", "--list",
        action="store_true",
        help="List unique tensor patterns in the model (layer numbers replaced with #)"
    )

    args = parser.parse_args()

    model_path = args.model_path
    if model_path is None:
        model_path_str = os.environ.get("MODEL_PATH")
        if model_path_str is None:
            print("Error: --model-path not provided and MODEL_PATH environment variable not set")
            sys.exit(1)
        model_path = Path(model_path_str)

    if not model_path.exists():
        print(f"Error: Model path does not exist: {model_path}")
        sys.exit(1)

    if not model_path.is_dir():
        print(f"Error: Model path is not a directory: {model_path}")
        sys.exit(1)

    if args.list:
        list_all_tensors(model_path, unique=True)
    else:
        if args.tensor_name is None:
            print("Error: tensor_name is required when not using --list")
            sys.exit(1)
        print_tensor_info(model_path, args.tensor_name)


if __name__ == "__main__":
    main()