1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
|
# This file is a part of Julia. License is MIT: https://julialang.org/license
module Cartesian
export @nloops, @nref, @ncall, @nexprs, @nextract, @nall, @nany, @ntuple, @nif
### Cartesian-specific macros
"""
@nloops N itersym rangeexpr bodyexpr
@nloops N itersym rangeexpr preexpr bodyexpr
@nloops N itersym rangeexpr preexpr postexpr bodyexpr
Generate `N` nested loops, using `itersym` as the prefix for the iteration variables.
`rangeexpr` may be an anonymous-function expression, or a simple symbol `var` in which case
the range is `axes(var, d)` for dimension `d`.
Optionally, you can provide "pre" and "post" expressions. These get executed first and last,
respectively, in the body of each loop. For example:
@nloops 2 i A d -> j_d = min(i_d, 5) begin
s += @nref 2 A j
end
would generate:
for i_2 = axes(A, 2)
j_2 = min(i_2, 5)
for i_1 = axes(A, 1)
j_1 = min(i_1, 5)
s += A[j_1, j_2]
end
end
If you want just a post-expression, supply [`nothing`](@ref) for the pre-expression. Using
parentheses and semicolons, you can supply multi-statement expressions.
"""
macro nloops(N, itersym, rangeexpr, args...)
_nloops(N, itersym, rangeexpr, args...)
end
function _nloops(N::Int, itersym::Symbol, arraysym::Symbol, args::Expr...)
@gensym d
_nloops(N, itersym, :($d->Base.axes($arraysym, $d)), args...)
end
function _nloops(N::Int, itersym::Symbol, rangeexpr::Expr, args::Expr...)
if rangeexpr.head != :->
throw(ArgumentError("second argument must be an anonymous function expression to compute the range"))
end
if !(1 <= length(args) <= 3)
throw(ArgumentError("number of arguments must be 1 ≤ length(args) ≤ 3, got $nargs"))
end
body = args[end]
ex = Expr(:escape, body)
for dim = 1:N
itervar = inlineanonymous(itersym, dim)
rng = inlineanonymous(rangeexpr, dim)
preexpr = length(args) > 1 ? inlineanonymous(args[1], dim) : (:(nothing))
postexpr = length(args) > 2 ? inlineanonymous(args[2], dim) : (:(nothing))
ex = quote
for $(esc(itervar)) = $(esc(rng))
$(esc(preexpr))
$ex
$(esc(postexpr))
end
end
end
ex
end
"""
@nref N A indexexpr
Generate expressions like `A[i_1, i_2, ...]`. `indexexpr` can either be an iteration-symbol
prefix, or an anonymous-function expression.
# Examples
```jldoctest
julia> @macroexpand Base.Cartesian.@nref 3 A i
:(A[i_1, i_2, i_3])
```
"""
macro nref(N::Int, A::Symbol, ex)
vars = Any[ inlineanonymous(ex,i) for i = 1:N ]
Expr(:escape, Expr(:ref, A, vars...))
end
"""
@ncall N f sym...
Generate a function call expression. `sym` represents any number of function arguments, the
last of which may be an anonymous-function expression and is expanded into `N` arguments.
For example, `@ncall 3 func a` generates
func(a_1, a_2, a_3)
while `@ncall 2 func a b i->c[i]` yields
func(a, b, c[1], c[2])
"""
macro ncall(N::Int, f, args...)
pre = args[1:end-1]
ex = args[end]
vars = Any[ inlineanonymous(ex,i) for i = 1:N ]
Expr(:escape, Expr(:call, f, pre..., vars...))
end
"""
@nexprs N expr
Generate `N` expressions. `expr` should be an anonymous-function expression.
# Examples
```jldoctest
julia> @macroexpand Base.Cartesian.@nexprs 4 i -> y[i] = A[i+j]
quote
y[1] = A[1 + j]
y[2] = A[2 + j]
y[3] = A[3 + j]
y[4] = A[4 + j]
end
```
"""
macro nexprs(N::Int, ex::Expr)
exs = Any[ inlineanonymous(ex,i) for i = 1:N ]
Expr(:escape, Expr(:block, exs...))
end
"""
@nextract N esym isym
Generate `N` variables `esym_1`, `esym_2`, ..., `esym_N` to extract values from `isym`.
`isym` can be either a `Symbol` or anonymous-function expression.
`@nextract 2 x y` would generate
x_1 = y[1]
x_2 = y[2]
while `@nextract 3 x d->y[2d-1]` yields
x_1 = y[1]
x_2 = y[3]
x_3 = y[5]
"""
macro nextract(N::Int, esym::Symbol, isym::Symbol)
aexprs = Any[ Expr(:escape, Expr(:(=), inlineanonymous(esym, i), :(($isym)[$i]))) for i = 1:N ]
Expr(:block, aexprs...)
end
macro nextract(N::Int, esym::Symbol, ex::Expr)
aexprs = Any[ Expr(:escape, Expr(:(=), inlineanonymous(esym, i), inlineanonymous(ex,i))) for i = 1:N ]
Expr(:block, aexprs...)
end
"""
@nall N expr
Check whether all of the expressions generated by the anonymous-function expression `expr`
evaluate to `true`.
`@nall 3 d->(i_d > 1)` would generate the expression `(i_1 > 1 && i_2 > 1 && i_3 > 1)`. This
can be convenient for bounds-checking.
"""
macro nall(N::Int, criterion::Expr)
if criterion.head != :->
throw(ArgumentError("second argument must be an anonymous function expression yielding the criterion"))
end
conds = Any[ Expr(:escape, inlineanonymous(criterion, i)) for i = 1:N ]
Expr(:&&, conds...)
end
"""
@nany N expr
Check whether any of the expressions generated by the anonymous-function expression `expr`
evaluate to `true`.
`@nany 3 d->(i_d > 1)` would generate the expression `(i_1 > 1 || i_2 > 1 || i_3 > 1)`.
"""
macro nany(N::Int, criterion::Expr)
if criterion.head != :->
error("Second argument must be an anonymous function expression yielding the criterion")
end
conds = Any[ Expr(:escape, inlineanonymous(criterion, i)) for i = 1:N ]
Expr(:||, conds...)
end
"""
@ntuple N expr
Generates an `N`-tuple. `@ntuple 2 i` would generate `(i_1, i_2)`, and `@ntuple 2 k->k+1`
would generate `(2,3)`.
"""
macro ntuple(N::Int, ex)
vars = Any[ inlineanonymous(ex,i) for i = 1:N ]
Expr(:escape, Expr(:tuple, vars...))
end
"""
@nif N conditionexpr expr
@nif N conditionexpr expr elseexpr
Generates a sequence of `if ... elseif ... else ... end` statements. For example:
@nif 3 d->(i_d >= size(A,d)) d->(error("Dimension ", d, " too big")) d->println("All OK")
would generate:
if i_1 > size(A, 1)
error("Dimension ", 1, " too big")
elseif i_2 > size(A, 2)
error("Dimension ", 2, " too big")
else
println("All OK")
end
"""
macro nif(N, condition, operation...)
# Handle the final "else"
ex = esc(inlineanonymous(length(operation) > 1 ? operation[2] : operation[1], N))
# Make the nested if statements
for i = N-1:-1:1
ex = Expr(:if, esc(inlineanonymous(condition,i)), esc(inlineanonymous(operation[1],i)), ex)
end
ex
end
## Utilities
# Simplify expressions like :(d->3:size(A,d)-3) given an explicit value for d
function inlineanonymous(ex::Expr, val)
if ex.head != :->
throw(ArgumentError("not an anonymous function"))
end
if !isa(ex.args[1], Symbol)
throw(ArgumentError("not a single-argument anonymous function"))
end
sym = ex.args[1]
ex = ex.args[2]
exout = lreplace(ex, sym, val)
exout = poplinenum(exout)
exprresolve(exout)
end
# Given :i and 3, this generates :i_3
inlineanonymous(base::Symbol, ext) = Symbol(base,'_',ext)
# Replace a symbol by a value or a "coded" symbol
# E.g., for d = 3,
# lreplace(:d, :d, 3) -> 3
# lreplace(:i_d, :d, 3) -> :i_3
# lreplace(:i_{d-1}, :d, 3) -> :i_2
# This follows LaTeX notation.
struct LReplace{S<:AbstractString}
pat_sym::Symbol
pat_str::S
val::Int
end
LReplace(sym::Symbol, val::Integer) = LReplace(sym, string(sym), val)
lreplace(ex, sym::Symbol, val) = lreplace!(copy(ex), LReplace(sym, val))
function lreplace!(sym::Symbol, r::LReplace)
sym == r.pat_sym && return r.val
Symbol(lreplace!(string(sym), r))
end
function lreplace!(str::AbstractString, r::LReplace)
i = firstindex(str)
pat = r.pat_str
j = firstindex(pat)
matching = false
local istart::Int
while i <= ncodeunits(str)
cstr = str[i]
i = nextind(str, i)
if !matching
if cstr != '_' || i > ncodeunits(str)
continue
end
istart = i
cstr = str[i]
i = nextind(str, i)
end
if j <= lastindex(pat)
cr = pat[j]
j = nextind(pat, j)
if cstr == cr
matching = true
else
matching = false
j = firstindex(pat)
i = istart
continue
end
end
if matching && j > lastindex(pat)
if i > lastindex(str) || str[i] == '_'
# We have a match
return string(str[1:prevind(str, istart)], r.val, lreplace!(str[i:end], r))
end
matching = false
j = firstindex(pat)
i = istart
end
end
str
end
function lreplace!(ex::Expr, r::LReplace)
# Curly-brace notation, which acts like parentheses
if ex.head == :curly && length(ex.args) == 2 && isa(ex.args[1], Symbol) && endswith(string(ex.args[1]), "_")
excurly = exprresolve(lreplace!(ex.args[2], r))
if isa(excurly, Number)
return Symbol(ex.args[1],excurly)
else
ex.args[2] = excurly
return ex
end
end
for i in 1:length(ex.args)
ex.args[i] = lreplace!(ex.args[i], r)
end
ex
end
lreplace!(arg, r::LReplace) = arg
poplinenum(arg) = arg
function poplinenum(ex::Expr)
if ex.head == :block
if length(ex.args) == 1
return ex.args[1]
elseif length(ex.args) == 2 && isa(ex.args[1], LineNumberNode)
return ex.args[2]
elseif (length(ex.args) == 2 && isa(ex.args[1], Expr) && ex.args[1].head == :line)
return ex.args[2]
end
end
ex
end
## Resolve expressions at parsing time ##
const exprresolve_arith_dict = Dict{Symbol,Function}(:+ => +,
:- => -, :* => *, :/ => /, :^ => ^, :div => div)
const exprresolve_cond_dict = Dict{Symbol,Function}(:(==) => ==,
:(<) => <, :(>) => >, :(<=) => <=, :(>=) => >=)
function exprresolve_arith(ex::Expr)
if ex.head == :call && haskey(exprresolve_arith_dict, ex.args[1]) && all([isa(ex.args[i], Number) for i = 2:length(ex.args)])
return true, exprresolve_arith_dict[ex.args[1]](ex.args[2:end]...)
end
false, 0
end
exprresolve_arith(arg) = false, 0
exprresolve_conditional(b::Bool) = true, b
function exprresolve_conditional(ex::Expr)
if ex.head == :call && ex.args[1] ∈ keys(exprresolve_cond_dict) && isa(ex.args[2], Number) && isa(ex.args[3], Number)
return true, exprresolve_cond_dict[ex.args[1]](ex.args[2], ex.args[3])
end
false, false
end
exprresolve_conditional(arg) = false, false
exprresolve(arg) = arg
function exprresolve(ex::Expr)
for i = 1:length(ex.args)
ex.args[i] = exprresolve(ex.args[i])
end
# Handle simple arithmetic
can_eval, result = exprresolve_arith(ex)
if can_eval
return result
elseif ex.head == :call && (ex.args[1] == :+ || ex.args[1] == :-) && length(ex.args) == 3 && ex.args[3] == 0
# simplify x+0 and x-0
return ex.args[2]
end
# Resolve array references
if ex.head == :ref && isa(ex.args[1], Array)
for i = 2:length(ex.args)
if !isa(ex.args[i], Real)
return ex
end
end
return ex.args[1][ex.args[2:end]...]
end
# Resolve conditionals
if ex.head == :if
can_eval, tf = exprresolve_conditional(ex.args[1])
if can_eval
ex = tf ? ex.args[2] : ex.args[3]
end
end
ex
end
end
|