File: collect_factors.cc

package info (click to toggle)
cadabra2 2.4.3.2-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 78,732 kB
  • sloc: ansic: 133,450; cpp: 92,064; python: 1,530; javascript: 203; sh: 184; xml: 182; objc: 53; makefile: 51
file content (179 lines) | stat: -rw-r--r-- 6,051 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179

#include "Props.hh"
#include "Compare.hh"
#include "Cleanup.hh"
#include "IndexIterator.hh"
#include "algorithms/collect_factors.hh"
#include "algorithms/collect_terms.hh"
#include "properties/Symbol.hh"
#include "properties/Coordinate.hh"

//#define DEBUG

using namespace cadabra;

collect_factors::collect_factors(const Kernel& k, Ex& e)
	: Algorithm(k, e)
	{
	}

bool collect_factors::can_apply(iterator it)
	{
	if(*it->name=="\\prod") return true;
	return false;
	}

// The hash map is such that all objects which are equal have to sit in the same
// bin, but objects in the same bin do not necessarily all have to be equal.
void collect_factors::fill_hash_map(iterator it)
	{
	factor_hash.clear();
	sibling_iterator sib=tr.begin(it);
	unsigned int factors=0;
	while(sib!=tr.end(it)) { // iterate over all factors in the product
		auto chsib=index_iterator::begin(kernel.properties, sib);
		auto chend=index_iterator::end(kernel.properties, sib);
		bool dontcollect=false;
		while(chsib!=chend) { // iterate over all child nodes of a factor
			const Symbol     *smb=kernel.properties.get<Symbol>(chsib, true);
			const Coordinate *coo=kernel.properties.get<Coordinate>(chsib, true);			
			// std::cerr << chsib << ": " << smb << std::endl;
			if((chsib->fl.parent_rel==str_node::p_sub || chsib->fl.parent_rel==str_node::p_super) &&
			      chsib->is_rational()==false && smb==0 && coo==0) {
				dontcollect=true;
				break;
				}
			++chsib;
			}
		if(!dontcollect) {
			if(*sib->name=="\\pow") {
				if(tr.begin(sib)->is_rational()==false) // do not collect exponents of numbers
					factor_hash.insert(std::pair<hashval_t, sibling_iterator>(tr.calc_hash(tr.begin(sib)), tr.begin(sib)));
				}
			else
				factor_hash.insert(std::pair<hashval_t, sibling_iterator>(tr.calc_hash(sib), sib));
			++factors;
			}
		++sib;
		}
	}

Algorithm::result_t collect_factors::apply(iterator& st)
	{
	assert(tr.is_valid(st));
	assert(*st->name=="\\prod");
	result_t res=result_t::l_no_action;

	Ex_comparator comp(kernel.properties);

	fill_hash_map(st);
	factor_hash_iterator_t ht=factor_hash.begin();
	while(ht!=factor_hash.end()) {
		hashval_t curr=ht->first;  // hash value of the current set of terms
		factor_hash_iterator_t thisbin1=ht, thisbin2;
		while(thisbin1!=factor_hash.end() && thisbin1->first==curr) {
			thisbin2=thisbin1;
			++thisbin2;
			Ex expsum;
			iterator expsumit=expsum.set_head(str_node("\\sum"));
			// add the exponent of the first element in this hash bin
			if(*(tr.parent((*thisbin1).second)->name)=="\\pow") {
				sibling_iterator powch=tr.parent((*thisbin1).second).begin();
				++powch;
				iterator newch= expsum.append_child(expsumit, iterator(powch));
				newch->fl.bracket=str_node::b_round;
				}
			else {
				expsum.append_child(expsumit, str_node("1", str_node::b_round));
				}
			// FIXME: If the multiplier of this factor is non-zero, we
			// have (pure number)**(exp). We need to catch this
			// separately.  std::cerr << (*thisbin1).second << std::endl;
			// For now, we have disabled collecting such factors; sympy
			// can do it anyway.
			assert(*((*thisbin1).second->multiplier)==1);
			// find the other, identical factors
			while(thisbin2!=factor_hash.end() && thisbin2->first==curr) {
				if(subtree_exact_equal(&kernel.properties, (*thisbin1).second, (*thisbin2).second)) {
					// only do something if this factor can be moved to the other one
					iterator objnode1=(*thisbin1).second;
					iterator objnode2=(*thisbin2).second;
					if(*tr.parent(objnode1)->name=="\\pow") objnode1=tr.parent(objnode1);
					if(*tr.parent(objnode2)->name=="\\pow") objnode2=tr.parent(objnode2);
					if(comp.can_move_adjacent(st, objnode1, objnode2)) {
						// all clear
						assert(*((*thisbin2).second->multiplier)==1);
						res=result_t::l_applied;
						if(*(tr.parent((*thisbin2).second)->name)=="\\pow") {
							sibling_iterator powch=tr.parent((*thisbin2).second).begin();
							++powch;
							iterator newch=expsum.append_child(expsumit, iterator(powch));
							newch->fl.bracket=str_node::b_round;
							}
						else {
							expsum.append_child(expsumit, str_node("1", str_node::b_round));
							}
						factor_hash_iterator_t tmp=thisbin2;
						++tmp;
						if(*(tr.parent((*thisbin2).second)->name)=="\\pow")
							tr.erase(tr.parent((*thisbin2).second));
						else
							tr.erase((*thisbin2).second);
						factor_hash.erase(thisbin2);
						thisbin2=tmp;
						res=result_t::l_applied;
						}
					else ++thisbin2;
					}
				else ++thisbin2;
				}
			// make the modification to the tree
			if(expsum.number_of_children(expsum.begin())>1) {
				iterator top=expsum.begin();
				cleanup_dispatch(kernel,expsum, top);
				//				cleanup_nests_below(expsum, expsum.begin());
				if(! (expsum.begin()->is_identity()) ) {
					collect_terms ct(kernel, expsum);
					iterator tp=expsum.begin();
					ct.apply(tp);

					iterator inserthere=thisbin1->second;
					if(*(tr.parent(inserthere)->name)=="\\pow")
						inserthere=tr.parent(inserthere);
					if(expsum.begin()->is_rational() && (expsum.begin()->is_identity() ||
					                                     expsum.begin()->is_zero() ) ) {
						if(*(inserthere->name)=="\\pow") {
							tr.flatten(inserthere);
							inserthere=tr.erase(inserthere);
							sibling_iterator nxt=inserthere;
							++nxt;
							tr.erase(nxt);
							}
						if(expsum.begin()->is_zero()) {
							rset_t::iterator rem=inserthere->multiplier;
							node_one(inserthere);
							inserthere->multiplier=rem;
							}
						}
					else {
						Ex repl;
						repl.set_head(str_node("\\pow"));
						repl.append_child(repl.begin(), iterator((*thisbin1).second));
						repl.append_child(repl.begin(), expsum.begin());
						if(*(inserthere->name)!="\\pow") {
							inserthere=(*thisbin1).second;
							}
						tr.insert_subtree(inserthere, repl.begin());
						tr.erase(inserthere);
						}
					}
				}
			//			else txtout << "only one left" << std::endl;
			++thisbin1;
			}
		ht=thisbin1;
		}
	cleanup_dispatch(kernel, tr, st);
	return res;
	}