File: subprocesscalculator.py

package info (click to toggle)
python-ase 3.26.0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 15,484 kB
  • sloc: python: 148,112; xml: 2,728; makefile: 110; javascript: 47
file content (361 lines) | stat: -rw-r--r-- 10,637 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
# fmt: off

import os
import pickle
import sys
from abc import ABC, abstractmethod
from subprocess import PIPE, Popen

from ase.calculators.calculator import Calculator, all_properties


class PackedCalculator(ABC):
    """Portable calculator for use via PythonSubProcessCalculator.

    This class allows creating and talking to a calculator which
    exists inside a different process, possibly with MPI or srun.

    Use this when you want to use ASE mostly in serial, but run some
    calculations in a parallel Python environment.

    Most existing calculators can be used this way through the
    NamedPackedCalculator implementation.  To customize the behaviour
    for other calculators, write a custom class inheriting this one.

    Example::

      from ase.build import bulk

      atoms = bulk('Au')
      pack = NamedPackedCalculator('emt')

      with pack.calculator() as atoms.calc:
          energy = atoms.get_potential_energy()

    The computation takes place inside a subprocess which lives as long
    as the with statement.
    """

    @abstractmethod
    def unpack_calculator(self) -> Calculator:
        """Return the calculator packed inside.

        This method will be called inside the subprocess doing
        computations."""

    def calculator(self, mpi_command=None) -> 'PythonSubProcessCalculator':
        """Return a PythonSubProcessCalculator for this calculator.

        The subprocess calculator wraps a subprocess containing
        the actual calculator, and computations are done inside that
        subprocess."""
        return PythonSubProcessCalculator(self, mpi_command=mpi_command)


class NamedPackedCalculator(PackedCalculator):
    """PackedCalculator implementation which works with standard calculators.

    This works with calculators known by ase.calculators.calculator."""

    def __init__(self, name, kwargs=None):
        self._name = name
        if kwargs is None:
            kwargs = {}
        self._kwargs = kwargs

    def unpack_calculator(self):
        from ase.calculators.calculator import get_calculator_class
        cls = get_calculator_class(self._name)
        return cls(**self._kwargs)

    def __repr__(self):
        return f'{self.__class__.__name__}({self._name}, {self._kwargs})'


class MPICommand:
    def __init__(self, argv):
        self.argv = argv

    @classmethod
    def python_argv(cls):
        return [sys.executable, '-m', 'ase.calculators.subprocesscalculator']

    @classmethod
    def parallel(cls, nprocs, mpi_argv=()):
        return cls(['mpiexec', '-n', str(nprocs)]
                   + list(mpi_argv)
                   + cls.python_argv()
                   + ['mpi4py'])

    @classmethod
    def serial(cls):
        return MPICommand(cls.python_argv() + ['standard'])

    def execute(self):
        # On this computer (Ubuntu 20.04 + OpenMPI) the subprocess crashes
        # without output during startup if os.environ is not passed along.
        # Hence we pass os.environ.  Not sure if this is a machine thing
        # or in general.  --askhl
        return Popen(self.argv, stdout=PIPE,
                     stdin=PIPE, env=os.environ)


def gpaw_process(ncores=1, **kwargs):
    packed = NamedPackedCalculator('gpaw', kwargs)
    mpicommand = MPICommand([
        sys.executable, '-m', 'gpaw', '-P', str(ncores), 'python', '-m',
        'ase.calculators.subprocesscalculator', 'standard',
    ])
    return PythonSubProcessCalculator(packed, mpicommand)


class PythonSubProcessCalculator(Calculator):
    """Calculator for running calculations in external processes.

    TODO: This should work with arbitrary commands including MPI stuff.

    This calculator runs a subprocess wherein it sets up an
    actual calculator.  Calculations are forwarded through pickle
    to that calculator, which returns results through pickle."""
    implemented_properties = list(all_properties)

    def __init__(self, calc_input, mpi_command=None):
        super().__init__()

        # self.proc = None
        self.calc_input = calc_input
        if mpi_command is None:
            mpi_command = MPICommand.serial()
        self.mpi_command = mpi_command

        self.protocol = None

    def set(self, **kwargs):
        if hasattr(self, 'client'):
            raise RuntimeError('No setting things for now, thanks')

    def __repr__(self):
        return '{}({})'.format(type(self).__name__,
                               self.calc_input)

    def __enter__(self):
        assert self.protocol is None
        proc = self.mpi_command.execute()
        self.protocol = Protocol(proc)
        self.protocol.send(self.calc_input)
        return self

    def __exit__(self, *args):
        self.protocol.send('stop')
        self.protocol.proc.communicate()
        self.protocol = None

    def _run_calculation(self, atoms, properties, system_changes):
        self.protocol.send('calculate')
        self.protocol.send((atoms, properties, system_changes))

    def calculate(self, atoms, properties, system_changes):
        Calculator.calculate(self, atoms, properties, system_changes)
        # We send a pickle of self.atoms because this is a fresh copy
        # of the input, but without an unpicklable calculator:
        self._run_calculation(self.atoms.copy(), properties, system_changes)
        results = self.protocol.recv()
        self.results.update(results)

    def backend(self):
        return ParallelBackendInterface(self)


