File: generators.py

package info (click to toggle)
python-drf-spectacular 0.28.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,748 kB
  • sloc: python: 14,174; javascript: 114; sh: 61; makefile: 30
file content (293 lines) | stat: -rw-r--r-- 13,104 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import os
import re

from django.urls import URLPattern, URLResolver
from rest_framework import views, viewsets
from rest_framework.schemas.generators import BaseSchemaGenerator
from rest_framework.schemas.generators import EndpointEnumerator as BaseEndpointEnumerator
from rest_framework.settings import api_settings

from drf_spectacular.drainage import (
    add_trace_message, error, get_override, reset_generator_stats, warn,
)
from drf_spectacular.extensions import OpenApiViewExtension
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.plumbing import (
    ComponentRegistry, alpha_operation_sorter, build_root_object, camelize_operation, get_class,
    is_versioning_supported, modify_for_versioning, normalize_result_object,
    operation_matches_version, process_webhooks, sanitize_result_object,
)
from drf_spectacular.settings import spectacular_settings


class EndpointEnumerator(BaseEndpointEnumerator):
    def get_api_endpoints(self, patterns=None, prefix=''):
        api_endpoints = self._get_api_endpoints(patterns, prefix)

        for hook in spectacular_settings.PREPROCESSING_HOOKS:
            api_endpoints = hook(endpoints=api_endpoints)

        api_endpoints_deduplicated = {}
        for path, path_regex, method, callback in api_endpoints:
            if (path, method) not in api_endpoints_deduplicated:
                api_endpoints_deduplicated[path, method] = (path, path_regex, method, callback)

        api_endpoints = list(api_endpoints_deduplicated.values())

        if callable(spectacular_settings.SORT_OPERATIONS):
            return sorted(api_endpoints, key=spectacular_settings.SORT_OPERATIONS)
        elif spectacular_settings.SORT_OPERATIONS:
            return sorted(api_endpoints, key=alpha_operation_sorter)
        else:
            return api_endpoints

    def get_path_from_regex(self, path_regex):
        path = super().get_path_from_regex(path_regex)
        # bugfix oversight in DRF regex stripping
        path = path.replace('\\.', '.')
        return path

    def _get_api_endpoints(self, patterns, prefix):
        """
        Return a list of all available API endpoints by inspecting the URL conf.
        Only modification the DRF version is passing through the path_regex.
        """
        if patterns is None:
            patterns = self.patterns

        api_endpoints = []

        for pattern in patterns:
            path_regex = prefix + str(pattern.pattern)
            if isinstance(pattern, URLPattern):
                path = self.get_path_from_regex(path_regex)
                callback = pattern.callback
                if self.should_include_endpoint(path, callback):
                    for method in self.get_allowed_methods(callback):
                        endpoint = (path, path_regex, method, callback)
                        api_endpoints.append(endpoint)

            elif isinstance(pattern, URLResolver):
                nested_endpoints = self._get_api_endpoints(
                    patterns=pattern.url_patterns,
                    prefix=path_regex
                )
                api_endpoints.extend(nested_endpoints)

        return api_endpoints

    def get_allowed_methods(self, callback):
        if hasattr(callback, 'actions'):
            actions = set(callback.actions)
            if 'http_method_names' in callback.initkwargs:
                http_method_names = set(callback.initkwargs['http_method_names'])
            else:
                http_method_names = set(callback.cls.http_method_names)

            methods = [method.upper() for method in actions & http_method_names]
        else:
            # pass to constructor allowed method names to get valid ones
            kwargs = {}
            if 'http_method_names' in callback.initkwargs:
                kwargs['http_method_names'] = callback.initkwargs['http_method_names']

            methods = callback.cls(**kwargs).allowed_methods

        return [
            method for method in methods
            if method not in ('OPTIONS', 'HEAD', 'TRACE', 'CONNECT')
        ]


