File: FloatingPointArith.java

package info (click to toggle)
cvc5 1.3.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 87,260 kB
  • sloc: cpp: 383,850; java: 12,207; python: 12,090; sh: 5,679; ansic: 4,729; lisp: 763; perl: 208; makefile: 38
file content (124 lines) | stat: -rw-r--r-- 5,703 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/******************************************************************************
 * Top contributors (to current version):
 *   Aina Niemetz, Mudathir Mohamed, Andres Noetzli
 *
 * This file is part of the cvc5 project.
 *
 * Copyright (c) 2009-2025 by the authors listed in the file AUTHORS
 * in the top-level source directory and their institutional affiliations.
 * All rights reserved.  See the file COPYING in the top-level source
 * directory for licensing information.
 * ****************************************************************************
 *
 * An example of solving floating-point problems with cvc5's Java API
 *
 * This example shows to create floating-point types, variables and expressions,
 * and how to create rounding mode constants by solving toy problems. The
 * example also shows making special values (such as NaN and +oo) and converting
 * an IEEE 754-2008 bit-vector to a floating-point number.
 */

import static io.github.cvc5.Kind.*;

import io.github.cvc5.*;

public class FloatingPointArith
{
  public static void main(String[] args) throws CVC5ApiException
  {
    TermManager tm = new TermManager();
    Solver solver = new Solver(tm);
    {
      solver.setOption("incremental", "true");
      solver.setOption("produce-models", "true");

      // Make single precision floating-point variables
      Sort fpt32 = tm.mkFloatingPointSort(8, 24);
      Term a = tm.mkConst(fpt32, "a");
      Term b = tm.mkConst(fpt32, "b");
      Term c = tm.mkConst(fpt32, "c");
      Term d = tm.mkConst(fpt32, "d");
      Term e = tm.mkConst(fpt32, "e");
      // Rounding mode
      Term rm = tm.mkRoundingMode(RoundingMode.ROUND_NEAREST_TIES_TO_EVEN);

      System.out.println("Show that fused multiplication and addition `(fp.fma RM a b c)`");
      System.out.println("is different from `(fp.add RM (fp.mul a b) c)`:");
      solver.push(1);
      Term fma = tm.mkTerm(Kind.FLOATINGPOINT_FMA, new Term[] {rm, a, b, c});
      Term mul = tm.mkTerm(Kind.FLOATINGPOINT_MULT, rm, a, b);
      Term add = tm.mkTerm(Kind.FLOATINGPOINT_ADD, rm, mul, c);
      solver.assertFormula(tm.mkTerm(Kind.DISTINCT, fma, add));
      Result r = solver.checkSat(); // result is sat
      System.out.println("Expect sat: " + r);
      System.out.println("Value of `a`: " + solver.getValue(a));
      System.out.println("Value of `b`: " + solver.getValue(b));
      System.out.println("Value of `c`: " + solver.getValue(c));
      System.out.println("Value of `(fp.fma RNE a b c)`: " + solver.getValue(fma));
      System.out.println("Value of `(fp.add RNE (fp.mul a b) c)`: " + solver.getValue(add));
      System.out.println();
      solver.pop(1);

      System.out.println("Show that floating-point addition is not associative:");
      System.out.println("(a + (b + c)) != ((a + b) + c)");
      Term lhs =
          tm.mkTerm(Kind.FLOATINGPOINT_ADD, rm, a, tm.mkTerm(Kind.FLOATINGPOINT_ADD, rm, b, c));
      Term rhs =
          tm.mkTerm(Kind.FLOATINGPOINT_ADD, rm, tm.mkTerm(Kind.FLOATINGPOINT_ADD, rm, a, b), c);
      solver.assertFormula(tm.mkTerm(Kind.NOT, tm.mkTerm(Kind.EQUAL, a, b)));

      r = solver.checkSat(); // result is sat
      assert r.isSat();

      System.out.println("Value of `a`: " + solver.getValue(a));
      System.out.println("Value of `b`: " + solver.getValue(b));
      System.out.println("Value of `c`: " + solver.getValue(c));

      System.out.println("Now, restrict `a` to be either NaN or positive infinity:");
      Term nan = tm.mkFloatingPointNaN(8, 24);
      Term inf = tm.mkFloatingPointPosInf(8, 24);
      solver.assertFormula(
          tm.mkTerm(Kind.OR, tm.mkTerm(Kind.EQUAL, a, inf), tm.mkTerm(Kind.EQUAL, a, nan)));

      r = solver.checkSat(); // result is sat
      assert r.isSat();

      System.out.println("Value of `a`: " + solver.getValue(a));
      System.out.println("Value of `b`: " + solver.getValue(b));
      System.out.println("Value of `c`: " + solver.getValue(c));

      System.out.println("Now, try to find a (normal) floating-point number that rounds");
      System.out.println("to different integer values for different rounding modes:");
      Term rtp = tm.mkRoundingMode(RoundingMode.ROUND_TOWARD_POSITIVE);
      Term rtn = tm.mkRoundingMode(RoundingMode.ROUND_TOWARD_NEGATIVE);
      Op op = tm.mkOp(Kind.FLOATINGPOINT_TO_UBV, 16); // (_ fp.to_ubv 16)
      lhs = tm.mkTerm(op, rtp, d);
      rhs = tm.mkTerm(op, rtn, d);
      solver.assertFormula(tm.mkTerm(Kind.FLOATINGPOINT_IS_NORMAL, d));
      solver.assertFormula(tm.mkTerm(Kind.NOT, tm.mkTerm(Kind.EQUAL, lhs, rhs)));

      r = solver.checkSat(); // result is sat
      assert r.isSat();

      System.out.println("Get value of `d` as floating-point, bit-vector and real:");
      Term val = solver.getValue(d);
      System.out.println("Value of `d`: " + val);
      System.out.println("Value of `((_ fp.to_ubv 16) RTP d)`: " + solver.getValue(lhs));
      System.out.println("Value of `((_ fp.to_ubv 16) RTN d)`: " + solver.getValue(rhs));
      System.out.println("Value of `(fp.to_real d)`: "
          + solver.getValue(tm.mkTerm(Kind.FLOATINGPOINT_TO_REAL, val)));

      System.out.println("Finally, try to find a floating-point number between positive");
      System.out.println("zero and the smallest positive floating-point number:");
      Term zero = tm.mkFloatingPointPosZero(8, 24);
      Term smallest = tm.mkFloatingPoint(8, 24, tm.mkBitVector(32, 0b001));
      solver.assertFormula(tm.mkTerm(Kind.AND,
          tm.mkTerm(Kind.FLOATINGPOINT_LT, zero, e),
          tm.mkTerm(Kind.FLOATINGPOINT_LT, e, smallest)));

      r = solver.checkSat(); // result is unsat
      assert !r.isSat();
    }
    Context.deletePointers();
  }
}