1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
|
#include "pythoncdb/py_ex.hh"
#include "pythoncdb/py_helpers.hh"
#include "pythoncdb/py_globals.hh"
#include "Bridge.hh"
#include "algorithms/collect_terms.hh"
using namespace cadabra;
namespace py = pybind11;
void pull_in(std::shared_ptr<Ex> ex, Kernel *kernel)
{
collect_terms rr(*kernel, *ex);
// bool acted=false;
Ex::iterator it=ex->begin();
while(it!=ex->end()) {
if(*it->name=="@") {
std::string pobj = *(Ex::begin(it)->name);
std::shared_ptr<Ex> pull_ex = fetch_from_python(pobj);
if(pull_ex) {
// acted=true;
multiplier_t mult = *(it->multiplier);
str_node::parent_rel_t prel = (it->fl.parent_rel);
auto topnode_it = pull_ex->begin();
auto at_arg = ex->begin(it);
ex->erase(at_arg); // erase argument of @
it=ex->replace(it, *(topnode_it)); // replace @ with head of ex
if(ex->number_of_children(topnode_it)>0) {
Ex::sibling_iterator walk=ex->end(topnode_it);
do {
--walk;
ex->prepend_child(it, (Ex::iterator)walk);
}
while(walk!=ex->begin(topnode_it));
}
// FIXME: prepend_children is broken!
// ex->prepend_children(it, ex->begin(topnode_it), ex->end(topnode_it)); // add children of ex
multiply(it->multiplier, mult);
it->fl.parent_rel=prel;
rr.rename_replacement_dummies(it, false);
}
else throw ArgumentException("Python object '"+pobj+"' does not exist.");
}
++it;
}
// if(acted)
// std::cerr << "pull_in done: " << *ex << std::endl;
return;
}
void run_python_functions(std::shared_ptr<Ex> ex, Kernel *kernel)
{
if(kernel->call_embedded_python_functions==false)
return;
Ex::post_order_iterator it = ex->begin_post();
auto locals=get_locals();
while(it!=ex->end_post()) {
auto nxt=it;
++nxt;
// Only call functions if the cadabra symbols have one or
// more child nodes which all have bracket_t::b_none.
Ex::sibling_iterator sib=ex->begin(it);
if(sib==ex->end(it)) {
it=nxt;
continue;
}
bool cancall=true;
while(sib!=ex->end(it)) {
if(sib->fl.parent_rel!=str_node::parent_rel_t::p_none) {
cancall=false;
break;
}
++sib;
}
if(!cancall) {
it=nxt;
continue;
}
if(scope_has(locals, *it->name)) {
//std::cerr << "can run function " << *it->name << std::endl;
py::object fun=locals[(*it->name).c_str()];
Ex::sibling_iterator sib=ex->begin(it);
py::object res;
if(sib!=ex->end(it)) {
Ex tmp1(sib);
++sib;
if(sib!=ex->end(it)) {
Ex tmp2(sib);
++sib;
if(sib!=ex->end(it)) {
Ex tmp3(sib);
res = fun(tmp1, tmp2, tmp3);
++sib;
if(sib!=ex->end(it)) {
throw RuntimeException("Cannot yet call inline functions with more than 3 arguments.");
}
}
else res = fun(tmp1, tmp2);
}
else res = fun(tmp1);
}
else res = fun();
Ex repl = res.cast<Ex>();
Ex::iterator tmpit=it;
rset_t::iterator mult=tmpit->multiplier;
ex->move_ontop(tmpit, repl.begin())->multiplier=mult;
}
it=nxt;
}
}
|