1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
|
import re
from typing import Callable, TypeAlias, cast
from asgiref.sync import iscoroutinefunction, sync_to_async
from django.conf import settings
from django.db.models import Q
from django.http import Http404, HttpRequest, HttpResponse
from django.shortcuts import redirect
from django.urls import clear_url_caches, set_urlconf
from django.utils.decorators import sync_and_async_middleware
from django_pgschemas.models import TenantModel as TenantModelBase
from django_pgschemas.routing.info import DomainInfo, HeadersInfo, SessionInfo
from django_pgschemas.routing.models import get_primary_domain_for_tenant
from django_pgschemas.routing.urlresolvers import get_urlconf_from_schema
from django_pgschemas.schema import Schema, activate, activate_public
from django_pgschemas.settings import get_tenant_header, get_tenant_session_key
from django_pgschemas.utils import get_domain_model, get_tenant_model, remove_www
def strip_tenant_from_path_factory(prefix: str) -> Callable[[str], str]:
def strip_tenant_from_path(path: str) -> str:
return re.sub(r"^/{}/".format(prefix), "/", path)
return strip_tenant_from_path
ResponseHandler: TypeAlias = Callable[[HttpRequest], HttpResponse]
def route_domain(request: HttpRequest) -> HttpResponse | None:
hostname = remove_www(request.get_host().split(":")[0])
activate_public()
tenant: Schema | None = None
# Checking for static tenants
for schema, data in settings.TENANTS.items():
if schema in ["public", "default"]:
continue
if hostname in data.get("DOMAINS", []):
tenant = Schema.create(
schema_name=schema,
routing=DomainInfo(domain=hostname),
)
break
# Checking for dynamic tenants
else:
DomainModel = get_domain_model()
prefix = request.path.split("/")[1]
domain = None
if DomainModel is not None:
try:
domain = DomainModel.objects.select_related("tenant").get(
domain=hostname, folder=prefix
)
except DomainModel.DoesNotExist:
try:
domain = DomainModel.objects.select_related("tenant").get(
domain=hostname, folder=""
)
except DomainModel.DoesNotExist:
pass
if domain is not None:
tenant = cast(TenantModelBase, domain.tenant)
tenant.routing = DomainInfo(domain=hostname)
request.strip_tenant_from_path = lambda x: x
if prefix and domain.folder == prefix:
tenant.routing = DomainInfo(domain=hostname, folder=prefix)
request.strip_tenant_from_path = strip_tenant_from_path_factory(prefix)
clear_url_caches() # Required to remove previous tenant prefix from cache (#8)
if domain.redirect_to_primary:
primary_domain = get_primary_domain_for_tenant(tenant)
if primary_domain:
path = request.strip_tenant_from_path(request.path)
return redirect(primary_domain.absolute_url(path), permanent=True)
# Checking fallback domains
if not tenant:
for schema, data in settings.TENANTS.items():
if schema in ["public", "default"]:
continue
if hostname in data.get("FALLBACK_DOMAINS", []):
tenant = Schema.create(
schema_name=schema,
routing=DomainInfo(domain=hostname),
)
break
# No tenant found from domain / folder
if not tenant:
raise Http404("No tenant for hostname '%s'" % hostname)
urlconf = get_urlconf_from_schema(tenant)
request.tenant = tenant
request.urlconf = urlconf
set_urlconf(urlconf)
activate(tenant)
return None
def route_session(request: HttpRequest) -> HttpResponse | None:
tenant_session_key = get_tenant_session_key()
if not hasattr(request, "session") or not (
tenant_ref := request.session.get(tenant_session_key)
):
return None
tenant: Schema | None = None
# Checking for static tenants
for schema, data in settings.TENANTS.items():
if schema in ["public", "default"]:
continue
if tenant_ref == schema or tenant_ref == data.get("SESSION_KEY"):
tenant = Schema.create(schema_name=schema)
break
# Checking for dynamic tenants
else:
if (TenantModel := get_tenant_model()) is not None:
tenant = TenantModel._default_manager.filter(
Q(pk__iexact=tenant_ref) | Q(schema_name=tenant_ref)
).first()
if tenant is not None:
tenant.routing = SessionInfo(reference=tenant_ref)
request.tenant = tenant
activate(tenant)
return None
def route_headers(request: HttpRequest) -> HttpResponse | None:
tenant_header = get_tenant_header()
if not (tenant_ref := request.headers.get(tenant_header)):
return None
tenant: Schema | None = None
# Checking for static tenants
for schema, data in settings.TENANTS.items():
if schema in ["public", "default"]:
continue
if tenant_ref == schema or tenant_ref == data.get("HEADER"):
tenant = Schema.create(schema_name=schema)
break
# Checking for dynamic tenants
else:
if (TenantModel := get_tenant_model()) is not None:
tenant = TenantModel._default_manager.filter(
Q(pk__iexact=tenant_ref) | Q(schema_name=tenant_ref)
).first()
if tenant is not None:
tenant.routing = HeadersInfo(reference=tenant_ref)
request.tenant = tenant
activate(tenant)
return None
def middleware_factory(
handler: Callable[[HttpRequest], HttpResponse | None],
) -> Callable[[ResponseHandler], ResponseHandler]:
@sync_and_async_middleware
def middleware(get_response: ResponseHandler) -> ResponseHandler:
if iscoroutinefunction(get_response):
async_base_middleware = sync_to_async(handler)
async def sync_middleware(request: HttpRequest) -> HttpResponse | None:
if response := await async_base_middleware(request):
return response
return await get_response(request)
return sync_middleware
else:
def async_middleware(request: HttpRequest) -> HttpResponse | None:
if response := handler(request):
return response
return get_response(request)
return async_middleware
return middleware
DomainRoutingMiddleware = middleware_factory(route_domain)
SessionRoutingMiddleware = middleware_factory(route_session)
HeadersRoutingMiddleware = middleware_factory(route_headers)
|