File: middleware.py

package info (click to toggle)
python-django-pgschemas 1.0.1-2
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 848 kB
  • sloc: python: 3,887; makefile: 33; sh: 10; sql: 2
file content (205 lines) | stat: -rw-r--r-- 6,947 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import re
from typing import Callable, TypeAlias, cast

from asgiref.sync import iscoroutinefunction, sync_to_async
from django.conf import settings
from django.db.models import Q
from django.http import Http404, HttpRequest, HttpResponse
from django.shortcuts import redirect
from django.urls import clear_url_caches, set_urlconf
from django.utils.decorators import sync_and_async_middleware

from django_pgschemas.models import TenantModel as TenantModelBase
from django_pgschemas.routing.info import DomainInfo, HeadersInfo, SessionInfo
from django_pgschemas.routing.models import get_primary_domain_for_tenant
from django_pgschemas.routing.urlresolvers import get_urlconf_from_schema
from django_pgschemas.schema import Schema, activate, activate_public
from django_pgschemas.settings import get_tenant_header, get_tenant_session_key
from django_pgschemas.utils import get_domain_model, get_tenant_model, remove_www


def strip_tenant_from_path_factory(prefix: str) -> Callable[[str], str]:
    def strip_tenant_from_path(path: str) -> str:
        return re.sub(r"^/{}/".format(prefix), "/", path)

    return strip_tenant_from_path


ResponseHandler: TypeAlias = Callable[[HttpRequest], HttpResponse]


def route_domain(request: HttpRequest) -> HttpResponse | None:
    hostname = remove_www(request.get_host().split(":")[0])

    activate_public()
    tenant: Schema | None = None

    # Checking for static tenants
    for schema, data in settings.TENANTS.items():
        if schema in ["public", "default"]:
            continue
        if hostname in data.get("DOMAINS", []):
            tenant = Schema.create(
                schema_name=schema,
                routing=DomainInfo(domain=hostname),
            )
            break

    # Checking for dynamic tenants
    else:
        DomainModel = get_domain_model()

        prefix = request.path.split("/")[1]
        domain = None

        if DomainModel is not None:
            try:
                domain = DomainModel.objects.select_related("tenant").get(
                    domain=hostname, folder=prefix
                )
            except DomainModel.DoesNotExist:
                try:
                    domain = DomainModel.objects.select_related("tenant").get(
                        domain=hostname, folder=""
                    )
                except DomainModel.DoesNotExist:
                    pass

        if domain is not None:
            tenant = cast(TenantModelBase, domain.tenant)
            tenant.routing = DomainInfo(domain=hostname)
            request.strip_tenant_from_path = lambda x: x

            if prefix and domain.folder == prefix:
                tenant.routing = DomainInfo(domain=hostname, folder=prefix)
                request.strip_tenant_from_path = strip_tenant_from_path_factory(prefix)
                clear_url_caches()  # Required to remove previous tenant prefix from cache (#8)

            if domain.redirect_to_primary:
                primary_domain = get_primary_domain_for_tenant(tenant)
                if primary_domain:
                    path = request.strip_tenant_from_path(request.path)
                    return redirect(primary_domain.absolute_url(path), permanent=True)

    # Checking fallback domains
    if not tenant:
        for schema, data in settings.TENANTS.items():
            if schema in ["public", "default"]:
                continue
            if hostname in data.get("FALLBACK_DOMAINS", []):
                tenant = Schema.create(
                    schema_name=schema,
                    routing=DomainInfo(domain=hostname),
                )
                break

    # No tenant found from domain / folder
    if not tenant:
        raise Http404("No tenant for hostname '%s'" % hostname)

    urlconf = get_urlconf_from_schema(tenant)

    request.tenant = tenant
    request.urlconf = urlconf
    set_urlconf(urlconf)

    activate(tenant)
    return None


def route_session(request: HttpRequest) -> HttpResponse | None:
    tenant_session_key = get_tenant_session_key()

    if not hasattr(request, "session") or not (
        tenant_ref := request.session.get(tenant_session_key)
    ):
        return None

    tenant: Schema | None = None

    # Checking for static tenants
    for schema, data in settings.TENANTS.items():
        if schema in ["public", "default"]:
            continue
        if tenant_ref == schema or tenant_ref == data.get("SESSION_KEY"):
            tenant = Schema.create(schema_name=schema)
            break

    # Checking for dynamic tenants
    else:
        if (TenantModel := get_tenant_model()) is not None:
            tenant = TenantModel._default_manager.filter(
                Q(pk__iexact=tenant_ref) | Q(schema_name=tenant_ref)
            ).first()

    if tenant is not None:
        tenant.routing = SessionInfo(reference=tenant_ref)
        request.tenant = tenant
        activate(tenant)

    return None


def route_headers(request: HttpRequest) -> HttpResponse | None:
    tenant_header = get_tenant_header()

    if not (tenant_ref := request.headers.get(tenant_header)):
        return None

    tenant: Schema | None = None

    # Checking for static tenants
    for schema, data in settings.TENANTS.items():
        if schema in ["public", "default"]:
            continue
        if tenant_ref == schema or tenant_ref == data.get("HEADER"):
            tenant = Schema.create(schema_name=schema)
            break

    # Checking for dynamic tenants
    else:
        if (TenantModel := get_tenant_model()) is not None:
            tenant = TenantModel._default_manager.filter(
                Q(pk__iexact=tenant_ref) | Q(schema_name=tenant_ref)
            ).first()

    if tenant is not None:
        tenant.routing = HeadersInfo(reference=tenant_ref)
        request.tenant = tenant
        activate(tenant)

    return None


def middleware_factory(
    handler: Callable[[HttpRequest], HttpResponse | None],
) -> Callable[[ResponseHandler], ResponseHandler]:
    @sync_and_async_middleware
    def middleware(get_response: ResponseHandler) -> ResponseHandler:
        if iscoroutinefunction(get_response):
            async_base_middleware = sync_to_async(handler)

            async def sync_middleware(request: HttpRequest) -> HttpResponse | None:
                if response := await async_base_middleware(request):
                    return response

                return await get_response(request)

            return sync_middleware

        else:

            def async_middleware(request: HttpRequest) -> HttpResponse | None:
                if response := handler(request):
                    return response

                return get_response(request)

            return async_middleware

    return middleware


DomainRoutingMiddleware = middleware_factory(route_domain)
SessionRoutingMiddleware = middleware_factory(route_session)
HeadersRoutingMiddleware = middleware_factory(route_headers)