File: trait_grid_model.py

package info (click to toggle)
python-pyface 8.0.0-5
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 13,944 kB
  • sloc: python: 54,107; makefile: 82
file content (703 lines) | stat: -rw-r--r-- 24,092 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
# (C) Copyright 2005-2023 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!

""" A TraitGridModel builds a grid from a list of traits objects. Each row
represents on object, each column one trait from those objects. All the objects
must be of the same type. Optionally a user may pass in a list of trait names
defining which traits will be shown in the columns and in which order. If this
list is not passed in, then the first object is inspected and every trait
from that object gets a column."""

from functools import cmp_to_key

from traits.api import (
    Any,
    Bool,
    Callable,
    Dict,
    HasTraits,
    Instance,
    Int,
    List,
    Str,
    Type,
    Union,
)
from traits.observation.api import match

from .grid_model import GridColumn, GridModel, GridSortEvent
from .trait_grid_cell_adapter import TraitGridCellAdapter


# The classes below are part of the table specification.
class TraitGridColumn(GridColumn):
    """ Structure for holding column specifications in a TraitGridModel. """

    # The trait name for this column. This takes precedence over method
    name = Union(None, Str)

    # A method name to call to get the value for this column
    method = Union(None, Str)

    # A method to be used to sort on this column
    sorter = Callable

    # A dictionary of formats for the display of different types. If it is
    # defined as a callable, then that callable must accept a single argument.
    formats = Dict(Type, Union(Str, Callable))

    # A name to designate the type of this column
    typename = Union(None, Str)
    # note: context menus should go in here as well? but we need
    #       more info than we have available at this point

    size = Int(-1)


class TraitGridSelection(HasTraits):
    """ Structure for holding specification information. """

    # The selected object
    obj = Instance(HasTraits)

    # The specific trait selected on the object
    trait_name = Union(None, Str)


