1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
|
#include "Props.hh"
#include "Compare.hh"
#include "Cleanup.hh"
#include "IndexIterator.hh"
#include "algorithms/collect_factors.hh"
#include "algorithms/collect_terms.hh"
#include "properties/Symbol.hh"
#include "properties/Coordinate.hh"
//#define DEBUG
using namespace cadabra;
collect_factors::collect_factors(const Kernel& k, Ex& e)
: Algorithm(k, e)
{
}
bool collect_factors::can_apply(iterator it)
{
if(*it->name=="\\prod") return true;
return false;
}
// The hash map is such that all objects which are equal have to sit in the same
// bin, but objects in the same bin do not necessarily all have to be equal.
void collect_factors::fill_hash_map(iterator it)
{
factor_hash.clear();
sibling_iterator sib=tr.begin(it);
unsigned int factors=0;
while(sib!=tr.end(it)) { // iterate over all factors in the product
auto chsib=index_iterator::begin(kernel.properties, sib);
auto chend=index_iterator::end(kernel.properties, sib);
bool dontcollect=false;
while(chsib!=chend) { // iterate over all child nodes of a factor
const Symbol *smb=kernel.properties.get<Symbol>(chsib, true);
const Coordinate *coo=kernel.properties.get<Coordinate>(chsib, true);
// std::cerr << chsib << ": " << smb << std::endl;
if((chsib->fl.parent_rel==str_node::p_sub || chsib->fl.parent_rel==str_node::p_super) &&
chsib->is_rational()==false && smb==0 && coo==0) {
dontcollect=true;
break;
}
++chsib;
}
if(!dontcollect) {
if(*sib->name=="\\pow") {
if(tr.begin(sib)->is_rational()==false) // do not collect exponents of numbers
factor_hash.insert(std::pair<hashval_t, sibling_iterator>(tr.calc_hash(tr.begin(sib)), tr.begin(sib)));
}
else
factor_hash.insert(std::pair<hashval_t, sibling_iterator>(tr.calc_hash(sib), sib));
++factors;
}
++sib;
}
}
Algorithm::result_t collect_factors::apply(iterator& st)
{
assert(tr.is_valid(st));
assert(*st->name=="\\prod");
result_t res=result_t::l_no_action;
Ex_comparator comp(kernel.properties);
fill_hash_map(st);
factor_hash_iterator_t ht=factor_hash.begin();
while(ht!=factor_hash.end()) {
hashval_t curr=ht->first; // hash value of the current set of terms
factor_hash_iterator_t thisbin1=ht, thisbin2;
while(thisbin1!=factor_hash.end() && thisbin1->first==curr) {
thisbin2=thisbin1;
++thisbin2;
Ex expsum;
iterator expsumit=expsum.set_head(str_node("\\sum"));
// add the exponent of the first element in this hash bin
if(*(tr.parent((*thisbin1).second)->name)=="\\pow") {
sibling_iterator powch=tr.parent((*thisbin1).second).begin();
++powch;
iterator newch= expsum.append_child(expsumit, iterator(powch));
newch->fl.bracket=str_node::b_round;
}
else {
expsum.append_child(expsumit, str_node("1", str_node::b_round));
}
// FIXME: If the multiplier of this factor is non-zero, we
// have (pure number)**(exp). We need to catch this
// separately. std::cerr << (*thisbin1).second << std::endl;
// For now, we have disabled collecting such factors; sympy
// can do it anyway.
assert(*((*thisbin1).second->multiplier)==1);
// find the other, identical factors
while(thisbin2!=factor_hash.end() && thisbin2->first==curr) {
if(subtree_exact_equal(&kernel.properties, (*thisbin1).second, (*thisbin2).second)) {
// only do something if this factor can be moved to the other one
iterator objnode1=(*thisbin1).second;
iterator objnode2=(*thisbin2).second;
if(*tr.parent(objnode1)->name=="\\pow") objnode1=tr.parent(objnode1);
if(*tr.parent(objnode2)->name=="\\pow") objnode2=tr.parent(objnode2);
if(comp.can_move_adjacent(st, objnode1, objnode2)) {
// all clear
assert(*((*thisbin2).second->multiplier)==1);
res=result_t::l_applied;
if(*(tr.parent((*thisbin2).second)->name)=="\\pow") {
sibling_iterator powch=tr.parent((*thisbin2).second).begin();
++powch;
iterator newch=expsum.append_child(expsumit, iterator(powch));
newch->fl.bracket=str_node::b_round;
}
else {
expsum.append_child(expsumit, str_node("1", str_node::b_round));
}
factor_hash_iterator_t tmp=thisbin2;
++tmp;
if(*(tr.parent((*thisbin2).second)->name)=="\\pow")
tr.erase(tr.parent((*thisbin2).second));
else
tr.erase((*thisbin2).second);
factor_hash.erase(thisbin2);
thisbin2=tmp;
res=result_t::l_applied;
}
else ++thisbin2;
}
else ++thisbin2;
}
// make the modification to the tree
if(expsum.number_of_children(expsum.begin())>1) {
iterator top=expsum.begin();
cleanup_dispatch(kernel,expsum, top);
// cleanup_nests_below(expsum, expsum.begin());
if(! (expsum.begin()->is_identity()) ) {
collect_terms ct(kernel, expsum);
iterator tp=expsum.begin();
ct.apply(tp);
iterator inserthere=thisbin1->second;
if(*(tr.parent(inserthere)->name)=="\\pow")
inserthere=tr.parent(inserthere);
if(expsum.begin()->is_rational() && (expsum.begin()->is_identity() ||
expsum.begin()->is_zero() ) ) {
if(*(inserthere->name)=="\\pow") {
tr.flatten(inserthere);
inserthere=tr.erase(inserthere);
sibling_iterator nxt=inserthere;
++nxt;
tr.erase(nxt);
}
if(expsum.begin()->is_zero()) {
rset_t::iterator rem=inserthere->multiplier;
node_one(inserthere);
inserthere->multiplier=rem;
}
}
else {
Ex repl;
repl.set_head(str_node("\\pow"));
repl.append_child(repl.begin(), iterator((*thisbin1).second));
repl.append_child(repl.begin(), expsum.begin());
if(*(inserthere->name)!="\\pow") {
inserthere=(*thisbin1).second;
}
tr.insert_subtree(inserthere, repl.begin());
tr.erase(inserthere);
}
}
}
// else txtout << "only one left" << std::endl;
++thisbin1;
}
ht=thisbin1;
}
cleanup_dispatch(kernel, tr, st);
return res;
}
|