# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
import copy
import json
import textwrap
from typing import List, Dict, Optional, Mapping, Any, Iterable, Union, Tuple
from collections import OrderedDict
from functools import partial
import numpy as np
from numpy.random import RandomState
from textworld import g_rng
from textworld.utils import encode_seeds
from textworld.generator.data import KnowledgeBase
from textworld.generator.text_grammar import Grammar, GrammarOptions
from textworld.generator.world import World
from textworld.logic import Action, Proposition, State
from textworld.generator.graph_networks import DIRECTIONS
from textworld.generator.chaining import ChainingOptions
from textworld.generator.dependency_tree import DependencyTree
from textworld.generator.dependency_tree import DependencyTreeElement
[docs]class UnderspecifiedEventError(NameError):
def __init__(self):
msg = "Either the actions or the conditions is needed to create an event."
super().__init__(msg)
[docs]class UnderspecifiedQuestError(NameError):
def __init__(self):
msg = "At least one winning or failing event is needed to create a quest."
super().__init__(msg)
[docs]def gen_commands_from_actions(actions: Iterable[Action], kb: Optional[KnowledgeBase] = None) -> List[str]:
kb = kb or KnowledgeBase.default()
def _get_name_mapping(action):
mapping = kb.rules[action.name].match(action)
return {ph.name: var.name for ph, var in mapping.items()}
commands = []
for action in actions:
command = "None"
if action is not None:
command = kb.inform7_commands[action.name]
command = command.format(**_get_name_mapping(action))
commands.append(command)
return commands
[docs]class Event:
"""
Event happening in TextWorld.
An event gets triggered when its set of conditions become all statisfied.
"""
def __init__(self, actions: Iterable[Action] = (),
conditions: Iterable[Proposition] = (),
commands: Iterable[str] = ()) -> None:
"""
Args:
actions: The actions to be performed to trigger this event.
If an empty list, then `conditions` must be provided.
conditions: Set of propositions which need to
be all true in order for this event
to get triggered.
commands: Human readable version of the actions.
"""
self.actions = actions
self.commands = commands
#: :py:class:`textworld.logic.Action`: Action that can only be applied
#: when all conditions are statisfied.
self.condition = self.set_conditions(conditions)
@property
def actions(self) -> Tuple[Action]:
""" Actions to perform to trigger this event. """
return self._actions
@actions.setter
def actions(self, actions: Iterable[Action]) -> None:
self._actions = tuple(actions)
@property
def commands(self) -> Tuple[str]:
""" Human readable version of the actions. """
return self._commands
@commands.setter
def commands(self, commands: Iterable[str]) -> None:
self._commands = tuple(commands)
[docs] def is_triggering(self, state: State) -> bool:
""" Check if this event would be triggered in a given state. """
return state.is_applicable(self.condition)
[docs] def set_conditions(self, conditions: Iterable[Proposition]) -> Action:
"""
Set the triggering conditions for this event.
Args:
conditions: Set of propositions which need to
be all true in order for this event
to get triggered.
Returns:
Action that can only be applied when all conditions are statisfied.
"""
if not conditions:
if len(self.actions) == 0:
raise UnderspecifiedEventError()
# The default winning conditions are the postconditions of the
# last action in the quest.
conditions = self.actions[-1].postconditions
variables = sorted(set([v for c in conditions for v in c.arguments]))
event = Proposition("event", arguments=variables)
self.condition = Action("trigger", preconditions=conditions,
postconditions=list(conditions) + [event])
return self.condition
def __hash__(self) -> int:
return hash((self.actions, self.commands, self.condition))
def __eq__(self, other: Any) -> bool:
return (isinstance(other, Event)
and self.actions == other.actions
and self.commands == other.commands
and self.condition == other.condition)
[docs] @classmethod
def deserialize(cls, data: Mapping) -> "Event":
""" Creates an `Event` from serialized data.
Args:
data: Serialized data with the needed information to build a
`Event` object.
"""
actions = [Action.deserialize(d) for d in data["actions"]]
condition = Action.deserialize(data["condition"])
event = cls(actions, condition.preconditions, data["commands"])
return event
[docs] def serialize(self) -> Mapping:
""" Serialize this event.
Results:
`Event`'s data serialized to be JSON compatible.
"""
data = {}
data["commands"] = self.commands
data["actions"] = [action.serialize() for action in self.actions]
data["condition"] = self.condition.serialize()
return data
[docs] def copy(self) -> "Event":
""" Copy this event. """
return self.deserialize(self.serialize())
[docs]class Quest:
""" Quest representation in TextWorld.
A quest is defined by a mutually exclusive set of winning events and
a mutually exclusive set of failing events.
"""
def __init__(self,
win_events: Iterable[Event] = (),
fail_events: Iterable[Event] = (),
reward: Optional[int] = None,
desc: Optional[str] = None,
commands: Iterable[str] = (),
optional: bool = False,
repeatable: bool = False) -> None:
r"""
Args:
win_events: Mutually exclusive set of winning events. That is,
only one such event needs to be triggered in order
to complete this quest.
fail_events: Mutually exclusive set of failing events. That is,
only one such event needs to be triggered in order
to fail this quest.
reward: Reward given for completing this quest. By default,
reward is set to 1 if there is at least one winning events
otherwise it is set to 0.
desc: A text description of the quest.
commands: List of text commands leading to this quest completion.
optional: If True, this quest is optional to finish the game.
repeatable: If True, this quest can be completed more than once.
"""
self.win_events = tuple(win_events)
self.fail_events = tuple(fail_events)
#: str: A text description of the quest.
self.desc = desc
self.commands = tuple(commands)
#: bool: Whether this quest is optional or not to finish the game.
self.optional = optional
#: bool: Whether this quest can be completed more than once.
self.repeatable = repeatable
if self.repeatable:
assert self.optional # Only optional quest can be repeatable.
#: int: Reward given for completing this quest.
# Unless explicitly provided, reward is set to 1 if there is at least
# one winning events otherwise it is set to 0.
self.reward = int(len(win_events) > 0) if reward is None else reward
if len(self.win_events) == 0 and len(self.fail_events) == 0:
raise UnderspecifiedQuestError()
@property
def win_events(self) -> Tuple[Event]:
""" Mutually exclusive set of winning events. That is,
only one such event needs to be triggered in order
to complete this quest.
"""
return self._win_events
@win_events.setter
def win_events(self, events: Iterable[Event]) -> None:
self._win_events = tuple(events)
@property
def fail_events(self) -> Tuple[Event]:
""" Mutually exclusive set of failing events. That is,
only one such event needs to be triggered in order
to fail this quest.
"""
return self._fail_events
@fail_events.setter
def fail_events(self, events: Iterable[Event]) -> None:
self._fail_events = tuple(events)
@property
def commands(self) -> Iterable[str]:
""" List of text commands leading to this quest completion. """
return self._commands
@commands.setter
def commands(self, commands: Iterable[str]) -> None:
self._commands = tuple(commands)
[docs] def is_winning(self, state: State) -> bool:
""" Check if this quest is winning in that particular state. """
return any(event.is_triggering(state) for event in self.win_events)
[docs] def is_failing(self, state: State) -> bool:
""" Check if this quest is failing in that particular state. """
return any(event.is_triggering(state) for event in self.fail_events)
def __hash__(self) -> int:
return hash((self.win_events, self.fail_events, self.reward,
self.desc, self.commands, self.optional, self.repeatable))
def __eq__(self, other: Any) -> bool:
return (isinstance(other, Quest)
and self.win_events == other.win_events
and self.fail_events == other.fail_events
and self.reward == other.reward
and self.desc == other.desc
and self.commands == other.commands
and self.optional == other.optional
and self.repeatable == other.repeatable)
[docs] @classmethod
def deserialize(cls, data: Mapping) -> "Quest":
""" Creates a `Quest` from serialized data.
Args:
data: Serialized data with the needed information to build a
`Quest` object.
"""
win_events = [Event.deserialize(d) for d in data["win_events"]]
fail_events = [Event.deserialize(d) for d in data["fail_events"]]
commands = data.get("commands", [])
reward = data["reward"]
desc = data["desc"]
optional = data.get("optional", False)
repeatable = data.get("repeatable", False)
return cls(win_events, fail_events, reward, desc, commands, optional, repeatable)
[docs] def serialize(self) -> Mapping:
""" Serialize this quest.
Returns:
Quest's data serialized to be JSON compatible.
"""
data = {}
data["desc"] = self.desc
data["reward"] = self.reward
data["commands"] = self.commands
data["win_events"] = [event.serialize() for event in self.win_events]
data["fail_events"] = [event.serialize() for event in self.fail_events]
data["optional"] = self.optional
data["repeatable"] = self.repeatable
return data
[docs] def copy(self) -> "Quest":
""" Copy this quest. """
return self.deserialize(self.serialize())
[docs]class EntityInfo:
""" Additional information about entities in the game. """
__slots__ = ['id', 'type', 'name', 'noun', 'adj', 'desc', 'room_type', 'definite', 'indefinite', 'synonyms']
def __init__(self, id: str, type: str) -> None:
#: str: Unique name for this entity. It is used when generating
self.id = id
#: str: The type of this entity.
self.type = type
#: str: The name that will be displayed in-game to identify this entity.
self.name = None
#: str: The noun part of the name, if available.
self.noun = None
#: str: The adjective (i.e. descriptive) part of the name, if available.
self.adj = None
#: str: The definite article to use for this entity.
self.definite = None
#: str: The indefinite article to use for this entity.
self.indefinite = None
#: List[str]: Alternative names that can be used to refer to this entity.
self.synonyms = None
#: str: Text description displayed when examining this entity in the game.
self.desc = None
#: str: Type of the room this entity belongs to. It used to influence
#: its `name` during text generation.
self.room_type = None
def __eq__(self, other: Any) -> bool:
return (isinstance(other, EntityInfo)
and all(getattr(self, slot) == getattr(other, slot)
for slot in self.__slots__))
def __hash__(self) -> int:
return hash(tuple(getattr(self, slot) for slot in self.__slots__))
def __str__(self) -> str:
return "Info({}: {} | {})".format(self.name, self.adj, self.noun)
[docs] @classmethod
def deserialize(cls, data: Mapping) -> "EntityInfo":
""" Creates a `EntityInfo` from serialized data.
Args:
data: Serialized data with the needed information to build a
`EntityInfo` object.
"""
info = cls(data["id"], data["type"])
for slot in cls.__slots__:
setattr(info, slot, data.get(slot))
return info
[docs] def serialize(self) -> Mapping:
""" Serialize this object.
Returns:
EntityInfo's data serialized to be JSON compatible.
"""
return {slot: getattr(self, slot) for slot in self.__slots__}
[docs]class Game:
""" Game representation in TextWorld.
A `Game` is defined by a world and it can have quest(s) or not.
Additionally, a grammar can be provided to control the text generation.
"""
_SERIAL_VERSION = 1
def __init__(self, world: World, grammar: Optional[Grammar] = None,
quests: Iterable[Quest] = ()) -> None:
"""
Args:
world: The world to use for the game.
quests: The quests to be done in the game.
grammar: The grammar to control the text generation.
"""
self.world = world
self.quests = tuple(quests)
self.metadata = {}
self._objective = None
self._infos = self._build_infos()
self.kb = world.kb
self.change_grammar(grammar)
@property
def infos(self) -> Dict[str, EntityInfo]:
""" Information about the entities in the game. """
return self._infos
def _build_infos(self) -> Dict[str, EntityInfo]:
mapping = OrderedDict()
for entity in self.world.entities:
if entity not in mapping:
mapping[entity.id] = EntityInfo(entity.id, entity.type)
return mapping
[docs] def copy(self) -> "Game":
""" Make a shallow copy of this game. """
game = Game(self.world, None, self.quests)
game._infos = dict(self.infos)
game._objective = self._objective
game.metadata = dict(self.metadata)
return game
[docs] def change_grammar(self, grammar: Grammar) -> None:
""" Changes the grammar used and regenerate all text. """
self.grammar = grammar
_gen_commands = partial(gen_commands_from_actions, kb=self.kb)
if self.grammar:
from textworld.generator.inform7 import Inform7Game
from textworld.generator.text_generation import generate_text_from_grammar
inform7 = Inform7Game(self)
_gen_commands = inform7.gen_commands_from_actions
generate_text_from_grammar(self, self.grammar)
for quest in self.quests:
for event in quest.win_events:
event.commands = _gen_commands(event.actions)
if quest.win_events:
quest.commands = quest.win_events[0].commands
# Check if we can derive a global winning policy from the quests.
if self.grammar:
from textworld.generator.text_generation import describe_event
policy = GameProgression(self).winning_policy
if policy:
mapping = {k: info.name for k, info in self._infos.items()}
commands = [a.format_command(mapping) for a in policy]
self.metadata["walkthrough"] = commands
self.objective = describe_event(Event(policy), self, self.grammar)
[docs] def save(self, filename: str) -> None:
""" Saves the serialized data of this game to a file. """
with open(filename, 'w') as f:
json.dump(self.serialize(), f)
[docs] @classmethod
def load(cls, filename: str) -> "Game":
""" Creates `Game` from serialized data saved in a file. """
with open(filename, 'r') as f:
return cls.deserialize(json.load(f))
[docs] @classmethod
def deserialize(cls, data: Mapping) -> "Game":
""" Creates a `Game` from serialized data.
Args:
data: Serialized data with the needed information to build a
`Game` object.
"""
version = data.get("version", cls._SERIAL_VERSION)
if version != cls._SERIAL_VERSION:
msg = "Cannot deserialize a TextWorld version {} game, expected version {}"
raise ValueError(msg.format(version, cls._SERIAL_VERSION))
kb = KnowledgeBase.deserialize(data["KB"])
world = World.deserialize(data["world"], kb=kb)
game = cls(world)
game.grammar_options = GrammarOptions(data["grammar"])
game.quests = tuple([Quest.deserialize(d) for d in data["quests"]])
game._infos = {k: EntityInfo.deserialize(v) for k, v in data["infos"]}
game.metadata = data.get("metadata", {})
game._objective = data.get("objective", None)
return game
[docs] def serialize(self) -> Mapping:
""" Serialize this object.
Results:
Game's data serialized to be JSON compatible
"""
data = {}
data["version"] = self._SERIAL_VERSION
data["world"] = self.world.serialize()
data["grammar"] = self.grammar.options.serialize() if self.grammar else {}
data["quests"] = [quest.serialize() for quest in self.quests]
data["infos"] = [(k, v.serialize()) for k, v in self._infos.items()]
data["KB"] = self.kb.serialize()
data["metadata"] = self.metadata
data["objective"] = self._objective
return data
def __eq__(self, other: Any) -> bool:
return (isinstance(other, Game)
and self.world == other.world
and self.infos == other.infos
and self.quests == other.quests
and self.metadata == other.metadata
and self._objective == other._objective)
def __hash__(self) -> int:
state = (self.world,
frozenset(self.quests),
frozenset(self.infos.items()),
self._objective)
return hash(state)
@property
def max_score(self) -> float:
""" Sum of the reward of all quests. """
if any(quest.repeatable and quest.reward > 0 for quest in self.quests):
return np.inf
return sum(quest.reward for quest in self.quests if not quest.optional or quest.reward > 0)
@property
def command_templates(self) -> List[str]:
""" All command templates understood in this game. """
return sorted(set(cmd for cmd in self.kb.inform7_commands.values()))
@property
def directions_names(self) -> List[str]:
return DIRECTIONS
@property
def objects_types(self) -> List[str]:
""" All types of objects in this game. """
return sorted(self.kb.types.types)
@property
def objects_names(self) -> List[str]:
""" The names of all relevant objects in this game. """
def _filter_unnamed_and_room_entities(e):
return e.name and e.type != "r"
entities_infos = filter(_filter_unnamed_and_room_entities, self.infos.values())
return [info.name for info in entities_infos]
@property
def entity_names(self) -> List[str]:
return self.objects_names + self.directions_names
@property
def objects_names_and_types(self) -> List[str]:
""" The names of all non-player objects along with their type in this game. """
def _filter_unnamed_and_room_entities(e):
return e.name and e.type != "r"
entities_infos = filter(_filter_unnamed_and_room_entities, self.infos.values())
return [(info.name, info.type) for info in entities_infos]
@property
def verbs(self) -> List[str]:
""" Verbs that should be recognized in this game. """
# Retrieve commands templates for every rule.
return sorted(set(cmd.split()[0] for cmd in self.command_templates))
@property
def objective(self) -> str:
if self._objective is not None:
return self._objective
# TODO: Find a better way of describing the objective of the game with several quests.
self._objective = "\nAND\n".join(quest.desc for quest in self.quests if quest.desc)
return self._objective
@objective.setter
def objective(self, value: str):
self._objective = value
@property
def walkthrough(self) -> Optional[List[str]]:
walkthrough = self.metadata.get("walkthrough")
if walkthrough:
return walkthrough
# Check if we can derive a walkthrough from the quests.
policy = GameProgression(self).winning_policy
if policy:
mapping = {k: info.name for k, info in self._infos.items()}
walkthrough = [a.format_command(mapping) for a in policy]
self.metadata["walkthrough"] = walkthrough
return walkthrough
[docs]class ActionDependencyTreeElement(DependencyTreeElement):
""" Representation of an `Action` in the dependency tree.
The notion of dependency and ordering is defined as follows:
* action1 depends on action2 if action1 needs the propositions
added by action2;
* action1 should be performed before action2 if action2 removes
propositions needed by action1.
"""
[docs] def depends_on(self, other: "ActionDependencyTreeElement") -> bool:
""" Check whether this action depends on the `other`.
Action1 depends on action2 when the intersection between
the propositions added by action2 and the preconditions
of the action1 is not empty, i.e. action1 needs the
propositions added by action2.
"""
return len(other.action.added & self.action._pre_set) > 0
@property
def action(self) -> Action:
return self.value
[docs] def is_distinct_from(self, others: List["ActionDependencyTreeElement"]) -> bool:
"""
Check whether this element is distinct from `others`.
We check if self.action has any additional information
that `others` actions don't have. This helps us to
identify whether a group of nodes in the dependency tree
already contain all the needed information that self.action
would bring.
"""
new_facts = set(self.action.added)
for other in others:
new_facts -= other.action.added
return len(new_facts) > 0
def __lt__(self, other: "ActionDependencyTreeElement") -> bool:
""" Order ActionDependencyTreeElement elements.
Actions that remove information needed by other actions
should be sorted further in the list.
Notes:
This is not a proper ordering, i.e. two actions
can mutually removed information needed by each other.
"""
def _required_facts(node):
pre_set = set(node.action._pre_set)
while node.parent is not None:
pre_set |= node.parent.action._pre_set
pre_set -= node.action.added
node = node.parent
return pre_set
return len(other.action.removed & _required_facts(self)) > len(self.action.removed & _required_facts(other))
def __str__(self) -> str:
params = ", ".join(map(str, self.action.variables))
return "{}({})".format(self.action.name, params)
[docs]class ActionDependencyTree(DependencyTree):
def __init__(self, *args, kb: Optional[KnowledgeBase] = None, **kwargs):
super().__init__(*args, **kwargs)
self._kb = kb or KnowledgeBase.default()
[docs] def remove(self, action: Action) -> Tuple[bool, Optional[Action]]:
changed = super().remove(action)
# The last action might have impacted one of the subquests.
reverse_action = self._kb.get_reverse_action(action)
if self.empty:
return changed, reverse_action
if reverse_action is not None:
changed = self.push(reverse_action)
elif self.push(action.inverse()):
# The last action did impact one of the subquests
# but there's no reverse action to recover from it.
changed = True
return changed, reverse_action
[docs] def flatten(self) -> Iterable[Action]:
"""
Generates a flatten representation of this dependency tree.
Actions are greedily yielded by iteratively popping leaves from
the dependency tree.
"""
tree = self.copy() # Make a copy of the tree to work on.
last_reverse_action = None
changed = False
while len(tree.roots) > 0:
# Use 'sort' to try leaves that doesn't affect the others first.
for leaf in sorted(tree.leaves_elements):
if leaf.action != last_reverse_action or not changed:
break # Choose an action that avoids cycles.
yield leaf.action
changed, last_reverse_action = tree.remove(leaf.action)
# Prune empty roots
for root in list(tree.roots):
if len(root.children) == 0:
yield root.element.action
tree.remove(root.element.action)
[docs] def copy(self) -> "ActionDependencyTree":
tree = super().copy()
tree._kb = self._kb
return tree
[docs]class EventProgression:
""" EventProgression monitors a particular event.
Internally, the event is represented as a dependency tree of
relevant actions to be performed.
"""
def __init__(self, event: Event, kb: KnowledgeBase) -> None:
"""
Args:
quest: The quest to keep track of its completion.
"""
self._kb = kb or KnowledgeBase.default()
self.event = event
self._triggered = False
self._untriggerable = False
self._policy = ()
# Build a tree representation of the quest.
self._tree = ActionDependencyTree(kb=self._kb,
element_type=ActionDependencyTreeElement)
if len(event.actions) > 0:
self._tree.push(event.condition)
for action in event.actions[::-1]:
self._tree.push(action)
self._policy = event.actions + (event.condition,)
[docs] def copy(self) -> "EventProgression":
""" Return a soft copy. """
ep = EventProgression(self.event, self._kb)
ep._triggered = self._triggered
ep._untriggerable = self._untriggerable
ep._policy = self._policy
ep._tree = self._tree.copy()
return ep
@property
def triggering_policy(self) -> List[Action]:
""" Actions to be performed in order to trigger the event. """
if self.done:
return ()
# Discard all "trigger" actions.
return tuple(a for a in self._policy if a.name != "trigger")
@property
def done(self) -> bool:
""" Check if the quest is done (i.e. triggered or untriggerable). """
return self.triggered or self.untriggerable
@property
def triggered(self) -> bool:
""" Check whether the event has been triggered. """
return self._triggered
@property
def untriggerable(self) -> bool:
""" Check whether the event is in an untriggerable state. """
return self._untriggerable
[docs] def update(self, action: Optional[Action] = None, state: Optional[State] = None) -> None:
""" Update event progression given available information.
Args:
action: Action potentially affecting the event progression.
state: Current game state.
"""
if self.done:
return # Nothing to do, the quest is already done.
if state is not None:
# Check if event is triggered.
self._triggered = self.event.is_triggering(state)
# Try compressing the winning policy given the new game state.
if self.compress_policy(state):
return # A shorter winning policy has been found.
if action is not None and not self._tree.empty:
# Determine if we moved away from the goal or closer to it.
changed, reverse_action = self._tree.remove(action)
if changed and reverse_action is None: # Irreversible action.
self._untriggerable = True # Can't track quest anymore.
if changed and reverse_action is not None:
# Rebuild policy.
self._policy = tuple(self._tree.flatten())
[docs] def compress_policy(self, state: State) -> bool:
""" Compress the policy given a game state.
Args:
state: Current game state.
Returns:
Whether the policy was compressed or not.
"""
def _find_shorter_policy(policy):
for j in range(0, len(policy)):
for i in range(j + 1, len(policy))[::-1]:
shorter_policy = policy[:j] + policy[i:]
if state.is_sequence_applicable(shorter_policy):
self._tree = ActionDependencyTree(kb=self._kb,
element_type=ActionDependencyTreeElement)
for action in shorter_policy[::-1]:
self._tree.push(action)
return shorter_policy
return None
compressed = False
policy = _find_shorter_policy(tuple(a for a in self._tree.flatten()))
while policy is not None:
compressed = True
self._policy = policy
policy = _find_shorter_policy(policy)
return compressed
[docs]class QuestProgression:
""" QuestProgression keeps track of the completion of a quest.
Internally, the quest is represented as a dependency tree of
relevant actions to be performed.
"""
def __init__(self, quest: Quest, kb: KnowledgeBase) -> None:
"""
Args:
quest: The quest to keep track of its completion.
"""
self.quest = quest
self.kb = kb
self.nb_completions = 0
self.win_events = [EventProgression(event, kb) for event in quest.win_events]
self.fail_events = [EventProgression(event, kb) for event in quest.fail_events]
[docs] def copy(self) -> "QuestProgression":
""" Return a soft copy. """
qp = QuestProgression(self.quest, self.kb)
qp.win_events = [event_progression.copy() for event_progression in self.win_events]
qp.fail_events = [event_progression.copy() for event_progression in self.fail_events]
qp.nb_completions = self.nb_completions
return qp
@property
def _tree(self) -> Optional[List[ActionDependencyTree]]:
events = [event for event in self.win_events if len(event.triggering_policy) > 0]
if len(events) == 0:
return None
event = min(events, key=lambda event: len(event.triggering_policy))
return event._tree
@property
def winning_policy(self) -> Optional[List[Action]]:
""" Actions to be performed in order to complete the quest. """
if self.done:
return None
winning_policies = [event.triggering_policy for event in self.win_events if len(event.triggering_policy) > 0]
if len(winning_policies) == 0:
return None
return min(winning_policies, key=lambda policy: len(policy))
@property
def completable(self) -> bool:
""" Check if the quest has winning events. """
return len(self.win_events) > 0
@property
def done(self) -> bool:
""" Check if the quest is done (i.e. completed, failed or unfinishable). """
return self.completed or self.failed or self.unfinishable
@property
def completed(self) -> bool:
""" Check whether the quest is completed. """
return any(event.triggered for event in self.win_events)
@property
def failed(self) -> bool:
""" Check whether the quest has failed. """
return any(event.triggered for event in self.fail_events)
@property
def unfinishable(self) -> bool:
""" Check whether the quest is in an unfinishable state. """
return any(event.untriggerable for event in self.win_events)
[docs] def update(self, action: Optional[Action] = None, state: Optional[State] = None) -> None:
""" Update quest progression given available information.
Args:
action: Action potentially affecting the quest progression.
state: Current game state.
"""
if self.done:
return # Nothing to do, the quest is already done.
for event in (self.win_events + self.fail_events):
event.update(action, state)
if self.completed:
self.nb_completions += 1
# If repeatable quest is completed, reset its win_events' triggered state.
if self.quest.repeatable:
for event in self.win_events:
event._triggered = False
assert not self.completed # TODO make a unit test for this.
[docs]class GameProgression:
""" GameProgression keeps track of the progression of a game.
If `tracking_quests` is True, then `winning_policy` will be the list
of Action that need to be applied in order to complete the game.
"""
def __init__(self, game: Game, track_quests: bool = True) -> None:
"""
Args:
game: The game for which to track progression.
track_quests: whether quest progressions are being tracked.
"""
self.game = game
self.state = game.world.state.copy()
self._valid_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(),
self.game.kb.types.constants_mapping))
self.quest_progressions = []
if track_quests:
self.quest_progressions = [QuestProgression(quest, game.kb) for quest in game.quests]
for quest_progression in self.quest_progressions:
quest_progression.update(action=None, state=self.state)
[docs] def copy(self) -> "GameProgression":
""" Return a soft copy. """
gp = GameProgression(self.game, track_quests=False)
gp.state = self.state.copy()
gp._valid_actions = self._valid_actions
if self.tracking_quests:
gp.quest_progressions = [quest_progression.copy() for quest_progression in self.quest_progressions]
return gp
@property
def done(self) -> bool:
""" Whether all non-optional quests are completed or at least one has failed or is unfinishable. """
return self.completed or self.failed
@property
def completed(self) -> bool:
""" Whether all non-optional quests are completed. """
if not self.tracking_quests:
return False # There is nothing to be "completed".
return all(qp.completed for qp in self.quest_progressions if qp.completable and not qp.quest.optional)
@property
def failed(self) -> bool:
""" Whether at least one non-optional quest has failed or is unfinishable. """
if not self.tracking_quests:
return False # There is nothing to be "failed".
return any((qp.failed or qp.unfinishable) for qp in self.quest_progressions if not qp.quest.optional)
@property
def score(self) -> int:
""" Sum of the reward of all completed quests. """
return sum(qp.quest.reward * qp.nb_completions for qp in self.quest_progressions)
@property
def tracking_quests(self) -> bool:
""" Whether quests are being tracked or not. """
return len(self.quest_progressions) > 0
@property
def valid_actions(self) -> List[Action]:
""" Actions that are valid at the current state. """
return self._valid_actions
@property
def winning_policy(self) -> Optional[List[Action]]:
""" Actions to be performed in order to complete the game.
Returns:
A policy that leads to winning the game. It can be `None`
if `tracking_quests` is `False` or the quest has failed.
"""
if not self.tracking_quests:
return None
if self.done:
return None
# Greedily build a new winning policy by merging all quest trees.
trees = [quest._tree for quest in self.quest_progressions
if quest.completable and not quest.done and not quest.quest.optional]
if None in trees:
# Some quests don't have triggering policy.
return None
main_quest_tree = ActionDependencyTree(kb=self.game.kb,
element_type=ActionDependencyTreeElement,
trees=trees)
# Discard all "trigger" actions.
return tuple(a for a in main_quest_tree.flatten() if a.name != "trigger")
[docs] def update(self, action: Action) -> None:
""" Update the state of the game given the provided action.
Args:
action: Action affecting the state of the game.
"""
# Update world facts.
self.state.apply(action)
# Get valid actions.
self._valid_actions = list(self.state.all_applicable_actions(self.game.kb.rules.values(),
self.game.kb.types.constants_mapping))
# Update all quest progressions given the last action and new state.
for quest_progression in self.quest_progressions:
quest_progression.update(action, self.state)
[docs]class GameOptions:
"""
Options for customizing the game generation.
Attributes:
nb_rooms (int):
Number of rooms in the game.
nb_objects (int):
Number of objects in the game.
nb_parallel_quests (int):
Number of parallel quests, i.e. not sharing a common goal.
quest_depth (int):
Number of actions that need to be performed to solve a subquest.
path (str):
Path of the compiled game (.ulx or .z8). Also, the source (.ni)
and metadata (.json) files will be saved along with it.
force_recompile (bool):
If `True`, recompile game even if it already exists.
file_ext (str):
Type of the generated game file. Either .z8 (Z-Machine) or .ulx (Glulx).
If `path` already has an extension, this is ignored.
chaining (ChainingOptions):
For customizing the quest generation (see
:py:class:`textworld.generator.ChainingOptions <textworld.generator.chaining.ChainingOptions>`
for the list of available options).
grammar (GrammarOptions):
For customizing the text generation (see
:py:class:`textworld.generator.GrammarOptions <textworld.generator.text_grammar.GrammarOptions>`
for the list of available options).
"""
def __init__(self):
self.chaining = ChainingOptions()
self.grammar = GrammarOptions()
self._kb = None
self._seeds = None
self.nb_parallel_quests = 1
self.nb_rooms = 1
self.nb_objects = 1
self.force_recompile = False
self.file_ext = ".ulx"
self.path = "./tw_games/"
@property
def quest_length(self) -> int:
""" Number of actions that need to be performed to complete the game. """
assert self.chaining.min_length == self.chaining.max_length
return self.chaining.min_length
@quest_length.setter
def quest_length(self, value: int) -> None:
self.chaining.min_length = value
self.chaining.max_length = value
self.chaining.max_depth = value
@property
def quest_breadth(self) -> int:
""" Number of subquests per independent quest. It controls how nonlinear
a quest can be (1 means linear).
"""
assert self.chaining.min_breadth == self.chaining.max_breadth
return self.chaining.min_breadth
@quest_breadth.setter
def quest_breadth(self, value: int) -> None:
self.chaining.min_breadth = value
self.chaining.max_breadth = value
@property
def seeds(self):
""" Seeds for the different generation processes.
* If `None`, seeds will be sampled from
:py:data:`textworld.g_rng <textworld.utils.g_rng>`.
* If `int`, it acts as a seed for a random generator that will be
used to sample the other seeds.
* If dict, the following keys can be set:
* `'map'`: control the map generation;
* `'objects'`: control the type of objects and their
location;
* `'quest'`: control the quest generation;
* `'grammar'`: control the text generation.
For any key missing, a random number gets assigned (sampled
from :py:data:`textworld.g_rng <textworld.utils.g_rng>`).
"""
if self._seeds is None:
self.seeds = {} # Generate seeds from g_rng.
return self._seeds
@seeds.setter
def seeds(self, value: Union[int, Mapping[str, int]]) -> None:
keys = ['map', 'objects', 'quest', 'grammar']
def _key_missing(seeds):
return not set(seeds.keys()).issuperset(keys)
seeds = value
if type(value) is int:
rng = RandomState(value)
seeds = {}
elif _key_missing(value):
rng = g_rng.next()
# Check if we need to generate missing seeds.
self._seeds = {}
for key in keys:
if key in seeds:
self._seeds[key] = seeds[key]
else:
self._seeds[key] = rng.randint(65635)
@property
def rngs(self) -> Dict[str, RandomState]:
rngs = {}
for key, seed in self.seeds.items():
rngs[key] = RandomState(seed)
return rngs
@property
def kb(self) -> KnowledgeBase:
""" The knowledge base containing the logic and the text grammars (see
:py:class:`textworld.generator.KnowledgeBase <textworld.generator.data.KnowledgeBase>`
for more information).
"""
if self._kb is None:
self.kb = KnowledgeBase.load()
return self._kb
@kb.setter
def kb(self, value: KnowledgeBase) -> None:
self._kb = value
self.chaining.kb = self._kb
[docs] def copy(self) -> "GameOptions":
return copy.copy(self)
@property
def uuid(self) -> str:
# TODO: generate uuid from chaining options?
uuid = "tw-{specs}-{grammar}-{seeds}"
uuid = uuid.format(specs=encode_seeds((self.nb_rooms, self.nb_objects, self.nb_parallel_quests,
self.chaining.min_length, self.chaining.max_length,
self.chaining.min_depth, self.chaining.max_depth,
self.chaining.min_breadth, self.chaining.max_breadth)),
grammar=self.grammar.uuid,
seeds=encode_seeds([self.seeds[k] for k in sorted(self._seeds)]))
return uuid
def __str__(self) -> str:
infos = ["-= Game options =-"]
slots = ["nb_rooms", "nb_objects", "nb_parallel_quests", "path", "force_recompile", "file_ext", "seeds"]
for slot in slots:
infos.append("{}: {}".format(slot, getattr(self, slot)))
text = "\n ".join(infos)
text += "\n chaining options:\n"
text += textwrap.indent(str(self.chaining), " ")
text += "\n grammar options:\n"
text += textwrap.indent(str(self.grammar), " ")
text += "\n KB:\n"
text += textwrap.indent(str(self.kb), " ")
return text