File: functional.h

package info (click to toggle)
yosys 0.52-2
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 69,796 kB
  • sloc: ansic: 696,955; cpp: 239,736; python: 14,617; yacc: 3,529; sh: 2,175; makefile: 1,945; lex: 697; perl: 445; javascript: 323; tcl: 162; vhdl: 115
file content (645 lines) | stat: -rw-r--r-- 30,761 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
/*
 *  yosys -- Yosys Open SYnthesis Suite
 *
 *  Copyright (C) 2024  Emily Schmidt <emily@yosyshq.com>
 *  Copyright (C) 2024 National Technology and Engineering Solutions of Sandia, LLC
 *
 *  Permission to use, copy, modify, and/or distribute this software for any
 *  purpose with or without fee is hereby granted, provided that the above
 *  copyright notice and this permission notice appear in all copies.
 *
 *  THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 *  WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 *  MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 *  ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 *  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 *  ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 *  OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 */

#ifndef FUNCTIONAL_H
#define FUNCTIONAL_H

#include "kernel/yosys.h"
#include "kernel/compute_graph.h"
#include "kernel/drivertools.h"
#include "kernel/mem.h"
#include "kernel/utils.h"

USING_YOSYS_NAMESPACE
YOSYS_NAMESPACE_BEGIN

namespace Functional {
	// each function is documented with a short pseudocode declaration or definition
	// standard C/Verilog operators are used to describe the result
	// 
	// the sorts used in this are:
	// - bit[N]: a bitvector of N bits
	//   bit[N] can be indicated as signed or unsigned. this is not tracked by the functional backend
	//   but is meant to indicate how the value is interpreted
	//   if a bit[N] is marked as neither signed nor unsigned, this means the result should be valid with *either* interpretation
	// - memory[N, M]: a memory with N address and M data bits
	// - int: C++ int
	// - Const[N]: yosys RTLIL::Const (with size() == N)
	// - IdString: yosys IdString
	// - any: used in documentation to indicate that the sort is unconstrained
	//
	// nodes in the functional backend are either of sort bit[N] or memory[N,M] (for some N, M: int)
	// additionally, they can carry a constant of sort int, Const[N] or IdString
	// each node has a 'sort' field that stores the sort of the node
	// slice, zero_extend, sign_extend use the sort field to store out_width
	enum class Fn {
		// invalid() = known-invalid/shouldn't happen value
		// TODO: maybe remove this and use e.g. std::optional instead?
		invalid,
		// buf(a: any): any = a
		// no-op operation
		// when constructing the compute graph we generate invalid buf() nodes as a placeholder
		// and later insert the argument
		buf,
		// slice(a: bit[in_width], offset: int, out_width: int): bit[out_width] = a[offset +: out_width]
		// required: offset + out_width <= in_width
		slice,
		// zero_extend(a: unsigned bit[in_width], out_width: int): unsigned bit[out_width] = a (zero extended)
		// required: out_width > in_width
		zero_extend,
		// sign_extend(a: signed bit[in_width], out_width: int): signed bit[out_width] = a (sign extended)
		// required: out_width > in_width
		sign_extend,
		// concat(a: bit[N], b: bit[M]): bit[N+M] = {b, a} (verilog syntax)
		// concatenates two bitvectors, with a in the least significant position and b in the more significant position
		concat,
		// add(a: bit[N], b: bit[N]): bit[N] = a + b
		add,
		// sub(a: bit[N], b: bit[N]): bit[N] = a - b
		sub,
		// mul(a: bit[N], b: bit[N]): bit[N] = a * b
		mul,
		// unsigned_div(a: unsigned bit[N], b: unsigned bit[N]): bit[N] = a / b
		unsigned_div,
		// unsigned_mod(a: signed bit[N], b: signed bit[N]): bit[N] = a % b
		unsigned_mod,
		// bitwise_and(a: bit[N], b: bit[N]): bit[N] = a & b
		bitwise_and,
		// bitwise_or(a: bit[N], b: bit[N]): bit[N] = a | b
		bitwise_or,
		// bitwise_xor(a: bit[N], b: bit[N]): bit[N] = a ^ b
		bitwise_xor,
		// bitwise_not(a: bit[N]): bit[N] = ~a
		bitwise_not,
		// reduce_and(a: bit[N]): bit[1] = &a
		reduce_and,
		// reduce_or(a: bit[N]): bit[1] = |a
		reduce_or,
		// reduce_xor(a: bit[N]): bit[1] = ^a
		reduce_xor,
		// unary_minus(a: bit[N]): bit[N] = -a
		unary_minus,
		// equal(a: bit[N], b: bit[N]): bit[1] = (a == b)
		equal,
		// not_equal(a: bit[N], b: bit[N]): bit[1] = (a != b)
		not_equal,
		// signed_greater_than(a: signed bit[N], b: signed bit[N]): bit[1] = (a > b)
		signed_greater_than,
		// signed_greater_equal(a: signed bit[N], b: signed bit[N]): bit[1] = (a >= b)
		signed_greater_equal,
		// unsigned_greater_than(a: unsigned bit[N], b: unsigned bit[N]): bit[1] = (a > b)
		unsigned_greater_than,
		// unsigned_greater_equal(a: unsigned bit[N], b: unsigned bit[N]): bit[1] = (a >= b)
		unsigned_greater_equal,
		// logical_shift_left(a: bit[N], b: unsigned bit[M]): bit[N] = a << b
		// required: M == clog2(N)
		logical_shift_left,
		// logical_shift_right(a: unsigned bit[N], b: unsigned bit[M]): unsigned bit[N] = a >> b
		// required: M == clog2(N)
		logical_shift_right,
		// arithmetic_shift_right(a: signed bit[N], b: unsigned bit[M]): signed bit[N] = a >> b
		// required: M == clog2(N)
		arithmetic_shift_right,
		// mux(a: bit[N], b: bit[N], s: bit[1]): bit[N] = s ? b : a
		mux,
		// constant(a: Const[N]): bit[N] = a
		constant,
		// input(a: IdString): any
		// returns the current value of the input with the specified name
		input,
		// state(a: IdString): any
		// returns the current value of the state variable with the specified name
		state,
		// memory_read(memory: memory[addr_width, data_width], addr: bit[addr_width]): bit[data_width] = memory[addr]
		memory_read,
		// memory_write(memory: memory[addr_width, data_width], addr: bit[addr_width], data: bit[data_width]): memory[addr_width, data_width]
		// returns a copy of `memory` but with the value at `addr` changed to `data`
		memory_write
	};
	// returns the name of a Fn value, as a string literal
	const char *fn_to_string(Fn);
	// Sort represents the sort or type of a node
	// currently the only two sorts are signal/bit and memory
	class Sort {
		std::variant<int, std::pair<int, int>> _v;
	public:
		explicit Sort(int width) : _v(width) { }
		Sort(int addr_width, int data_width) : _v(std::make_pair(addr_width, data_width)) { }
		bool is_signal() const { return _v.index() == 0; }
		bool is_memory() const { return _v.index() == 1; }
		// returns the width of a bitvector sort, errors out for other sorts
		int width() const { return std::get<0>(_v); }
		// returns the address width of a bitvector sort, errors out for other sorts
		int addr_width() const { return std::get<1>(_v).first; }
		// returns the data width of a bitvector sort, errors out for other sorts
		int data_width() const { return std::get<1>(_v).second; }
		bool operator==(Sort const& other) const { return _v == other._v; }
		[[nodiscard]] Hasher hash_into(Hasher h) const { h.eat(_v); return h; }
	};
	class IR;
	class Factory;
	class Node;
	class IRInput {
		friend class Factory;
	public:
		IdString name;
		IdString kind;
		Sort sort;
	private:
		IRInput(IR &, IdString name, IdString kind, Sort sort)
		: name(name), kind(kind), sort(std::move(sort)) {}
	};
	class IROutput {
		friend class Factory;
		IR &_ir;
	public:
		IdString name;
		IdString kind;
		Sort sort;
	private:
		IROutput(IR &ir, IdString name, IdString kind, Sort sort)
		: _ir(ir), name(name), kind(kind), sort(std::move(sort)) {}
	public:
		Node value() const;
		bool has_value() const;
		void set_value(Node value);
	};
	class IRState {
		friend class Factory;
		IR &_ir;
	public:
		IdString name;
		IdString kind;
		Sort sort;
	private:
		std::variant<RTLIL::Const, MemContents> _initial;
		IRState(IR &ir, IdString name, IdString kind, Sort sort)
		: _ir(ir), name(name), kind(kind), sort(std::move(sort)) {}
	public:
		Node next_value() const;
		bool has_next_value() const;
		RTLIL::Const const& initial_value_signal() const { return std::get<RTLIL::Const>(_initial); }
		MemContents const& initial_value_memory() const { return std::get<MemContents>(_initial); }
		void set_next_value(Node value);
		void set_initial_value(RTLIL::Const value) { value.extu(sort.width()); _initial = std::move(value); }
		void set_initial_value(MemContents value) { log_assert(Sort(value.addr_width(), value.data_width()) == sort); _initial = std::move(value); }
	};
	class IR {
		friend class Factory;
		friend class Node;
		friend class IRInput;
		friend class IROutput;
		friend class IRState;
		// one NodeData is stored per Node, containing the function and non-node arguments
		// note that NodeData is deduplicated by ComputeGraph
		class NodeData {
			Fn _fn;
			std::variant<
				std::monostate,
				RTLIL::Const,
				std::pair<IdString, IdString>,
				int
			> _extra;
		public:
			NodeData() : _fn(Fn::invalid) {}
			NodeData(Fn fn) : _fn(fn) {}
			template<class T> NodeData(Fn fn, T &&extra) : _fn(fn), _extra(std::forward<T>(extra)) {}
			Fn fn() const { return _fn; }
			const RTLIL::Const &as_const() const { return std::get<RTLIL::Const>(_extra); }
			std::pair<IdString, IdString> as_idstring_pair() const { return std::get<std::pair<IdString, IdString>>(_extra); }
			int as_int() const { return std::get<int>(_extra); }
			[[nodiscard]] Hasher hash_into(Hasher h) const {
				h.eat((unsigned int) _fn);
				h.eat(_extra);
				return h;
			}
			bool operator==(NodeData const &other) const {
				return _fn == other._fn && _extra == other._extra;
			}
		};
		// Attr contains all the information about a note that should not be deduplicated
		struct Attr {
			Sort sort;
		};
		// our specialised version of ComputeGraph
		// the sparse_attr IdString stores a naming suggestion, retrieved with name()
		// the key is currently used to identify the nodes that represent output and next state values
		// the bool is true for next state values
		using Graph = ComputeGraph<NodeData, Attr, IdString, std::tuple<IdString, IdString, bool>>;
		Graph _graph;
		dict<std::pair<IdString, IdString>, IRInput> _inputs;
		dict<std::pair<IdString, IdString>, IROutput> _outputs;
		dict<std::pair<IdString, IdString>, IRState> _states;
		IR::Graph::Ref mutate(Node n);
	public:
		static IR from_module(Module *module);
		Factory factory();
		int size() const { return _graph.size(); }
		Node operator[](int i);
		void topological_sort();
		void forward_buf();
		IRInput const& input(IdString name, IdString kind) const { return _inputs.at({name, kind}); }
		IRInput const& input(IdString name) const { return input(name, ID($input)); }
		IROutput const& output(IdString name, IdString kind) const { return _outputs.at({name, kind}); }
		IROutput const& output(IdString name) const { return output(name, ID($output)); }
		IRState const& state(IdString name, IdString kind) const { return _states.at({name, kind}); }
		IRState const& state(IdString name) const { return state(name, ID($state)); }
		bool has_input(IdString name, IdString kind) const { return _inputs.count({name, kind}); }
		bool has_output(IdString name, IdString kind) const { return _outputs.count({name, kind}); }
		bool has_state(IdString name, IdString kind) const { return _states.count({name, kind}); }
		vector<IRInput const*> inputs(IdString kind) const;
		vector<IRInput const*> inputs() const { return inputs(ID($input)); }
		vector<IROutput const*> outputs(IdString kind) const;
		vector<IROutput const*> outputs() const { return outputs(ID($output)); }
		vector<IRState const*> states(IdString kind) const;
		vector<IRState const*> states() const { return states(ID($state)); }
		vector<IRInput const*> all_inputs() const;
		vector<IROutput const*> all_outputs() const;
		vector<IRState const*> all_states() const;
		class iterator {
			friend class IR;
			IR *_ir;
			int _index;
			iterator(IR *ir, int index) : _ir(ir), _index(index) {}
		public:
			using iterator_category = std::input_iterator_tag;
			using value_type = Node;
			using pointer = arrow_proxy<Node>;
			using reference = Node;
			using difference_type = ptrdiff_t;
			Node operator*();
			iterator &operator++() { _index++; return *this; }
			bool operator!=(iterator const &other) const { return _ir != other._ir || _index != other._index; }
			bool operator==(iterator const &other) const { return !(*this != other); }
			pointer operator->();
		};
		iterator begin() { return iterator(this, 0); }
		iterator end() { return iterator(this, _graph.size()); }
	};
	// Node is an immutable reference to a FunctionalIR node
	class Node {
		friend class Factory;
		friend class IR;
		friend class IRInput;
		friend class IROutput;
		friend class IRState;
		IR::Graph::ConstRef _ref;
		explicit Node(IR::Graph::ConstRef ref) : _ref(ref) { }
		explicit operator IR::Graph::ConstRef() { return _ref; }
	public:
		// the node's index. may change if nodes are added or removed
		int id() const { return _ref.index(); }
		// a name suggestion for the node, which need not be unique
		IdString name() const {
			if(_ref.has_sparse_attr())
				return _ref.sparse_attr();
			else
				return std::string("\\n") + std::to_string(id());
		}
		Fn fn() const { return _ref.function().fn(); }
		Sort sort() const { return _ref.attr().sort; }
		// returns the width of a bitvector node, errors out for other nodes
		int width() const { return sort().width(); }
		size_t arg_count() const { return _ref.size(); }
		Node arg(int n) const { return Node(_ref.arg(n)); }
		// visit calls the appropriate visitor method depending on the type of the node
		template<class Visitor> auto visit(Visitor v) const
		{
			// currently templated but could be switched to AbstractVisitor &
			switch(_ref.function().fn()) {
			case Fn::invalid: log_error("invalid node in visit"); break;
			case Fn::buf: return v.buf(*this, arg(0)); break;
			case Fn::slice: return v.slice(*this, arg(0), _ref.function().as_int(), sort().width()); break;
			case Fn::zero_extend: return v.zero_extend(*this, arg(0), width()); break;
			case Fn::sign_extend: return v.sign_extend(*this, arg(0), width()); break;
			case Fn::concat: return v.concat(*this, arg(0), arg(1)); break;
			case Fn::add: return v.add(*this, arg(0), arg(1)); break;
			case Fn::sub: return v.sub(*this, arg(0), arg(1)); break;
			case Fn::mul: return v.mul(*this, arg(0), arg(1)); break;
			case Fn::unsigned_div: return v.unsigned_div(*this, arg(0), arg(1)); break;
			case Fn::unsigned_mod: return v.unsigned_mod(*this, arg(0), arg(1)); break;
			case Fn::bitwise_and: return v.bitwise_and(*this, arg(0), arg(1)); break;
			case Fn::bitwise_or: return v.bitwise_or(*this, arg(0), arg(1)); break;
			case Fn::bitwise_xor: return v.bitwise_xor(*this, arg(0), arg(1)); break;
			case Fn::bitwise_not: return v.bitwise_not(*this, arg(0)); break;
			case Fn::unary_minus: return v.unary_minus(*this, arg(0)); break;
			case Fn::reduce_and: return v.reduce_and(*this, arg(0)); break;
			case Fn::reduce_or: return v.reduce_or(*this, arg(0)); break;
			case Fn::reduce_xor: return v.reduce_xor(*this, arg(0)); break;
			case Fn::equal: return v.equal(*this, arg(0), arg(1)); break;
			case Fn::not_equal: return v.not_equal(*this, arg(0), arg(1)); break;
			case Fn::signed_greater_than: return v.signed_greater_than(*this, arg(0), arg(1)); break; 
			case Fn::signed_greater_equal: return v.signed_greater_equal(*this, arg(0), arg(1)); break;
			case Fn::unsigned_greater_than: return v.unsigned_greater_than(*this, arg(0), arg(1)); break; 
			case Fn::unsigned_greater_equal: return v.unsigned_greater_equal(*this, arg(0), arg(1)); break;
			case Fn::logical_shift_left: return v.logical_shift_left(*this, arg(0), arg(1)); break;
			case Fn::logical_shift_right: return v.logical_shift_right(*this, arg(0), arg(1)); break;
			case Fn::arithmetic_shift_right: return v.arithmetic_shift_right(*this, arg(0), arg(1)); break;
			case Fn::mux: return v.mux(*this, arg(0), arg(1), arg(2)); break;
			case Fn::constant: return v.constant(*this, _ref.function().as_const()); break;
			case Fn::input: return v.input(*this, _ref.function().as_idstring_pair().first, _ref.function().as_idstring_pair().second); break;
			case Fn::state: return v.state(*this, _ref.function().as_idstring_pair().first, _ref.function().as_idstring_pair().second); break;
			case Fn::memory_read: return v.memory_read(*this, arg(0), arg(1)); break;
			case Fn::memory_write: return v.memory_write(*this, arg(0), arg(1), arg(2)); break;
			}
			log_abort();
		}
		std::string to_string();
		std::string to_string(std::function<std::string(Node)>);
	};
	inline IR::Graph::Ref IR::mutate(Node n) { return _graph[n._ref.index()]; }
	inline Node IR::operator[](int i) { return Node(_graph[i]); }
	inline Node IROutput::value() const { return Node(_ir._graph({name, kind, false})); }
	inline bool IROutput::has_value() const { return _ir._graph.has_key({name, kind, false}); }
	inline void IROutput::set_value(Node value) { log_assert(sort == value.sort()); _ir.mutate(value).assign_key({name, kind, false}); }
	inline Node IRState::next_value() const { return Node(_ir._graph({name, kind, true})); }
	inline bool IRState::has_next_value() const { return _ir._graph.has_key({name, kind, true}); }
	inline void IRState::set_next_value(Node value) { log_assert(sort == value.sort()); _ir.mutate(value).assign_key({name, kind, true}); }
	inline Node IR::iterator::operator*() { return Node(_ir->_graph[_index]); }
	inline arrow_proxy<Node> IR::iterator::operator->() { return arrow_proxy<Node>(**this); }
	// AbstractVisitor provides an abstract base class for visitors
	template<class T> struct AbstractVisitor {
		virtual T buf(Node self, Node n) = 0;
		virtual T slice(Node self, Node a, int offset, int out_width) = 0;
		virtual T zero_extend(Node self, Node a, int out_width) = 0;
		virtual T sign_extend(Node self, Node a, int out_width) = 0;
		virtual T concat(Node self, Node a, Node b) = 0;
		virtual T add(Node self, Node a, Node b) = 0;
		virtual T sub(Node self, Node a, Node b) = 0;
		virtual T mul(Node self, Node a, Node b) = 0;
		virtual T unsigned_div(Node self, Node a, Node b) = 0;
		virtual T unsigned_mod(Node self, Node a, Node b) = 0;
		virtual T bitwise_and(Node self, Node a, Node b) = 0;
		virtual T bitwise_or(Node self, Node a, Node b) = 0;
		virtual T bitwise_xor(Node self, Node a, Node b) = 0;
		virtual T bitwise_not(Node self, Node a) = 0;
		virtual T unary_minus(Node self, Node a) = 0;
		virtual T reduce_and(Node self, Node a) = 0;
		virtual T reduce_or(Node self, Node a) = 0;
		virtual T reduce_xor(Node self, Node a) = 0;
		virtual T equal(Node self, Node a, Node b) = 0;
		virtual T not_equal(Node self, Node a, Node b) = 0;
		virtual T signed_greater_than(Node self, Node a, Node b) = 0;
		virtual T signed_greater_equal(Node self, Node a, Node b) = 0;
		virtual T unsigned_greater_than(Node self, Node a, Node b) = 0;
		virtual T unsigned_greater_equal(Node self, Node a, Node b) = 0;
		virtual T logical_shift_left(Node self, Node a, Node b) = 0;
		virtual T logical_shift_right(Node self, Node a, Node b) = 0;
		virtual T arithmetic_shift_right(Node self, Node a, Node b) = 0;
		virtual T mux(Node self, Node a, Node b, Node s) = 0;
		virtual T constant(Node self, RTLIL::Const const & value) = 0;
		virtual T input(Node self, IdString name, IdString kind) = 0;
		virtual T state(Node self, IdString name, IdString kind) = 0;
		virtual T memory_read(Node self, Node mem, Node addr) = 0;
		virtual T memory_write(Node self, Node mem, Node addr, Node data) = 0;
	};
	// DefaultVisitor provides defaults for all visitor methods which just calls default_handler
	template<class T> struct DefaultVisitor : public AbstractVisitor<T> {
		virtual T default_handler(Node self) = 0;
		T buf(Node self, Node) override { return default_handler(self); }
		T slice(Node self, Node, int, int) override { return default_handler(self); }
		T zero_extend(Node self, Node, int) override { return default_handler(self); }
		T sign_extend(Node self, Node, int) override { return default_handler(self); }
		T concat(Node self, Node, Node) override { return default_handler(self); }
		T add(Node self, Node, Node) override { return default_handler(self); }
		T sub(Node self, Node, Node) override { return default_handler(self); }
		T mul(Node self, Node, Node) override { return default_handler(self); }
		T unsigned_div(Node self, Node, Node) override { return default_handler(self); }
		T unsigned_mod(Node self, Node, Node) override { return default_handler(self); }
		T bitwise_and(Node self, Node, Node) override { return default_handler(self); }
		T bitwise_or(Node self, Node, Node) override { return default_handler(self); }
		T bitwise_xor(Node self, Node, Node) override { return default_handler(self); }
		T bitwise_not(Node self, Node) override { return default_handler(self); }
		T unary_minus(Node self, Node) override { return default_handler(self); }
		T reduce_and(Node self, Node) override { return default_handler(self); }
		T reduce_or(Node self, Node) override { return default_handler(self); }
		T reduce_xor(Node self, Node) override { return default_handler(self); }
		T equal(Node self, Node, Node) override { return default_handler(self); }
		T not_equal(Node self, Node, Node) override { return default_handler(self); }
		T signed_greater_than(Node self, Node, Node) override { return default_handler(self); }
		T signed_greater_equal(Node self, Node, Node) override { return default_handler(self); }
		T unsigned_greater_than(Node self, Node, Node) override { return default_handler(self); }
		T unsigned_greater_equal(Node self, Node, Node) override { return default_handler(self); }
		T logical_shift_left(Node self, Node, Node) override { return default_handler(self); }
		T logical_shift_right(Node self, Node, Node) override { return default_handler(self); }
		T arithmetic_shift_right(Node self, Node, Node) override { return default_handler(self); }
		T mux(Node self, Node, Node, Node) override { return default_handler(self); }
		T constant(Node self, RTLIL::Const const &) override { return default_handler(self); }
		T input(Node self, IdString, IdString) override { return default_handler(self); }
		T state(Node self, IdString, IdString) override { return default_handler(self); }
		T memory_read(Node self, Node, Node) override { return default_handler(self); }
		T memory_write(Node self, Node, Node, Node) override { return default_handler(self); }
	};
	// a factory is used to modify a FunctionalIR. it creates new nodes and allows for some modification of existing nodes.
	class Factory {
		friend class IR;
		IR &_ir;
		explicit Factory(IR &ir) : _ir(ir) {}
		Node add(IR::NodeData &&fn, Sort const &sort, std::initializer_list<Node> args) {
			log_assert(!sort.is_signal() || sort.width() > 0);
			log_assert(!sort.is_memory() || (sort.addr_width() > 0 && sort.data_width() > 0));
			IR::Graph::Ref ref = _ir._graph.add(std::move(fn), {std::move(sort)});
			for (auto arg : args)
				ref.append_arg(IR::Graph::ConstRef(arg));
			return Node(ref);
		}
		void check_basic_binary(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && a.sort() == b.sort()); }
		void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal() && b.width() == ceil_log2(a.width())); }
		void check_unary(Node const &a) { log_assert(a.sort().is_signal()); }
	public:
		IR &ir() { return _ir; }
		Node slice(Node a, int offset, int out_width) {
			log_assert(a.sort().is_signal() && offset + out_width <= a.sort().width());
			if(offset == 0 && out_width == a.width())
				return a;
			return add(IR::NodeData(Fn::slice, offset), Sort(out_width), {a});
		}
		// extend will either extend or truncate the provided value to reach the desired width
		Node extend(Node a, int out_width, bool is_signed) {
			int in_width = a.sort().width();
			log_assert(a.sort().is_signal());
			if(in_width == out_width)
				return a;
			if(in_width > out_width)
				return slice(a, 0, out_width);
			if(is_signed)
				return add(Fn::sign_extend, Sort(out_width), {a});
			else
				return add(Fn::zero_extend, Sort(out_width), {a});
		}
		Node concat(Node a, Node b) {
			log_assert(a.sort().is_signal() && b.sort().is_signal());
			return add(Fn::concat, Sort(a.sort().width() + b.sort().width()), {a, b});
		}
		Node add(Node a, Node b) { check_basic_binary(a, b); return add(Fn::add, a.sort(), {a, b}); }
		Node sub(Node a, Node b) { check_basic_binary(a, b); return add(Fn::sub, a.sort(), {a, b}); }
		Node mul(Node a, Node b) { check_basic_binary(a, b); return add(Fn::mul, a.sort(), {a, b}); }
		Node unsigned_div(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_div, a.sort(), {a, b}); }
		Node unsigned_mod(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_mod, a.sort(), {a, b}); }
		Node bitwise_and(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_and, a.sort(), {a, b}); }
		Node bitwise_or(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_or, a.sort(), {a, b}); }
		Node bitwise_xor(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_xor, a.sort(), {a, b}); }
		Node bitwise_not(Node a) { check_unary(a); return add(Fn::bitwise_not, a.sort(), {a}); }
		Node unary_minus(Node a) { check_unary(a); return add(Fn::unary_minus, a.sort(), {a}); }
		Node reduce_and(Node a) {
			check_unary(a);
			if(a.width() == 1)
				return a;
			return add(Fn::reduce_and, Sort(1), {a});
		}
		Node reduce_or(Node a) {
			check_unary(a);
			if(a.width() == 1)
				return a;
			return add(Fn::reduce_or, Sort(1), {a});
		}
		Node reduce_xor(Node a) { 
			check_unary(a);
			if(a.width() == 1)
				return a;
			return add(Fn::reduce_xor, Sort(1), {a});
		}
		Node equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::equal, Sort(1), {a, b}); }
		Node not_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::not_equal, Sort(1), {a, b}); }
		Node signed_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_than, Sort(1), {a, b}); }
		Node signed_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_equal, Sort(1), {a, b}); }
		Node unsigned_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_than, Sort(1), {a, b}); }
		Node unsigned_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_equal, Sort(1), {a, b}); }
		Node logical_shift_left(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_left, a.sort(), {a, b}); }
		Node logical_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_right, a.sort(), {a, b}); }
		Node arithmetic_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::arithmetic_shift_right, a.sort(), {a, b}); }
		Node mux(Node a, Node b, Node s) {
			log_assert(a.sort().is_signal() && a.sort() == b.sort() && s.sort() == Sort(1));
			return add(Fn::mux, a.sort(), {a, b, s});
		}
		Node memory_read(Node mem, Node addr) {
			log_assert(mem.sort().is_memory() && addr.sort().is_signal() && mem.sort().addr_width() == addr.sort().width());
			return add(Fn::memory_read, Sort(mem.sort().data_width()), {mem, addr});
		}
		Node memory_write(Node mem, Node addr, Node data) {
			log_assert(mem.sort().is_memory() && addr.sort().is_signal() && data.sort().is_signal() &&
				mem.sort().addr_width() == addr.sort().width() && mem.sort().data_width() == data.sort().width());
			return add(Fn::memory_write, mem.sort(), {mem, addr, data});
		}
		Node constant(RTLIL::Const value) {
			int s = value.size();
			return add(IR::NodeData(Fn::constant, std::move(value)), Sort(s), {});
		}
		Node create_pending(int width) {
			return add(Fn::buf, Sort(width), {});
		}
		void update_pending(Node node, Node value) {
			log_assert(node._ref.function() == Fn::buf && node._ref.size() == 0);
			log_assert(node.sort() == value.sort());
			_ir.mutate(node).append_arg(value._ref);
		}
		IRInput &add_input(IdString name, IdString kind, Sort sort) {
			auto [it, inserted] = _ir._inputs.emplace({name, kind}, IRInput(_ir, name, kind, std::move(sort)));
			if (!inserted) log_error("input `%s` was re-defined", name.c_str());
			return it->second;
		}
		IROutput &add_output(IdString name, IdString kind, Sort sort) {
			auto [it, inserted] = _ir._outputs.emplace({name, kind}, IROutput(_ir, name, kind, std::move(sort)));
			if (!inserted) log_error("output `%s` was re-defined", name.c_str());
			return it->second;
		}
		IRState &add_state(IdString name, IdString kind, Sort sort) {
			auto [it, inserted] = _ir._states.emplace({name, kind}, IRState(_ir, name, kind, std::move(sort)));
			if (!inserted) log_error("state `%s` was re-defined", name.c_str());
			return it->second;
		}
		Node value(IRInput const& input) {
			return add(IR::NodeData(Fn::input, std::pair(input.name, input.kind)), input.sort, {});
		}
		Node value(IRState const& state) {
			return add(IR::NodeData(Fn::state, std::pair(state.name, state.kind)), state.sort, {});
		}
		void suggest_name(Node node, IdString name) {
			_ir.mutate(node).sparse_attr() = name;
		}
	};
	inline Factory IR::factory() { return Factory(*this); }
	template<class Id> class Scope {
	protected:
		char substitution_character = '_';
		virtual bool is_character_legal(char, int) = 0;
	private:
		pool<std::string> _used_names;
		dict<Id, std::string> _by_id;
	public:
		void reserve(std::string name) {
			_used_names.insert(std::move(name));
		}
		std::string unique_name(IdString suggestion) {
			std::string str = RTLIL::unescape_id(suggestion);
			for(size_t i = 0; i < str.size(); i++)
				if(!is_character_legal(str[i], i))
					str[i] = substitution_character;
			if(_used_names.count(str) == 0) {
				_used_names.insert(str);
				return str;
			}
			for (int idx = 0 ; ; idx++){
				std::string suffixed = str + "_" + std::to_string(idx);
				if(_used_names.count(suffixed) == 0) {
					_used_names.insert(suffixed);
					return suffixed;
				}
			}
		}
		std::string operator()(Id id, IdString suggestion) {
			auto it = _by_id.find(id);
			if(it != _by_id.end())
				return it->second;
			std::string str = unique_name(suggestion);
			_by_id.insert({id, str});
			return str;
		}
	};
	class Writer {
		std::ostream *os;
		void print_impl(const char *fmt, vector<std::function<void()>>& fns);
	public:
		Writer(std::ostream &os) : os(&os) {}
		template<class T> Writer& operator <<(T&& arg) { *os << std::forward<T>(arg); return *this; }
		template<typename... Args>
		void print(const char *fmt, Args&&... args)
		{
			vector<std::function<void()>> fns { [&]() { *this << args; }... };
			print_impl(fmt, fns);
		}
		template<typename Fn, typename... Args>
		void print_with(Fn fn, const char *fmt, Args&&... args)
		{
			vector<std::function<void()>> fns { [&]() {
				if constexpr (std::is_invocable_v<Fn, Args>)
					*this << fn(args);
				else
					*this << args; }...
			};
			print_impl(fmt, fns);
		}
	};

}

YOSYS_NAMESPACE_END

#endif