# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import pytest

import awkward as ak

ROOT = pytest.importorskip("ROOT")


compiler = ROOT.gInterpreter.Declare


def test_data_frame_filter():
    ROOT.ROOT.EnableImplicitMT()

    array_x = ak.Array(
        [
            {"x": [1.1, 1.2, 1.3]},
            {"x": [2.1, 2.2]},
            {"x": [3.1]},
            {"x": [4.1, 4.2, 4.3, 4.4]},
            {"x": [5.1]},
            {"x": [6.1, 6.2, 6.3]},
            {"x": [7.1, 7.2]},
            {"x": [8.1]},
            {"x": [9.1, 9.2, 9.3, 9.4]},
            {"x": [10.1]},
            {"x": [11.1, 11.2, 11.3]},
            {"x": [12.1, 12.2]},
            {"x": [13.1]},
            {"x": [14.1, 14.2, 14.3, 14.4]},
            {"x": [15.1]},
            {"x": [16.1, 16.2, 16.3]},
            {"x": [17.1, 17.2]},
            {"x": [18.1]},
            {"x": [19.1, 19.2, 19.3, 19.4]},
            {"x": [10.1]},
            {"x": [21.1, 21.2, 21.3]},
            {"x": [22.1, 22.2]},
            {"x": [23.1]},
            {"x": [24.1, 24.2, 24.3, 24.4]},
            {"x": [25.1]},
            {"x": [26.1, 26.2, 26.3]},
            {"x": [27.1, 27.2]},
            {"x": [28.1]},
            {"x": [29.1, 29.2, 29.3, 29.4]},
            {"x": [30.1]},
        ]
    )
    array_y = ak.Array(
        [
            1,
            2,
            3,
            4,
            5,
            6,
            7,
            8,
            9,
            10,
            11,
            12,
            13,
            14,
            15,
            16,
            17,
            18,
            19,
            20,
            21,
            22,
            23,
            24,
            25,
            26,
            27,
            28,
            29,
            30,
        ]
    )
    array_z = ak.Array(
        [
            [1.1],
            [2.1, 2.3, 2.4],
            [3.1],
            [4.1, 4.2, 4.3],
            [5.1],
            [6.1],
            [7.1, 7.3, 7.4],
            [8.1],
            [9.1, 9.2, 9.3],
            [10.1],
            [11.1],
            [12.1, 12.3, 12.4],
            [13.1],
            [14.1, 14.2, 14.3],
            [15.1],
            [16.1],
            [17.1, 17.3, 17.4],
            [18.1],
            [19.1, 19.2, 19.3],
            [20.1],
            [21.1],
            [22.1, 22.3, 22.4],
            [23.1],
            [24.1, 24.2, 24.3],
            [25.1],
            [26.1],
            [27.1, 27.3, 27.4],
            [28.1],
            [29.1, 29.2, 29.3],
            [30.1],
        ]
    )

    df = ak.to_rdataframe({"x": array_x, "y": array_y, "z": array_z})

    assert str(df.GetColumnType("x")).startswith("awkward::Record_")
    assert df.GetColumnType("y") == "int64_t"
    assert df.GetColumnType("z") == "ROOT::VecOps::RVec<double>"

    df = df.Filter("y % 2 == 0")

    out = ak.from_rdataframe(
        df,
        columns=(
            "x",
            "y",
            "z",
        ),
        keep_order=True,
    )

    assert out["x"].tolist() == [
        {"x": [2.1, 2.2]},
        {"x": [4.1, 4.2, 4.3, 4.4]},
        {"x": [6.1, 6.2, 6.3]},
        {"x": [8.1]},
        {"x": [10.1]},
        {"x": [12.1, 12.2]},
        {"x": [14.1, 14.2, 14.3, 14.4]},
        {"x": [16.1, 16.2, 16.3]},
        {"x": [18.1]},
        {"x": [10.1]},
        {"x": [22.1, 22.2]},
        {"x": [24.1, 24.2, 24.3, 24.4]},
        {"x": [26.1, 26.2, 26.3]},
        {"x": [28.1]},
        {"x": [30.1]},
    ]
    assert out["y"].tolist() == [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]
    assert out["z"].tolist() == [
        [2.1, 2.3, 2.4],
        [4.1, 4.2, 4.3],
        [6.1],
        [8.1],
        [10.1],
        [12.1, 12.3, 12.4],
        [14.1, 14.2, 14.3],
        [16.1],
        [18.1],
        [20.1],
        [22.1, 22.3, 22.4],
        [24.1, 24.2, 24.3],
        [26.1],
        [28.1],
        [30.1],
    ]

    assert len(out) == 15
