File: integrate_by_parts.cc

package info (click to toggle)
cadabra2 2.4.3.2-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 78,732 kB
  • sloc: ansic: 133,450; cpp: 92,064; python: 1,530; javascript: 203; sh: 184; xml: 182; objc: 53; makefile: 51
file content (252 lines) | stat: -rw-r--r-- 6,975 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252

#include <iterator>
#include "algorithms/integrate_by_parts.hh"
#include "properties/Derivative.hh"
#include "Cleanup.hh"

using namespace cadabra;

integrate_by_parts::integrate_by_parts(const Kernel& k, Ex& tr, Ex& af)
	: Algorithm(k, tr), away_from(af)
	{
	}

bool integrate_by_parts::can_apply(iterator st)
	{
	if(*st->name=="\\int") return true;
	return false;
	}

Algorithm::result_t integrate_by_parts::apply(iterator& it)
	{
	result_t ret=result_t::l_no_action;

	auto sib=tr.begin(it);
	while(sib!=tr.end(it)) {
		if(sib->fl.parent_rel==str_node::p_none) {
			// Have found the integrand.
			if(*sib->name=="\\sum") {
				auto term=tr.begin(sib);
				while(term!=tr.end(sib)) {
					iterator ti(term);
					++term;
					auto res=handle_term(it, ti);
					if(res==result_t::l_applied) {
						ret=res;
						cleanup_dispatch(kernel, tr, ti);
						}
					}
				// Cleanup nested sums
				iterator tmp(sib);
				cleanup_dispatch(kernel, tr, tmp);
				}
			else {
				iterator ti(sib);
				ret=handle_term(it, ti);
				if(ret==result_t::l_applied)
					cleanup_dispatch(kernel, tr, ti);
				}
			break;
			}
		++sib;
		}

	cleanup_dispatch(kernel, tr, it);
	return ret;
	}

bool integrate_by_parts::int_and_derivative_related(iterator, iterator) const
	{
	return true;
	}

bool integrate_by_parts::derivative_acting_on_arg(iterator der_it) const
	{
	sibling_iterator arg=tr.begin(der_it);
	while(arg->is_index())
		++arg;

	Ex_comparator comp(kernel.properties);
	auto top=away_from.begin();
	if( is_in( comp.equal_subtree(top, arg), {
	Ex_comparator::match_t::subtree_match,
	Ex_comparator::match_t::match_index_less,
	Ex_comparator::match_t::match_index_greater
	} ) ) return true;
	return false;
	}

void integrate_by_parts::split_off_single_derivative(iterator, iterator der_it)
	{
	auto ni=number_of_direct_indices(der_it);
	if(ni==0 || ni==1) return;

	sibling_iterator sib=tr.begin(der_it);
	++sib;
	sibling_iterator arg=sib;
	while(arg!=tr.end(der_it) && arg->is_index())
		++arg;
	if(arg==tr.end(der_it))
		throw ConsistencyException("Derivative without argument encountered");
	auto wrap=tr.wrap(arg, str_node(der_it->name));
	while(sib!=wrap) {
		auto nxt=sib;
		++nxt;
		tr.move_before(tr.begin(wrap), sib);
		sib=nxt;
		}
	}

