File: graph_executor.hpp

package info (click to toggle)
taskflow 3.9.0%2Bds-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 45,948 kB
  • sloc: cpp: 39,058; xml: 35,572; python: 12,935; javascript: 1,732; makefile: 59; sh: 16
file content (78 lines) | stat: -rw-r--r-- 1,692 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#pragma once

#include "./graph_base.hpp"
#include <taskflow/cuda/cudaflow.hpp>
#include <cassert>

template <typename OPT>
class GraphExecutor {

  public:

    GraphExecutor(Graph& graph, int dev_id = 0);

    template <typename... OPT_Args>
    void traversal(OPT_Args&&... args);

  private:

    int _dev_id;

    Graph& _g;

};

template <typename OPT>
GraphExecutor<OPT>::GraphExecutor(Graph& graph, int dev_id): _g{graph}, _dev_id{dev_id} {
  //TODO: why we cannot put cuda lambda function here?
}

template <typename OPT>
template <typename... OPT_Args>
void GraphExecutor<OPT>::traversal(OPT_Args&&... args) {

  tf::Taskflow taskflow;
  tf::Executor executor;

  taskflow.emplace([this, args...]() {

    tf::cudaFlowCapturer cf;

    cf.make_optimizer<OPT>(args...);

    std::vector<std::vector<tf::cudaTask>> tasks;
    tasks.resize(_g.get_graph().size());

    for(size_t l = 0; l < _g.get_graph().size(); ++l) {
      tasks[l].resize((_g.get_graph())[l].size());
      for(size_t i = 0; i < (_g.get_graph())[l].size(); ++i) {
        bool* v = _g.at(l, i).visited;
        tasks[l][i] = cf.single_task([v] __device__ () {
          *v = true;
        });
      }
    }

    for(size_t l = 0; l < _g.get_graph().size() - 1; ++l) {
      for(size_t i = 0; i < (_g.get_graph())[l].size(); ++i) {
        for(auto&& out_node: _g.at(l, i).out_nodes) {
          tasks[l][i].precede(tasks[l + 1][out_node]);
        }
      }
    }

    tf::cudaStream stream;
    cf.run(stream);
    stream.synchronize();

  }).name("traverse");

  //auto check_t = taskflow.emplace([this](){
    //assert(_g.traversed());
  //});

  //trav_t.precede(check_t);

  executor.run(taskflow).wait();
}