class Protocol:
    def __init__(self, proc):
        self.proc = proc

    def send(self, obj):
        pickle.dump(obj, self.proc.stdin)
        self.proc.stdin.flush()

    def recv(self):
        response_type, value = pickle.load(self.proc.stdout)

        if response_type == 'raise':
            raise value

        assert response_type == 'return'
        return value


class MockMethod:
    def __init__(self, name, calc):
        self.name = name
        self.calc = calc

    def __call__(self, *args, **kwargs):
        protocol = self.calc.protocol
        protocol.send('callmethod')
        protocol.send([self.name, args, kwargs])
        return protocol.recv()


class ParallelBackendInterface:
    def __init__(self, calc):
        self.calc = calc

    def __getattr__(self, name):
        return MockMethod(name, self.calc)


run_modes = {'standard', 'mpi4py'}


def callmethod(calc, attrname, args, kwargs):
    method = getattr(calc, attrname)
    value = method(*args, **kwargs)
    return value


def callfunction(func, args, kwargs):
    return func(*args, **kwargs)


def calculate(calc, atoms, properties, system_changes):
    # Again we need formalization of the results/outputs, and
    # a way to programmatically access all available properties.
    # We do a wild hack for now:
    calc.results.clear()
    # If we don't clear(), the caching is broken!  For stress.
    # But not for forces.  What dark magic from the depths of the
    # underworld is at play here?
    calc.calculate(atoms=atoms, properties=properties,
                   system_changes=system_changes)
    results = calc.results
    return results


def bad_mode():
    return SystemExit(f'sys.argv[1] must be one of {run_modes}')


def parallel_startup():
    try:
        run_mode = sys.argv[1]
    except IndexError:
        raise bad_mode()

    if run_mode not in run_modes:
        raise bad_mode()

    if run_mode == 'mpi4py':
        # We must import mpi4py before the rest of ASE, or world will not
        # be correctly initialized.
        import mpi4py  # noqa

    # We switch stdout so stray print statements won't interfere with outputs:
    binary_stdout = sys.stdout.buffer
    sys.stdout = sys.stderr

    return Client(input_fd=sys.stdin.buffer,
                  output_fd=binary_stdout)


class Client:
    def __init__(self, input_fd, output_fd):
        from ase.parallel import world
        self._world = world
        self.input_fd = input_fd
        self.output_fd = output_fd

    def recv(self):
        from ase.parallel import broadcast
        if self._world.rank == 0:
            obj = pickle.load(self.input_fd)
        else:
            obj = None

        obj = broadcast(obj, 0, self._world)
        return obj

    def send(self, obj):
        if self._world.rank == 0:
            pickle.dump(obj, self.output_fd)
            self.output_fd.flush()

    def mainloop(self, calc):
        while True:
            instruction = self.recv()
            if instruction == 'stop':
                return

            instruction_data = self.recv()

            response_type, value = self.process_instruction(
                calc, instruction, instruction_data)
            self.send((response_type, value))

    def process_instruction(self, calc, instruction, instruction_data):
        if instruction == 'callmethod':
            function = callmethod
            args = (calc, *instruction_data)
        elif instruction == 'calculate':
            function = calculate
            args = (calc, *instruction_data)
        elif instruction == 'callfunction':
            function = callfunction
            args = instruction_data
        else:
            raise RuntimeError(f'Bad instruction: {instruction}')

        try:
            value = function(*args)
        except Exception as ex:
            import traceback
            traceback.print_exc()
            response_type = 'raise'
            value = ex
        else:
            response_type = 'return'
        return response_type, value


class ParallelDispatch:
    """Utility class to run functions in parallel.

    with ParallelDispatch(...) as parallel:
        parallel.call(function, args, kwargs)

    """

    def __init__(self, mpicommand):
        self._mpicommand = mpicommand
        self._protocol = None

    def call(self, func, *args, **kwargs):
        self._protocol.send('callfunction')
        self._protocol.send((func, args, kwargs))
        return self._protocol.recv()

    def __enter__(self):
        assert self._protocol is None
        self._protocol = Protocol(self._mpicommand.execute())

        # Even if we are not using a calculator, we have to send one:
        pack = NamedPackedCalculator('emt', {})
        self._protocol.send(pack)
        # (We should get rid of that requirement.)

        return self

    def __exit__(self, *args):
        self._protocol.send('stop')
        self._protocol.proc.communicate()
        self._protocol = None


def main():
    client = parallel_startup()
    pack = client.recv()
    calc = pack.unpack_calculator()
    client.mainloop(calc)


if __name__ == '__main__':
    main()