File: element_crossovers.py

package info (click to toggle)
python-ase 3.26.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 15,484 kB
  • sloc: python: 148,112; xml: 2,728; makefile: 110; javascript: 47
file content (148 lines) | stat: -rw-r--r-- 5,651 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# fmt: off

"""Crossover classes, that cross the elements in the supplied
atoms objects.

"""
import numpy as np

from ase.ga.offspring_creator import OffspringCreator


class ElementCrossover(OffspringCreator):
    """Base class for all operators where the elements of
    the atoms objects cross.

    """

    def __init__(self, element_pool, max_diff_elements,
                 min_percentage_elements, verbose, rng=np.random):
        OffspringCreator.__init__(self, verbose, rng=rng)

        if not isinstance(element_pool[0], (list, np.ndarray)):
            self.element_pools = [element_pool]
        else:
            self.element_pools = element_pool

        if max_diff_elements is None:
            self.max_diff_elements = [None for _ in self.element_pools]
        elif isinstance(max_diff_elements, int):
            self.max_diff_elements = [max_diff_elements]
        else:
            self.max_diff_elements = max_diff_elements
        assert len(self.max_diff_elements) == len(self.element_pools)

        if min_percentage_elements is None:
            self.min_percentage_elements = [0 for _ in self.element_pools]
        elif isinstance(min_percentage_elements, (int, float)):
            self.min_percentage_elements = [min_percentage_elements]
        else:
            self.min_percentage_elements = min_percentage_elements
        assert len(self.min_percentage_elements) == len(self.element_pools)

        self.min_inputs = 2

    def get_new_individual(self, parents):
        raise NotImplementedError


class OnePointElementCrossover(ElementCrossover):
    """Crossover of the elements in the atoms objects. Point of cross
    is chosen randomly.

    Parameters:

    element_pool: List of elements in the phase space. The elements can be
        grouped if the individual consist of different types of elements.
        The list should then be a list of lists e.g. [[list1], [list2]]

    max_diff_elements: The maximum number of different elements in the
        individual. Default is infinite. If the elements are grouped
        max_diff_elements should be supplied as a list with each input
        corresponding to the elements specified in the same input in
        element_pool.

    min_percentage_elements: The minimum percentage of any element in
        the individual. Default is any number is allowed. If the elements
        are grouped min_percentage_elements should be supplied as a list
        with each input corresponding to the elements specified in the
        same input in element_pool.

    Example: element_pool=[[A,B,C,D],[x,y,z]], max_diff_elements=[3,2],
        min_percentage_elements=[.25, .5]
        An individual could be "D,B,B,C,x,x,x,x,z,z,z,z"

    rng: Random number generator
        By default numpy.random.
    """

    def __init__(self, element_pool, max_diff_elements=None,
                 min_percentage_elements=None, verbose=False, rng=np.random):
        ElementCrossover.__init__(self, element_pool, max_diff_elements,
                                  min_percentage_elements, verbose, rng=rng)
        self.descriptor = 'OnePointElementCrossover'

    def get_new_individual(self, parents):
        f, m = parents

        indi = self.initialize_individual(f)
        indi.info['data']['parents'] = [i.info['confid'] for i in parents]

        cut_choices = [i for i in range(1, len(f) - 1)]
        self.rng.shuffle(cut_choices)
        for cut in cut_choices:
            fsyms = f.get_chemical_symbols()
            msyms = m.get_chemical_symbols()
            syms = fsyms[:cut] + msyms[cut:]
            ok = True
            for i, e in enumerate(self.element_pools):
                elems = e[:]
                elems_in, indices_in = zip(*[(a.symbol, a.index) for a in f
                                             if a.symbol in elems])
                max_diff_elem = self.max_diff_elements[i]
                min_percent_elem = self.min_percentage_elements[i]
                if min_percent_elem == 0:
                    min_percent_elem = 1. / len(elems_in)
                if max_diff_elem is None:
                    max_diff_elem = len(elems_in)

                syms_in = [syms[i] for i in indices_in]
                for s in set(syms_in):
                    percentage = syms_in.count(s) / float(len(syms_in))
                    if percentage < min_percent_elem:
                        ok = False
                        break
                    num_diff = len(set(syms_in))
                    if num_diff > max_diff_elem:
                        ok = False
                        break
                if not ok:
                    break
            if ok:
                break

        # Sufficient or does some individuals appear
        # below min_percentage_elements

        for a in f[:cut] + m[cut:]:
            indi.append(a)

        parent_message = ':Parents {} {}'.format(f.info['confid'],
                                                 m.info['confid'])
        return (self.finalize_individual(indi),
                self.descriptor + parent_message)


class TwoPointElementCrossover(ElementCrossover):
    """Crosses two individuals by choosing two cross points
    at random"""

    def __init__(self, element_pool, max_diff_elements=None,
                 min_percentage_elements=None, verbose=False):
        ElementCrossover.__init__(self, element_pool,
                                  max_diff_elements,
                                  min_percentage_elements, verbose)
        self.descriptor = 'TwoPointElementCrossover'

    def get_new_individual(self, parents):
        raise NotImplementedError