From b79429c010488c835a3a31ad0e8951101c2dd02d Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Tue, 5 May 2026 11:44:37 +0200 Subject: [PATCH 1/8] base agent allowing parallel simultaneous environment interaction --- netsecgame/agents/parallel_base_agent.py | 462 +++++++++++++++++++++++ 1 file changed, 462 insertions(+) create mode 100644 netsecgame/agents/parallel_base_agent.py diff --git a/netsecgame/agents/parallel_base_agent.py b/netsecgame/agents/parallel_base_agent.py new file mode 100644 index 00000000..484bfcf8 --- /dev/null +++ b/netsecgame/agents/parallel_base_agent.py @@ -0,0 +1,462 @@ +# Author: Ondrej Lukas, ondrej.lukas@aic.cvut.cz +# Parallel agent class that manages connections to multiple game server instances simultaneously. +import logging +import socket +import json +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Optional, Tuple, Dict, Any, List, Union, overload + +from netsecgame.game_components import ( + Action, GameState, Observation, ActionType, + GameStatus, AgentInfo, ProtocolConfig, AgentRole, +) + + +class ParallelBaseAgent: + """ + Agent that manages connections to multiple NetSecGame server instances + simultaneously, enabling parallel environment interaction. + + Unlike BaseAgent (which manages a single socket), this class maintains + one TCP socket per environment and exposes vectorized versions of + ``register()``, ``make_step()``, and ``request_game_reset()`` that + operate on lists of actions/observations. + + Args: + game_hosts: Host address(es). A single string is broadcast to all + ``game_ports``. A list must either have length 1 (broadcast) + or match the length of ``game_ports``. + game_ports: Port(s) — one per environment instance. A single int + creates a one-environment agent. + role: The agent role shared across all environments. + max_workers: Maximum threads in the pool. Defaults to ``len(game_ports)``. + """ + + def __init__( + self, + game_hosts: str | List[str], + game_ports: int | List[int], + role: AgentRole, + max_workers: Optional[int] = None, + ) -> None: + # ------------------------------------------------------------------ + # Normalize scalars to lists + # ------------------------------------------------------------------ + if isinstance(game_hosts, str): + game_hosts = [game_hosts] + if isinstance(game_ports, int): + game_ports = [game_ports] + + # ------------------------------------------------------------------ + # Validate & broadcast hosts + # ------------------------------------------------------------------ + if len(game_hosts) == 1 and len(game_ports) > 1: + game_hosts = game_hosts * len(game_ports) + elif len(game_hosts) != len(game_ports): + raise ValueError( + f"game_hosts length ({len(game_hosts)}) must be 1 " + f"(broadcast) or match game_ports length ({len(game_ports)})" + ) + # make sure port numbers are integers + game_ports = [int(port) for port in game_ports] + + + self._envs: List[Tuple[str, int]] = list(zip(game_hosts, game_ports)) + self._num_envs: int = len(self._envs) + self._single_env: bool = self._num_envs == 1 + self._role: AgentRole = role + self._logger: logging.Logger = logging.getLogger(self.__class__.__name__) + + # ------------------------------------------------------------------ + # Open one TCP socket per environment + # ------------------------------------------------------------------ + self._sockets: List[Optional[socket.socket]] = [] + for host, port in self._envs: + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect((host, port)) + self._sockets.append(sock) + self._logger.info(f"Connected to {host}:{port}") + except socket.error as e: + self._logger.warning(f"Failed to connect to {host}:{port}: {e}") + self._sockets.append(None) + + # ------------------------------------------------------------------ + # Done mask & thread pool + # ------------------------------------------------------------------ + self._done_mask: List[bool] = [s is None for s in self._sockets] + self._executor: ThreadPoolExecutor = ThreadPoolExecutor( + max_workers=max_workers or self._num_envs + ) + self._logger.info( + f"ParallelBaseAgent created with {self._num_envs} environments " + f"({sum(self.connected)} connected)" + ) + + # ====================================================================== + # Lifecycle + # ====================================================================== + + def __del__(self) -> None: + """Close any remaining sockets and shut down the thread pool when the + object is garbage-collected.""" + self.terminate_connection() + + def terminate_connection(self) -> None: + """Close all sockets and shut down the thread pool.""" + for i, sock in enumerate(self._sockets): + if sock is not None: + try: + sock.close() + self._logger.info(f"Socket for env {i} ({self._envs[i]}) closed") + except socket.error as e: + self._logger.error( + f"Error closing socket for env {i} ({self._envs[i]}): {e}" + ) + self._sockets[i] = None + try: + self._executor.shutdown(wait=False) + except Exception: + pass + + # ====================================================================== + # Properties + # ====================================================================== + + @property + def num_envs(self) -> int: + """Number of environments (always equals ``len(game_ports)``).""" + return self._num_envs + + @property + def done_mask(self) -> "bool | List[bool]": + """Current done mask. A single ``bool`` in single-env mode, + otherwise a list of bools (one per env).""" + if self._single_env: + return self._done_mask[0] + return list(self._done_mask) + + @property + def all_done(self) -> bool: + """``True`` when every episode has ended.""" + return all(self._done_mask) + + @property + def connected(self) -> List[bool]: + """Per-environment connection status.""" + return [s is not None for s in self._sockets] + + @property + def role(self) -> AgentRole: + return self._role + + @property + def logger(self) -> logging.Logger: + """Returns the logger instance for this agent.""" + return self._logger + + # ====================================================================== + # Private: per-environment communication + # ====================================================================== + + @staticmethod + def _send_data(sock: socket.socket, msg: str, logger: logging.Logger) -> None: + """Send a JSON-encoded message over *sock*.""" + logger.debug(f"Sending: {msg}") + sock.sendall(msg.encode()) + + @staticmethod + def _receive_data( + sock: socket.socket, logger: logging.Logger + ) -> Tuple[GameStatus, Dict[str, Any], Optional[str]]: + """Block until a full response is received from *sock*.""" + data = b"" + while True: + chunk = sock.recv(ProtocolConfig.BUFFER_SIZE) + if not chunk: + break + data += chunk + if ProtocolConfig.END_OF_MESSAGE in data: + break + if ProtocolConfig.END_OF_MESSAGE not in data: + raise ConnectionError("Unfinished connection.") + data = data.replace(ProtocolConfig.END_OF_MESSAGE, b"").decode() + logger.debug(f"Received: {data}") + data_dict = json.loads(data) + status = data_dict.get("status", "") + observation = data_dict.get("observation", {}) + message = data_dict.get("message", None) + return GameStatus.from_string(str(status)), observation, message + + def _communicate_single( + self, env_idx: int, action: Action + ) -> Tuple[GameStatus, Dict[str, Any], Optional[str]]: + """Send *action* to environment *env_idx* and return the response.""" + sock = self._sockets[env_idx] + if sock is None: + raise ConnectionError(f"Socket for env {env_idx} is not connected.") + if not isinstance(action, Action): + raise ValueError("Data should be ONLY of type Action") + self._send_data(sock, action.to_json(), self._logger) + return self._receive_data(sock, self._logger) + + # ------------------------------------------------------------------ + # Single-env operations (run inside worker threads) + # ------------------------------------------------------------------ + + def _register_single(self, env_idx: int) -> Optional[Observation]: + """Register on a single environment.""" + try: + status, obs_dict, message = self._communicate_single( + env_idx, + Action( + ActionType.JoinGame, + parameters={ + "agent_info": AgentInfo( + self.__class__.__name__, self._role.value + ) + }, + ), + ) + if status is GameStatus.CREATED: + self._logger.info( + f"Env {env_idx}: registration successful! {message}" + ) + return Observation( + GameState.from_dict(obs_dict["state"]), + obs_dict["reward"], + obs_dict["end"], + message, + ) + else: + self._logger.error( + f"Env {env_idx}: registration failed " + f"(status: {status}, msg: {message})" + ) + return None + except Exception as e: + self._logger.error(f"Env {env_idx}: exception in register: {e}") + return None + + def _make_step_single( + self, env_idx: int, action: Action + ) -> Optional[Observation]: + """Execute a single step on one environment.""" + try: + _, obs_dict, _ = self._communicate_single(env_idx, action) + if obs_dict: + return Observation( + GameState.from_dict(obs_dict["state"]), + obs_dict["reward"], + obs_dict["end"], + obs_dict["info"], + ) + return None + except Exception as e: + self._logger.error(f"Env {env_idx}: exception in make_step: {e}") + return None + + def _request_game_reset_single( + self, + env_idx: int, + request_trajectory: bool, + randomize_topology: bool, + seed: Optional[int], + ) -> Optional[Observation]: + """Reset a single environment.""" + try: + status, obs_dict, message = self._communicate_single( + env_idx, + Action( + ActionType.ResetGame, + parameters={ + "request_trajectory": request_trajectory, + "randomize_topology": randomize_topology, + "seed": seed, + }, + ), + ) + if status is GameStatus.RESET_DONE: + self._logger.debug(f"Env {env_idx}: reset successful") + return Observation( + GameState.from_dict(obs_dict["state"]), + obs_dict["reward"], + obs_dict["end"], + message, + ) + else: + self._logger.error( + f"Env {env_idx}: reset failed " + f"(status: {status}, msg: {message})" + ) + return None + except Exception as e: + self._logger.error( + f"Env {env_idx}: exception in request_game_reset: {e}" + ) + return None + + # ====================================================================== + # Private: parallel dispatch helper + # ====================================================================== + + def _run_parallel( + self, + fn, + env_indices: Optional[List[int]] = None, + *, + args_per_env: Optional[Dict[int, tuple]] = None, + ) -> List[Any]: + """Submit *fn(env_idx, ...)* for each index and collect results in + order. + + Args: + fn: Callable whose first argument is the environment index. + env_indices: Which envs to dispatch to. Defaults to all connected + envs. + args_per_env: Optional extra positional args per env index. + + Returns: + List of results, one per ``self._num_envs``. Indices not included + in *env_indices* get ``None``. + """ + if env_indices is None: + env_indices = [i for i in range(self._num_envs) if self._sockets[i] is not None] + if args_per_env is None: + args_per_env = {} + + results: List[Any] = [None] * self._num_envs + + # Fast path: skip the thread pool when there is only one env to call + if len(env_indices) == 1: + i = env_indices[0] + extra = args_per_env.get(i, ()) + try: + results[i] = fn(i, *extra) + except Exception as e: + self._logger.error(f"Env {i}: unhandled exception: {e}") + results[i] = None + return results + + future_to_idx = {} + for i in env_indices: + extra = args_per_env.get(i, ()) + future = self._executor.submit(fn, i, *extra) + future_to_idx[future] = i + + for future in as_completed(future_to_idx): + idx = future_to_idx[future] + try: + results[idx] = future.result() + except Exception as e: + self._logger.error(f"Env {idx}: unhandled exception: {e}") + results[idx] = None + + return results + + # ====================================================================== + # Public API + # ====================================================================== + + def register(self) -> "Observation | None | List[Optional[Observation]]": + """Register in all connected environments in parallel. + + Returns: + In single-env mode: the initial ``Observation`` (or ``None``). + In multi-env mode: list of initial observations, positionally + aligned with ``game_ports``. + """ + results = self._run_parallel(self._register_single) + # Re-initialise done mask: failed envs stay done + self._done_mask = [r is None for r in results] + if self._single_env: + return results[0] + return results + + def make_step( + self, actions: "Action | List[Action]" + ) -> "Tuple[Observation | None, bool] | Tuple[List[Optional[Observation]], List[bool]]": + """Execute one step in every **active** environment in parallel. + + Args: + actions: In single-env mode a single ``Action``; in multi-env + mode a list of ``Action`` objects (one per environment). + Actions at indices where the done mask is ``True`` are + **ignored** (no message is sent to that env). + + Returns: + In single-env mode: ``(observation, done)`` — a single + ``Observation | None`` and a ``bool``. + In multi-env mode: ``(observations, done_mask)`` — lists + positionally aligned with ``game_ports``. + + Raises: + ValueError: If the number of actions doesn't match ``num_envs``. + """ + # Normalise single action to list + if self._single_env and isinstance(actions, Action): + actions = [actions] + if len(actions) != self._num_envs: + raise ValueError( + f"Expected {self._num_envs} actions, got {len(actions)}" + ) + + # Only dispatch to envs that are still active + active = [i for i in range(self._num_envs) if not self._done_mask[i]] + args = {i: (actions[i],) for i in active} + + results = self._run_parallel( + self._make_step_single, active, args_per_env=args + ) + + # Update done mask + for i in active: + obs = results[i] + if obs is None: + # Communication failure → mark as done + self._done_mask[i] = True + elif obs.end: + self._done_mask[i] = True + + if self._single_env: + return results[0]#, self._done_mask[0] + return results#, list(self._done_mask) + + def request_game_reset( + self, + request_trajectory: bool = False, + randomize_topology: bool = False, + seed: Optional[int] = None, + ) -> "Observation | None | List[Optional[Observation]]": + """Reset all environments in parallel. + + Args: + request_trajectory: If True, request episode trajectory from + the server. + randomize_topology: If True, randomize the network topology. + seed: RNG seed. Required when ``randomize_topology`` is True. + + Returns: + In single-env mode: the initial ``Observation`` (or ``None``). + In multi-env mode: list of initial observations, positionally + aligned with ``game_ports``. + """ + if seed is None and randomize_topology: + raise ValueError( + "Topology randomization without seed is not supported." + ) + + # Reset every env that still has a socket (even done ones) + connected = [i for i in range(self._num_envs) if self._sockets[i] is not None] + args = { + i: (request_trajectory, randomize_topology, seed) for i in connected + } + results = self._run_parallel( + self._request_game_reset_single, connected, args_per_env=args + ) + + # Clear the done mask for successfully reset envs + self._done_mask = [results[i] is None for i in range(self._num_envs)] + if self._single_env: + return results[0] + return results From 2c4971e4f4b6a6eeaeefa561b37d93ee5e7a71d6 Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Tue, 5 May 2026 13:44:16 +0200 Subject: [PATCH 2/8] Add example implenetation of random attacker --- examples/agents/random_attacker.py | 235 +++++++++++++++++++++++++++++ 1 file changed, 235 insertions(+) create mode 100644 examples/agents/random_attacker.py diff --git a/examples/agents/random_attacker.py b/examples/agents/random_attacker.py new file mode 100644 index 00000000..7343a933 --- /dev/null +++ b/examples/agents/random_attacker.py @@ -0,0 +1,235 @@ +# Author: Ondrej Lukas, ondrej.lukas@aic.cvut.cz +# This agent is an example of an attacker agent that plays the NetSecGame +# using random actions. It is intended to be used as a baseline for +# evaluating the performance of other agents. +# It can be used both with single environment, and with multiple simultaneous environments. + +import logging +import argparse +import numpy as np +from os import path, makedirs +import random +from netsecgame import Action, Observation, generate_valid_actions, AgentRole +from netsecgame.agents.parallel_base_agent import ParallelBaseAgent +from netsecgame.game_components import AgentStatus + +class RandomAttackerAgent(ParallelBaseAgent): + """ + An attacker agent that selects actions randomly without learning. + Inherits from ParallelBaseAgent. + """ + + def __init__(self, host, port, role, seed) -> None: + """ + Initialize the RandomAttackerAgent. + + Args: + host (str): Host address to connect to. + port (int | list[int]): Port number(s) to connect to. + role (AgentRole): The role of the agent (e.g., AgentRole.Attacker). + seed (int): Seed for random number generation for the agent's decisions. + """ + super().__init__(host, port, role) + self.rng = random.Random(seed) + + def select_action(self, observation: Observation) -> Action: + """ + Selects a random action from the set of valid actions in the current state. + + Args: + observation (Observation): The current observation including the game state. + + Returns: + Action: The randomly selected action. + """ + valid_actions = sorted(generate_valid_actions(observation.state), key=lambda x: str(x)) + # randomly choose with the seeded rng + action = self.rng.choice(valid_actions) + return action + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", help="Host(s) where the game server is", default="127.0.0.1", + type=str, nargs="+", required=False) + parser.add_argument("--port", help="Port(s) where the game server is", + default=[9000], type=int, nargs='+', required=False) + parser.add_argument("--episodes", help="Sets number of episodes to play", default=100, type=int) + parser.add_argument("--seed", help="Sets random seed for agent's decisions", default=42, type=int) + parser.add_argument("--logdir", help="Folder to store logs", default=path.join(path.dirname(path.abspath(__file__)), "logs")) + args = parser.parse_args() + + if not path.exists(args.logdir): + makedirs(args.logdir) + logging.basicConfig(filename=path.join(args.logdir, "random_agent.log"), filemode='w', format='%(asctime)s %(name)s %(levelname)s %(message)s', datefmt='%H:%M:%S',level=logging.INFO) + + # Create agent + agent = RandomAttackerAgent(args.host, args.port, AgentRole.Attacker, seed=args.seed) + num_envs = agent.num_envs + + # Register in the game + observations = agent.register() + observations = agent.request_game_reset(randomize_topology=False) + + # Ensure observations is always a list for uniform handling + if num_envs == 1: + observations = [observations] + + # To keep statistics of each episode + wins = 0 + detected = 0 + max_steps = 0 + num_win_steps = [] + num_detected_steps = [] + num_max_steps_steps = [] + num_detected_returns = [] + num_win_returns = [] + num_max_steps_returns = [] + + # To keep statistics per env + wins_env = [0] * num_envs + detected_env = [0] * num_envs + max_steps_env = [0] * num_envs + num_win_steps_env = [[] for _ in range(num_envs)] + num_detected_steps_env = [[] for _ in range(num_envs)] + num_max_steps_steps_env = [[] for _ in range(num_envs)] + num_detected_returns_env = [[] for _ in range(num_envs)] + num_win_returns_env = [[] for _ in range(num_envs)] + num_max_steps_returns_env = [[] for _ in range(num_envs)] + + for episode in range(1, args.episodes + 1): + agent.logger.info(f'Starting episode {episode}') + print(f'Starting episode {episode}') + + # Per-env tracking for this episode + episodic_returns = [[] for _ in range(num_envs)] + num_steps = [0] * num_envs + + # Play the game until all envs are done + while not agent.all_done: + # Select an action for each active env + actions = [] + done_mask = agent.done_mask if num_envs > 1 else [agent.done_mask] + for i in range(num_envs): + if not done_mask[i]: + obs = observations[i] + episodic_returns[i].append(obs.reward) + num_steps[i] += 1 + actions.append(agent.select_action(obs)) + else: + # Placeholder action for done envs (will be ignored) + actions.append(Action(action_type=agent.select_action(observations[0]).action_type)) + + # Step all envs + new_observations = agent.make_step(actions) + if num_envs == 1: + new_observations = [new_observations] + + # Update observations for active envs + for i in range(num_envs): + if new_observations[i] is not None: + observations[i] = new_observations[i] + + # Episode finished — collect stats per env + for i in range(num_envs): + obs = observations[i] + agent.logger.debug(f'Env {i} final observation: {obs}') + current_return = np.sum(episodic_returns[i]) + reward = current_return + + agent.logger.info(f"Episode {episode}, Env {i}: ended with return {current_return}.") + + if obs.info and obs.info.get('end_reason') == AgentStatus.Fail: + detected += 1 + detected_env[i] += 1 + num_detected_steps.append(num_steps[i]) + num_detected_steps_env[i].append(num_steps[i]) + num_detected_returns.append(reward) + num_detected_returns_env[i].append(reward) + elif obs.info and obs.info.get('end_reason') == AgentStatus.Success: + wins += 1 + wins_env[i] += 1 + num_win_steps.append(num_steps[i]) + num_win_steps_env[i].append(num_steps[i]) + num_win_returns.append(reward) + num_win_returns_env[i].append(reward) + elif obs.info and obs.info.get('end_reason') == AgentStatus.TimeoutReached: + max_steps += 1 + max_steps_env[i] += 1 + num_max_steps_steps.append(num_steps[i]) + num_max_steps_steps_env[i].append(num_steps[i]) + num_max_steps_returns.append(reward) + num_max_steps_returns_env[i].append(reward) + + # Reset all envs for next episode + if episode < args.episodes: + if episode % 10 == 0: + observations = agent.request_game_reset(randomize_topology=True, seed=episode) + else: + observations = agent.request_game_reset(randomize_topology=False, seed=episode) + if num_envs == 1: + observations = [observations] + + # Calculate stats for logging (counts are across all envs) + total_episodes_played = episode * num_envs + eval_win_rate = (wins / total_episodes_played) * 100 + eval_detection_rate = (detected / total_episodes_played) * 100 + + all_returns = num_detected_returns + num_win_returns + num_max_steps_returns + eval_average_returns = np.mean(all_returns) if all_returns else 0 + eval_std_returns = np.std(all_returns) if all_returns else 0 + + all_steps = num_win_steps + num_detected_steps + num_max_steps_steps + eval_average_episode_steps = np.mean(all_steps) if all_steps else 0 + eval_std_episode_steps = np.std(all_steps) if all_steps else 0 + + # Calculate stats per env for logging + eval_win_rate_env = [(wins_env[i] / episode) * 100 for i in range(num_envs)] + eval_detection_rate_env = [(detected_env[i] / episode) * 100 for i in range(num_envs)] + + eval_average_returns_env = [] + eval_std_returns_env = [] + eval_average_episode_steps_env = [] + eval_std_episode_steps_env = [] + + for i in range(num_envs): + all_returns_i = num_detected_returns_env[i] + num_win_returns_env[i] + num_max_steps_returns_env[i] + eval_average_returns_env.append(np.mean(all_returns_i) if all_returns_i else 0) + eval_std_returns_env.append(np.std(all_returns_i) if all_returns_i else 0) + + all_steps_i = num_win_steps_env[i] + num_detected_steps_env[i] + num_max_steps_steps_env[i] + eval_average_episode_steps_env.append(np.mean(all_steps_i) if all_steps_i else 0) + eval_std_episode_steps_env.append(np.std(all_steps_i) if all_steps_i else 0) + + # Format table text + header = f"| {'Scope':<8} | {'Episodes':<8} | {'Wins':<6} | {'Dets':<6} | {'MaxStp':<6} | {'Win%':<7} | {'Det%':<7} | {'Avg Returns':<20} | {'Avg Steps':<20} |" + separator = "-" * len(header) + + table_lines = [ + f"\nFinal results for {args.episodes} episodes x {num_envs} envs ({total_episodes_played} total):", + separator, + header, + separator + ] + + global_returns_str = f"{eval_average_returns:.2f} +/- {eval_std_returns:.2f}" + global_steps_str = f"{eval_average_episode_steps:.2f} +/- {eval_std_episode_steps:.2f}" + global_row = f"| {'Global':<8} | {total_episodes_played:<8} | {wins:<6} | {detected:<6} | {max_steps:<6} | {eval_win_rate:>6.2f}% | {eval_detection_rate:>6.2f}% | {global_returns_str:<20} | {global_steps_str:<20} |" + table_lines.append(global_row) + + for i in range(num_envs): + env_returns_str = f"{eval_average_returns_env[i]:.2f} +/- {eval_std_returns_env[i]:.2f}" + env_steps_str = f"{eval_average_episode_steps_env[i]:.2f} +/- {eval_std_episode_steps_env[i]:.2f}" + env_row = f"| {f'Env {i}':<8} | {args.episodes:<8} | {wins_env[i]:<6} | {detected_env[i]:<6} | {max_steps_env[i]:<6} | {eval_win_rate_env[i]:>6.2f}% | {eval_detection_rate_env[i]:>6.2f}% | {env_returns_str:<20} | {env_steps_str:<20} |" + table_lines.append(env_row) + + table_lines.append(separator) + table_text = "\n".join(table_lines) + + agent.logger.info(table_text) + print(table_text) + agent._logger.info("Terminating interaction") + # stop gracefully + agent.terminate_connection() + +if __name__ == '__main__': + main() \ No newline at end of file From 13668c55f8829b132e8451eeb1fafc233d4a181f Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Tue, 5 May 2026 14:12:53 +0200 Subject: [PATCH 3/8] verify that the attributes were created --- netsecgame/agents/parallel_base_agent.py | 30 +++++++++++++----------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/netsecgame/agents/parallel_base_agent.py b/netsecgame/agents/parallel_base_agent.py index 484bfcf8..ed5b3966 100644 --- a/netsecgame/agents/parallel_base_agent.py +++ b/netsecgame/agents/parallel_base_agent.py @@ -104,20 +104,22 @@ def __del__(self) -> None: def terminate_connection(self) -> None: """Close all sockets and shut down the thread pool.""" - for i, sock in enumerate(self._sockets): - if sock is not None: - try: - sock.close() - self._logger.info(f"Socket for env {i} ({self._envs[i]}) closed") - except socket.error as e: - self._logger.error( - f"Error closing socket for env {i} ({self._envs[i]}): {e}" - ) - self._sockets[i] = None - try: - self._executor.shutdown(wait=False) - except Exception: - pass + if hasattr(self, '_sockets'): + for i, sock in enumerate(self._sockets): + if sock is not None: + try: + sock.close() + self._logger.info(f"Socket for env {i} ({self._envs[i]}) closed") + except socket.error as e: + self._logger.error( + f"Error closing socket for env {i} ({self._envs[i]}): {e}" + ) + self._sockets[i] = None + if hasattr(self, '_executor'): + try: + self._executor.shutdown(wait=False) + except Exception: + pass # ====================================================================== # Properties From fe3712c01fc13384aea348bdc234b87e0f467b15 Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Tue, 5 May 2026 14:13:14 +0200 Subject: [PATCH 4/8] Add tests for parallel base_agent --- tests/agents/test_parallel_base_agent.py | 199 +++++++++++++++++++++++ 1 file changed, 199 insertions(+) create mode 100644 tests/agents/test_parallel_base_agent.py diff --git a/tests/agents/test_parallel_base_agent.py b/tests/agents/test_parallel_base_agent.py new file mode 100644 index 00000000..232d96db --- /dev/null +++ b/tests/agents/test_parallel_base_agent.py @@ -0,0 +1,199 @@ +import pytest +from unittest.mock import patch, MagicMock +import socket +import json + +from netsecgame.agents.parallel_base_agent import ParallelBaseAgent +from netsecgame.game_components import ( + Action, Observation, ActionType, GameStatus, AgentRole +) + +# Helper for empty valid state +VALID_STATE_DICT = { + "known_networks": [], + "known_hosts": [], + "controlled_hosts": [], + "known_services": {}, + "known_data": {} +} + + +@pytest.fixture +def mock_socket(): + with patch('netsecgame.agents.parallel_base_agent.socket.socket') as mock_sock: + yield mock_sock + + +class TestParallelBaseAgent: + + def test_single_env_init(self, mock_socket): + agent = ParallelBaseAgent("127.0.0.1", 9000, AgentRole.Attacker) + assert agent.num_envs == 1 + assert agent._single_env is True + assert len(agent._sockets) == 1 + assert agent.connected == [True] + + def test_broadcast_host_init(self, mock_socket): + agent = ParallelBaseAgent("127.0.0.1", [9000, 9001, 9002], AgentRole.Attacker) + assert agent.num_envs == 3 + assert agent._single_env is False + assert len(agent._sockets) == 3 + assert agent.connected == [True, True, True] + + def test_multi_host_init(self, mock_socket): + agent = ParallelBaseAgent(["127.0.0.1", "127.0.0.2"], [9000, 9001], AgentRole.Attacker) + assert agent.num_envs == 2 + assert agent._single_env is False + + def test_mismatched_lengths(self, mock_socket): + with pytest.raises(ValueError, match="must be 1 .* or match game_ports length"): + ParallelBaseAgent(["127.0.0.1", "127.0.0.2"], [9000, 9001, 9002], AgentRole.Attacker) + + def test_connection_failures(self, mock_socket): + mock_instance = MagicMock() + def connect_side_effect(addr): + if addr[1] == 9001: + raise socket.error("Mocked connection error") + mock_instance.connect.side_effect = connect_side_effect + mock_socket.return_value = mock_instance + + agent = ParallelBaseAgent("127.0.0.1", [9000, 9001, 9002], AgentRole.Attacker) + assert agent.num_envs == 3 + assert agent.connected == [True, False, True] + assert agent._sockets[1] is None + assert agent.done_mask == [False, True, False] + + def test_terminate_connection(self, mock_socket): + mock_socket.side_effect = lambda *args, **kwargs: MagicMock() + agent = ParallelBaseAgent("127.0.0.1", [9000, 9001], AgentRole.Attacker) + sockets = list(agent._sockets) + agent.terminate_connection() + + for sock in sockets: + sock.close.assert_called_once() + assert agent._sockets == [None, None] + + @patch('netsecgame.agents.parallel_base_agent.ParallelBaseAgent._communicate_single') + def test_register_single_env(self, mock_comm, mock_socket): + mock_comm.return_value = ( + GameStatus.CREATED, + {"state": VALID_STATE_DICT, "reward": 0, "end": False}, + "Registered" + ) + agent = ParallelBaseAgent("127.0.0.1", 9000, AgentRole.Attacker) + obs = agent.register() + + assert isinstance(obs, Observation) + assert agent.done_mask is False + + @patch('netsecgame.agents.parallel_base_agent.ParallelBaseAgent._communicate_single') + def test_register_multi_env(self, mock_comm, mock_socket): + mock_comm.side_effect = [ + (GameStatus.CREATED, {"state": VALID_STATE_DICT, "reward": 0, "end": False}, "Registered"), + (GameStatus.CREATED, {"state": VALID_STATE_DICT, "reward": 1, "end": False}, "Registered") + ] + + agent = ParallelBaseAgent("127.0.0.1", [9000, 9001], AgentRole.Attacker) + obs_list = agent.register() + + assert len(obs_list) == 2 + assert isinstance(obs_list[0], Observation) + assert isinstance(obs_list[1], Observation) + assert agent.done_mask == [False, False] + + @patch('netsecgame.agents.parallel_base_agent.ParallelBaseAgent._communicate_single') + def test_register_partial_failure(self, mock_comm, mock_socket): + mock_comm.side_effect = [ + (GameStatus.CREATED, {"state": VALID_STATE_DICT, "reward": 0, "end": False}, "Registered"), + (GameStatus.BAD_REQUEST, {}, "Failed") + ] + agent = ParallelBaseAgent("127.0.0.1", [9000, 9001], AgentRole.Attacker) + obs_list = agent.register() + + assert obs_list[0] is not None + assert obs_list[1] is None + assert agent.done_mask == [False, True] + + @patch('netsecgame.agents.parallel_base_agent.ParallelBaseAgent._communicate_single') + def test_make_step_single_env(self, mock_comm, mock_socket): + mock_comm.return_value = ( + GameStatus.OK, + {"state": VALID_STATE_DICT, "reward": 1, "end": False, "info": {}}, + "Step 1" + ) + agent = ParallelBaseAgent("127.0.0.1", 9000, AgentRole.Attacker) + agent._done_mask = [False] + obs = agent.make_step(Action(ActionType.ScanNetwork)) + + assert isinstance(obs, Observation) + assert agent.done_mask is False + + @patch('netsecgame.agents.parallel_base_agent.ParallelBaseAgent._communicate_single') + def test_make_step_multi_env(self, mock_comm, mock_socket): + mock_comm.side_effect = [ + (GameStatus.OK, {"state": VALID_STATE_DICT, "reward": 0, "end": False, "info": {}}, "Step 1"), + (GameStatus.OK, {"state": VALID_STATE_DICT, "reward": 1, "end": True, "info": {}}, "Step 2") + ] + + agent = ParallelBaseAgent("127.0.0.1", [9000, 9001], AgentRole.Attacker) + agent._done_mask = [False, False] + actions = [Action(ActionType.ScanNetwork), Action(ActionType.ScanNetwork)] + + obs_list = agent.make_step(actions) + + assert len(obs_list) == 2 + assert obs_list[0].end is False + assert obs_list[1].end is True + + assert agent.done_mask == [False, True] + assert agent.all_done is False + + @patch('netsecgame.agents.parallel_base_agent.ParallelBaseAgent._communicate_single') + def test_make_step_skips_done(self, mock_comm, mock_socket): + mock_comm.side_effect = [ + (GameStatus.OK, {"state": VALID_STATE_DICT, "reward": 0, "end": False, "info": {}}, "Step 1"), + ] + + agent = ParallelBaseAgent("127.0.0.1", [9000, 9001], AgentRole.Attacker) + agent._done_mask = [False, True] + + actions = [Action(ActionType.ScanNetwork), Action(ActionType.ScanNetwork)] + obs_list = agent.make_step(actions) + + assert len(obs_list) == 2 + assert isinstance(obs_list[0], Observation) + assert obs_list[1] is None + assert mock_comm.call_count == 1 + + @patch('netsecgame.agents.parallel_base_agent.ParallelBaseAgent._communicate_single') + def test_request_game_reset(self, mock_comm, mock_socket): + mock_comm.side_effect = [ + (GameStatus.RESET_DONE, {"state": VALID_STATE_DICT, "reward": 0, "end": False}, "Reset"), + (GameStatus.RESET_DONE, {"state": VALID_STATE_DICT, "reward": 0, "end": False}, "Reset") + ] + + agent = ParallelBaseAgent("127.0.0.1", [9000, 9001], AgentRole.Attacker) + agent._done_mask = [True, True] + + obs_list = agent.request_game_reset() + + assert len(obs_list) == 2 + assert agent.done_mask == [False, False] + + def test_request_game_reset_topology_no_seed(self, mock_socket): + agent = ParallelBaseAgent("127.0.0.1", 9000, AgentRole.Attacker) + with pytest.raises(ValueError, match="Topology randomization without seed is not supported."): + agent.request_game_reset(randomize_topology=True, seed=None) + + def test_run_parallel_exception_isolation(self, mock_socket): + agent = ParallelBaseAgent("127.0.0.1", [9000, 9001, 9002], AgentRole.Attacker) + agent._done_mask = [False, False, False] + + def faulty_fn(env_idx): + if env_idx == 1: + raise ValueError("Thread crash") + return f"Success {env_idx}" + + results = agent._run_parallel(faulty_fn) + + assert results == ["Success 0", None, "Success 2"] From abcbe7d3b46144dcb570d50f005055cb466c4a8b Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Tue, 5 May 2026 14:16:41 +0200 Subject: [PATCH 5/8] Fix ruff erros --- netsecgame/agents/parallel_base_agent.py | 2 +- tests/agents/test_parallel_base_agent.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/netsecgame/agents/parallel_base_agent.py b/netsecgame/agents/parallel_base_agent.py index ed5b3966..63214a8b 100644 --- a/netsecgame/agents/parallel_base_agent.py +++ b/netsecgame/agents/parallel_base_agent.py @@ -4,7 +4,7 @@ import socket import json from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Optional, Tuple, Dict, Any, List, Union, overload +from typing import Optional, Tuple, Dict, Any, List from netsecgame.game_components import ( Action, GameState, Observation, ActionType, diff --git a/tests/agents/test_parallel_base_agent.py b/tests/agents/test_parallel_base_agent.py index 232d96db..194f1782 100644 --- a/tests/agents/test_parallel_base_agent.py +++ b/tests/agents/test_parallel_base_agent.py @@ -1,8 +1,6 @@ import pytest from unittest.mock import patch, MagicMock import socket -import json - from netsecgame.agents.parallel_base_agent import ParallelBaseAgent from netsecgame.game_components import ( Action, Observation, ActionType, GameStatus, AgentRole From 8e7cc248457bf92ad52de968f6253cace5c7cc66 Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Tue, 5 May 2026 14:21:07 +0200 Subject: [PATCH 6/8] Add typing --- netsecgame/agents/parallel_base_agent.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/netsecgame/agents/parallel_base_agent.py b/netsecgame/agents/parallel_base_agent.py index 63214a8b..7462da8b 100644 --- a/netsecgame/agents/parallel_base_agent.py +++ b/netsecgame/agents/parallel_base_agent.py @@ -4,7 +4,7 @@ import socket import json from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Optional, Tuple, Dict, Any, List +from typing import Optional, Tuple, Dict, Any, List, Callable from netsecgame.game_components import ( Action, GameState, Observation, ActionType, @@ -22,6 +22,9 @@ class ParallelBaseAgent: ``register()``, ``make_step()``, and ``request_game_reset()`` that operate on lists of actions/observations. + For a concrete example of extending and using this class, + see ``examples/agents/random_attacker.py``. + Args: game_hosts: Host address(es). A single string is broadcast to all ``game_ports``. A list must either have length 1 (broadcast) @@ -304,7 +307,7 @@ def _request_game_reset_single( def _run_parallel( self, - fn, + fn: Callable[[int, ...], Any], env_indices: Optional[List[int]] = None, *, args_per_env: Optional[Dict[int, tuple]] = None, From 40f1f422515be7adf9b5a8dd3f76ff7d48d730ed Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Tue, 5 May 2026 14:21:22 +0200 Subject: [PATCH 7/8] Add to docs --- docs/architecture.md | 2 ++ docs/base_agent.md | 1 + 2 files changed, 3 insertions(+) diff --git a/docs/architecture.md b/docs/architecture.md index a297f987..dd3e514e 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -126,6 +126,7 @@ After submitting Action `a` to the environment, agents receive an `Observation` ├── netsecgame/ │ ├── agents/ │ │ ├── base_agent.py # Base agent class — API for agent-server communication +│ │ ├── parallel_base_agent.py # Agent class for multi-environment parallel execution │ ├── game/ │ │ ├── scenarios/ │ │ │ ├── one_net.py # Single network scenario @@ -160,5 +161,6 @@ After submitting Action `a` to the environment, agents receive an `Observation` - **[`game_components.py`](game_components.md)** — Library of core objects used throughout the environment. - **[`global_defender.py`](global_defender.md)** — Stochastic omnipresent defender simulating a SIEM system. - **[`base_agent.py`](base_agent.md)** — Base class for all agents. Implements the TCP communication protocol. +- **`parallel_base_agent.py`** — Base class for parallel multi-environment agents. See `examples/agents/random_attacker.py` for a reference implementation. The [scenarios](#) define the **topology** of a network (hosts, connections, networks, services, data, firewall rules) while the [task configuration](configuration.md) defines the exact task for agents within a given topology. \ No newline at end of file diff --git a/docs/base_agent.md b/docs/base_agent.md index 7146b137..68fdad5c 100644 --- a/docs/base_agent.md +++ b/docs/base_agent.md @@ -4,3 +4,4 @@ The `BaseAgent` class provides the foundational interface for all agents interac All custom agents should extend this class and implement their decision-making logic by overriding a method like `choose_action` (see [Getting Started](getting_started.md#creating-your-first-agent) for an example). ::: netsecgame.agents.base_agent.BaseAgent +::: netsecgame.agents.parallel_base_agent.ParallelBaseAgent \ No newline at end of file From aff55cc9e81ccf68eff36a175d02fd3d795458e5 Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Tue, 5 May 2026 14:26:19 +0200 Subject: [PATCH 8/8] Do not return the done mask --- netsecgame/agents/parallel_base_agent.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/netsecgame/agents/parallel_base_agent.py b/netsecgame/agents/parallel_base_agent.py index 7462da8b..067ca3eb 100644 --- a/netsecgame/agents/parallel_base_agent.py +++ b/netsecgame/agents/parallel_base_agent.py @@ -380,7 +380,7 @@ def register(self) -> "Observation | None | List[Optional[Observation]]": def make_step( self, actions: "Action | List[Action]" - ) -> "Tuple[Observation | None, bool] | Tuple[List[Optional[Observation]], List[bool]]": + ) -> "Observation | None | List[Optional[Observation]]": """Execute one step in every **active** environment in parallel. Args: @@ -390,10 +390,12 @@ def make_step( **ignored** (no message is sent to that env). Returns: - In single-env mode: ``(observation, done)`` — a single - ``Observation | None`` and a ``bool``. - In multi-env mode: ``(observations, done_mask)`` — lists - positionally aligned with ``game_ports``. + In single-env mode: a single ``Observation | None``. + In multi-env mode: list of observations positionally aligned + with ``game_ports``. + + *Note: If you need to access the boolean done statuses across + all environments, you can use the `self.done_mask` property.* Raises: ValueError: If the number of actions doesn't match ``num_envs``. @@ -424,8 +426,8 @@ def make_step( self._done_mask[i] = True if self._single_env: - return results[0]#, self._done_mask[0] - return results#, list(self._done_mask) + return results[0] + return results def request_game_reset( self,