File: __main__.py

package info (click to toggle)
python-datamodel-code-generator 0.26.4-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 712 kB
  • sloc: python: 9,525; makefile: 14
file content (568 lines) | stat: -rw-r--r-- 21,333 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
#! /usr/bin/env python

"""
Main function.
"""

from __future__ import annotations

import json
import signal
import sys
import warnings
from collections import defaultdict
from enum import IntEnum
from io import TextIOBase
from pathlib import Path
from typing import (
    TYPE_CHECKING,
    Any,
    DefaultDict,
    Dict,
    List,
    Optional,
    Sequence,
    Set,
    Tuple,
    Union,
    cast,
)
from urllib.parse import ParseResult, urlparse

import argcomplete
import black
from pydantic import BaseModel

from datamodel_code_generator.model.pydantic_v2 import UnionMode

if TYPE_CHECKING:
    from argparse import Namespace

    from typing_extensions import Self

from datamodel_code_generator import (
    DataModelType,
    Error,
    InputFileType,
    InvalidClassNameError,
    OpenAPIScope,
    enable_debug_message,
    generate,
)
from datamodel_code_generator.arguments import DEFAULT_ENCODING, arg_parser, namespace
from datamodel_code_generator.format import (
    DatetimeClassType,
    PythonVersion,
    black_find_project_root,
    is_supported_in_black,
)
from datamodel_code_generator.parser import LiteralType
from datamodel_code_generator.reference import is_url
from datamodel_code_generator.types import StrictTypes
from datamodel_code_generator.util import (
    PYDANTIC_V2,
    ConfigDict,
    Model,
    field_validator,
    load_toml,
    model_validator,
)


class Exit(IntEnum):
    """Exit reasons."""

    OK = 0
    ERROR = 1
    KeyboardInterrupt = 2


def sig_int_handler(_: int, __: Any) -> None:  # pragma: no cover
    exit(Exit.OK)


signal.signal(signal.SIGINT, sig_int_handler)