Algorithm::result_t integrate_by_parts::handle_term(iterator int_it, iterator& it)
	{
	// Either this is a Derivative node, in which case it is a total derivative.
	// Or this is a product, in which case we need to scan factors for a Derivative
	// and figure out whether it contains the searched-for expression.

	if(*it->name=="\\ldots") return result_t::l_no_action;

	const Derivative *dtop=kernel.properties.get<Derivative>(it);
	if(dtop) {
		if(int_and_derivative_related(int_it, it)) {
			zero(it->multiplier);
			return result_t::l_applied;
			}
		}

	prod_wrap_single_term(it);

	assert(*it->name=="\\prod");
	auto fac=tr.begin(it);
	int pos=0;
	while(fac!=tr.end(it)) {
		const Derivative *der=kernel.properties.get<Derivative>(fac);
		if(der) {
			// If this is a multiple partial derivative, we split off the
			// outermost derivative and then look at the remaining argument.
			split_off_single_derivative(int_it, fac);
//			std::cerr << "after split-off:" << int_it << std::endl;
//			std::cerr << " c1: " << int_and_derivative_related(int_it, fac) << std::endl;
//			std::cerr << " c2: " << derivative_acting_on_arg(fac) << std::endl;
			if(int_and_derivative_related(int_it, fac) && derivative_acting_on_arg(fac) ) {
				// Generate one term with the derivative acting on all
				// factors which come before the derivative node (if
				// present).  Generate another one for those factors
				// coming after the derivative (if present).

				// FIXME: this does not yet take anti-commutativity of the
				// derivative itself into account.

				if(fac==tr.begin(it) || std::next(fac)==tr.end(it)) {
					// Derivative is first or last factor in product; generate one term only.
					// Note: total derivatives have already been handled!

					sibling_iterator from, to;
					if(fac==tr.begin(it)) {
						from=fac;
						++from;
						to=tr.end(it);
						}
					else {
						from=tr.begin(it);
						to=fac;
						}
					if(std::next(from)!=to)
						from = tr.wrap(from, to, str_node("\\prod"));

					// Find the argument of the derivative. Needs to handle derivatives with
					// a single index, and derivatives with no index at all.
					auto der_arg = tr.begin(fac);
					if(der_arg==tr.end(fac))
						throw ConsistencyException("integrate_by_parts: Derivative without argument encountered.");

					if(der_arg->is_index()) {
						++der_arg;
						// This _has_ to be the argument because we have peeled off a single derivative.
						if(der_arg==tr.end(fac))
							throw ConsistencyException("integrate_by_parts: Derivative without argument encountered.");
						}
					
					tr.swap(der_arg, from);
					tr.swap(fac, der_arg);
					multiply(it->multiplier, -1);
					iterator tmp(fac);
					cleanup_dispatch(kernel, tr, tmp);
					return result_t::l_applied;
					}
				else {
					// Two terms needed.
					Ex sum("\\sum");
					auto t1prod = sum.append_child(sum.begin(), it);
					auto t2prod = sum.append_child(sum.begin(), it);

					// First term.
					sibling_iterator from=sum.begin(t1prod);
					sibling_iterator to  =from;
					to+=pos;
					if(std::next(from)!=to)
						from = tr.wrap(from, to, str_node("\\prod"));

					auto der_arg = tr.begin(to);
					while(der_arg->is_index() && der_arg!=tr.end(to))
						++der_arg;

					if(der_arg==tr.end(to))
						throw ConsistencyException("integrate_by_parts: Derivative without argument encountered.");

					tr.swap(der_arg, from);
					tr.swap(to, der_arg);
					multiply(t1prod->multiplier, -1);
					iterator tmp(to);
					cleanup_dispatch(kernel, tr, tmp);

					// Second term.
					from=sum.begin(t2prod);
					from+=pos;
					auto der=from;
					++from;
					to  =sum.end(t2prod);
					if(std::next(from)!=to)
						from = tr.wrap(from, to, str_node("\\prod"));

					der_arg = tr.begin(der);
					while(der_arg->is_index() && der_arg!=tr.end(der))
						++der_arg;

					if(der_arg==tr.end(der))
						throw ConsistencyException("integrate_by_parts: Derivative without argument encountered.");

					tr.swap(der_arg, from);
					tr.swap(der, der_arg);
					multiply(t2prod->multiplier, -1);
					tmp=der;
					cleanup_dispatch(kernel, tr, tmp);

					// Replace the original with the sum.
					it=tr.move_ontop(it, sum.begin());

					return result_t::l_applied;
					}

				//				for(auto& pat: kernel.properties.pats) {
				//					if(pat.first->name()=="Coordinate") {
				//						std::cerr << pat.second->obj << std::endl;
				//						}
				//					}
				}
			else {
				// Undo the split-off.
				iterator tmp=fac;
				cleanup_dispatch(kernel, tr, tmp);
				}
			}
		++fac;
		++pos;
		}

	return result_t::l_no_action;
	}

Ex integrate_by_parts::wrap(iterator, sibling_iterator, sibling_iterator ) const
	{
	return Ex("");
	}