File: pwiz.py

package info (click to toggle)
peewee 3.13.1+dfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye, sid
  • size: 2,848 kB
  • sloc: python: 32,191; makefile: 126; ansic: 109
file content (223 lines) | stat: -rwxr-xr-x 8,089 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
#!/usr/bin/env python

import datetime
import sys
from getpass import getpass
from optparse import OptionParser

from peewee import *
from peewee import print_
from peewee import __version__ as peewee_version
from playhouse.cockroachdb import CockroachDatabase
from playhouse.reflection import *


HEADER = """from peewee import *%s

database = %s('%s'%s)
"""

BASE_MODEL = """\
class BaseModel(Model):
    class Meta:
        database = database
"""

UNKNOWN_FIELD = """\
class UnknownField(object):
    def __init__(self, *_, **__): pass
"""

DATABASE_ALIASES = {
    CockroachDatabase: ['cockroach', 'cockroachdb', 'crdb'],
    MySQLDatabase: ['mysql', 'mysqldb'],
    PostgresqlDatabase: ['postgres', 'postgresql'],
    SqliteDatabase: ['sqlite', 'sqlite3'],
}

DATABASE_MAP = dict((value, key)
                    for key in DATABASE_ALIASES
                    for value in DATABASE_ALIASES[key])

def make_introspector(database_type, database_name, **kwargs):
    if database_type not in DATABASE_MAP:
        err('Unrecognized database, must be one of: %s' %
            ', '.join(DATABASE_MAP.keys()))
        sys.exit(1)

    schema = kwargs.pop('schema', None)
    DatabaseClass = DATABASE_MAP[database_type]
    db = DatabaseClass(database_name, **kwargs)
    return Introspector.from_database(db, schema=schema)

def print_models(introspector, tables=None, preserve_order=False,
                 include_views=False, ignore_unknown=False, snake_case=True):
    database = introspector.introspect(table_names=tables,
                                       include_views=include_views,
                                       snake_case=snake_case)

    db_kwargs = introspector.get_database_kwargs()
    header = HEADER % (
        introspector.get_additional_imports(),
        introspector.get_database_class().__name__,
        introspector.get_database_name(),
        ', **%s' % repr(db_kwargs) if db_kwargs else '')
    print_(header)

    if not ignore_unknown:
        print_(UNKNOWN_FIELD)

    print_(BASE_MODEL)

    def _print_table(table, seen, accum=None):
        accum = accum or []
        foreign_keys = database.foreign_keys[table]
        for foreign_key in foreign_keys:
            dest = foreign_key.dest_table

            # In the event the destination table has already been pushed
            # for printing, then we have a reference cycle.
            if dest in accum and table not in accum:
                print_('# Possible reference cycle: %s' % dest)

            # If this is not a self-referential foreign key, and we have
            # not already processed the destination table, do so now.
            if dest not in seen and dest not in accum:
                seen.add(dest)
                if dest != table:
                    _print_table(dest, seen, accum + [table])

        print_('class %s(BaseModel):' % database.model_names[table])
        columns = database.columns[table].items()
        if not preserve_order:
            columns = sorted(columns)
        primary_keys = database.primary_keys[table]
        for name, column in columns:
            skip = all([
                name in primary_keys,
                name == 'id',
                len(primary_keys) == 1,
                column.field_class in introspector.pk_classes])
            if skip:
                continue
            if column.primary_key and len(primary_keys) > 1:
                # If we have a CompositeKey, then we do not want to explicitly
                # mark the columns as being primary keys.
                column.primary_key = False

            is_unknown = column.field_class is UnknownField
            if is_unknown and ignore_unknown:
                disp = '%s - %s' % (column.name, column.raw_column_type or '?')
                print_('    # %s' % disp)
            else:
                print_('    %s' % column.get_field())

        print_('')
        print_('    class Meta:')
        print_('        table_name = \'%s\'' % table)
        multi_column_indexes = database.multi_column_indexes(table)
        if multi_column_indexes:
            print_('        indexes = (')
            for fields, unique in sorted(multi_column_indexes):
                print_('            ((%s), %s),' % (
                    ', '.join("'%s'" % field for field in fields),
                    unique,
                ))
            print_('        )')

        if introspector.schema:
            print_('        schema = \'%s\'' % introspector.schema)
        if len(primary_keys) > 1:
            pk_field_names = sorted([
                field.name for col, field in columns
                if col in primary_keys])
            pk_list = ', '.join("'%s'" % pk for pk in pk_field_names)
            print_('        primary_key = CompositeKey(%s)' % pk_list)
        elif not primary_keys:
            print_('        primary_key = False')
        print_('')

        seen.add(table)

    seen = set()
    for table in sorted(database.model_names.keys()):
        if table not in seen:
            if not tables or table in tables:
                _print_table(table, seen)

def print_header(cmd_line, introspector):
    timestamp = datetime.datetime.now()
    print_('# Code generated by:')
    print_('# python -m pwiz %s' % cmd_line)
    print_('# Date: %s' % timestamp.strftime('%B %d, %Y %I:%M%p'))
    print_('# Database: %s' % introspector.get_database_name())
    print_('# Peewee version: %s' % peewee_version)
    print_('')


def err(msg):
    sys.stderr.write('\033[91m%s\033[0m\n' % msg)
    sys.stderr.flush()

def get_option_parser():
    parser = OptionParser(usage='usage: %prog [options] database_name')
    ao = parser.add_option
    ao('-H', '--host', dest='host')
    ao('-p', '--port', dest='port', type='int')
    ao('-u', '--user', dest='user')
    ao('-P', '--password', dest='password', action='store_true')
    engines = sorted(DATABASE_MAP)
    ao('-e', '--engine', dest='engine', default='postgresql', choices=engines,
       help=('Database type, e.g. sqlite, mysql, postgresql or cockroachdb. '
             'Default is "postgresql".'))
    ao('-s', '--schema', dest='schema')
    ao('-t', '--tables', dest='tables',
       help=('Only generate the specified tables. Multiple table names should '
             'be separated by commas.'))
    ao('-v', '--views', dest='views', action='store_true',
       help='Generate model classes for VIEWs in addition to tables.')
    ao('-i', '--info', dest='info', action='store_true',
       help=('Add database information and other metadata to top of the '
             'generated file.'))
    ao('-o', '--preserve-order', action='store_true', dest='preserve_order',
       help='Model definition column ordering matches source table.')
    ao('-I', '--ignore-unknown', action='store_true', dest='ignore_unknown',
       help='Ignore fields whose type cannot be determined.')
    ao('-L', '--legacy-naming', action='store_true', dest='legacy_naming',
       help='Use legacy table- and column-name generation.')
    return parser

def get_connect_kwargs(options):
    ops = ('host', 'port', 'user', 'schema')
    kwargs = dict((o, getattr(options, o)) for o in ops if getattr(options, o))
    if options.password:
        kwargs['password'] = getpass()
    return kwargs


if __name__ == '__main__':
    raw_argv = sys.argv

    parser = get_option_parser()
    options, args = parser.parse_args()

    if len(args) < 1:
        err('Missing required parameter "database"')
        parser.print_help()
        sys.exit(1)

    connect = get_connect_kwargs(options)
    database = args[-1]

    tables = None
    if options.tables:
        tables = [table.strip() for table in options.tables.split(',')
                  if table.strip()]

    introspector = make_introspector(options.engine, database, **connect)
    if options.info:
        cmd_line = ' '.join(raw_argv[1:])
        print_header(cmd_line, introspector)

    print_models(introspector, tables, options.preserve_order, options.views,
                 options.ignore_unknown, not options.legacy_naming)