# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
import sys
import textwrap
from io import StringIO
from typing import List, Optional, Dict, Any, Tuple, Union
import numpy as np
from termcolor import colored
import textworld
from textworld import EnvInfos
from textworld.envs.wrappers import Filter, GenericEnvironment, Limit
from textworld.envs.batch import AsyncBatchEnv, SyncBatchEnv
from textworld.gym.envs.utils import shuffled_cycle
from functools import partial
def _make_env(request_infos, max_episode_steps=None, wrappers=[]):
env = GenericEnvironment(request_infos)
if max_episode_steps:
env = Limit(env, max_episode_steps=max_episode_steps)
for wrapper in wrappers + [Filter]:
env = wrapper(env)
return env
[docs]class TextworldBatchGymEnv:
def __init__(self,
gamefiles: List[str],
request_infos: Optional[EnvInfos] = None,
batch_size: int = 1,
asynchronous: bool = True,
auto_reset: bool = False,
max_episode_steps: Optional[int] = None,
wrappers: List[textworld.core.Wrapper] = []) -> None:
""" Environment for playing text-based games in batch.
Arguments:
gamefiles:
Paths of every game composing the pool (`*.ulx|*.z[1-8]|*.json`).
request_infos:
For customizing the information returned by this environment
(see :py:class:`textworld.EnvInfos <textworld.core.EnvInfos>`
for the list of available information).
.. warning:: Only supported for TextWorld games (i.e., that have a corresponding `*.json` file).
batch_size:
If provided, it indicates the number of games to play at the same time.
By default, a single game is played at once.
.. warning:: When `batch_size` is provided (even for batch_size=1), `env.step` expects
a list of commands as input and outputs a list of states. `env.reset` also
outputs a list of states.
asynchronous:
If `True`, wraps the environments in an `AsyncBatchEnv` (which uses
`multiprocessing` to run the environments in parallel). If `False`,
wraps the environments in a `SyncBatchEnv`. Default: `True`.
auto_reset:
If `True`, each game *independently* resets once it is done (i.e., reset happens
on the next `env.step` call).
Otherwise, once a game is done, subsequent calls to `env.step` won't have any effects.
max_episode_steps:
Number of steps allocated to play each game. Once exhausted, the game is done.
"""
self.gamefiles = gamefiles
self.batch_size = batch_size
self.request_infos = request_infos or EnvInfos()
self.seed(1234)
env_fns = [partial(_make_env, self.request_infos, max_episode_steps, wrappers) for _ in range(self.batch_size)]
BatchEnvType = AsyncBatchEnv if self.batch_size > 1 and asynchronous else SyncBatchEnv
self.batch_env = BatchEnvType(env_fns, auto_reset)
[docs] def seed(self, seed: Optional[int] = None) -> List[int]:
""" Set the seed for this environment's random generator(s).
This environment use a random generator to shuffle the order in which
the games are played.
Arguments:
seed: Number that will be used to seed the random generators.
Returns:
All the seeds used to set this environment's random generator(s).
"""
# We shuffle the order in which the game will be seen.
rng = np.random.RandomState(seed)
gamefiles = list(self.gamefiles) # Soft copy to avoid shuffling original list.
rng.shuffle(gamefiles)
# Prepare iterator used for looping through the games.
self._gamefiles_iterator = shuffled_cycle(gamefiles, rng=rng)
return [seed]
[docs] def reset(self) -> Tuple[List[str], Dict[str, List[Any]]]:
""" Resets the text-based environment.
Resetting this environment means starting the next game in the pool.
Returns:
A tuple (observations, infos) where
* observation: text observed in the initial state for each game in the batch;
* infos: additional information as requested for each game in the batch.
"""
if self.batch_env is not None:
self.batch_env.close()
gamefiles = [next(self._gamefiles_iterator) for _ in range(self.batch_size)]
self.batch_env.load(gamefiles)
self.last_commands = [None] * self.batch_size
self.obs, infos = self.batch_env.reset()
return self.obs, infos
[docs] def skip(self, nb_games: int = 1) -> None:
""" Skip games.
Arguments:
nb_games: Number of games to skip.
"""
for _ in range(nb_games):
next(self._gamefiles_iterator)
[docs] def step(self, commands) -> Tuple[List[str], List[float], List[bool], Dict[str, List[Any]]]:
""" Runs a command in each text-based environment of the batch.
Arguments:
commands: Text command to send to the game interpreter.
Returns:
A tuple (observations, scores, dones, infos) where
* observations: text observed in the new state for each game in the batch;
* scores: total number of points accumulated so far for each game in the batch;
* dones: whether each game in the batch is finished or not;
* infos: additional information as requested for each game in the batch.
"""
assert isinstance(commands, (list, tuple)), "Expected a list of commands."
self.last_commands = commands
self.obs, scores, dones, infos = self.batch_env.step(self.last_commands)
return self.obs, scores, dones, infos
[docs] def close(self) -> None:
""" Close this environment. """
if self.batch_env is not None:
self.batch_env.close()
self.batch_env = None
[docs] def render(self, mode: str = 'human') -> Optional[Union[StringIO, str]]:
""" Renders the current state of each environment in the batch.
Each rendering is composed of the previous text command (if there's one) and
the text describing the current observation.
Arguments:
mode:
Controls where and how the text is rendered. Supported modes are:
* human: Display text to the current display or terminal and
return nothing.
* ansi: Return a `StringIO` containing a terminal-style
text representation. The text can include newlines and ANSI
escape sequences (e.g. for colors).
* text: Return a string (`str`) containing the text without
any ANSI escape sequences.
Returns:
Depending on the `mode`, this method returns either nothing, a
string, or a `StringIO` object.
"""
outfile = StringIO() if mode in ['ansi', "text"] else sys.stdout
renderings = []
for last_command, ob in zip(self.last_commands, self.obs):
msg = ob.rstrip() + "\n"
if last_command is not None:
command = "> " + last_command
if mode in ["ansi", "human"]:
command = colored(command, "yellow")
msg = command + "\n" + msg
if mode == "human":
# Wrap each paragraph at 80 characters.
paragraphs = msg.split("\n")
paragraphs = ["\n".join(textwrap.wrap(paragraph, width=80)) for paragraph in paragraphs]
msg = "\n".join(paragraphs)
renderings.append(msg)
outfile.write("\n-----\n".join(renderings) + "\n")
if mode == "text":
outfile.seek(0)
return outfile.read()
if mode == 'ansi':
return outfile