# SPDX-FileCopyrightText: Copyright (c) 2021-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

EVENT_TYPE_NVTX_DOMAIN_CREATE = 75
EVENT_TYPE_NVTX_PUSHPOP_RANGE = 59
EVENT_TYPE_NVTX_STARTEND_RANGE = 60
EVENT_TYPE_NVTXT_PUSHPOP_RANGE = 70
EVENT_TYPE_NVTXT_STARTEND_RANGE = 71

CREATE_RUNTIME_RIDX_STATEMENTS = [
f"""
DROP TABLE IF EXISTS temp.NVTX_EVENTS_MINMAXTS
""",

f"""
CREATE TEMP TABLE NVTX_EVENTS_MINMAXTS
AS SELECT
    min(min(start), min(end)) AS min,
    max(max(start), max(end)) AS max
FROM main.NVTX_EVENTS
WHERE
       eventType == {EVENT_TYPE_NVTX_PUSHPOP_RANGE}
    OR eventType == {EVENT_TYPE_NVTX_STARTEND_RANGE}
    OR eventType == {EVENT_TYPE_NVTXT_PUSHPOP_RANGE}
    OR eventType == {EVENT_TYPE_NVTXT_STARTEND_RANGE}
""",

f"""
DROP TABLE IF EXISTS temp.NVTX_EVENTS_RIDX
""",

f"""
CREATE VIRTUAL TABLE temp.NVTX_EVENTS_RIDX
USING rtree (
    rangeId,
    startTS,
    endTS,
    +startNS  INTEGER,
    +endNS    INTEGER,
    +tid      INTEGER,
    +name     TEXT,
)
""",

f"""
INSERT INTO temp.NVTX_EVENTS_RIDX
    SELECT
        e.rowid AS rangeId,
        rtree_scale(e.start,
            (SELECT min FROM temp.NVTX_EVENTS_MINMAXTS),
            (SELECT max FROM temp.NVTX_EVENTS_MINMAXTS)) AS startTS,
        rtree_scale(ifnull(e.end, (SELECT max FROM temp.NVTX_EVENTS_MINMAXTS)),
            (SELECT min FROM temp.NVTX_EVENTS_MINMAXTS),
            (SELECT max FROM temp.NVTX_EVENTS_MINMAXTS)) AS endTS,
        e.start AS startNS,
        ifnull(e.end, (SELECT max FROM temp.NVTX_EVENTS_MINMAXTS)) AS endNS,
        e.globalTid AS tid,
        COALESCE(sid.value, e.text) AS name
    FROM
        main.NVTX_EVENTS AS e
    LEFT JOIN
        StringIds AS sid
        ON e.textId == sid.id
    WHERE
          (e.eventType == {EVENT_TYPE_NVTX_PUSHPOP_RANGE}
        OR e.eventType == {EVENT_TYPE_NVTX_STARTEND_RANGE}
        OR e.eventType == {EVENT_TYPE_NVTXT_PUSHPOP_RANGE}
        OR e.eventType == {EVENT_TYPE_NVTXT_STARTEND_RANGE})
        AND e.endGlobalTid IS NULL
""",

    ]

QUERY_NVTX_KERNEL_NAME = """
WITH
    kernel AS (
        {KERNEL_TABLE}
    ),
    projection AS (
        SELECT
            kernel.rowid,
            rt.name AS nvtxName,
            max(rt.startNS) AS maxStart
        FROM
            kernel
        LEFT JOIN
            main.CUPTI_ACTIVITY_KIND_RUNTIME AS r
            ON      kernel.correlationId == r.correlationId
                AND kernel.globalPid == (r.globalTid & 0xFFFFFFFFFF000000)
        LEFT JOIN
            temp.NVTX_EVENTS_RIDX AS rt
            ON      rt.startTS <= rtree_scale(r.start,
                        (SELECT min FROM temp.NVTX_EVENTS_MINMAXTS),
                        (SELECT max FROM temp.NVTX_EVENTS_MINMAXTS))
                AND rt.endTS >= rtree_scale(r.end,
                        (SELECT min FROM temp.NVTX_EVENTS_MINMAXTS),
                        (SELECT max FROM temp.NVTX_EVENTS_MINMAXTS))
                AND rt.startNS <= r.start
                AND rt.endNS >= r.end
                AND rt.tid == r.globalTid
        GROUP BY kernel.rowid
    )
SELECT
    kernel.*,
    COALESCE(nvtxName || '/' || kernelName, kernelName) AS name
FROM
    kernel
JOIN
    projection
    ON      kernel.rowid == projection.rowid
"""

QUERY_KERNEL_NAME = """
WITH
    kernel AS (
        {KERNEL_TABLE}
    )
SELECT
    *,
    kernelName as name
FROM
    kernel
"""

QUERY_KERNEL = """
SELECT
    kernel.rowid,
    kernel.*,
    sid.value AS kernelName
FROM
    CUPTI_ACTIVITY_KIND_KERNEL as kernel
LEFT JOIN
    StringIds AS sid
    ON sid.id == coalesce(kernel.{NAME_COL_NAME}, kernel.demangledName)
"""

# Create a temporary view named 'CUPTI_ACTIVITY_KIND_KERNEL_NAMED' by adding
# a new column 'name' to the 'CUPTI_ACTIVITY_KIND_KERNEL' table. This column
# gives the kernel string name, which can be either the base or mangled name,
# optionally prefixed by the corresponding NVTX range name based on the
# provided options.
def create_kernel_view(instance):
    use_base = getattr(instance._parsed_args, 'base', False)
    use_mangled = getattr(instance._parsed_args, 'mangled', False)
    use_nvtx_name = getattr(instance._parsed_args, 'nvtx_name', False)

    name_col_name = 'demangledName'
    if use_base:
        name_col_name = 'shortName'
    elif use_mangled and instance.table_col_exists('CUPTI_ACTIVITY_KIND_KERNEL', 'mangledName'):
        name_col_name = 'mangledName'

    kernel_query = QUERY_KERNEL.format(NAME_COL_NAME = name_col_name)
    if use_nvtx_name:
        if not instance.table_exists('NVTX_EVENTS'):
            return "{DBFILE} does not contain NV Tools Extension (NVTX) data."
        if not instance.table_exists('CUPTI_ACTIVITY_KIND_RUNTIME'):
            return "{DBFILE} does not contain CUDA API data."

        for stmt in CREATE_RUNTIME_RIDX_STATEMENTS:
            errmsg = instance._execute_statement(stmt)
            if errmsg != None:
                return errmsg

        kernel_name_query = QUERY_NVTX_KERNEL_NAME.format(KERNEL_TABLE = kernel_query)
    else:
        kernel_name_query = QUERY_KERNEL_NAME.format(KERNEL_TABLE = kernel_query)

    errmsg = instance._execute_statement(
        'DROP VIEW IF EXISTS temp.CUPTI_ACTIVITY_KIND_KERNEL_NAMED'
    )
    if errmsg != None:
        return errmsg

    errmsg = instance._execute_statement(
        'CREATE TEMP VIEW CUPTI_ACTIVITY_KIND_KERNEL_NAMED AS {QUERY}'.format(
            QUERY = kernel_name_query
        )
    )
    if errmsg != None:
        return errmsg
