File: rust_enum.py

package info (click to toggle)
lsprotocol 2025.0.0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,408 kB
  • sloc: python: 7,567; cs: 1,225; sh: 15; makefile: 4
file content (84 lines) | stat: -rw-r--r-- 2,582 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from typing import List, Union

import generator.model as model

from .rust_commons import TypeData, generate_extras
from .rust_lang_utils import indent_lines, lines_to_doc_comments, to_upper_camel_case


def _get_enum_docs(enum: Union[model.Enum, model.EnumItem]) -> List[str]:
    doc = enum.documentation.splitlines(keepends=False) if enum.documentation else []
    return lines_to_doc_comments(doc)


def generate_serde(enum: model.Enum) -> List[str]:
    ser = [
        f"impl Serialize for {enum.name} {{",
        "fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: serde::Serializer,{",
        "match self {",
    ]

    de = [
        f"impl<'de> Deserialize<'de> for {enum.name} {{",
        f"fn deserialize<D>(deserializer: D) -> Result<{enum.name}, D::Error> where D: serde::Deserializer<'de>,"
        "{",
        "let value = i32::deserialize(deserializer)?;",
        "match value {",
    ]
    for item in enum.values:
        full_name = f"{enum.name}::{to_upper_camel_case(item.name)}"
        ser += [f"{full_name} => serializer.serialize_i32({item.value}),"]
        de += [f"{item.value} => Ok({full_name}),"]
    ser += [
        "}",  # match
        "}",  # fn
        "}",  # impl
    ]
    de += [
        '_ => Err(serde::de::Error::custom("Unexpected value"))',
        "}",  # match
        "}",  # fn
        "}",  # impl
    ]
    return ser + de


def generate_enum(enum: model.Enum, types: TypeData) -> None:
    is_int = all(isinstance(item.value, int) for item in enum.values)

    lines = _get_enum_docs(enum) + generate_extras(enum)
    if is_int:
        lines += ["#[derive(PartialEq, Debug, Eq, Clone)]"]
    else:
        lines += ["#[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Clone)]"]
    lines += [f"pub enum {enum.name} {{"]

    for item in enum.values:
        if is_int:
            field = [
                f"{to_upper_camel_case(item.name)} = {item.value},",
            ]
        else:
            field = [
                f'#[serde(rename = "{item.value}")]',
                f"{to_upper_camel_case(item.name)},",
            ]

        lines += indent_lines(
            _get_enum_docs(item) + generate_extras(item) + field + [""]
        )

    lines += ["}"]

    if is_int:
        lines += generate_serde(enum)

    types.add_type_info(enum, enum.name, lines)


def generate_enums(enums: List[model.Enum], types: TypeData) -> None:
    for enum in enums:
        generate_enum(enum, types)