# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
import glob
import re
import warnings
from os.path import join as pjoin
from collections import OrderedDict, defaultdict
from typing import Any, Optional, Mapping, List, Tuple, Container, Union
from numpy.random import RandomState
import textworld
from textworld import g_rng
from textworld.utils import uniquify
from textworld.generator.data import KnowledgeBase
from textworld.textgen import TextGrammar
NB_EXPANSION_RETRIES = 20
[docs]def fix_determinant(var):
var = var.replace(" ", " ")
var = var.replace(" a a", " an a")
var = var.replace(" a e", " an e")
var = var.replace(" a i", " an i")
var = var.replace(" a o", " an o")
var = var.replace(" a u", " an u")
var = var.replace(" A a", " An a")
var = var.replace(" A e", " An e")
var = var.replace(" A i", " An i")
var = var.replace(" A o", " An o")
var = var.replace(" A u", " An u")
return var
[docs]class MissingTextGrammar(NameError):
def __init__(self, path):
msg = "Cannot find any theme files: {path}."
super().__init__(msg.format(path=path))
[docs]class GrammarOptions:
__slots__ = ['theme', 'names_to_exclude', 'include_adj', 'blend_descriptions',
'ambiguous_instructions', 'only_last_action',
'blend_instructions',
'allowed_variables_numbering', 'unique_expansion']
def __init__(self, options=None, **kwargs):
if isinstance(options, GrammarOptions):
options = options.serialize()
options = options or kwargs
#: str: Grammar theme's name. All `*.twg` files starting with that name will be loaded.
self.theme = options.get("theme", "house")
#: List[str]: List of names the text generation should not use.
self.names_to_exclude = list(options.get("names_to_exclude", []))
#: bool: Append numbers after an object name if there is not enough variation for it.
self.allowed_variables_numbering = options.get("allowed_variables_numbering", False)
#: bool: When True, #symbol# are force to be expanded to unique text.
self.unique_expansion = options.get("unique_expansion", False)
#: bool: When True, object names can be preceeded by an adjective.
self.include_adj = options.get("include_adj", False)
#: bool: When True, only the last action of a quest will be described
#: in the generated objective.
self.only_last_action = options.get("only_last_action", False)
#: bool: When True, consecutive actions to be accomplished might be
#: described in a single sentence rather than separate ones.
self.blend_instructions = options.get("blend_instructions", False)
#: bool: When True, objects sharing some properties might be described
#: in a single sentence rather than separate consecutive ones.
self.blend_descriptions = options.get("blend_descriptions", False)
#: bool: When True, in the game objective, objects of interest might
#: be refer to by their type or adjective rather than full name.
self.ambiguous_instructions = options.get("ambiguous_instructions", False)
[docs] def serialize(self) -> Mapping:
""" Serialize this object.
Results:
GrammarOptions's data serialized to be JSON compatible.
"""
return {slot: getattr(self, slot) for slot in self.__slots__}
[docs] @classmethod
def deserialize(cls, data: Mapping) -> "GrammarOptions":
""" Creates a `GrammarOptions` from serialized data.
Args:
data: Serialized data with the needed information to build a
`GrammarOptions` object.
"""
return cls(data)
[docs] def copy(self) -> "GrammarOptions":
return GrammarOptions.deserialize(self.serialize())
def __eq__(self, other) -> bool:
return (isinstance(other, GrammarOptions)
and all(getattr(self, slot) == getattr(other, slot) for slot in self.__slots__))
@property
def uuid(self) -> str:
""" Generate UUID for this set of grammar options. """
def _unsigned(n):
return n & 0xFFFFFFFFFFFFFFFF
# Skip theme and names_to_exclude.
values = [int(getattr(self, s)) for s in self.__slots__[2:]]
option = "".join(map(str, values))
from hashids import Hashids
hashids = Hashids(salt="TextWorld")
if len(self.names_to_exclude) > 0:
names_to_exclude_hash = _unsigned(hash(frozenset(self.names_to_exclude)))
return self.theme + "-" + hashids.encode(names_to_exclude_hash) + "-" + hashids.encode(int(option))
return self.theme + "-" + hashids.encode(int(option))
def __str__(self) -> str:
infos = []
for slot in self.__slots__:
infos.append("{}: {}".format(slot, getattr(self, slot)))
return "\n".join(infos)
[docs]class Grammar:
"""
Context-Free Grammar for text generation.
"""
_cache = {}
def __init__(self, options: Union[GrammarOptions, Mapping[str, Any]] = {}, rng: Optional[RandomState] = None,
kb: Optional[KnowledgeBase] = None):
"""
Arguments:
options:
For customizing text generation process (see
:py:class:`textworld.generator.GrammarOptions <textworld.generator.text_grammar.GrammarOptions>`
for the list of available options).
rng:
Random generator used for sampling tag expansions.
"""
self.options = GrammarOptions(options)
self.grammar = OrderedDict()
self.rng = g_rng.next() if rng is None else rng
self.allowed_variables_numbering = self.options.allowed_variables_numbering
self.unique_expansion = self.options.unique_expansion
self.all_expansions = defaultdict(list)
# The current used symbols
self.overflow_dict = OrderedDict()
self.used_names = set(self.options.names_to_exclude)
# Load the grammar associated to the provided theme.
self.theme = self.options.theme
# Load the object names file
path = pjoin(KnowledgeBase.default().text_grammars_path, glob.escape(self.theme) + "*.twg")
files = glob.glob(path)
if kb is not None:
path = pjoin(kb.text_grammars_path, glob.escape(self.theme) + "*.twg")
files += glob.glob(path)
if len(files) == 0:
raise MissingTextGrammar(path)
self.grammar_files = files
for filename in files:
self._parse(filename)
def __eq__(self, other):
return (isinstance(other, Grammar)
and self.overflow_dict == other.overflow_dict
and self.grammar == other.grammar
and self.options.uuid == other.options.uuid
and self.used_names == other.used_names)
def _parse(self, path: str):
"""
Parse lines and add them to the grammar.
"""
if path not in self._cache:
with open(path) as f:
self._cache[path] = TextGrammar.parse(f.read(), filename=path)
for name, rule in self._cache[path].rules.items():
self.grammar["#" + name + "#"] = rule
[docs] def has_tag(self, tag: str) -> bool:
"""
Check if the grammar has a given tag.
"""
return tag in self.grammar
[docs] def get_random_expansion(self, tag: str, rng: Optional[RandomState] = None) -> str:
"""
Return a randomly chosen expansion for the given tag.
Parameters
----------
tag :
Grammar tag to be expanded.
rng : optional
Random generator used to chose an expansion when there is many.
By default, it used the random generator of this grammar object.
Returns
-------
expansion :
An expansion chosen randomly for the provided tag.
"""
rng = rng or self.rng
if not self.has_tag(tag):
raise ValueError("Tag: {} does not exist!".format(tag))
for _ in range(NB_EXPANSION_RETRIES):
expansion = rng.choice(self.grammar[tag].alternatives)
expansion = expansion.full_form()
if not self.unique_expansion or expansion not in self.all_expansions[tag]:
break
self.all_expansions[tag].append(expansion)
return expansion
[docs] def expand(self, text: str, rng: Optional[RandomState] = None) -> str:
"""
Expand some text until there is no more tag to expand.
Parameters
----------
text :
Text potentially containing grammar tags to be expanded.
rng : optional
Random generator used to chose an expansion when there is many.
By default, it used the random generator of this grammar object.
Returns
-------
expanded_text :
Resulting text in which there is no grammar tag left to be expanded.
"""
rng = self.rng if rng is None else rng
while "#" in text:
to_replace = re.findall(r'[#][^#]*[#]', text)
tag = self.rng.choice(to_replace)
replacement = self.get_random_expansion(tag, rng)
text = text.replace(tag, replacement)
return text
[docs] def split_name_adj_noun(self, candidate: str, include_adj: bool) -> Optional[Tuple[str, str, str]]:
"""
Extract the full name, the adjective and the noun from a string.
Parameters
----------
candidate :
String that may contain one adjective-noun sperator '|'.
include_adj : optional
If True, the name can contain a generated adjective.
If False, any generated adjective will be discarded.
Returns
-------
name :
The whole name, i.e. `adj + " " + noun`.
adj :
The adjective part of the name.
noun :
The noun part of the name.
"""
parts = candidate.split("|")
noun = parts[-1].strip()
if len(parts) == 1 or not include_adj:
adj = None
elif len(parts) == 2:
adj = parts[0].strip()
else:
raise ValueError("Too many separators '|' in '{}'".format(candidate))
name = adj + " " + noun if adj is not None else noun
return name, adj, noun
[docs] def generate_name(self, obj_type: str, room_type: str = "",
include_adj: Optional[bool] = None, exclude: Container[str] = []) -> Tuple[str, str, str]:
"""
Generate a name given an object type and the type room it belongs to.
Parameters
----------
obj_type :
Type of the object for which we will generate a name.
room_type : optional
Type of the room the object belongs to.
include_adj : optional
If True, the name can contain a generated adjective.
If False, any generated adjective will be discarded.
Default: use value grammar.options.include_adj
exclude : optional
List of names we should avoid generating.
Returns
-------
name :
The whole name, i.e. `adj + " " + noun`.
adj :
The adjective part of the name.
noun :
The noun part of the name.
"""
if include_adj is None:
include_adj = self.options.include_adj
# Get room-specialized name, if possible.
symbol = "#{}_({})#".format(room_type, obj_type)
if not self.has_tag(symbol):
# Otherwise, fallback on the generic object names.
symbol = "#({})#".format(obj_type)
# We don't want to generate a name that is in `exclude`.
found_candidate = False
for i in range(50): # We default to fifty attempts
candidate = self.expand(symbol)
name, adj, noun = self.split_name_adj_noun(candidate, include_adj)
if name not in exclude:
found_candidate = True
break
if not found_candidate:
# Not enough variation for the object we want to name.
# Warn the user and fall back on adding an adjective if we can.
if not include_adj:
name, adj, noun = self.generate_name(obj_type, room_type, include_adj=True, exclude=exclude)
msg = ("Not enough variation for '{}'. Falling back on using adjective '{}'."
" To avoid this message you can add more variation in the '{}'"
" related grammar files located in '{}'.")
msg = msg.format(symbol, adj, self.theme, KnowledgeBase.default().text_grammars_path)
warnings.warn(msg, textworld.GenerationWarning)
return name, adj, noun
# Still not enough variation for the object we want to name.
if not self.allowed_variables_numbering:
msg = ("Not enough variation for '{}'. You can add more variation"
" in the '{}' related grammar files located in '{}'"
" or turn on the 'include_adj=True' grammar flag."
" In last resort, you could always turn on the"
" 'allowed_variables_numbering=True' grammar flag"
" to append unique number to object name.")
msg = msg.format(symbol, self.theme, KnowledgeBase.default().text_grammars_path)
raise ValueError(msg)
if obj_type not in self.overflow_dict:
self.overflow_dict[obj_type] = []
# Append unique (per type) number to the noun.
suffix = " {}".format(len(self.overflow_dict[obj_type]))
noun += suffix
name += suffix
self.overflow_dict[obj_type].append(name)
return name, adj, noun
[docs] def get_vocabulary(self) -> List[str]:
seen = set()
all_words = set()
pattern = re.compile(r'[#][^#]*[#]')
i7_pattern = re.compile(r'\[[^]]*\]')
tw_pattern = re.compile(r'\((obj|name[^)]*|action|list_of_actions)\)')
to_expand = list(self.grammar.keys())
while len(to_expand) > 0:
tag = to_expand.pop()
if tag in seen:
continue
seen.add(tag)
# Remove i7 code snippets.
tag = i7_pattern.sub(" ", tag)
# Remove all TW placeholders.
tag = tw_pattern.sub(" ", tag)
words = tag.split()
for word in words:
if pattern.search(word):
for to_replace in pattern.findall(word):
for alternative in self.grammar[to_replace].alternatives:
to_expand.append(word.replace(to_replace, alternative.full_form()))
else:
all_words.add(word)
return sorted(all_words)
[docs] def get_all_expansions_for_tag(self, tag: str, max_depth: int = 500) -> List[str]:
"""
Get all possible expansions for a grammar tag.
Parameters
----------
tag :
Grammar tag to be expanded.
max_depth : optional
Maximum recursion depth when expanding tag.
Returns
-------
expansions :
All possible expansions.
"""
if tag not in self.grammar:
return []
variants = []
# Recursively get all symbol possibilities
def _iterate(tag, depth):
if "#" in tag and depth < max_depth:
depth += 1
to_replace = re.findall(r'[#][^#]*[#]', tag)
for replace in to_replace:
for rhs in self.grammar[replace].alternatives:
_iterate(tag.replace(replace, rhs.full_form()), depth)
else:
variants.append(tag)
_iterate(tag, 0)
return variants
[docs] def get_all_expansions_for_type(self, type: str):
"""
Get all possible expansions for a given object type.
Parameters
----------
type :
Object type.
Returns
-------
names :
All possible names.
"""
expansions = self.get_all_expansions_for_tag("#({})#".format(type))
for room_type in self.grammar["#room_type#"].alternatives:
expansions += self.get_all_expansions_for_tag("#{}_({})#".format(room_type.full_form(), type))
return uniquify(expansions)
[docs] def get_all_names_for_type(self, type: str, include_adj: True):
"""
Get all possible names for a given object type.
Parameters
----------
type :
Object type.
include_adj : optional
If True, names can contain generated adjectives.
If False, any generated adjectives will be discarded.
Returns
-------
names :
All possible names sorted in alphabetical order.
"""
expansions = self.get_all_expansions_for_type(type)
names = [self.split_name_adj_noun(expansion, include_adj)[0] for expansion in expansions]
return sorted(set(names))
[docs] def get_all_adjective_for_type(self, type: str):
"""
Get all possible adjectives for a given object type.
Parameters
----------
type :
Object type.
Returns
-------
adjectives :
All possible adjectives sorted in alphabetical order.
"""
expansions = self.get_all_expansions_for_type(type)
adjectives = [self.split_name_adj_noun(expansion, include_adj=True)[1] for expansion in expansions]
return sorted(set(adjectives))
[docs] def get_all_nouns_for_type(self, type: str):
"""
Get all possible nouns for a given object type.
Parameters
----------
type :
Object type.
Returns
-------
nouns :
All possible nouns sorted in alphabetical order.
"""
expansions = self.get_all_expansions_for_type(type)
nouns = [self.split_name_adj_noun(expansion, include_adj=False)[2] for expansion in expansions]
return sorted(set(nouns))
[docs] def check(self) -> bool:
"""
Check if this grammar is valid.
TODO: use logging mechanism to report warnings and errors.
"""
errors_found = False
for symbol in self.grammar:
if len(self.grammar[symbol].alternatives) == 0:
print("[Warning] Symbol {} has empty tags".format(symbol))
for tag in self.grammar[symbol].alternatives:
tag = tag.full_form()
if tag == "":
print("[Warning] Symbol {} has empty tags".format(symbol))
for symb in re.findall(r'[#][^#]*[#]', tag):
if symb not in self.grammar:
print("[Error] Symbol {} not found in grammar (Occurs in expansion of {})".format(symb, symbol))
errors_found = True
return not errors_found