1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
|
from pomegranate import *
def test_json():
d1 = DiscreteDistribution({"A": 0.1, "B": 0.9})
d2 = ConditionalProbabilityTable(
[["A", "A", 0.1], ["A", "B", 0.9], ["B", "A", 0.7], ["B", "B", 0.3]],
[d1])
bayes_net = BayesianNetwork("test")
s1 = Node(d1, name="d1")
s2 = Node(d2, name="d2")
bayes_net.add_states(s1, s2)
bayes_net.add_edge(s1, s2)
bayes_net.bake()
fg = bayes_net.graph
also_fg = FactorGraph.from_json(fg.to_json())
assert fg.to_json() == also_fg.to_json()
assert len(fg.edges) == len(also_fg.edges)
for e1, e2 in zip(fg.edges, also_fg.edges):
assert e1[0] == e2[0]
assert e1[1] == e2[1]
def test_robust_json():
d1 = DiscreteDistribution({"A": 0.1, "B": 0.9})
d2 = ConditionalProbabilityTable(
[["A", "A", 0.1], ["A", "B", 0.9], ["B", "A", 0.7], ["B", "B", 0.3]],
[d1])
bayes_net = BayesianNetwork("test")
s1 = Node(d1, name="d1")
s2 = Node(d2, name="d2")
bayes_net.add_states(s1, s2)
bayes_net.add_edge(s1, s2)
bayes_net.bake()
fg = bayes_net.graph
also_fg = from_json(fg.to_json())
assert fg.to_json() == also_fg.to_json()
assert len(fg.edges) == len(also_fg.edges)
for e1, e2 in zip(fg.edges, also_fg.edges):
assert e1[0] == e2[0]
assert e1[1] == e2[1]
|