class Config(BaseModel):
    if PYDANTIC_V2:
        model_config = ConfigDict(arbitrary_types_allowed=True)

        def get(self, item: str) -> Any:
            return getattr(self, item)

        def __getitem__(self, item: str) -> Any:
            return self.get(item)

        if TYPE_CHECKING:

            @classmethod
            def get_fields(cls) -> Dict[str, Any]: ...

        else:

            @classmethod
            def parse_obj(cls: type[Model], obj: Any) -> Model:
                return cls.model_validate(obj)

            @classmethod
            def get_fields(cls) -> Dict[str, Any]:
                return cls.model_fields

    else:

        class Config:
            # validate_assignment = True
            # Pydantic 1.5.1 doesn't support validate_assignment correctly
            arbitrary_types_allowed = (TextIOBase,)

        if not TYPE_CHECKING:

            @classmethod
            def get_fields(cls) -> Dict[str, Any]:
                return cls.__fields__

    @field_validator(
        'aliases', 'extra_template_data', 'custom_formatters_kwargs', mode='before'
    )
    def validate_file(cls, value: Any) -> Optional[TextIOBase]:
        if value is None or isinstance(value, TextIOBase):
            return value
        return cast(TextIOBase, Path(value).expanduser().resolve().open('rt'))

    @field_validator(
        'input',
        'output',
        'custom_template_dir',
        'custom_file_header_path',
        mode='before',
    )
    def validate_path(cls, value: Any) -> Optional[Path]:
        if value is None or isinstance(value, Path):
            return value  # pragma: no cover
        return Path(value).expanduser().resolve()

    @field_validator('url', mode='before')
    def validate_url(cls, value: Any) -> Optional[ParseResult]:
        if isinstance(value, str) and is_url(value):  # pragma: no cover
            return urlparse(value)
        elif value is None:  # pragma: no cover
            return None
        raise Error(
            f"This protocol doesn't support only http/https. --input={value}"
        )  # pragma: no cover

    @model_validator(mode='after')
    def validate_use_generic_container_types(
        cls, values: Dict[str, Any]
    ) -> Dict[str, Any]:
        if values.get('use_generic_container_types'):
            target_python_version: PythonVersion = values['target_python_version']
            if target_python_version == target_python_version.PY_36:
                raise Error(
                    f'`--use-generic-container-types` can not be used with `--target-python-version` {target_python_version.PY_36.value}.\n'
                    ' The version will be not supported in a future version'
                )
        return values

    @model_validator(mode='after')
    def validate_original_field_name_delimiter(
        cls, values: Dict[str, Any]
    ) -> Dict[str, Any]:
        if values.get('original_field_name_delimiter') is not None:
            if not values.get('snake_case_field'):
                raise Error(
                    '`--original-field-name-delimiter` can not be used without `--snake-case-field`.'
                )
        return values

    @model_validator(mode='after')
    def validate_custom_file_header(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        if values.get('custom_file_header') and values.get('custom_file_header_path'):
            raise Error(
                '`--custom_file_header_path` can not be used with `--custom_file_header`.'
            )  # pragma: no cover
        return values

    @model_validator(mode='after')
    def validate_keyword_only(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        output_model_type: DataModelType = values.get('output_model_type')
        python_target: PythonVersion = values.get('target_python_version')
        if (
            values.get('keyword_only')
            and output_model_type == DataModelType.DataclassesDataclass
            and not python_target.has_kw_only_dataclass
        ):
            raise Error(
                f'`--keyword-only` requires `--target-python-version` {PythonVersion.PY_310.value} or higher.'
            )
        return values

    @model_validator(mode='after')
    def validate_output_datetime_class(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        datetime_class_type: Optional[DatetimeClassType] = values.get(
            'output_datetime_class'
        )
        if (
            datetime_class_type
            and datetime_class_type is not DatetimeClassType.Datetime
            and values.get('output_model_type') == DataModelType.DataclassesDataclass
        ):
            raise Error(
                '`--output-datetime-class` only allows "datetime" for '
                f'`--output-model-type` {DataModelType.DataclassesDataclass.value}'
            )
        return values

    # Pydantic 1.5.1 doesn't support each_item=True correctly
    @field_validator('http_headers', mode='before')
    def validate_http_headers(cls, value: Any) -> Optional[List[Tuple[str, str]]]:
        def validate_each_item(each_item: Any) -> Tuple[str, str]:
            if isinstance(each_item, str):  # pragma: no cover
                try:
                    field_name, field_value = each_item.split(':', maxsplit=1)  # type: str, str
                    return field_name, field_value.lstrip()
                except ValueError:
                    raise Error(f'Invalid http header: {each_item!r}')
            return each_item  # pragma: no cover

        if isinstance(value, list):
            return [validate_each_item(each_item) for each_item in value]
        return value  # pragma: no cover

    @field_validator('http_query_parameters', mode='before')
    def validate_http_query_parameters(
        cls, value: Any
    ) -> Optional[List[Tuple[str, str]]]:
        def validate_each_item(each_item: Any) -> Tuple[str, str]:
            if isinstance(each_item, str):  # pragma: no cover
                try:
                    field_name, field_value = each_item.split('=', maxsplit=1)  # type: str, str
                    return field_name, field_value.lstrip()
                except ValueError:
                    raise Error(f'Invalid http query parameter: {each_item!r}')
            return each_item  # pragma: no cover

        if isinstance(value, list):
            return [validate_each_item(each_item) for each_item in value]
        return value  # pragma: no cover

    @model_validator(mode='before')
    def validate_additional_imports(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        if values.get('additional_imports') is not None:
            values['additional_imports'] = values.get('additional_imports').split(',')
        return values

    @model_validator(mode='before')
    def validate_custom_formatters(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        if values.get('custom_formatters') is not None:
            values['custom_formatters'] = values.get('custom_formatters').split(',')
        return values

    if PYDANTIC_V2:

        @model_validator(mode='after')  # type: ignore
        def validate_root(self: Self) -> Self:
            if self.use_annotated:
                self.field_constraints = True
            return self

    else:

        @model_validator(mode='after')
        def validate_root(cls, values: Any) -> Any:
            if values.get('use_annotated'):
                values['field_constraints'] = True
            return values

    input: Optional[Union[Path, str]] = None
    input_file_type: InputFileType = InputFileType.Auto
    output_model_type: DataModelType = DataModelType.PydanticBaseModel
    output: Optional[Path] = None
    debug: bool = False
    disable_warnings: bool = False
    target_python_version: PythonVersion = PythonVersion.PY_38
    base_class: str = ''
    additional_imports: Optional[List[str]] = (None,)
    custom_template_dir: Optional[Path] = None
    extra_template_data: Optional[TextIOBase] = None
    validation: bool = False
    field_constraints: bool = False
    snake_case_field: bool = False
    strip_default_none: bool = False
    aliases: Optional[TextIOBase] = None
    disable_timestamp: bool = False
    enable_version_header: bool = False
    allow_population_by_field_name: bool = False
    allow_extra_fields: bool = False
    use_default: bool = False
    force_optional: bool = False
    class_name: Optional[str] = None
    use_standard_collections: bool = False
    use_schema_description: bool = False
    use_field_description: bool = False
    use_default_kwarg: bool = False
    reuse_model: bool = False
    encoding: str = DEFAULT_ENCODING
    enum_field_as_literal: Optional[LiteralType] = None
    use_one_literal_as_default: bool = False
    set_default_enum_member: bool = False
    use_subclass_enum: bool = False
    strict_nullable: bool = False
    use_generic_container_types: bool = False
    use_union_operator: bool = False
    enable_faux_immutability: bool = False
    url: Optional[ParseResult] = None
    disable_appending_item_suffix: bool = False
    strict_types: List[StrictTypes] = []
    empty_enum_field_name: Optional[str] = None
    field_extra_keys: Optional[Set[str]] = None
    field_include_all_keys: bool = False
    field_extra_keys_without_x_prefix: Optional[Set[str]] = None
    openapi_scopes: Optional[List[OpenAPIScope]] = [OpenAPIScope.Schemas]
    wrap_string_literal: Optional[bool] = None
    use_title_as_name: bool = False
    use_operation_id_as_name: bool = False
    use_unique_items_as_set: bool = False
    http_headers: Optional[Sequence[Tuple[str, str]]] = None
    http_ignore_tls: bool = False
    use_annotated: bool = False
    use_non_positive_negative_number_constrained_types: bool = False
    original_field_name_delimiter: Optional[str] = None
    use_double_quotes: bool = False
    collapse_root_models: bool = False
    special_field_name_prefix: Optional[str] = None
    remove_special_field_name_prefix: bool = False
    capitalise_enum_members: bool = False
    keep_model_order: bool = False
    custom_file_header: Optional[str] = None
    custom_file_header_path: Optional[Path] = None
    custom_formatters: Optional[List[str]] = None
    custom_formatters_kwargs: Optional[TextIOBase] = None
    use_pendulum: bool = False
    http_query_parameters: Optional[Sequence[Tuple[str, str]]] = None
    treat_dot_as_module: bool = False
    use_exact_imports: bool = False
    union_mode: Optional[UnionMode] = None
    output_datetime_class: Optional[DatetimeClassType] = None
    keyword_only: bool = False
    no_alias: bool = False

    def merge_args(self, args: Namespace) -> None:
        set_args = {
            f: getattr(args, f)
            for f in self.get_fields()
            if getattr(args, f) is not None
        }

        if set_args.get('output_model_type') == DataModelType.MsgspecStruct.value:
            set_args['use_annotated'] = True

        if set_args.get('use_annotated'):
            set_args['field_constraints'] = True

        parsed_args = Config.parse_obj(set_args)
        for field_name in set_args:
            setattr(self, field_name, getattr(parsed_args, field_name))


def main(args: Optional[Sequence[str]] = None) -> Exit:
    """Main function."""

    # add cli completion support
    argcomplete.autocomplete(arg_parser)

    if args is None:  # pragma: no cover
        args = sys.argv[1:]

    arg_parser.parse_args(args, namespace=namespace)

    if namespace.version:
        from datamodel_code_generator.version import version

        print(version)
        exit(0)

    root = black_find_project_root((Path().resolve(),))
    pyproject_toml_path = root / 'pyproject.toml'
    if pyproject_toml_path.is_file():
        pyproject_toml: Dict[str, Any] = {
            k.replace('-', '_'): v
            for k, v in load_toml(pyproject_toml_path)
            .get('tool', {})
            .get('datamodel-codegen', {})
            .items()
        }
    else:
        pyproject_toml = {}

    try:
        config = Config.parse_obj(pyproject_toml)
        config.merge_args(namespace)
    except Error as e:
        print(e.message, file=sys.stderr)
        return Exit.ERROR

    if not config.input and not config.url and sys.stdin.isatty():
        print(
            'Not Found Input: require `stdin` or arguments `--input` or `--url`',
            file=sys.stderr,
        )
        arg_parser.print_help()
        return Exit.ERROR

    if not is_supported_in_black(config.target_python_version):  # pragma: no cover
        print(
            f"Installed black doesn't support Python version {config.target_python_version.value}.\n"  # type: ignore
            f'You have to install a newer black.\n'
            f'Installed black version: {black.__version__}',
            file=sys.stderr,
        )
        return Exit.ERROR

    if config.debug:  # pragma: no cover
        enable_debug_message()

    if config.disable_warnings:
        warnings.simplefilter('ignore')
    extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]]
    if config.extra_template_data is None:
        extra_template_data = None
    else:
        with config.extra_template_data as data:
            try:
                extra_template_data = json.load(
                    data, object_hook=lambda d: defaultdict(dict, **d)
                )
            except json.JSONDecodeError as e:
                print(f'Unable to load extra template data: {e}', file=sys.stderr)
                return Exit.ERROR

    if config.aliases is None:
        aliases = None
    else:
        with config.aliases as data:
            try:
                aliases = json.load(data)
            except json.JSONDecodeError as e:
                print(f'Unable to load alias mapping: {e}', file=sys.stderr)
                return Exit.ERROR
        if not isinstance(aliases, dict) or not all(
            isinstance(k, str) and isinstance(v, str) for k, v in aliases.items()
        ):
            print(
                'Alias mapping must be a JSON string mapping (e.g. {"from": "to", ...})',
                file=sys.stderr,
            )
            return Exit.ERROR

    if config.custom_formatters_kwargs is None:
        custom_formatters_kwargs = None
    else:
        with config.custom_formatters_kwargs as data:
            try:
                custom_formatters_kwargs = json.load(data)
            except json.JSONDecodeError as e:  # pragma: no cover
                print(
                    f'Unable to load custom_formatters_kwargs mapping: {e}',
                    file=sys.stderr,
                )
                return Exit.ERROR
        if not isinstance(custom_formatters_kwargs, dict) or not all(
            isinstance(k, str) and isinstance(v, str)
            for k, v in custom_formatters_kwargs.items()
        ):  # pragma: no cover
            print(
                'Custom formatters kwargs mapping must be a JSON string mapping (e.g. {"from": "to", ...})',
                file=sys.stderr,
            )
            return Exit.ERROR

    try:
        generate(
            input_=config.url or config.input or sys.stdin.read(),
            input_file_type=config.input_file_type,
            output=config.output,
            output_model_type=config.output_model_type,
            target_python_version=config.target_python_version,
            base_class=config.base_class,
            additional_imports=config.additional_imports,
            custom_template_dir=config.custom_template_dir,
            validation=config.validation,
            field_constraints=config.field_constraints,
            snake_case_field=config.snake_case_field,
            strip_default_none=config.strip_default_none,
            extra_template_data=extra_template_data,
            aliases=aliases,
            disable_timestamp=config.disable_timestamp,
            enable_version_header=config.enable_version_header,
            allow_population_by_field_name=config.allow_population_by_field_name,
            allow_extra_fields=config.allow_extra_fields,
            apply_default_values_for_required_fields=config.use_default,
            force_optional_for_required_fields=config.force_optional,
            class_name=config.class_name,
            use_standard_collections=config.use_standard_collections,
            use_schema_description=config.use_schema_description,
            use_field_description=config.use_field_description,
            use_default_kwarg=config.use_default_kwarg,
            reuse_model=config.reuse_model,
            encoding=config.encoding,
            enum_field_as_literal=config.enum_field_as_literal,
            use_one_literal_as_default=config.use_one_literal_as_default,
            set_default_enum_member=config.set_default_enum_member,
            use_subclass_enum=config.use_subclass_enum,
            strict_nullable=config.strict_nullable,
            use_generic_container_types=config.use_generic_container_types,
            enable_faux_immutability=config.enable_faux_immutability,
            disable_appending_item_suffix=config.disable_appending_item_suffix,
            strict_types=config.strict_types,
            empty_enum_field_name=config.empty_enum_field_name,
            field_extra_keys=config.field_extra_keys,
            field_include_all_keys=config.field_include_all_keys,
            field_extra_keys_without_x_prefix=config.field_extra_keys_without_x_prefix,
            openapi_scopes=config.openapi_scopes,
            wrap_string_literal=config.wrap_string_literal,
            use_title_as_name=config.use_title_as_name,
            use_operation_id_as_name=config.use_operation_id_as_name,
            use_unique_items_as_set=config.use_unique_items_as_set,
            http_headers=config.http_headers,
            http_ignore_tls=config.http_ignore_tls,
            use_annotated=config.use_annotated,
            use_non_positive_negative_number_constrained_types=config.use_non_positive_negative_number_constrained_types,
            original_field_name_delimiter=config.original_field_name_delimiter,
            use_double_quotes=config.use_double_quotes,
            collapse_root_models=config.collapse_root_models,
            use_union_operator=config.use_union_operator,
            special_field_name_prefix=config.special_field_name_prefix,
            remove_special_field_name_prefix=config.remove_special_field_name_prefix,
            capitalise_enum_members=config.capitalise_enum_members,
            keep_model_order=config.keep_model_order,
            custom_file_header=config.custom_file_header,
            custom_file_header_path=config.custom_file_header_path,
            custom_formatters=config.custom_formatters,
            custom_formatters_kwargs=custom_formatters_kwargs,
            use_pendulum=config.use_pendulum,
            http_query_parameters=config.http_query_parameters,
            treat_dots_as_module=config.treat_dot_as_module,
            use_exact_imports=config.use_exact_imports,
            union_mode=config.union_mode,
            output_datetime_class=config.output_datetime_class,
            keyword_only=config.keyword_only,
            no_alias=config.no_alias,
        )
        return Exit.OK
    except InvalidClassNameError as e:
        print(f'{e} You have to set `--class-name` option', file=sys.stderr)
        return Exit.ERROR
    except Error as e:
        print(str(e), file=sys.stderr)
        return Exit.ERROR
    except Exception:
        import traceback

        print(traceback.format_exc(), file=sys.stderr)
        return Exit.ERROR


if __name__ == '__main__':
    sys.exit(main())