1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
|
import json
from itertools import cycle
from xml.etree import ElementTree as ET
from .api import RLAPI
class ZeroAD:
def __init__(self, uri="http://localhost:6000"):
self.api = RLAPI(uri)
self.current_state = None
self.cache = {}
self.player_id = 1
def step(self, actions=None, player=None):
if actions is None:
actions = []
player_ids = cycle([self.player_id]) if player is None else cycle(player)
cmds = zip(player_ids, actions, strict=False)
cmds = ((player, action) for (player, action) in cmds if action is not None)
state_json = self.api.step(cmds)
self.current_state = GameState(json.loads(state_json), self)
return self.current_state
def reset(self, config="", save_replay=False, player_id=1):
state_json = self.api.reset(config, player_id, save_replay)
self.current_state = GameState(json.loads(state_json), self)
return self.current_state
def evaluate(self, code):
return self.api.evaluate(code)
def get_template(self, name):
return self.get_templates([name])[0]
def get_templates(self, names):
templates = self.api.get_templates(names)
return [(name, EntityTemplate(content)) for (name, content) in templates]
def update_templates(self, types=None):
if types is None:
types = []
all_types = list({unit.type() for unit in self.current_state.units()})
all_types += types
template_pairs = self.get_templates(all_types)
self.cache = {}
for name, tpl in template_pairs:
self.cache[name] = tpl
return template_pairs
class GameState:
def __init__(self, data, game):
self.data = data
self.game = game
self.mapSize = self.data["mapSize"]
def units(self, owner=None, entity_type=None):
def filter_fn(e):
return (owner is None or e["owner"] == owner) and (
entity_type is None or entity_type in e["template"]
)
return [Entity(e, self.game) for e in self.data["entities"].values() if filter_fn(e)]
def unit(self, entity_id):
entity_id = str(entity_id)
return (
Entity(self.data["entities"][entity_id], self.game)
if entity_id in self.data["entities"]
else None
)
class Entity:
def __init__(self, data, game):
self.data = data
self.game = game
self.template = self.game.cache.get(self.type(), None)
def type(self):
return self.data["template"]
def id(self):
return self.data["id"]
def owner(self):
return self.data["owner"]
def max_health(self):
template = self.get_template()
return float(template.get("Health/Max"))
def health(self, ratio=False):
if ratio:
return self.data["hitpoints"] / self.max_health()
return self.data["hitpoints"]
def position(self):
return self.data["position"]
def get_template(self):
if self.template is None:
self.game.update_templates([self.type()])
self.template = self.game.cache[self.type()]
return self.template
class EntityTemplate:
def __init__(self, xml):
self.data = ET.fromstring(f"<Entity>{xml}</Entity>")
def get(self, path):
node = self.data.find(path)
return node.text if node is not None else None
def set(self, path, value):
node = self.data.find(path)
if node:
node.text = str(value)
return node is not None
def __str__(self):
return ET.tostring(self.data).decode("utf-8")
|