class SchemaGenerator(BaseSchemaGenerator):
    endpoint_inspector_cls = EndpointEnumerator

    def __init__(self, *args, **kwargs):
        self.registry = ComponentRegistry()
        self.api_version = kwargs.pop('api_version', None)
        self.inspector = None
        super().__init__(*args, **kwargs)

    def coerce_path(self, path, method, view):
        """
        Customized coerce_path which also considers the `_pk` suffix in URL paths
        of nested routers.
        """
        path = super().coerce_path(path, method, view)  # take care of {pk}
        if spectacular_settings.SCHEMA_COERCE_PATH_PK_SUFFIX:
            path = re.sub(pattern=r'{(\w+)_pk}', repl=r'{\1_id}', string=path)
        return path

    def create_view(self, callback, method, request=None):
        """
        customized create_view which is called when all routes are traversed. part of this
        is instantiating views with default params. in case of custom routes (@action) the
        custom AutoSchema is injected properly through 'initkwargs' on view. However, when
        decorating plain views like retrieve, this initialization logic is not running.
        Therefore forcefully set the schema if @extend_schema decorator was used.
        """
        override_view = OpenApiViewExtension.get_match(callback.cls)
        if override_view:
            original_cls = callback.cls
            callback.cls = override_view.view_replacement()

        # we refrain from passing request and deal with it ourselves in parse()
        view = super().create_view(callback, method, None)

        # drf-yasg compatibility feature. makes the view aware that we are running
        # schema generation and not a real request.
        view.swagger_fake_view = True

        # callback.cls is hosted in urlpatterns and is therefore not an ephemeral modification.
        # restore after view creation so potential revisits have a clean state as basis.
        if override_view:
            callback.cls = original_cls

        if isinstance(view, viewsets.ViewSetMixin):
            action = getattr(view, view.action)
        elif isinstance(view, views.APIView):
            action = getattr(view, method.lower())
        else:
            error(
                'Using not supported View class. Class must be derived from APIView '
                'or any of its subclasses like GenericApiView, GenericViewSet.'
            )
            return view

        action_schema = getattr(action, 'kwargs', {}).get('schema', None)
        if not action_schema:
            # there is no method/action customized schema so we are done here.
            return view

        # action_schema is either a class or instance. when @extend_schema is used, it
        # is always a class to prevent the weakref reverse "schema.view" bug for multi
        # annotations. The bug is prevented by delaying the instantiation of the schema
        # class until create_view (here) and not doing it immediately in @extend_schema.
        action_schema_class = get_class(action_schema)
        view_schema_class = get_class(callback.cls.schema)

        if not issubclass(action_schema_class, view_schema_class):
            # this handles the case of having a manually set custom AutoSchema on the
            # view together with extend_schema. In most cases, the decorator mechanics
            # prevent extend_schema from having access to the view's schema class. So
            # extend_schema is forced to use DEFAULT_SCHEMA_CLASS as fallback base class
            # instead of the correct base class set in view. We remedy this chicken-egg
            # problem here by rearranging the class hierarchy.
            mro = tuple(
                cls for cls in action_schema_class.__mro__
                if cls not in api_settings.DEFAULT_SCHEMA_CLASS.__mro__
            ) + view_schema_class.__mro__
            action_schema_class = type('ExtendedRearrangedSchema', mro, {})

        view.schema = action_schema_class()
        return view

    def _initialise_endpoints(self):
        if self.endpoints is None:
            self.inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
            self.endpoints = self.inspector.get_api_endpoints()

    def _get_paths_and_endpoints(self):
        """
        Generate (path, method, view) given (path, method, callback) for paths.
        """
        view_endpoints = []
        for path, path_regex, method, callback in self.endpoints:
            view = self.create_view(callback, method)
            path = self.coerce_path(path, method, view)
            view_endpoints.append((path, path_regex, method, view))

        return view_endpoints

    def parse(self, input_request, public):
        """ Iterate endpoints generating per method path operations. """
        result = {}
        self._initialise_endpoints()
        endpoints = self._get_paths_and_endpoints()

        if spectacular_settings.SCHEMA_PATH_PREFIX is None:
            # estimate common path prefix if none was given. only use it if we encountered more
            # than one view to prevent emission of erroneous and unnecessary fallback names.
            non_trivial_prefix = len(set([view.__class__ for _, _, _, view in endpoints])) > 1
            if non_trivial_prefix:
                path_prefix = os.path.commonpath([path for path, _, _, _ in endpoints])
                path_prefix = re.escape(path_prefix)  # guard for RE special chars in path
            else:
                path_prefix = '/'
        else:
            path_prefix = spectacular_settings.SCHEMA_PATH_PREFIX
        if not path_prefix.startswith('^'):
            path_prefix = '^' + path_prefix  # make sure regex only matches from the start

        for path, path_regex, method, view in endpoints:
            # emit queued up warnings/error that happened prior to generation (decoration)
            for w in get_override(view, 'warnings', []):
                warn(w)
            for e in get_override(view, 'errors', []):
                error(e)

            view.request = spectacular_settings.GET_MOCK_REQUEST(method, path, view, input_request)

            if not (public or self.has_view_permissions(path, method, view)):
                continue

            if view.versioning_class and not is_versioning_supported(view.versioning_class):
                warn(
                    f'using unsupported versioning class "{view.versioning_class}". view will be '
                    f'processed as unversioned view.'
                )
            elif view.versioning_class:
                version = (
                    self.api_version  # explicit version from CLI, SpecView or SpecView request
                    or view.versioning_class.default_version  # fallback
                )
                if not version:
                    continue
                path = modify_for_versioning(self.inspector.patterns, method, path, view, version)
                if not operation_matches_version(view, version):
                    continue

            assert isinstance(view.schema, AutoSchema), (
                f'Incompatible AutoSchema used on View {view.__class__}. Is DRF\'s '
                f'DEFAULT_SCHEMA_CLASS pointing to "drf_spectacular.openapi.AutoSchema" '
                f'or any other drf-spectacular compatible AutoSchema?'
            )
            with add_trace_message(getattr(view, '__class__', view)):
                operation = view.schema.get_operation(
                    path, path_regex, path_prefix, method, self.registry
                )

            # operation was manually removed via @extend_schema
            if not operation:
                continue

            if spectacular_settings.SCHEMA_PATH_PREFIX_TRIM:
                path = re.sub(pattern=path_prefix, repl='', string=path, flags=re.IGNORECASE)

            if spectacular_settings.SCHEMA_PATH_PREFIX_INSERT:
                path = spectacular_settings.SCHEMA_PATH_PREFIX_INSERT + path

            if not path.startswith('/'):
                path = '/' + path

            if spectacular_settings.CAMELIZE_NAMES:
                path, operation = camelize_operation(path, operation)

            result.setdefault(path, {})
            result[path][method.lower()] = operation

        return result

    def get_schema(self, request=None, public=False):
        """ Generate a OpenAPI schema. """
        reset_generator_stats()
        result = build_root_object(
            paths=self.parse(request, public),
            components=self.registry.build(spectacular_settings.APPEND_COMPONENTS),
            webhooks=process_webhooks(spectacular_settings.WEBHOOKS, self.registry),
            version=self.api_version or getattr(request, 'version', None),
        )
        for hook in spectacular_settings.POSTPROCESSING_HOOKS:
            result = hook(result=result, generator=self, request=request, public=public)

        return sanitize_result_object(normalize_result_object(result))