File: opencl.py

package info (click to toggle)
compyle 0.8.1-11
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,100 kB
  • sloc: python: 12,337; makefile: 21
file content (115 lines) | stat: -rw-r--r-- 2,935 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Common OpenCL related functionality.
"""
from __future__ import print_function
import pyopencl as cl

from .config import get_config
from .profile import profile_kernel, named_profile

_ctx = None
_queue = None


class DeviceWGSException(Exception):
    pass


def get_context():
    global _ctx
    if _ctx is None:
        _ctx = cl.create_some_context()
    return _ctx


def set_context(ctx):
    global _ctx
    _ctx = ctx


def get_queue():
    global _queue
    if _queue is None:
        kwargs = dict()
        if get_config().profile:
            kwargs['properties'] = cl.command_queue_properties.PROFILING_ENABLE
        _queue = cl.CommandQueue(get_context(), **kwargs)
    return _queue


def set_queue(q):
    global _queue
    _queue = q


class SimpleKernel(object):
    """ElementwiseKernel substitute that supports a custom work group size.
    """

    def __init__(self, ctx, args, operation, wgs,
                 name="", preamble="", options=[]):
        self.args = args
        self.operation = operation
        self.name = name
        self.preamble = preamble
        self.options = options

        self.prg = cl.Program(ctx, self._generate()).build(options)
        self.knl = getattr(self.prg, name)

        if self.get_max_wgs() < wgs:
            raise DeviceWGSException("")

    def _massage_arg(self, arg):
        if '*' in arg:
            return "__global " + arg
        return arg

    def _generate(self):
        args = [self._massage_arg(arg) for arg in self.args.split(",")]

        source = r"""
        %(preamble)s

        __kernel void %(name)s(%(args)s)
        {
          int lid = get_local_id(0);
          int gsize = get_global_size(0);
          int work_group_start = get_local_size(0)*get_group_id(0);
          long i = get_global_id(0);

          %(body)s
        }
        """ % {
            "args": ",".join(args),
            "name": self.name,
            "preamble": self.preamble,
            "body": self.operation
        }

        return source

    def get_max_wgs(self):
        return self.knl.get_work_group_info(
            cl.kernel_work_group_info.WORK_GROUP_SIZE,
            get_queue().device
        )

    def __call__(self, *args, **kwargs):
        wait_for = kwargs.pop("wait_for", None)
        queue = kwargs.pop("queue", None)
        gs = kwargs.pop("gs", None)
        ls = kwargs.pop("ls", None)

        if queue is None or gs is None or ls is None:
            raise ValueError("queue, gs and ls can not be empty")

        if kwargs:
            raise TypeError("unknown keyword arguments: '%s'"
                            % ", ".join(kwargs))

        def unwrap(arg):
            return arg.data if isinstance(arg, cl.array.Array) else arg

        self.knl.set_args(*[unwrap(arg) for arg in args])
        return cl.enqueue_nd_range_kernel(queue, self.knl, gs, ls,
                                          wait_for=wait_for)