1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
|
# (C) Copyright 2005-2023 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!
"""
Unit tests for the `HasTraits.class_traits` class function.
"""
import unittest
from traits.api import HasTraits, Int, List, Str
class A(HasTraits):
x = Int
name = Str(marked=True)
class B(A):
pass
class C(B):
lst = List(marked=False)
y = Int(marked=True)
class TestClassTraits(unittest.TestCase):
def test_all_class_traits(self):
expected = ["x", "name", "trait_added", "trait_modified"]
self.assertCountEqual(A.class_traits(), expected)
# Check that derived classes report the correct traits.
self.assertCountEqual(B.class_traits(), expected)
expected.extend(("lst", "y"))
self.assertCountEqual(C.class_traits(), expected)
def test_class_traits_with_metadata(self):
# Retrieve all traits that have the `marked` metadata
# attribute set to True.
traits = C.class_traits(marked=True)
self.assertCountEqual(list(traits.keys()), ("y", "name"))
# Retrieve all traits that have a `marked` metadata attribute,
# regardless of its value.
marked_traits = C.class_traits(marked=lambda attr: attr is not None)
self.assertCountEqual(marked_traits, ("y", "name", "lst"))
|