File: cloneschema.py

package info (click to toggle)
python-django-pgschemas 1.0.1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 848 kB
  • sloc: python: 3,887; makefile: 33; sh: 10; sql: 2
file content (144 lines) | stat: -rw-r--r-- 5,555 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from django.core.checks import Tags, run_checks
from django.core.management.base import BaseCommand, CommandError

from django_pgschemas.utils import clone_schema, get_domain_model, get_tenant_model


class Command(BaseCommand):
    help = "Clones a schema"

    def _run_checks(self, **kwargs):  # pragma: no cover
        issues = run_checks(tags=[Tags.database])
        issues.extend(super()._run_checks(**kwargs))
        return issues

    def add_arguments(self, parser):
        super().add_arguments(parser)
        parser.add_argument(
            "source",
            help="The name of the schema you want to clone",
        )
        parser.add_argument(
            "destination",
            help="The name of the schema you want to create as clone",
        )
        parser.add_argument(
            "--noinput",
            "--no-input",
            action="store_false",
            dest="interactive",
            help="Tells Django to NOT prompt the user for input of any kind.",
        )
        parser.add_argument(
            "--dry-run",
            dest="dry_run",
            action="store_true",
            help="Just show what clone would do; without actually cloning.",
        )

    def _ask(self, question):
        bool_true = ("y", "yes", "t", "true", "on", "1")
        bool_false = ("n", "no", "f", "false", "off", "0")
        answer = None
        while answer is None:
            try:
                raw_answer = input(f"{question.strip()} [Y/n] ").strip().lower() or "y"
                if raw_answer not in bool_true + bool_false:
                    raise ValueError()
                answer = raw_answer in bool_true
            except ValueError:
                self.stderr.write(f"{raw_answer} is not a valid answer.")
                pass
        return answer

    def _check_required_field(self, field, exclude=None):
        if exclude is None:
            exclude = []
        return (
            field.editable
            and not field.primary_key
            and not field.is_relation
            and not (
                field.null
                or field.has_default()
                or (field.blank and field.empty_strings_allowed)
                or getattr(field, "auto_now", False)
                or getattr(field, "auto_now_add", False)
            )
            and field.name not in exclude
        )

    def _get_constructed_instance(self, model_class, data):
        fields = [
            field
            for field in model_class._meta.fields
            if self._check_required_field(field, data.keys())
        ]
        instance = model_class(**data)
        if fields:
            self.stdout.write(
                self.style.WARNING(f"We need some data for model '{model_class._meta.model_name}':")
            )
            for field in fields:
                while field.name not in data:
                    raw_value = input(f"Value for field '{field.name}': ")
                    try:
                        data[field.name] = field.clean(raw_value, None)
                        instance = model_class(**data)
                        instance.clean()
                    except Exception as e:
                        if hasattr(e, "message"):
                            self.stderr.write(e.message)  # noqa
                        elif hasattr(e, "messages"):
                            self.stderr.write(" ".join(e.messages))  # noqa
                        else:
                            self.stderr.write(e)
                        data.pop(field.name, None)
        return instance

    def get_dynamic_tenant(self, **options):
        tenant = None
        domain = None
        if self._ask(
            "You are cloning a schema for a dynamic tenant. Would you like to create a database entry for it?"
        ):
            tenant = self._get_constructed_instance(
                get_tenant_model(), {"schema_name": options["destination"]}
            )
            domain = self._get_constructed_instance(get_domain_model(), {"is_primary": True})
            if options["verbosity"] >= 1:
                self.stdout.write(self.style.WARNING("Looks good! Let's get to it!"))
        return tenant, domain

    def handle(self, *args, **options):
        tenant = None
        domain = None
        dry_run = options.get("dry_run")
        if options.get("interactive", True):
            TenantModel = get_tenant_model()
            if (
                TenantModel is not None
                and TenantModel.objects.filter(schema_name=options["source"]).exists()
            ):
                tenant, domain = self.get_dynamic_tenant(**options)
        try:
            clone_schema(options["source"], options["destination"], dry_run)
            if tenant and domain:
                if options["verbosity"] >= 1:
                    self.stdout.write("Schema cloned.")
                if not dry_run:
                    tenant.save()
                domain.tenant = tenant
                if not dry_run:
                    domain.save()
                if options["verbosity"] >= 1:
                    self.stdout.write("Tenant and domain successfully saved.")
            if options["verbosity"] >= 1:
                self.stdout.write("All done!")
        except Exception as e:
            if hasattr(e, "message"):
                raise CommandError(e.message)  # noqa
            elif hasattr(e, "messages"):
                raise CommandError(" ".join(e.messages))  # noqa
            else:
                raise CommandError(e)