1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
|
discard """
output: '''heavy_calc_impl is called
sub_calc1_impl is called
sub_calc2_impl is called
** no changes recompute effectively
** change one input and recompute effectively
heavy_calc_impl is called
sub_calc2_impl is called'''
"""
# sample incremental
import tables
import macros
var inputs = initTable[string, float]()
var cache = initTable[string, float]()
var dep_tree {.compileTime.} = initTable[string, string]()
macro symHash(s: typed{nkSym}): string =
result = newStrLitNode(symBodyHash(s))
#######################################################################################
template graph_node(key: string) {.pragma.}
proc tag(n: NimNode): NimNode =
## returns graph node unique name of a function or nil if it is not a graph node
expectKind(n, {nnkProcDef, nnkFuncDef})
for p in n.pragma:
if p.len > 0 and p[0] == bindSym"graph_node":
return p[1]
return nil
macro graph_node_key(n: typed{nkSym}): untyped =
result = newStrLitNode(n.symBodyHash)
macro graph_discovery(n: typed{nkSym}): untyped =
# discovers graph dependency tree and updated dep_tree global var
let mytag = newStrLitNode(n.symBodyHash)
var visited: seq[NimNode]
proc discover(n: NimNode) =
case n.kind:
of nnkNone..pred(nnkSym), succ(nnkSym)..nnkNilLit: discard
of nnkSym:
if n.symKind in {nskFunc, nskProc}:
if n notin visited:
visited.add n
let tag = n.getImpl.tag
if tag != nil:
dep_tree[tag.strVal] = mytag.strVal
else:
discover(n.getImpl.body)
else:
for child in n:
discover(child)
discover(n.getImpl.body)
result = newEmptyNode()
#######################################################################################
macro incremental_input(key: static[string], n: untyped{nkFuncDef}): untyped =
# mark leaf nodes of the graph
template getInput(key) {.dirty.} =
{.noSideEffect.}:
inputs[key]
result = n
result.pragma = nnkPragma.newTree(nnkCall.newTree(bindSym"graph_node", newStrLitNode(key)))
result.body = getAst(getInput(key))
macro incremental(n: untyped{nkFuncDef}): untyped =
## incrementalize side effect free computation
## wraps function into caching layer, mark caching function as a graph_node
## injects dependency discovery between graph nodes
template cache_func_body(func_name, func_name_str, func_call) {.dirty.} =
{.noSideEffect.}:
graph_discovery(func_name)
let key = graph_node_key(func_name)
if key in cache:
result = cache[key]
else:
echo func_name_str & " is called"
result = func_call
cache[key] = result
let func_name = n.name.strVal & "_impl"
let func_call = nnkCall.newTree(ident func_name)
for i in 1..<n.params.len:
func_call.add n.params[i][0]
let cache_func = n.copyNimTree
cache_func.body = getAst(cache_func_body(ident func_name, func_name, func_call))
cache_func.pragma = nnkPragma.newTree(newCall(bindSym"graph_node",
newCall(bindSym"symHash", ident func_name)))
n.name = ident(func_name)
result = nnkStmtList.newTree(n, cache_func)
###########################################################################
### Example
###########################################################################
func input1(): float {.incremental_input("a1").}
func input2(): float {.incremental_input("a2").}
func sub_calc1(a: float): float {.incremental.} =
a + input1()
func sub_calc2(b: float): float {.incremental.} =
b + input2()
func heavy_calc(a: float, b: float): float {.incremental.} =
sub_calc1(a) + sub_calc2(b)
###########################################################################
## graph finalize and inputs
###########################################################################
macro finalize_dep_tree(): untyped =
result = nnkTableConstr.newNimNode
for key, val in dep_tree:
result.add nnkExprColonExpr.newTree(newStrLitNode key, newStrLitNode val)
result = nnkCall.newTree(bindSym"toTable", result)
const dep_tree_final = finalize_dep_tree()
proc set_input(key: string, val: float) =
## set input value
## all affected nodes of graph are invalidated
inputs[key] = val
var k = key
while k != "":
k = dep_tree_final.getOrDefault(k , "")
cache.del(k)
###########################################################################
## demo
###########################################################################
set_input("a1", 5)
set_input("a2", 2)
discard heavy_calc(5.0, 10.0)
echo "** no changes recompute effectively"
discard heavy_calc(5.0, 10.0)
echo "** change one input and recompute effectively"
set_input("a2", 10)
discard heavy_calc(5.0, 10.0)
|