#!/usr/bin/env python3
#
# make-pci-ids - Creates a file containing PCI IDs.
# It use the databases from
# https://github.com/pciutils/pciids/raw/master/pci.ids
# to create our file epan/dissectors/pci-ids.c
#
# Wireshark - Network traffic analyzer
#
# By Caleb Chiu <caleb.chiu@macnica.com>
# Copyright 2021
#
# SPDX-License-Identifier: GPL-2.0-or-later
#

import string
import sys
import urllib.request, urllib.error, urllib.parse

OUTPUT_FILE = "epan/pci-ids.c"

MIN_VENDOR_COUNT = 2250 # 2261 on 2021-11-01
MIN_DEVICE_COUNT = 33000 # 33724 on 2021-11-01

CODE_PREFIX = """\
 *
 * Generated by tools/make-pci-ids.py
 * By Caleb Chiu <caleb.chiu@macnica.com>
 * Copyright 2021
 *
 *
 * SPDX-License-Identifier: GPL-2.0-or-later
 */

#include <config.h>

#include <stddef.h>
#include <stdlib.h>

#include "wsutil/array.h"

#include "pci-ids.h"

typedef struct
{
  uint16_t vid;
  uint16_t did;
  uint16_t svid;
  uint16_t ssid;
  const char *name;

} pci_id_t;

typedef struct
{
  uint16_t vid;
  uint16_t count;
  pci_id_t const *ids_ptr;

} pci_vid_index_t;

"""

CODE_POSTFIX = """
static int vid_search(const void *key, const void *tbl_entry)
{
    return (int)*(const uint16_t *)key -
           (int)((const pci_vid_index_t *)tbl_entry)->vid;
}

const char *pci_id_str(uint16_t vid, uint16_t did, uint16_t svid, uint16_t ssid)
{
    unsigned int i;
    static const char *not_found = \"Not found\";
    pci_vid_index_t const *index_ptr;
    pci_id_t const *ids_ptr;

    index_ptr = bsearch(&vid, pci_vid_index, array_length(pci_vid_index), sizeof pci_vid_index[0], vid_search);

    if(index_ptr == NULL)
        return not_found;

    ids_ptr = index_ptr->ids_ptr;
    for(i = 0; i < index_ptr->count; ids_ptr++, i++)
        if(vid == ids_ptr->vid &&
           did == ids_ptr->did &&
           svid == ids_ptr->svid &&
           ssid == ids_ptr->ssid)
           return ids_ptr->name;
    return  not_found;

}
"""


id_list=[]
count_list=[]


def exit_msg(msg=None, status=1):
    if msg is not None:
        sys.stderr.write(msg + '\n')
    sys.exit(status)


def main():
    req_headers = { 'User-Agent': 'Wireshark make-pci-ids' }
    req = urllib.request.Request('https://github.com/pciutils/pciids/raw/master/pci.ids', headers=req_headers)
    response = urllib.request.urlopen(req)
    lines = response.read().decode('UTF-8', 'replace').splitlines()

    out_lines = '''\
/* pci-ids.c
 *
 * pci-ids.c is based on the pci.ids of The PCI ID Repository at
 * https://pci-ids.ucw.cz/, fetched indirectly via
 * https://github.com/pciutils/pciids
'''
    vid = -1
    did = -1
    svid = -1
    entries = 0
    line_num = 0

    for line in lines:
        line = line.strip('\n')
        line_num += 1

        if line_num <= 15:
            line = line.replace('#', ' ', 1)
            line = line.lstrip()
            line = line.replace("GNU General Public License", "GPL")
            if line:
                line = ' * ' + line
            else:
                line = ' *' + line
            out_lines += line + '\n'
        if line_num == 15:
            out_lines += CODE_PREFIX

        line = line.replace("\\","\\\\")
        line = line.replace("\"","\\\"")
        line = line.replace("?","?-")
        tabs = len(line) - len(line.lstrip('\t'))
        if tabs == 0:
            #print line
            words = line.split(" ", 1)
            if len(words) < 2:
                continue
            if len(words[0]) != 4:
                continue
            if all(c in string.hexdigits for c in words[0]):
                hex_int = int(words[0], 16)
                if vid != -1:
                    out_lines += "}; /* pci_vid_%04X[] */\n\n" % (vid)
                    count_list.append(entries)
                vid = hex_int
                entries = 1
                did = -1
                svid = -1
                ssid = -1
                out_lines += "static pci_id_t const pci_vid_%04X[] = {\n" % (vid)
                out_lines += "{0x%04X, 0xFFFF, 0xFFFF, 0xFFFF, \"%s(0x%04X)\"},\n" % (vid, words[1].strip(), vid)
                id_list.append(vid)
                continue

        if tabs == 1:
            line = line.strip('\t')
            words = line.split(" ", 1)
            if len(words) < 2:
                continue
            if len(words[0]) != 4:
                continue
            if all(c in string.hexdigits for c in words[0]):
                hex_int = int(words[0], 16)
                did = hex_int
                svid = -1
                ssid = -1
                out_lines += "{0x%04X, 0x%04X, 0xFFFF, 0xFFFF, \"%s(0x%04X)\"},\n" % (vid, did, words[1].strip(), did)
                entries += 1
                continue

        if tabs == 2:
            line = line.strip('\t')
            words = line.split(" ", 2)
            if len(words[0]) != 4:
                continue
            if all(c in string.hexdigits for c in words[0]):
                hex_int = int(words[0], 16)
                svid = hex_int

            if all(c in string.hexdigits for c in words[1]):
                hex_int = int(words[1], 16)
                ssid = hex_int

            out_lines += "{0x%04X, 0x%04X, 0x%04X, 0x%04X, \"%s(0x%04X-0x%04X)\"},\n" % (vid, did, svid, ssid, words[2].strip(), svid, ssid)
            entries += 1
            svid = -1
            ssid = -1
            continue

    out_lines += "}; /* pci_vid_%04X[] */\n" % (vid)
    count_list.append(entries)

    out_lines += "\nstatic pci_vid_index_t const pci_vid_index[] = {\n"

    vendor_count = len(id_list)
    device_count = 0
    for i in range(vendor_count):
        out_lines += "{0x%04X, %d, pci_vid_%04X },\n" % (id_list[i], count_list[i], id_list[i])
        device_count += count_list[i]

    out_lines += "}; /* We have %d VIDs */\n" % (vendor_count)

    out_lines += CODE_POSTFIX

    if vendor_count < MIN_VENDOR_COUNT:
        exit_msg(f'Too few vendors. Wanted {MIN_VENDOR_COUNT}, got {vendor_count}.')

    if device_count < MIN_DEVICE_COUNT:
        exit_msg(f'Too few devices. Wanted {MIN_DEVICE_COUNT}, got {device_count}.')

    with open(OUTPUT_FILE, "w", encoding="utf-8") as pci_ids_f:
        pci_ids_f.write(out_lines)

if __name__ == '__main__':
    main()
