1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
|
/*
* enforce.cpp: deferred range checks.
*
* This file implements spigot_enforce(), which takes two spigot
* inputs and returns one of them effectively unchanged. The point of
* it is that it also continually checks to see which side of the
* second input the main one falls on, and if it ever finds out that
* that's the _wrong_ side, it throws an exception, which the caller
* has provided in advance (with an appropriate error message) in case
* it's needed.
*/
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include "spigot.h"
#include "funcs.h"
#include "error.h"
#define STRING(a,b) b
static const char *const relation_strings[] = { ENFORCE_RELATIONS(STRING) };
#undef STRING
class Enforcer : public BinaryIntervalSource {
Spigot *x_orig;
EnforceRelation rel;
Spigot *bound_orig;
spigot_error err;
BracketingGenerator bg_x, bg_bound;
public:
Enforcer(Spigot *x, EnforceRelation arel, Spigot *bound, spigot_error aerr)
: x_orig(x->clone()), rel(arel), bound_orig(bound->clone()),
err(aerr), bg_x(x), bg_bound(bound)
{
dprint("hello Enforcer %p %s %p", x, relation_strings[rel], bound);
}
virtual ~Enforcer()
{
delete x_orig;
delete bound_orig;
}
virtual Enforcer *clone()
{
return new Enforcer(x_orig->clone(), rel, bound_orig->clone(), err);
}
virtual bool is_rational(bigint *n, bigint *d)
{
return x_orig->is_rational(n, d);
}
virtual void gen_bin_interval(bigint *ret_lo, bigint *ret_hi,
unsigned *ret_bits)
{
bg_x.get_bracket_shift(ret_lo, ret_hi, ret_bits);
dprint("got x bracket (%b,%b) / 2^%d", ret_lo, ret_hi, (int)*ret_bits);
bg_bound.set_denominator_lower_bound_shift(*ret_bits);
bigint cmp_lo, cmp_hi;
unsigned cmp_bits;
bg_bound.get_bracket_shift(&cmp_lo, &cmp_hi, &cmp_bits);
dprint("got bound bracket (%b,%b) / 2^%d", &cmp_lo, &cmp_hi, cmp_bits);
assert(cmp_bits >= *ret_bits);
unsigned ret_shift = cmp_bits - *ret_bits;
bool ok;
if (rel == ENFORCE_GT || rel == ENFORCE_GE) {
ok = (*ret_hi << ret_shift) >= cmp_lo;
} else /* if (rel == ENFORCE_LT || rel == ENFORCE_LE) */ {
ok = (*ret_hi << ret_shift) <= cmp_lo;
}
if (!ok)
throw err;
}
};
Spigot *spigot_enforce(Spigot *x, EnforceRelation rel, Spigot *bound,
spigot_error err)
{
/*
* Start by at least _trying_ to report the error up front, if
* it's really obviously out of range: if we can detect the
* problem exactly via rationals, or if it's so far out that even
* a cursory check with get_approximate_approximant can tell.
*/
bigint xn, xd, bn, bd;
if (x->is_rational(&xn, &xd) && bound->is_rational(&bn, &bd)) {
// Use here that is_rational always returns a positive denominator
if ((rel == ENFORCE_GT && xn*bd <= bn*xd) ||
(rel == ENFORCE_GE && xn*bd < bn*xd) ||
(rel == ENFORCE_LT && xn*bd >= bn*xd) ||
(rel == ENFORCE_LE && xn*bd > bn*xd))
throw err;
} else {
StaticGenerator diffgen(spigot_sub(x->clone(), bound->clone()));
bigint lo, hi;
bool lo_open, hi_open;
diffgen.iterate_to_bounds(&lo, &hi, &lo_open, &hi_open, 0, NULL, true);
if (((rel == ENFORCE_GT || rel == ENFORCE_GE) &&
hi <= 0 && !(hi == 0 && !hi_open && rel == ENFORCE_GE)) ||
((rel == ENFORCE_LT || rel == ENFORCE_LE) &&
lo >= 0 && !(lo == 0 && !lo_open && rel == ENFORCE_LE))) {
throw err;
}
}
/*
* Failing that, use the above Enforcer class, which will watch
* for the number turning out to be on the wrong side of the
* boundary later on after further information comes to light.
*/
return new Enforcer(x, rel, bound, err);
}
|