1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
from django.core.checks import Tags, run_checks
from django.core.management.base import BaseCommand, CommandError
from django_pgschemas.utils import clone_schema, get_domain_model, get_tenant_model
class Command(BaseCommand):
help = "Clones a schema"
def _run_checks(self, **kwargs): # pragma: no cover
issues = run_checks(tags=[Tags.database])
issues.extend(super()._run_checks(**kwargs))
return issues
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"source",
help="The name of the schema you want to clone",
)
parser.add_argument(
"destination",
help="The name of the schema you want to create as clone",
)
parser.add_argument(
"--noinput",
"--no-input",
action="store_false",
dest="interactive",
help="Tells Django to NOT prompt the user for input of any kind.",
)
parser.add_argument(
"--dry-run",
dest="dry_run",
action="store_true",
help="Just show what clone would do; without actually cloning.",
)
def _ask(self, question):
bool_true = ("y", "yes", "t", "true", "on", "1")
bool_false = ("n", "no", "f", "false", "off", "0")
answer = None
while answer is None:
try:
raw_answer = input(f"{question.strip()} [Y/n] ").strip().lower() or "y"
if raw_answer not in bool_true + bool_false:
raise ValueError()
answer = raw_answer in bool_true
except ValueError:
self.stderr.write(f"{raw_answer} is not a valid answer.")
pass
return answer
def _check_required_field(self, field, exclude=None):
if exclude is None:
exclude = []
return (
field.editable
and not field.primary_key
and not field.is_relation
and not (
field.null
or field.has_default()
or (field.blank and field.empty_strings_allowed)
or getattr(field, "auto_now", False)
or getattr(field, "auto_now_add", False)
)
and field.name not in exclude
)
def _get_constructed_instance(self, model_class, data):
fields = [
field
for field in model_class._meta.fields
if self._check_required_field(field, data.keys())
]
instance = model_class(**data)
if fields:
self.stdout.write(
self.style.WARNING(f"We need some data for model '{model_class._meta.model_name}':")
)
for field in fields:
while field.name not in data:
raw_value = input(f"Value for field '{field.name}': ")
try:
data[field.name] = field.clean(raw_value, None)
instance = model_class(**data)
instance.clean()
except Exception as e:
if hasattr(e, "message"):
self.stderr.write(e.message) # noqa
elif hasattr(e, "messages"):
self.stderr.write(" ".join(e.messages)) # noqa
else:
self.stderr.write(e)
data.pop(field.name, None)
return instance
def get_dynamic_tenant(self, **options):
tenant = None
domain = None
if self._ask(
"You are cloning a schema for a dynamic tenant. Would you like to create a database entry for it?"
):
tenant = self._get_constructed_instance(
get_tenant_model(), {"schema_name": options["destination"]}
)
domain = self._get_constructed_instance(get_domain_model(), {"is_primary": True})
if options["verbosity"] >= 1:
self.stdout.write(self.style.WARNING("Looks good! Let's get to it!"))
return tenant, domain
def handle(self, *args, **options):
tenant = None
domain = None
dry_run = options.get("dry_run")
if options.get("interactive", True):
TenantModel = get_tenant_model()
if (
TenantModel is not None
and TenantModel.objects.filter(schema_name=options["source"]).exists()
):
tenant, domain = self.get_dynamic_tenant(**options)
try:
clone_schema(options["source"], options["destination"], dry_run)
if tenant and domain:
if options["verbosity"] >= 1:
self.stdout.write("Schema cloned.")
if not dry_run:
tenant.save()
domain.tenant = tenant
if not dry_run:
domain.save()
if options["verbosity"] >= 1:
self.stdout.write("Tenant and domain successfully saved.")
if options["verbosity"] >= 1:
self.stdout.write("All done!")
except Exception as e:
if hasattr(e, "message"):
raise CommandError(e.message) # noqa
elif hasattr(e, "messages"):
raise CommandError(" ".join(e.messages)) # noqa
else:
raise CommandError(e)
|