Source code for textworld.generator.vtypes

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.


from collections import deque, namedtuple
import numpy as np
import re
from typing import List

from textworld.logic import Placeholder, Variable


_WHITESPACE = re.compile(r"\s+")
_ID = re.compile(r"[\w/']+")
_PUNCT = ["::", ":", "$", "(", ")", ",", "&", "->"]

_Token = namedtuple("_Token", ("type", "value"))


def _tokenize(expr):
    """
    Helper tokenizer for logical expressions.
    """

    tokens = deque()

    i = 0
    while i < len(expr):
        m = _WHITESPACE.match(expr, i)
        if m:
            i = m.end()
            continue

        m = _ID.match(expr, i)
        if m:
            tokens.append(_Token("id", m.group()))
            i = m.end()
            continue

        for punct in _PUNCT:
            end = i + len(punct)
            chunk = expr[i:end]
            if chunk == punct:
                tokens.append(_Token(chunk, chunk))
                i = end
                break
        else:
            raise ValueError("Unexpected character `{}`.".format(expr[i]))

    return tokens


def _lookahead(tokens, type):
    return tokens and tokens[0].type == type


def _expect(tokens, type):
    if type == "id":
        human_type = "an identifier"
    else:
        human_type = "`{}`".format(type)

    if not tokens:
        raise ValueError("Expected {}; found end of input.".format(human_type))

    if tokens[0].type != type:
        raise ValueError("Expected {}; found `{}`.".format(human_type, tokens[0].value))

    return tokens.popleft()


[docs]class NotEnoughNounsError(NameError): pass
[docs]class VariableType: def __init__(self, type, name, parent=None): self.type = type self.name = name self.parent = parent self.children = [] # If the type starts with an upper case letter, it is a constant. self.is_constant = self.type[0] == self.type.upper()[0]
[docs] @classmethod def parse(cls, expr: str) -> "VariableType": """ Parse a variable type expression. Parameters ---------- expr : The string to parse, in the form `name: type -> parent1 & parent2` or `name: type` for root node. """ tokens = _tokenize(expr) name = _expect(tokens, "id").value _expect(tokens, ":") type = _expect(tokens, "id").value parent = None if _lookahead(tokens, "->"): tokens.popleft() parent = _expect(tokens, "id").value return cls(type, name, parent)
def __eq__(self, other): return (isinstance(other, VariableType) and self.name == other.name and self.type == other.type and self.parent == other.parent) def __str__(self): signature = "{}: {}".format(self.name, self.type) if self.parent is not None: signature += " -> " + self.parent return signature
[docs] def serialize(self) -> str: return str(self)
[docs] @classmethod def deserialize(cls, data: str) -> "VariableType": return cls.parse(data)
[docs]def parse_variable_types(content: str): """ Parse a list VariableType expressions. """ vtypes = [] for line in content.split("\n"): line = line.strip() if line.startswith("#") or line == "": continue vtypes.append(VariableType.parse(line)) return vtypes
[docs]class VariableTypeTree: """ Manages hierarchy of types defined in ./grammars/variables.txt. Used for extending the rules. """ CHEST = 'c' SUPPORTER = 's' CLASS_HOLDER = [CHEST, SUPPORTER] def __init__(self, vtypes: List[VariableType]): self.variables_types = {vtype.type: vtype for vtype in vtypes} # Make some convenient attributes. self.types = [vt.type for vt in vtypes] self.names = [vt.name for vt in vtypes] self.constants = [t for t in self if self.is_constant(t)] self.variables = [t for t in self if not self.is_constant(t)] self.constants_mapping = {Placeholder(c): Variable(c) for c in self.constants} # Adjust variable type's parent and children references. for vt in vtypes: if vt.parent is not None: vt_parent = self[vt.parent] vt_parent.children.append(vt.type)
[docs] @classmethod def load(cls, path: str): """ Read variables from text file. """ with open(path) as f: vtypes = parse_variable_types(f.read()) return cls(vtypes)
def __getitem__(self, vtype): """ Get VariableType object from its type string. """ vtype = vtype.rstrip("'") return self.variables_types[vtype] def __contains__(self, vtype): vtype = vtype.rstrip("'") return vtype in self.variables_types def __iter__(self): return iter(self.variables_types) def __len__(self): return len(self.variables_types)
[docs] def is_constant(self, vtype): return self[vtype].is_constant
[docs] def descendants(self, vtype): """Given a variable type, return all possible descendants.""" if vtype not in self.variables_types: return [] descendants = [] for child_type in self[vtype].children: descendants.append(child_type) descendants += self.descendants(child_type) return descendants
[docs] def get_description(self, vtype): if vtype in self.types: return self.names[self.types.index(vtype)] else: return vtype
[docs] def get_ancestors(self, vtype): """ List all ancestors of a type where the closest ancetors are first. """ vtypes = [] if self[vtype].parent is not None: vtypes.append(self[vtype].parent) vtypes.extend(self.get_ancestors(self[vtype].parent)) return vtypes
[docs] def is_descendant_of(self, child, parents): """ Return if child is a descendant of parent """ if not isinstance(parents, list): parents = [parents] for parent in parents: if child == parent or child in self.descendants(parent): return True return False
[docs] def sample(self, parent_type, rng, exceptions=[], include_parent=True, probs=None): """ Sample an object type given the parent's type. """ types = self.descendants(parent_type) if include_parent: types = [parent_type] + types types = [t for t in types if t not in exceptions] if probs is not None: probs = np.array([probs[t] for t in types], dtype="float") probs /= np.sum(probs) return rng.choice(types, p=probs)
[docs] def count(self, state): """ Counts how many objects there are of each type. """ types_counts = {t: 0 for t in self} for var in state.variables: if self.is_constant(var.type): continue if "_" not in var.name: continue cpt = int(var.name.split("_")[-1]) var_type = var.type types_counts[var_type] = max(cpt + 1, types_counts[var_type]) return types_counts
[docs] def serialize(self) -> List: return [vtype.serialize() for vtype in self.variables_types.values()]
[docs] @classmethod def deserialize(cls, data: List) -> "VariableTypeTree": vtypes = [VariableType.deserialize(d) for d in data] return cls(vtypes)
[docs]def get_new(type, types_counts, max_types_counts=None): """ Get the next available id for a given type. """ if max_types_counts is not None and types_counts[type] >= max_types_counts[type]: raise NotEnoughNounsError() new_id = "{}_{}".format(type, types_counts[type]) types_counts[type] += 1 return new_id