# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
from collections import deque, namedtuple
import numpy as np
import re
from typing import List
from textworld.logic import Placeholder, Variable
_WHITESPACE = re.compile(r"\s+")
_ID = re.compile(r"[\w/']+")
_PUNCT = ["::", ":", "$", "(", ")", ",", "&", "->"]
_Token = namedtuple("_Token", ("type", "value"))
def _tokenize(expr):
"""
Helper tokenizer for logical expressions.
"""
tokens = deque()
i = 0
while i < len(expr):
m = _WHITESPACE.match(expr, i)
if m:
i = m.end()
continue
m = _ID.match(expr, i)
if m:
tokens.append(_Token("id", m.group()))
i = m.end()
continue
for punct in _PUNCT:
end = i + len(punct)
chunk = expr[i:end]
if chunk == punct:
tokens.append(_Token(chunk, chunk))
i = end
break
else:
raise ValueError("Unexpected character `{}`.".format(expr[i]))
return tokens
def _lookahead(tokens, type):
return tokens and tokens[0].type == type
def _expect(tokens, type):
if type == "id":
human_type = "an identifier"
else:
human_type = "`{}`".format(type)
if not tokens:
raise ValueError("Expected {}; found end of input.".format(human_type))
if tokens[0].type != type:
raise ValueError("Expected {}; found `{}`.".format(human_type, tokens[0].value))
return tokens.popleft()
[docs]class NotEnoughNounsError(NameError):
pass
[docs]class VariableType:
def __init__(self, type, name, parent=None):
self.type = type
self.name = name
self.parent = parent
self.children = []
# If the type starts with an upper case letter, it is a constant.
self.is_constant = self.type[0] == self.type.upper()[0]
[docs] @classmethod
def parse(cls, expr: str) -> "VariableType":
"""
Parse a variable type expression.
Parameters
----------
expr :
The string to parse, in the form `name: type -> parent1 & parent2`
or `name: type` for root node.
"""
tokens = _tokenize(expr)
name = _expect(tokens, "id").value
_expect(tokens, ":")
type = _expect(tokens, "id").value
parent = None
if _lookahead(tokens, "->"):
tokens.popleft()
parent = _expect(tokens, "id").value
return cls(type, name, parent)
def __eq__(self, other):
return (isinstance(other, VariableType)
and self.name == other.name
and self.type == other.type
and self.parent == other.parent)
def __str__(self):
signature = "{}: {}".format(self.name, self.type)
if self.parent is not None:
signature += " -> " + self.parent
return signature
[docs] def serialize(self) -> str:
return str(self)
[docs] @classmethod
def deserialize(cls, data: str) -> "VariableType":
return cls.parse(data)
[docs]def parse_variable_types(content: str):
"""
Parse a list VariableType expressions.
"""
vtypes = []
for line in content.split("\n"):
line = line.strip()
if line.startswith("#") or line == "":
continue
vtypes.append(VariableType.parse(line))
return vtypes
[docs]class VariableTypeTree:
"""
Manages hierarchy of types defined in ./grammars/variables.txt.
Used for extending the rules.
"""
CHEST = 'c'
SUPPORTER = 's'
CLASS_HOLDER = [CHEST, SUPPORTER]
def __init__(self, vtypes: List[VariableType]):
self.variables_types = {vtype.type: vtype for vtype in vtypes}
# Make some convenient attributes.
self.types = [vt.type for vt in vtypes]
self.names = [vt.name for vt in vtypes]
self.constants = [t for t in self if self.is_constant(t)]
self.variables = [t for t in self if not self.is_constant(t)]
self.constants_mapping = {Placeholder(c): Variable(c) for c in self.constants}
# Adjust variable type's parent and children references.
for vt in vtypes:
if vt.parent is not None:
vt_parent = self[vt.parent]
vt_parent.children.append(vt.type)
[docs] @classmethod
def load(cls, path: str):
"""
Read variables from text file.
"""
with open(path) as f:
vtypes = parse_variable_types(f.read())
return cls(vtypes)
def __getitem__(self, vtype):
""" Get VariableType object from its type string. """
vtype = vtype.rstrip("'")
return self.variables_types[vtype]
def __contains__(self, vtype):
vtype = vtype.rstrip("'")
return vtype in self.variables_types
def __iter__(self):
return iter(self.variables_types)
def __len__(self):
return len(self.variables_types)
[docs] def is_constant(self, vtype):
return self[vtype].is_constant
[docs] def descendants(self, vtype):
"""Given a variable type, return all possible descendants."""
if vtype not in self.variables_types:
return []
descendants = []
for child_type in self[vtype].children:
descendants.append(child_type)
descendants += self.descendants(child_type)
return descendants
[docs] def get_description(self, vtype):
if vtype in self.types:
return self.names[self.types.index(vtype)]
else:
return vtype
[docs] def get_ancestors(self, vtype):
""" List all ancestors of a type where the closest ancetors are first. """
vtypes = []
if self[vtype].parent is not None:
vtypes.append(self[vtype].parent)
vtypes.extend(self.get_ancestors(self[vtype].parent))
return vtypes
[docs] def is_descendant_of(self, child, parents):
""" Return if child is a descendant of parent """
if not isinstance(parents, list):
parents = [parents]
for parent in parents:
if child == parent or child in self.descendants(parent):
return True
return False
[docs] def sample(self, parent_type, rng, exceptions=[], include_parent=True, probs=None):
""" Sample an object type given the parent's type. """
types = self.descendants(parent_type)
if include_parent:
types = [parent_type] + types
types = [t for t in types if t not in exceptions]
if probs is not None:
probs = np.array([probs[t] for t in types], dtype="float")
probs /= np.sum(probs)
return rng.choice(types, p=probs)
[docs] def count(self, state):
""" Counts how many objects there are of each type. """
types_counts = {t: 0 for t in self}
for var in state.variables:
if self.is_constant(var.type):
continue
if "_" not in var.name:
continue
cpt = int(var.name.split("_")[-1])
var_type = var.type
types_counts[var_type] = max(cpt + 1, types_counts[var_type])
return types_counts
[docs] def serialize(self) -> List:
return [vtype.serialize() for vtype in self.variables_types.values()]
[docs] @classmethod
def deserialize(cls, data: List) -> "VariableTypeTree":
vtypes = [VariableType.deserialize(d) for d in data]
return cls(vtypes)
[docs]def get_new(type, types_counts, max_types_counts=None):
""" Get the next available id for a given type. """
if max_types_counts is not None and types_counts[type] >= max_types_counts[type]:
raise NotEnoughNounsError()
new_id = "{}_{}".format(type, types_counts[type])
types_counts[type] += 1
return new_id