# The meat.
class TraitGridModel(GridModel):
    """ A TraitGridModel builds a grid from a list of traits objects. Each row
    represents on object, each column one trait from those objects. All the
    objects must be of the same type. Optionally a user may pass in a list of
    trait names defining which traits will be shown in the columns and in
    which order. If this list is not passed in, then the first object is
    inspected and every trait from that object gets a column."""

    # A 2-dimensional list/array containing the grid data.
    data = List(Any)

    # The column definitions
    columns = Union(None, List(Union(None, Str, Instance(TraitGridColumn))))

    # The trait to look at to get the row name
    row_name_trait = Union(None, Str)

    # Allow column sorting?
    allow_column_sort = Bool(True)

    # A factory to generate new rows. If this is not None then it must
    # be a no-argument function.
    row_factory = Callable

    # ------------------------------------------------------------------------
    # 'object' interface.
    # ------------------------------------------------------------------------
    def __init__(self, **traits):
        """ Create a TraitGridModel object. """

        # Base class constructor
        super().__init__(**traits)

        # if no columns are pass in then create the list of names
        # from the first trait in the list. if the list is empty,
        # the columns should be an empty list as well.
        self._auto_columns = self.columns

        if self.columns is None or len(self.columns) == 0:
            if self.data is not None and len(self.data) > 0:
                self._auto_columns = []

                # we only add traits that aren't events, since events
                # are write-only
                for name, trait in self.data[0].traits().items():
                    if trait.type != "event":
                        self._auto_columns.append(TraitGridColumn(name=name))
            else:
                self._auto_columns = []

        # attach trait handlers to the list object
        self.observe(self._on_data_changed, "data")
        self.observe(self._on_data_items_changed, "data:items")

        # attach appropriate trait handlers to objects in the list
        self.__manage_data_listeners(self.data)

        # attach a listener to the column definitions so we refresh when
        # they change
        self.observe(self._on_columns_changed, "columns")
        self.observe(self._on_columns_items_changed, "columns:items")
        # attach listeners to the column definitions themselves
        self.__manage_column_listeners(self.columns)

        # attach a listener to the row_name_trait
        self.observe(self._on_row_name_trait_changed, "row_name_trait")

    # ------------------------------------------------------------------------
    # 'GridModel' interface.
    # ------------------------------------------------------------------------

    def get_column_count(self):
        """ Return the number of columns for this table. """

        return len(self._auto_columns)

    def get_column_name(self, index):
        """ Return the label of the column specified by the
        (zero-based) index. """

        try:
            name = col = self._auto_columns[index]
            if isinstance(col, TraitGridColumn):
                if col.label is not None:
                    name = col.label
                else:
                    name = col.name
        except IndexError:
            name = ""

        return name

    def get_column_size(self, index):
        """ Return the size in pixels of the column indexed by col.
            A value of -1 or None means use the default. """

        size = -1
        try:
            col = self._auto_columns[index]
            if isinstance(col, TraitGridColumn):
                size = col.size
        except IndexError:
            pass

        return size

    def get_cols_drag_value(self, cols):
        """ Return the value to use when the specified columns are dragged or
        copied and pasted. cols is a list of column indexes. """

        # iterate over every column, building a list of the values in that
        # column
        value = []
        for col in cols:
            value.append(self.__get_data_column(col))

        return value

    def get_cols_selection_value(self, cols):
        """ Returns a list of TraitGridSelection objects containing the
        object corresponding to the grid rows and the traits corresponding
        to the specified columns. """

        values = []
        for obj in self.data:
            for col in cols:
                values.append(
                    TraitGridSelection(
                        obj=obj, trait_name=self.__get_column_name(col)
                    )
                )

        return values

    def sort_by_column(self, col, reverse=False):
        """ Sort model data by the column indexed by col. """

        # first check to see if we allow sorts by column
        if not self.allow_column_sort:
            return

        # see if a sorter is specified for this column
        try:
            column = self._auto_columns[col]
            name = self.__get_column_name(col)
            # by default we use cmp to sort on the traits
            key = None
            if (
                isinstance(column, TraitGridColumn)
                and column.sorter is not None
            ):
                key = cmp_to_key(column.sorter)
        except IndexError:
            return

        def key_function(a):
            trait = getattr(a, name, None)
            if key:
                return key(trait)

        self.data.sort(key=key_function, reverse=reverse)

        # now fire an event to tell the grid we're sorted
        self.column_sorted = GridSortEvent(index=col, reversed=reverse)

    def is_column_read_only(self, index):
        """ Return True if the column specified by the zero-based index
        is read-only. """

        return self.__get_column_readonly(index)

    def get_row_count(self):
        """ Return the number of rows for this table. """

        if self.data is not None:
            count = len(self.data)
        else:
            count = 0

        return count

    def get_row_name(self, index):
        """ Return the name of the row specified by the
        (zero-based) index. """

        if self.row_name_trait is not None:
            try:
                row = self._get_row(index)
                if hasattr(row, self.row_name_trait):
                    name = getattr(row, self.row_name_trait)
            except IndexError:
                name = str(index + 1)

        else:
            name = str(index + 1)

        return name

    def get_rows_drag_value(self, rows):
        """ Return the value to use when the specified rows are dragged or
        copied and pasted. rows is a list of row indexes. If there is only
        one row listed, return the corresponding trait object. If more than
        one row is listed then return a list of objects. """

        # return a list of objects
        value = []

        for index in rows:
            try:
                # note that we can't use get_value for this because it
                # sometimes returns strings instead of the actual value,
                # e.g. in cases where a float_format is specified
                value.append(self._get_row(index))
            except IndexError:
                value.append(None)

        return value

    def get_rows_selection_value(self, rows):
        """ Returns a list of TraitGridSelection objects containing the
        object corresponding to the selected rows. """

        values = []
        for row_index in rows:
            values.append(TraitGridSelection(obj=self.data[row_index]))

        return values

    def is_row_read_only(self, index):
        """ Return True if the row specified by the zero-based index
        is read-only. """

        return False

    def get_cell_editor(self, row, col):
        """ Return the editor for the specified cell. """

        # print 'TraitGridModel.get_cell_editor row: ', row, ' col: ', col

        obj = self.data[row]
        trait_name = self.__get_column_name(col)
        trait = obj.base_trait(trait_name)
        if trait is None:
            return None

        factory = trait.get_editor()

        return TraitGridCellAdapter(factory, obj, trait_name, "")

    def get_cell_drag_value(self, row, col):
        """ Return the value to use when the specified cell is dragged or
        copied and pasted. """

        # find the name of the column indexed by col
        # note that this code is the same as the get_value code but without
        # the potential string formatting
        column = self.__get_column(col)
        obj = self._get_row(row)

        value = self._get_data_from_row(obj, column)

        return value

    def get_cell_selection_value(self, row, col):
        """ Returns a TraitGridSelection object specifying the data stored
        in the table at (row, col). """

        obj = self.data[row]
        trait_name = self.__get_column_name(col)

        return TraitGridSelection(obj=obj, trait_name=trait_name)

    def resolve_selection(self, selection_list):
        """ Returns a list of (row, col) grid-cell coordinates that
        correspond to the objects in objlist. For each coordinate, if the
        row is -1 it indicates that the entire column is selected. Likewise
        coordinates with a column of -1 indicate an entire row that is
        selected. For the TraitGridModel, the objects in objlist must
        be TraitGridSelection objects. """

        cells = []
        for selection in selection_list:
            try:
                row = self.data.index(selection.obj)
            except ValueError:
                continue

            column = -1
            if selection.trait_name is not None:
                column = self._get_column_index_by_trait(selection.trait_name)
                if column is None:
                    continue

            cells.append((row, column))

        return cells

    def get_type(self, row, col):
        """ Return the value stored in the table at (row, col). """

        typename = self.__get_column_typename(col)

        return typename

    def get_value(self, row, col):
        """ Return the value stored in the table at (row, col). """

        value = self.get_cell_drag_value(row, col)
        formats = self.__get_column_formats(col)

        if (
            value is not None
            and formats is not None
            and type(value) in formats
            and formats[type(value)] is not None
        ):
            try:
                format = formats[type(value)]
                if callable(format):
                    value = format(value)
                else:
                    value = format % value
            except TypeError:
                # not enough arguments? wrong kind of arguments?
                pass

        return value

    def is_cell_empty(self, row, col):
        """ Returns True if the cell at (row, col) has a None value,
        False otherwise."""

        value = self.get_value(row, col)

        return value is None

    def is_cell_editable(self, row, col):
        """ Returns True if the cell at (row, col) is editable,
        False otherwise. """
        return not self.is_column_read_only(col)

    # ------------------------------------------------------------------------
    # protected 'GridModel' interface.
    # ------------------------------------------------------------------------
    def _insert_rows(self, pos, num_rows):
        """ Inserts num_rows at pos and fires an event iff a factory method
        for new rows is defined. Otherwise returns 0. """

        count = 0
        if self.row_factory is not None:
            new_data = []
            for i in range(num_rows):
                new_data.append(self.row_factory())

            count = self._insert_rows_into_model(pos, new_data)
            self.rows_added = ("added", pos, new_data)

        return count

    def _delete_rows(self, pos, num_rows):
        """ Removes rows pos through pos + num_rows from the model. """

        if pos + num_rows >= self.get_row_count():
            num_rows = self.get_rows_count() - pos

        return self._delete_rows_from_model(pos, num_rows)

    def _set_value(self, row, col, value):
        """ Sets the value of the cell at (row, col) to value.

        Raises a ValueError if the value is vetoed or the cell at
        (row, col) does not exist. """

        # print 'TraitGridModel._set_value: new: ', value

        new_rows = 0
        # find the column indexed by col
        column = self.__get_column(col)
        obj = self._get_row(row)
        success = False
        if obj is not None:
            success = self._set_data_on_row(obj, column, value)
        else:
            # Add a new row.
            new_rows = self._insert_rows(self.get_row_count(), 1)
            if new_rows > 0:
                # now set the value on the new object
                obj = self._get_row(self.get_row_count() - 1)
                success = self._set_data_on_row(obj, column, value)

        if not success:
            # fixme: what do we do in this case? veto the set somehow? raise
            #        an exception?
            pass

        return new_rows

    # ------------------------------------------------------------------------
    # protected interface.
    # ------------------------------------------------------------------------
    def _get_row(self, index):
        """ Return the object that corresponds to the row at index. Override
        this to handle very large data sets. """

        return self.data[index]

    def _get_data_from_row(self, row, column):
        """ Retrieve the data specified by column for this row. Attribute
        can be either a member of the row object, or a no-argument method
        on that object. Override this method to provide alternative ways
        of accessing the data in the object. """

        value = None

        if row is not None and column is not None:
            if not isinstance(column, TraitGridColumn):
                # first handle the case where the column
                # definition might be just a string
                if hasattr(row, column):
                    value = getattr(row, column)
            elif column.name is not None and hasattr(row, column.name):
                # this is the case when the trait name is specified
                value = getattr(row, column.name)
            elif column.method is not None and hasattr(row, column.method):
                # this is the case when an object method is specified
                value = getattr(row, column.method)()

        if value is None:
            return None
        else:
            return str(value)  # value

    def _set_data_on_row(self, row, column, value):
        """ Retrieve the data specified by column for this row. Attribute
        can be either a member of the row object, or a no-argument method
        on that object. Override this method to provide alternative ways
        of accessing the data in the object. """

        success = False

        if row is not None and column is not None:
            if not isinstance(column, TraitGridColumn):
                if hasattr(row, column):
                    # sometimes the underlying grid gives us 0/1 instead
                    # of True/False. do some conversion here to make that
                    # case worl.
                    # if type(getattr(row, column)) == bool and \
                    #       type(value) != bool:
                    # convert the value to a boolean
                    #    value = bool(value)

                    setattr(row, column, value)
                    success = True
            elif column.name is not None and hasattr(row, column.name):
                # sometimes the underlying grid gives us 0/1 instead
                # of True/False. do some conversion here to make that
                # case worl.
                # if type(getattr(row, column.name)) == bool and \
                #       type(value) != bool:
                # convert the value to a boolean
                #    value = bool(value)
                setattr(row, column.name, value)
                success = True

            # do nothing in the method case as we don't allow rows
            # defined to return a method value to set the value

        return success

    def _insert_rows_into_model(self, pos, new_data):
        """ Insert the given new rows into the model. Override this method
        to handle very large data sets. """

        for data in new_data:
            self.data.insert(pos, data)
            pos += 1

    def _delete_rows_from_model(self, pos, num_rows):
        """ Delete the specified rows from the model. Override this method
        to handle very large data sets. """
        del self.data[pos, pos + num_rows]

        return num_rows

    # ------------------------------------------------------------------------
    # trait handlers
    # ------------------------------------------------------------------------

    def _on_row_name_trait_changed(self, event):
        """ Force the grid to refresh when any underlying trait changes. """
        self.fire_content_changed()

    def _on_columns_changed(self, event):
        """ Force the grid to refresh when any underlying trait changes. """
        self.__manage_column_listeners(event.old, remove=True)
        self.__manage_column_listeners(self.columns)
        self._auto_columns = self.columns
        self.fire_structure_changed()

    def _on_columns_items_changed(self, event):
        """ Force the grid to refresh when any underlying trait changes. """

        self.__manage_column_listeners(event.removed, remove=True)
        self.__manage_column_listeners(event.added)
        self.fire_structure_changed()

    def _on_contained_trait_changed(self, event):
        """ Force the grid to refresh when any underlying trait changes. """
        self.fire_content_changed()

    def _on_data_changed(self, event):
        """ Force the grid to refresh when the underlying list changes. """

        self.__manage_data_listeners(event.old, remove=True)
        self.__manage_data_listeners(self.data)
        self.fire_structure_changed()

    def _on_data_items_changed(self, event):
        """ Force the grid to refresh when the underlying list changes. """

        # if an item was removed then remove that item's listener
        self.__manage_data_listeners(event.removed, remove=True)

        # if items were added then add trait change listeners on those items
        self.__manage_data_listeners(event.added)

        self.fire_content_changed()

    # ------------------------------------------------------------------------
    # private interface.
    # ------------------------------------------------------------------------

    def __get_data_column(self, col):
        """ Return a 1-d list of data from the column indexed by col. """

        row_count = self.get_row_count()

        coldata = []
        for row in range(row_count):
            try:
                val = self.get_value(row, col)
                if val is None:
                    coldata.append(None)
                else:
                    coldata.append(val)  # self.get_value(row, col))
            except IndexError:
                coldata.append(None)

        return coldata

    def __get_column(self, col):

        try:
            column = self._auto_columns[col]
        except IndexError:
            column = None

        return column

    def __get_column_name(self, col):

        name = column = self.__get_column(col)
        if isinstance(column, TraitGridColumn):
            name = column.name

        return name

    def __get_column_typename(self, col):

        column = self.__get_column(col)
        typename = None
        if isinstance(column, TraitGridColumn):
            typename = column.typename

        return typename

    def __get_column_readonly(self, col):

        read_only = False
        column = self.__get_column(col)
        if isinstance(column, TraitGridColumn):
            read_only = column.read_only

        return read_only

    def __get_column_formats(self, col):

        formats = None
        column = self.__get_column(col)
        if isinstance(column, TraitGridColumn):
            formats = column.formats

        return formats

    def _get_column_index_by_trait(self, trait_name):

        cols = self._auto_columns
        for i in range(len(cols)):
            col = cols[i]
            if isinstance(col, TraitGridColumn):
                col_name = col.name
            else:
                col_name = col

            if col_name == trait_name:
                return i

        return None

    def __manage_data_listeners(self, list, remove=False):
        # attach appropriate trait handlers to objects in the list
        if list is not None:
            for item in list:
                item.observe(
                    self._on_contained_trait_changed,
                    match(lambda name, trait: True),
                    remove=remove
                )

    def __manage_column_listeners(self, collist, remove=False):

        if collist is not None:
            for col in collist:
                if isinstance(col, TraitGridColumn):
                    col.observe(
                        self._on_columns_changed,
                        match(lambda name, trait: True),
                        remove=remove,
                    )