1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
|
/*
* Copyright (c) 2018-2020,2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef ARM_COMPUTE_GRAPH_GRAPH_H
#define ARM_COMPUTE_GRAPH_GRAPH_H
#include "arm_compute/graph/Edge.h"
#include "arm_compute/graph/INode.h"
#include "arm_compute/graph/Tensor.h"
#include "arm_compute/graph/Types.h"
#include "support/Mutex.h"
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#ifndef BARE_METAL
#include <thread>
#endif /* BARE_METAL */
namespace arm_compute
{
namespace graph
{
/** Graph class
*
* Represents a multiple source - multiple sink directed graph
*/
class Graph final
{
public:
Graph() = default;
/** Constructor
*
* @param[in] id Graph identification number. Can be used to differentiate between graphs. Default value 0
* @param[in] name Graph name. Default value empty string
*/
Graph(GraphID id, std::string name);
/** Prevent instances of this class from being copied (As this class contains pointers) */
Graph(const Graph &) = delete;
/** Prevent instances of this class from being copy assigned (As this class contains pointers) */
Graph &operator=(const Graph &) = delete;
/** Prevent instances of this class from being moved (As this class contains non movable objects) */
Graph(Graph &&) = delete;
/** Prevent instances of this class from being moved (As this class contains non movable objects) */
Graph &operator=(Graph &&) = delete;
/** Adds a node to the graph
*
* @note Models a single output node
*
* @tparam NT Node operation
* @tparam Ts Arguments to operation
*
* @param[in] args Node arguments
*
* @return ID of the node
*/
template <typename NT, typename... Ts>
NodeID add_node(Ts &&...args);
/** Remove the node with the given ID
*
* @param[in] nid ID of the node to remove
*
* @return True if the removal took place else false
*/
bool remove_node(NodeID nid);
/** Adds a connection between two nodes
*
* @param[in] source ID of the source node
* @param[in] source_idx Output index of the source node
* @param[in] sink ID of the sink node
* @param[in] sink_idx Input index of the sink node
*
* @return ID of this connection
*/
EdgeID add_connection(NodeID source, size_t source_idx, NodeID sink, size_t sink_idx);
/** Removes an edge (connection)
*
* @param[in] eid Connection to remove
*
* @return True if the removal took place else false
*/
bool remove_connection(EdgeID eid);
/** Returns graph name
*
* @return Graph name
*/
std::string name() const;
/** Returns graph id
*
* @return Graph id
*/
GraphID id() const;
/** Returns graph input nodes
*
* @param[in] type Type of nodes to return
*
* @return vector containing the graph node of given type
*/
const std::vector<NodeID> &nodes(NodeType type);
/** Returns nodes of graph
*
* @warning Nodes can be nullptr if they have been removed during the mutation steps of the graph
*
* @return Nodes of graph
*/
std::vector<std::unique_ptr<INode>> &nodes();
/** Returns nodes of graph
*
* @warning Nodes can be nullptr if they have been removed during the mutation steps of the graph
*
* @return Nodes of graph
*/
const std::vector<std::unique_ptr<INode>> &nodes() const;
/** Returns edges of graph
*
* @warning Edges can be nullptr if they have been removed during the mutation steps of the graph
*
* @return Edges of graph
*/
const std::vector<std::unique_ptr<Edge>> &edges() const;
/** Returns tensors of graph
*
* @warning Tensor can be nullptr if they have been removed during the mutation steps of the graph
*
* @return Tensors of graph
*/
std::vector<std::unique_ptr<Tensor>> &tensors();
/** Returns tensors of graph
*
* @warning Tensor can be nullptr if they have been removed during the mutation steps of the graph
*
* @return Tensors of graph
*/
const std::vector<std::unique_ptr<Tensor>> &tensors() const;
/** Get node object given its id
*
* @warning Can be nullptr if node was removed during the mutation steps of the graph
*
* @param[in] id Node ID
*
* @return The actual node object
*/
const INode *node(NodeID id) const;
/** Get node object given its id
*
* @warning Can be nullptr if node was removed during the mutation steps of the graph
*
* @param[in] id Node ID
*
* @return The actual node object
*/
INode *node(NodeID id);
/** Get edge object given its id
*
* @warning Can be nullptr if node was removed during the mutation steps of the graph
*
* @param[in] id Edge ID
*
* @return The actual edge object
*/
const Edge *edge(EdgeID id) const;
/** Get edge object given its id
*
* @warning Can be nullptr if node was removed during the mutation steps of the graph
*
* @param[in] id Edge ID
*
* @return The actual edge object
*/
Edge *edge(EdgeID id);
/** Get tensor object given its id
*
* @warning Can be nullptr if tensor was removed during the mutation steps of the graph
*
* @param[in] id Tensor ID
*
* @return The actual tensor object
*/
const Tensor *tensor(TensorID id) const;
/** Get tensor object given its id
*
* @warning Can be nullptr if tensor was removed during the mutation steps of the graph
*
* @param[in] id Tensor ID
*
* @return The actual tensor object
*/
Tensor *tensor(TensorID id);
private:
/** Creates a tensor object
*
* @param[in] desc Tensor descriptor
*
* @return Tensor ID
*/
TensorID create_tensor(const TensorDescriptor &desc = TensorDescriptor());
private:
GraphID _id = GraphID(0); /**< Graph id */
std::string _name = {}; /**< Graph name */
std::vector<std::unique_ptr<INode>> _nodes = {}; /**< Graph nodes */
std::vector<std::unique_ptr<Edge>> _edges = {}; /**< Graph edges */
std::vector<std::unique_ptr<Tensor>> _tensors = {}; /**< Graph tensors */
std::map<NodeType, std::vector<NodeID>> _tagged_nodes = {}; /**< Graph nodes map with the node type as key */
arm_compute::Mutex _mtx = {}; /**< Mutex used for graph construction */
};
template <typename NT, typename... Ts>
inline NodeID Graph::add_node(Ts &&...args)
{
arm_compute::lock_guard<arm_compute::Mutex> lock(_mtx);
// Create node
NodeID nid = _nodes.size();
auto node = std::make_unique<NT>(std::forward<Ts>(args)...);
node->set_graph(this);
node->set_id(nid);
// Keep track of input nodes
_tagged_nodes[node->type()].push_back(nid);
// Associate a new tensor with each output
for (auto &output : node->_outputs)
{
output = create_tensor();
}
// Propagate node shape if possible
node->forward_descriptors();
// Add node to the graph nodes
_nodes.push_back(std::move(node));
return nid;
}
} // namespace graph
} // namespace arm_compute
#endif /* ARM_COMPUTE_GRAPH_GRAPH_H */
|