backend.server

Backend server for the SI2 - Maze simulation environment. Handles WebSocket connections, map management, and simulation state.

  1"""
  2Backend server for the SI2 - Maze simulation environment.
  3Handles WebSocket connections, map management, and simulation state.
  4"""
  5
  6import asyncio
  7import json
  8import logging
  9import os
 10from typing import Any, Dict, List, Optional, Tuple
 11
 12# Configure standard logging
 13logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 14
 15
 16class SimulationServer:
 17    """
 18    Main simulation engine.
 19    Handles map loading, agent movement, and state broadcasting.
 20    """
 21
 22    def __init__(self) -> None:
 23        """Initialize the SimulationServer with default states."""
 24        self.frontend_ws: Optional[Any] = None
 25        self.agent_ws: Optional[Any] = None
 26        self.maps_dir: str = "maps"
 27        self.current_map: Optional[Dict[str, Any]] = None
 28        self.reachable_tiles: int = 0
 29        self.sim_state: Dict[str, Any] = {}
 30        self.running: bool = False
 31
 32        if not os.path.exists(self.maps_dir):
 33            os.makedirs(self.maps_dir)
 34            logging.info(f"Created maps directory at: {os.path.abspath(self.maps_dir)}")
 35
 36    def calculate_reachable_tiles(self) -> int:
 37        r"""Uses BFS to count floor tiles reachable from the start position.
 38
 39        The number of reachable tiles $R$ is defined as:
 40        $R = |\{ (x, y) \in \text{Grid} \mid \text{path}(\text{start}, (x, y)) \}|$
 41
 42        Returns:
 43            int: Number of reachable floor tiles.
 44        """
 45        if self.current_map is None:
 46            return 0
 47
 48        start_pos: Tuple[int, int] = tuple(self.current_map.get("start", [0, 0]))  # type: ignore
 49        width: int = self.current_map["width"]
 50        height: int = self.current_map["height"]
 51        is_teleport: bool = self.current_map.get("teleport", False)
 52
 53        queue: List[Tuple[int, int]] = [start_pos]
 54        visited: set[Tuple[int, int]] = {start_pos}
 55        reachable_count: int = 0
 56
 57        while queue:
 58            cx, cy = queue.pop(0)
 59            if self.current_map["grid"][cy][cx] != "obstacle":
 60                reachable_count += 1
 61
 62                # Check 4-way neighbors
 63                for dx, dy in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
 64                    nx, ny = cx + dx, cy + dy
 65
 66                    if is_teleport:
 67                        nx %= width
 68                        ny %= height
 69                    elif nx < 0 or nx >= width or ny < 0 or ny >= height:
 70                        continue
 71
 72                    if (nx, ny) not in visited and self.current_map["grid"][ny][nx] != "obstacle":
 73                        visited.add((nx, ny))
 74                        queue.append((nx, ny))
 75
 76        return reachable_count
 77
 78    async def start(self, host: str = "0.0.0.0", port: int = 8765) -> None:
 79        """Start the WebSocket server.
 80
 81        Args:
 82            host (str): Host address to bind to.
 83            port (int): Port number to listen on.
 84        """
 85        import websockets
 86
 87        logging.info(f"Starting websocket server on ws://{host}:{port}")
 88        async with websockets.serve(self.handle_client, host, port):
 89            await asyncio.Future()
 90
 91    async def handle_client(self, websocket: Any) -> None:
 92        """Handle incoming WebSocket connections.
 93
 94        Args:
 95            websocket (Any): The WebSocket connection object.
 96        """
 97        client_type: str = "Unknown"
 98        try:
 99            init_msg = await websocket.recv()
100            try:
101                data = json.loads(init_msg)
102            except json.JSONDecodeError:
103                logging.warning("Received malformed initial message.")
104                return
105
106            client_type = data.get("client", "Unknown")
107
108            if client_type == "frontend":
109                if self.frontend_ws is not None:
110                    logging.warning("Frontend already connected. Rejecting new connection.")
111                    await websocket.send(json.dumps({"type": "error", "message": "Frontend already connected."}))
112                    await websocket.close()
113                    return
114                logging.info("Frontend connected.")
115                self.frontend_ws = websocket
116                await self.send_map_list()
117                await self.frontend_loop(websocket)
118            elif client_type == "agent":
119                if self.agent_ws is not None:
120                    logging.warning("Agent already connected. Rejecting new connection.")
121                    await websocket.send(json.dumps({"type": "error", "message": "Agent already connected."}))
122                    await websocket.close()
123                    return
124                logging.info("Agent connected.")
125                self.agent_ws = websocket
126                if self.running and self.current_map:
127                    await self.send_agent_state()
128                await self.agent_loop(websocket)
129            else:
130                logging.warning(f"Unknown client type attempted connection: {client_type}")
131
132        except Exception as e:
133            logging.error(f"Error handling client {client_type}: {e}")
134        finally:
135            if websocket == self.frontend_ws:
136                self.frontend_ws = None
137                logging.info("Frontend session cleared.")
138            elif websocket == self.agent_ws:
139                self.agent_ws = None
140                logging.info("Agent session cleared.")
141
142    async def frontend_loop(self, websocket: Any) -> None:
143        """Main loop for handling frontend messages.
144
145        Args:
146            websocket (Any): The frontend WebSocket connection.
147        """
148        async for message in websocket:
149            try:
150                data = json.loads(message)
151                action = data.get("action")
152
153                if action == "load_map":
154                    self.load_map(data.get("filename"))
155                    await self.update_frontend()
156                    if self.agent_ws and self.current_map:
157                        await self.agent_ws.send(json.dumps({"type": "reset"}))
158                        await self.send_agent_state()
159                elif action == "save_map":
160                    filename = data.get("filename")
161                    map_data = data.get("map_data")
162                    success, error = self.save_map(filename, map_data)
163
164                    await websocket.send(
165                        json.dumps(
166                            {
167                                "type": "save_response",
168                                "success": success,
169                                "error": error,
170                            }
171                        )
172                    )
173
174                    if success:
175                        await self.send_map_list()
176                elif action == "start_sim":
177                    if self.current_map:
178                        self.reset_sim()
179                        self.running = True
180                        logging.info("Simulation started via frontend.")
181                        await self.update_frontend()
182                        if self.agent_ws:
183                            await self.send_agent_state()
184                    else:
185                        logging.warning("Attempted to start simulation without a map.")
186                elif action == "stop_sim":
187                    self.running = False
188                    logging.info("Simulation stopped via frontend.")
189                    await self.update_frontend()
190                elif action == "reset_sim":
191                    self.reset_sim()
192                    await self.update_frontend()
193                    if self.agent_ws and self.current_map:
194                        await self.agent_ws.send(json.dumps({"type": "reset"}))
195                        await self.send_agent_state()
196            except Exception as e:
197                logging.error(f"Error processing frontend message: {e}")
198
199    async def agent_loop(self, websocket: Any) -> None:
200        """Main loop for handling agent messages.
201
202        Args:
203            websocket (Any): The agent WebSocket connection.
204        """
205        async for message in websocket:
206            if not self.running or not self.current_map:
207                continue
208            try:
209                data = json.loads(message)
210                if data.get("action") == "move":
211                    direction = data.get("direction")
212                    self.process_move(direction)
213                    self.check_objective()
214                    await self.update_frontend()
215                    await self.send_agent_state()
216                elif data.get("action") == "telemetry":
217                    if self.frontend_ws:
218                        await self.frontend_ws.send(json.dumps({"type": "agent_telemetry", "data": data.get("data")}))
219            except Exception as e:
220                logging.error(f"Error processing agent message: {e}")
221
222    def process_move(self, direction: str) -> None:
223        """Process an agent movement request.
224
225        Args:
226            direction (str): Direction to move ('N', 'S', 'E', 'W').
227        """
228        if self.current_map is None:
229            return
230
231        x, y = self.sim_state["agent_pos"]
232        nx, ny = x, y
233
234        if direction == "N":
235            ny -= 1
236        elif direction == "S":
237            ny += 1
238        elif direction == "E":
239            nx += 1
240        elif direction == "W":
241            nx -= 1
242
243        width = self.current_map["width"]
244        height = self.current_map["height"]
245
246        if self.current_map.get("teleport", False):
247            nx = nx % width
248            ny = ny % height
249        else:
250            if nx < 0 or nx >= width or ny < 0 or ny >= height:
251                return
252
253        cell = self.current_map["grid"][ny][nx]
254        if cell == "obstacle":
255            key = f"{nx},{ny}"
256            self.sim_state["hits"][key] = self.sim_state["hits"].get(key, 0) + 1
257        else:
258            self.sim_state["agent_pos"] = [nx, ny]
259            key = f"{nx},{ny}"
260            self.sim_state["visits"][key] = self.sim_state["visits"].get(key, 0) + 1
261
262    def get_valid_actions(self) -> List[str]:
263        """Get a list of valid actions for the agent at its current position.
264
265        Returns:
266            List[str]: List of valid cardinal directions.
267        """
268        if self.current_map is None:
269            return []
270
271        x, y = self.sim_state["agent_pos"]
272        width = self.current_map["width"]
273        height = self.current_map["height"]
274        actions: List[str] = []
275
276        is_teleport: bool = self.current_map.get("teleport", False)
277
278        # Check North
279        ny = (y - 1) % height if is_teleport else y - 1
280        if (ny >= 0 or is_teleport) and self.current_map["grid"][ny][x] != "obstacle":
281            actions.append("N")
282
283        # Check South
284        ny = (y + 1) % height if is_teleport else y + 1
285        if (ny < height or is_teleport) and self.current_map["grid"][ny][x] != "obstacle":
286            actions.append("S")
287
288        # Check East
289        nx = (x + 1) % width if is_teleport else x + 1
290        if (nx < width or is_teleport) and self.current_map["grid"][y][nx] != "obstacle":
291            actions.append("E")
292
293        # Check West
294        nx = (x - 1) % width if is_teleport else x - 1
295        if (nx >= 0 or is_teleport) and self.current_map["grid"][y][nx] != "obstacle":
296            actions.append("W")
297
298        return actions
299
300    def reset_sim(self) -> None:
301        """Resets the map state and heatmaps to their initial conditions."""
302        if self.current_map:
303            start_pos = self.current_map.get("start", [0, 0])
304            self.sim_state = {
305                "agent_pos": start_pos,
306                "visits": {f"{start_pos[0]},{start_pos[1]}": 1},
307                "hits": {},
308            }
309            self.running = False
310            logging.info("Simulation reset to start state.")
311
312    def check_objective(self) -> None:
313        """Checks if the simulation objective has been reached."""
314        if self.current_map is None:
315            return
316
317        if self.current_map["type"] == "maze":
318            if self.sim_state["agent_pos"] == self.current_map.get("target"):
319                self.running = False
320                logging.info("Objective Reached: Maze target found!")
321        elif self.current_map["type"] == "room":
322            if len(self.sim_state["visits"]) >= self.reachable_tiles:
323                self.running = False
324                logging.info("Objective Reached: Room fully explored!")
325
326    async def send_agent_state(self) -> None:
327        """Sends the current simulation state to the agent."""
328        if self.agent_ws and self.current_map:
329            payload = {
330                "type": "state",
331                "position": self.sim_state["agent_pos"],
332                "valid_actions": self.get_valid_actions(),
333                "objective_reached": not self.running,
334                "target": self.current_map.get("target") if self.current_map["type"] == "maze" else None,
335                "start": self.current_map.get("start"),
336                "width": self.current_map.get("width"),
337                "height": self.current_map.get("height"),
338            }
339            await self.agent_ws.send(json.dumps(payload))
340
341    async def update_frontend(self) -> None:
342        """Sends the current simulation state to the frontend."""
343        if self.frontend_ws:
344            payload = {
345                "type": "update",
346                "map": self.current_map,
347                "state": self.sim_state,
348                "running": self.running,
349                "agent_connected": self.agent_ws is not None,
350            }
351            await self.frontend_ws.send(json.dumps(payload))
352
353    async def send_map_list(self) -> None:
354        """Sends the list of available maps to the frontend."""
355        if self.frontend_ws:
356            try:
357                maps = sorted([f for f in os.listdir(self.maps_dir) if f.endswith(".json")])
358                await self.frontend_ws.send(json.dumps({"type": "map_list", "maps": maps}))
359            except Exception as e:
360                logging.error(f"Failed to read maps directory: {e}")
361
362    def load_map(self, filename: str) -> None:
363        """Load a map from a JSON file.
364
365        Args:
366            filename (str): Name of the map file to load.
367        """
368        try:
369            filename = os.path.basename(filename)
370            filepath = os.path.join(self.maps_dir, filename)
371            with open(filepath, "r") as f:
372                self.current_map = json.load(f)
373
374            if self.current_map is None:
375                logging.error(f"Failed to load map {filename}: file is empty or invalid.")
376                return
377
378            self.reachable_tiles = self.calculate_reachable_tiles()
379            logging.info(f"Map loaded: {filename}. Reachable floor tiles: {self.reachable_tiles}")
380
381            start_pos = self.current_map.get("start", [0, 0])
382            self.sim_state = {
383                "agent_pos": start_pos,
384                "visits": {f"{start_pos[0]},{start_pos[1]}": 1},
385                "hits": {},
386            }
387            self.running = False
388            logging.info(f"Successfully loaded map: {filename}")
389        except Exception as e:
390            logging.error(f"Failed to load map {filename}: {e}")
391
392    def validate_map_data(self, data: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
393        """
394        Basic schema validation for incoming map data.
395
396        Args:
397            data (Dict[str, Any]): Map data to validate.
398
399        Returns:
400            Tuple[bool, Optional[str]]: (is_valid, error_message)
401        """
402        try:
403            required = ["width", "height", "type", "grid", "start"]
404            if not all(k in data for k in required):
405                return (
406                    False,
407                    "Missing required fields (width, height, type, grid, start).",
408                )
409
410            if not isinstance(data["grid"], list) or len(data["grid"]) != data["height"]:
411                return False, f"Grid height mismatch. Expected {data['height']} rows."
412
413            for row in data["grid"]:
414                if not isinstance(row, list) or len(row) != data["width"]:
415                    return (
416                        False,
417                        f"Grid width mismatch. Expected {data['width']} columns.",
418                    )
419
420            return True, None
421        except Exception as e:
422            return False, f"Validation error: {str(e)}"
423
424    def save_map(self, filename: str, map_data: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
425        """
426        Save map data to a JSON file.
427
428        Args:
429            filename (str): Name of the file to save.
430            map_data (Dict[str, Any]): Map data to save.
431
432        Returns:
433            Tuple[bool, Optional[str]]: (success, error_message)
434        """
435        try:
436            filename = os.path.basename(filename)
437            if not filename.endswith(".json"):
438                filename += ".json"
439
440            is_valid, error_msg = self.validate_map_data(map_data)
441            if not is_valid:
442                logging.warning(f"Rejected invalid map save request: {error_msg}")
443                return False, error_msg
444
445            filepath = os.path.join(self.maps_dir, filename)
446            with open(filepath, "w") as f:
447                json.dump(map_data, f)
448
449            logging.info(f"Successfully saved map: {filepath}")
450            return True, None
451        except PermissionError:
452            err = "Permission denied when saving. Check Docker volume permissions."
453            logging.error(err)
454            return False, err
455        except Exception as e:
456            err = f"Unexpected error saving map: {str(e)}"
457            logging.error(err)
458            return False, err
459
460
461if __name__ == "__main__":
462    import os
463    port = int(os.environ.get("PORT", 8765))
464    server = SimulationServer()
465    asyncio.run(server.start(port=port))
class SimulationServer:
 17class SimulationServer:
 18    """
 19    Main simulation engine.
 20    Handles map loading, agent movement, and state broadcasting.
 21    """
 22
 23    def __init__(self) -> None:
 24        """Initialize the SimulationServer with default states."""
 25        self.frontend_ws: Optional[Any] = None
 26        self.agent_ws: Optional[Any] = None
 27        self.maps_dir: str = "maps"
 28        self.current_map: Optional[Dict[str, Any]] = None
 29        self.reachable_tiles: int = 0
 30        self.sim_state: Dict[str, Any] = {}
 31        self.running: bool = False
 32
 33        if not os.path.exists(self.maps_dir):
 34            os.makedirs(self.maps_dir)
 35            logging.info(f"Created maps directory at: {os.path.abspath(self.maps_dir)}")
 36
 37    def calculate_reachable_tiles(self) -> int:
 38        r"""Uses BFS to count floor tiles reachable from the start position.
 39
 40        The number of reachable tiles $R$ is defined as:
 41        $R = |\{ (x, y) \in \text{Grid} \mid \text{path}(\text{start}, (x, y)) \}|$
 42
 43        Returns:
 44            int: Number of reachable floor tiles.
 45        """
 46        if self.current_map is None:
 47            return 0
 48
 49        start_pos: Tuple[int, int] = tuple(self.current_map.get("start", [0, 0]))  # type: ignore
 50        width: int = self.current_map["width"]
 51        height: int = self.current_map["height"]
 52        is_teleport: bool = self.current_map.get("teleport", False)
 53
 54        queue: List[Tuple[int, int]] = [start_pos]
 55        visited: set[Tuple[int, int]] = {start_pos}
 56        reachable_count: int = 0
 57
 58        while queue:
 59            cx, cy = queue.pop(0)
 60            if self.current_map["grid"][cy][cx] != "obstacle":
 61                reachable_count += 1
 62
 63                # Check 4-way neighbors
 64                for dx, dy in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
 65                    nx, ny = cx + dx, cy + dy
 66
 67                    if is_teleport:
 68                        nx %= width
 69                        ny %= height
 70                    elif nx < 0 or nx >= width or ny < 0 or ny >= height:
 71                        continue
 72
 73                    if (nx, ny) not in visited and self.current_map["grid"][ny][nx] != "obstacle":
 74                        visited.add((nx, ny))
 75                        queue.append((nx, ny))
 76
 77        return reachable_count
 78
 79    async def start(self, host: str = "0.0.0.0", port: int = 8765) -> None:
 80        """Start the WebSocket server.
 81
 82        Args:
 83            host (str): Host address to bind to.
 84            port (int): Port number to listen on.
 85        """
 86        import websockets
 87
 88        logging.info(f"Starting websocket server on ws://{host}:{port}")
 89        async with websockets.serve(self.handle_client, host, port):
 90            await asyncio.Future()
 91
 92    async def handle_client(self, websocket: Any) -> None:
 93        """Handle incoming WebSocket connections.
 94
 95        Args:
 96            websocket (Any): The WebSocket connection object.
 97        """
 98        client_type: str = "Unknown"
 99        try:
100            init_msg = await websocket.recv()
101            try:
102                data = json.loads(init_msg)
103            except json.JSONDecodeError:
104                logging.warning("Received malformed initial message.")
105                return
106
107            client_type = data.get("client", "Unknown")
108
109            if client_type == "frontend":
110                if self.frontend_ws is not None:
111                    logging.warning("Frontend already connected. Rejecting new connection.")
112                    await websocket.send(json.dumps({"type": "error", "message": "Frontend already connected."}))
113                    await websocket.close()
114                    return
115                logging.info("Frontend connected.")
116                self.frontend_ws = websocket
117                await self.send_map_list()
118                await self.frontend_loop(websocket)
119            elif client_type == "agent":
120                if self.agent_ws is not None:
121                    logging.warning("Agent already connected. Rejecting new connection.")
122                    await websocket.send(json.dumps({"type": "error", "message": "Agent already connected."}))
123                    await websocket.close()
124                    return
125                logging.info("Agent connected.")
126                self.agent_ws = websocket
127                if self.running and self.current_map:
128                    await self.send_agent_state()
129                await self.agent_loop(websocket)
130            else:
131                logging.warning(f"Unknown client type attempted connection: {client_type}")
132
133        except Exception as e:
134            logging.error(f"Error handling client {client_type}: {e}")
135        finally:
136            if websocket == self.frontend_ws:
137                self.frontend_ws = None
138                logging.info("Frontend session cleared.")
139            elif websocket == self.agent_ws:
140                self.agent_ws = None
141                logging.info("Agent session cleared.")
142
143    async def frontend_loop(self, websocket: Any) -> None:
144        """Main loop for handling frontend messages.
145
146        Args:
147            websocket (Any): The frontend WebSocket connection.
148        """
149        async for message in websocket:
150            try:
151                data = json.loads(message)
152                action = data.get("action")
153
154                if action == "load_map":
155                    self.load_map(data.get("filename"))
156                    await self.update_frontend()
157                    if self.agent_ws and self.current_map:
158                        await self.agent_ws.send(json.dumps({"type": "reset"}))
159                        await self.send_agent_state()
160                elif action == "save_map":
161                    filename = data.get("filename")
162                    map_data = data.get("map_data")
163                    success, error = self.save_map(filename, map_data)
164
165                    await websocket.send(
166                        json.dumps(
167                            {
168                                "type": "save_response",
169                                "success": success,
170                                "error": error,
171                            }
172                        )
173                    )
174
175                    if success:
176                        await self.send_map_list()
177                elif action == "start_sim":
178                    if self.current_map:
179                        self.reset_sim()
180                        self.running = True
181                        logging.info("Simulation started via frontend.")
182                        await self.update_frontend()
183                        if self.agent_ws:
184                            await self.send_agent_state()
185                    else:
186                        logging.warning("Attempted to start simulation without a map.")
187                elif action == "stop_sim":
188                    self.running = False
189                    logging.info("Simulation stopped via frontend.")
190                    await self.update_frontend()
191                elif action == "reset_sim":
192                    self.reset_sim()
193                    await self.update_frontend()
194                    if self.agent_ws and self.current_map:
195                        await self.agent_ws.send(json.dumps({"type": "reset"}))
196                        await self.send_agent_state()
197            except Exception as e:
198                logging.error(f"Error processing frontend message: {e}")
199
200    async def agent_loop(self, websocket: Any) -> None:
201        """Main loop for handling agent messages.
202
203        Args:
204            websocket (Any): The agent WebSocket connection.
205        """
206        async for message in websocket:
207            if not self.running or not self.current_map:
208                continue
209            try:
210                data = json.loads(message)
211                if data.get("action") == "move":
212                    direction = data.get("direction")
213                    self.process_move(direction)
214                    self.check_objective()
215                    await self.update_frontend()
216                    await self.send_agent_state()
217                elif data.get("action") == "telemetry":
218                    if self.frontend_ws:
219                        await self.frontend_ws.send(json.dumps({"type": "agent_telemetry", "data": data.get("data")}))
220            except Exception as e:
221                logging.error(f"Error processing agent message: {e}")
222
223    def process_move(self, direction: str) -> None:
224        """Process an agent movement request.
225
226        Args:
227            direction (str): Direction to move ('N', 'S', 'E', 'W').
228        """
229        if self.current_map is None:
230            return
231
232        x, y = self.sim_state["agent_pos"]
233        nx, ny = x, y
234
235        if direction == "N":
236            ny -= 1
237        elif direction == "S":
238            ny += 1
239        elif direction == "E":
240            nx += 1
241        elif direction == "W":
242            nx -= 1
243
244        width = self.current_map["width"]
245        height = self.current_map["height"]
246
247        if self.current_map.get("teleport", False):
248            nx = nx % width
249            ny = ny % height
250        else:
251            if nx < 0 or nx >= width or ny < 0 or ny >= height:
252                return
253
254        cell = self.current_map["grid"][ny][nx]
255        if cell == "obstacle":
256            key = f"{nx},{ny}"
257            self.sim_state["hits"][key] = self.sim_state["hits"].get(key, 0) + 1
258        else:
259            self.sim_state["agent_pos"] = [nx, ny]
260            key = f"{nx},{ny}"
261            self.sim_state["visits"][key] = self.sim_state["visits"].get(key, 0) + 1
262
263    def get_valid_actions(self) -> List[str]:
264        """Get a list of valid actions for the agent at its current position.
265
266        Returns:
267            List[str]: List of valid cardinal directions.
268        """
269        if self.current_map is None:
270            return []
271
272        x, y = self.sim_state["agent_pos"]
273        width = self.current_map["width"]
274        height = self.current_map["height"]
275        actions: List[str] = []
276
277        is_teleport: bool = self.current_map.get("teleport", False)
278
279        # Check North
280        ny = (y - 1) % height if is_teleport else y - 1
281        if (ny >= 0 or is_teleport) and self.current_map["grid"][ny][x] != "obstacle":
282            actions.append("N")
283
284        # Check South
285        ny = (y + 1) % height if is_teleport else y + 1
286        if (ny < height or is_teleport) and self.current_map["grid"][ny][x] != "obstacle":
287            actions.append("S")
288
289        # Check East
290        nx = (x + 1) % width if is_teleport else x + 1
291        if (nx < width or is_teleport) and self.current_map["grid"][y][nx] != "obstacle":
292            actions.append("E")
293
294        # Check West
295        nx = (x - 1) % width if is_teleport else x - 1
296        if (nx >= 0 or is_teleport) and self.current_map["grid"][y][nx] != "obstacle":
297            actions.append("W")
298
299        return actions
300
301    def reset_sim(self) -> None:
302        """Resets the map state and heatmaps to their initial conditions."""
303        if self.current_map:
304            start_pos = self.current_map.get("start", [0, 0])
305            self.sim_state = {
306                "agent_pos": start_pos,
307                "visits": {f"{start_pos[0]},{start_pos[1]}": 1},
308                "hits": {},
309            }
310            self.running = False
311            logging.info("Simulation reset to start state.")
312
313    def check_objective(self) -> None:
314        """Checks if the simulation objective has been reached."""
315        if self.current_map is None:
316            return
317
318        if self.current_map["type"] == "maze":
319            if self.sim_state["agent_pos"] == self.current_map.get("target"):
320                self.running = False
321                logging.info("Objective Reached: Maze target found!")
322        elif self.current_map["type"] == "room":
323            if len(self.sim_state["visits"]) >= self.reachable_tiles:
324                self.running = False
325                logging.info("Objective Reached: Room fully explored!")
326
327    async def send_agent_state(self) -> None:
328        """Sends the current simulation state to the agent."""
329        if self.agent_ws and self.current_map:
330            payload = {
331                "type": "state",
332                "position": self.sim_state["agent_pos"],
333                "valid_actions": self.get_valid_actions(),
334                "objective_reached": not self.running,
335                "target": self.current_map.get("target") if self.current_map["type"] == "maze" else None,
336                "start": self.current_map.get("start"),
337                "width": self.current_map.get("width"),
338                "height": self.current_map.get("height"),
339            }
340            await self.agent_ws.send(json.dumps(payload))
341
342    async def update_frontend(self) -> None:
343        """Sends the current simulation state to the frontend."""
344        if self.frontend_ws:
345            payload = {
346                "type": "update",
347                "map": self.current_map,
348                "state": self.sim_state,
349                "running": self.running,
350                "agent_connected": self.agent_ws is not None,
351            }
352            await self.frontend_ws.send(json.dumps(payload))
353
354    async def send_map_list(self) -> None:
355        """Sends the list of available maps to the frontend."""
356        if self.frontend_ws:
357            try:
358                maps = sorted([f for f in os.listdir(self.maps_dir) if f.endswith(".json")])
359                await self.frontend_ws.send(json.dumps({"type": "map_list", "maps": maps}))
360            except Exception as e:
361                logging.error(f"Failed to read maps directory: {e}")
362
363    def load_map(self, filename: str) -> None:
364        """Load a map from a JSON file.
365
366        Args:
367            filename (str): Name of the map file to load.
368        """
369        try:
370            filename = os.path.basename(filename)
371            filepath = os.path.join(self.maps_dir, filename)
372            with open(filepath, "r") as f:
373                self.current_map = json.load(f)
374
375            if self.current_map is None:
376                logging.error(f"Failed to load map {filename}: file is empty or invalid.")
377                return
378
379            self.reachable_tiles = self.calculate_reachable_tiles()
380            logging.info(f"Map loaded: {filename}. Reachable floor tiles: {self.reachable_tiles}")
381
382            start_pos = self.current_map.get("start", [0, 0])
383            self.sim_state = {
384                "agent_pos": start_pos,
385                "visits": {f"{start_pos[0]},{start_pos[1]}": 1},
386                "hits": {},
387            }
388            self.running = False
389            logging.info(f"Successfully loaded map: {filename}")
390        except Exception as e:
391            logging.error(f"Failed to load map {filename}: {e}")
392
393    def validate_map_data(self, data: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
394        """
395        Basic schema validation for incoming map data.
396
397        Args:
398            data (Dict[str, Any]): Map data to validate.
399
400        Returns:
401            Tuple[bool, Optional[str]]: (is_valid, error_message)
402        """
403        try:
404            required = ["width", "height", "type", "grid", "start"]
405            if not all(k in data for k in required):
406                return (
407                    False,
408                    "Missing required fields (width, height, type, grid, start).",
409                )
410
411            if not isinstance(data["grid"], list) or len(data["grid"]) != data["height"]:
412                return False, f"Grid height mismatch. Expected {data['height']} rows."
413
414            for row in data["grid"]:
415                if not isinstance(row, list) or len(row) != data["width"]:
416                    return (
417                        False,
418                        f"Grid width mismatch. Expected {data['width']} columns.",
419                    )
420
421            return True, None
422        except Exception as e:
423            return False, f"Validation error: {str(e)}"
424
425    def save_map(self, filename: str, map_data: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
426        """
427        Save map data to a JSON file.
428
429        Args:
430            filename (str): Name of the file to save.
431            map_data (Dict[str, Any]): Map data to save.
432
433        Returns:
434            Tuple[bool, Optional[str]]: (success, error_message)
435        """
436        try:
437            filename = os.path.basename(filename)
438            if not filename.endswith(".json"):
439                filename += ".json"
440
441            is_valid, error_msg = self.validate_map_data(map_data)
442            if not is_valid:
443                logging.warning(f"Rejected invalid map save request: {error_msg}")
444                return False, error_msg
445
446            filepath = os.path.join(self.maps_dir, filename)
447            with open(filepath, "w") as f:
448                json.dump(map_data, f)
449
450            logging.info(f"Successfully saved map: {filepath}")
451            return True, None
452        except PermissionError:
453            err = "Permission denied when saving. Check Docker volume permissions."
454            logging.error(err)
455            return False, err
456        except Exception as e:
457            err = f"Unexpected error saving map: {str(e)}"
458            logging.error(err)
459            return False, err

Main simulation engine. Handles map loading, agent movement, and state broadcasting.

SimulationServer()
23    def __init__(self) -> None:
24        """Initialize the SimulationServer with default states."""
25        self.frontend_ws: Optional[Any] = None
26        self.agent_ws: Optional[Any] = None
27        self.maps_dir: str = "maps"
28        self.current_map: Optional[Dict[str, Any]] = None
29        self.reachable_tiles: int = 0
30        self.sim_state: Dict[str, Any] = {}
31        self.running: bool = False
32
33        if not os.path.exists(self.maps_dir):
34            os.makedirs(self.maps_dir)
35            logging.info(f"Created maps directory at: {os.path.abspath(self.maps_dir)}")

Initialize the SimulationServer with default states.

frontend_ws: Optional[Any]
agent_ws: Optional[Any]
maps_dir: str
current_map: Optional[Dict[str, Any]]
reachable_tiles: int
sim_state: Dict[str, Any]
running: bool
def calculate_reachable_tiles(self) -> int:
37    def calculate_reachable_tiles(self) -> int:
38        r"""Uses BFS to count floor tiles reachable from the start position.
39
40        The number of reachable tiles $R$ is defined as:
41        $R = |\{ (x, y) \in \text{Grid} \mid \text{path}(\text{start}, (x, y)) \}|$
42
43        Returns:
44            int: Number of reachable floor tiles.
45        """
46        if self.current_map is None:
47            return 0
48
49        start_pos: Tuple[int, int] = tuple(self.current_map.get("start", [0, 0]))  # type: ignore
50        width: int = self.current_map["width"]
51        height: int = self.current_map["height"]
52        is_teleport: bool = self.current_map.get("teleport", False)
53
54        queue: List[Tuple[int, int]] = [start_pos]
55        visited: set[Tuple[int, int]] = {start_pos}
56        reachable_count: int = 0
57
58        while queue:
59            cx, cy = queue.pop(0)
60            if self.current_map["grid"][cy][cx] != "obstacle":
61                reachable_count += 1
62
63                # Check 4-way neighbors
64                for dx, dy in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
65                    nx, ny = cx + dx, cy + dy
66
67                    if is_teleport:
68                        nx %= width
69                        ny %= height
70                    elif nx < 0 or nx >= width or ny < 0 or ny >= height:
71                        continue
72
73                    if (nx, ny) not in visited and self.current_map["grid"][ny][nx] != "obstacle":
74                        visited.add((nx, ny))
75                        queue.append((nx, ny))
76
77        return reachable_count

Uses BFS to count floor tiles reachable from the start position.

The number of reachable tiles $R$ is defined as: $R = |{ (x, y) \in \text{Grid} \mid \text{path}(\text{start}, (x, y)) }|$

Returns: int: Number of reachable floor tiles.

async def start(self, host: str = '0.0.0.0', port: int = 8765) -> None:
79    async def start(self, host: str = "0.0.0.0", port: int = 8765) -> None:
80        """Start the WebSocket server.
81
82        Args:
83            host (str): Host address to bind to.
84            port (int): Port number to listen on.
85        """
86        import websockets
87
88        logging.info(f"Starting websocket server on ws://{host}:{port}")
89        async with websockets.serve(self.handle_client, host, port):
90            await asyncio.Future()

Start the WebSocket server.

Args: host (str): Host address to bind to. port (int): Port number to listen on.

async def handle_client(self, websocket: Any) -> None:
 92    async def handle_client(self, websocket: Any) -> None:
 93        """Handle incoming WebSocket connections.
 94
 95        Args:
 96            websocket (Any): The WebSocket connection object.
 97        """
 98        client_type: str = "Unknown"
 99        try:
100            init_msg = await websocket.recv()
101            try:
102                data = json.loads(init_msg)
103            except json.JSONDecodeError:
104                logging.warning("Received malformed initial message.")
105                return
106
107            client_type = data.get("client", "Unknown")
108
109            if client_type == "frontend":
110                if self.frontend_ws is not None:
111                    logging.warning("Frontend already connected. Rejecting new connection.")
112                    await websocket.send(json.dumps({"type": "error", "message": "Frontend already connected."}))
113                    await websocket.close()
114                    return
115                logging.info("Frontend connected.")
116                self.frontend_ws = websocket
117                await self.send_map_list()
118                await self.frontend_loop(websocket)
119            elif client_type == "agent":
120                if self.agent_ws is not None:
121                    logging.warning("Agent already connected. Rejecting new connection.")
122                    await websocket.send(json.dumps({"type": "error", "message": "Agent already connected."}))
123                    await websocket.close()
124                    return
125                logging.info("Agent connected.")
126                self.agent_ws = websocket
127                if self.running and self.current_map:
128                    await self.send_agent_state()
129                await self.agent_loop(websocket)
130            else:
131                logging.warning(f"Unknown client type attempted connection: {client_type}")
132
133        except Exception as e:
134            logging.error(f"Error handling client {client_type}: {e}")
135        finally:
136            if websocket == self.frontend_ws:
137                self.frontend_ws = None
138                logging.info("Frontend session cleared.")
139            elif websocket == self.agent_ws:
140                self.agent_ws = None
141                logging.info("Agent session cleared.")

Handle incoming WebSocket connections.

Args: websocket (Any): The WebSocket connection object.

async def frontend_loop(self, websocket: Any) -> None:
143    async def frontend_loop(self, websocket: Any) -> None:
144        """Main loop for handling frontend messages.
145
146        Args:
147            websocket (Any): The frontend WebSocket connection.
148        """
149        async for message in websocket:
150            try:
151                data = json.loads(message)
152                action = data.get("action")
153
154                if action == "load_map":
155                    self.load_map(data.get("filename"))
156                    await self.update_frontend()
157                    if self.agent_ws and self.current_map:
158                        await self.agent_ws.send(json.dumps({"type": "reset"}))
159                        await self.send_agent_state()
160                elif action == "save_map":
161                    filename = data.get("filename")
162                    map_data = data.get("map_data")
163                    success, error = self.save_map(filename, map_data)
164
165                    await websocket.send(
166                        json.dumps(
167                            {
168                                "type": "save_response",
169                                "success": success,
170                                "error": error,
171                            }
172                        )
173                    )
174
175                    if success:
176                        await self.send_map_list()
177                elif action == "start_sim":
178                    if self.current_map:
179                        self.reset_sim()
180                        self.running = True
181                        logging.info("Simulation started via frontend.")
182                        await self.update_frontend()
183                        if self.agent_ws:
184                            await self.send_agent_state()
185                    else:
186                        logging.warning("Attempted to start simulation without a map.")
187                elif action == "stop_sim":
188                    self.running = False
189                    logging.info("Simulation stopped via frontend.")
190                    await self.update_frontend()
191                elif action == "reset_sim":
192                    self.reset_sim()
193                    await self.update_frontend()
194                    if self.agent_ws and self.current_map:
195                        await self.agent_ws.send(json.dumps({"type": "reset"}))
196                        await self.send_agent_state()
197            except Exception as e:
198                logging.error(f"Error processing frontend message: {e}")

Main loop for handling frontend messages.

Args: websocket (Any): The frontend WebSocket connection.

async def agent_loop(self, websocket: Any) -> None:
200    async def agent_loop(self, websocket: Any) -> None:
201        """Main loop for handling agent messages.
202
203        Args:
204            websocket (Any): The agent WebSocket connection.
205        """
206        async for message in websocket:
207            if not self.running or not self.current_map:
208                continue
209            try:
210                data = json.loads(message)
211                if data.get("action") == "move":
212                    direction = data.get("direction")
213                    self.process_move(direction)
214                    self.check_objective()
215                    await self.update_frontend()
216                    await self.send_agent_state()
217                elif data.get("action") == "telemetry":
218                    if self.frontend_ws:
219                        await self.frontend_ws.send(json.dumps({"type": "agent_telemetry", "data": data.get("data")}))
220            except Exception as e:
221                logging.error(f"Error processing agent message: {e}")

Main loop for handling agent messages.

Args: websocket (Any): The agent WebSocket connection.

def process_move(self, direction: str) -> None:
223    def process_move(self, direction: str) -> None:
224        """Process an agent movement request.
225
226        Args:
227            direction (str): Direction to move ('N', 'S', 'E', 'W').
228        """
229        if self.current_map is None:
230            return
231
232        x, y = self.sim_state["agent_pos"]
233        nx, ny = x, y
234
235        if direction == "N":
236            ny -= 1
237        elif direction == "S":
238            ny += 1
239        elif direction == "E":
240            nx += 1
241        elif direction == "W":
242            nx -= 1
243
244        width = self.current_map["width"]
245        height = self.current_map["height"]
246
247        if self.current_map.get("teleport", False):
248            nx = nx % width
249            ny = ny % height
250        else:
251            if nx < 0 or nx >= width or ny < 0 or ny >= height:
252                return
253
254        cell = self.current_map["grid"][ny][nx]
255        if cell == "obstacle":
256            key = f"{nx},{ny}"
257            self.sim_state["hits"][key] = self.sim_state["hits"].get(key, 0) + 1
258        else:
259            self.sim_state["agent_pos"] = [nx, ny]
260            key = f"{nx},{ny}"
261            self.sim_state["visits"][key] = self.sim_state["visits"].get(key, 0) + 1

Process an agent movement request.

Args: direction (str): Direction to move ('N', 'S', 'E', 'W').

def get_valid_actions(self) -> List[str]:
263    def get_valid_actions(self) -> List[str]:
264        """Get a list of valid actions for the agent at its current position.
265
266        Returns:
267            List[str]: List of valid cardinal directions.
268        """
269        if self.current_map is None:
270            return []
271
272        x, y = self.sim_state["agent_pos"]
273        width = self.current_map["width"]
274        height = self.current_map["height"]
275        actions: List[str] = []
276
277        is_teleport: bool = self.current_map.get("teleport", False)
278
279        # Check North
280        ny = (y - 1) % height if is_teleport else y - 1
281        if (ny >= 0 or is_teleport) and self.current_map["grid"][ny][x] != "obstacle":
282            actions.append("N")
283
284        # Check South
285        ny = (y + 1) % height if is_teleport else y + 1
286        if (ny < height or is_teleport) and self.current_map["grid"][ny][x] != "obstacle":
287            actions.append("S")
288
289        # Check East
290        nx = (x + 1) % width if is_teleport else x + 1
291        if (nx < width or is_teleport) and self.current_map["grid"][y][nx] != "obstacle":
292            actions.append("E")
293
294        # Check West
295        nx = (x - 1) % width if is_teleport else x - 1
296        if (nx >= 0 or is_teleport) and self.current_map["grid"][y][nx] != "obstacle":
297            actions.append("W")
298
299        return actions

Get a list of valid actions for the agent at its current position.

Returns: List[str]: List of valid cardinal directions.

def reset_sim(self) -> None:
301    def reset_sim(self) -> None:
302        """Resets the map state and heatmaps to their initial conditions."""
303        if self.current_map:
304            start_pos = self.current_map.get("start", [0, 0])
305            self.sim_state = {
306                "agent_pos": start_pos,
307                "visits": {f"{start_pos[0]},{start_pos[1]}": 1},
308                "hits": {},
309            }
310            self.running = False
311            logging.info("Simulation reset to start state.")

Resets the map state and heatmaps to their initial conditions.

def check_objective(self) -> None:
313    def check_objective(self) -> None:
314        """Checks if the simulation objective has been reached."""
315        if self.current_map is None:
316            return
317
318        if self.current_map["type"] == "maze":
319            if self.sim_state["agent_pos"] == self.current_map.get("target"):
320                self.running = False
321                logging.info("Objective Reached: Maze target found!")
322        elif self.current_map["type"] == "room":
323            if len(self.sim_state["visits"]) >= self.reachable_tiles:
324                self.running = False
325                logging.info("Objective Reached: Room fully explored!")

Checks if the simulation objective has been reached.

async def send_agent_state(self) -> None:
327    async def send_agent_state(self) -> None:
328        """Sends the current simulation state to the agent."""
329        if self.agent_ws and self.current_map:
330            payload = {
331                "type": "state",
332                "position": self.sim_state["agent_pos"],
333                "valid_actions": self.get_valid_actions(),
334                "objective_reached": not self.running,
335                "target": self.current_map.get("target") if self.current_map["type"] == "maze" else None,
336                "start": self.current_map.get("start"),
337                "width": self.current_map.get("width"),
338                "height": self.current_map.get("height"),
339            }
340            await self.agent_ws.send(json.dumps(payload))

Sends the current simulation state to the agent.

async def update_frontend(self) -> None:
342    async def update_frontend(self) -> None:
343        """Sends the current simulation state to the frontend."""
344        if self.frontend_ws:
345            payload = {
346                "type": "update",
347                "map": self.current_map,
348                "state": self.sim_state,
349                "running": self.running,
350                "agent_connected": self.agent_ws is not None,
351            }
352            await self.frontend_ws.send(json.dumps(payload))

Sends the current simulation state to the frontend.

async def send_map_list(self) -> None:
354    async def send_map_list(self) -> None:
355        """Sends the list of available maps to the frontend."""
356        if self.frontend_ws:
357            try:
358                maps = sorted([f for f in os.listdir(self.maps_dir) if f.endswith(".json")])
359                await self.frontend_ws.send(json.dumps({"type": "map_list", "maps": maps}))
360            except Exception as e:
361                logging.error(f"Failed to read maps directory: {e}")

Sends the list of available maps to the frontend.

def load_map(self, filename: str) -> None:
363    def load_map(self, filename: str) -> None:
364        """Load a map from a JSON file.
365
366        Args:
367            filename (str): Name of the map file to load.
368        """
369        try:
370            filename = os.path.basename(filename)
371            filepath = os.path.join(self.maps_dir, filename)
372            with open(filepath, "r") as f:
373                self.current_map = json.load(f)
374
375            if self.current_map is None:
376                logging.error(f"Failed to load map {filename}: file is empty or invalid.")
377                return
378
379            self.reachable_tiles = self.calculate_reachable_tiles()
380            logging.info(f"Map loaded: {filename}. Reachable floor tiles: {self.reachable_tiles}")
381
382            start_pos = self.current_map.get("start", [0, 0])
383            self.sim_state = {
384                "agent_pos": start_pos,
385                "visits": {f"{start_pos[0]},{start_pos[1]}": 1},
386                "hits": {},
387            }
388            self.running = False
389            logging.info(f"Successfully loaded map: {filename}")
390        except Exception as e:
391            logging.error(f"Failed to load map {filename}: {e}")

Load a map from a JSON file.

Args: filename (str): Name of the map file to load.

def validate_map_data(self, data: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
393    def validate_map_data(self, data: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
394        """
395        Basic schema validation for incoming map data.
396
397        Args:
398            data (Dict[str, Any]): Map data to validate.
399
400        Returns:
401            Tuple[bool, Optional[str]]: (is_valid, error_message)
402        """
403        try:
404            required = ["width", "height", "type", "grid", "start"]
405            if not all(k in data for k in required):
406                return (
407                    False,
408                    "Missing required fields (width, height, type, grid, start).",
409                )
410
411            if not isinstance(data["grid"], list) or len(data["grid"]) != data["height"]:
412                return False, f"Grid height mismatch. Expected {data['height']} rows."
413
414            for row in data["grid"]:
415                if not isinstance(row, list) or len(row) != data["width"]:
416                    return (
417                        False,
418                        f"Grid width mismatch. Expected {data['width']} columns.",
419                    )
420
421            return True, None
422        except Exception as e:
423            return False, f"Validation error: {str(e)}"

Basic schema validation for incoming map data.

Args: data (Dict[str, Any]): Map data to validate.

Returns: Tuple[bool, Optional[str]]: (is_valid, error_message)

def save_map( self, filename: str, map_data: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
425    def save_map(self, filename: str, map_data: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
426        """
427        Save map data to a JSON file.
428
429        Args:
430            filename (str): Name of the file to save.
431            map_data (Dict[str, Any]): Map data to save.
432
433        Returns:
434            Tuple[bool, Optional[str]]: (success, error_message)
435        """
436        try:
437            filename = os.path.basename(filename)
438            if not filename.endswith(".json"):
439                filename += ".json"
440
441            is_valid, error_msg = self.validate_map_data(map_data)
442            if not is_valid:
443                logging.warning(f"Rejected invalid map save request: {error_msg}")
444                return False, error_msg
445
446            filepath = os.path.join(self.maps_dir, filename)
447            with open(filepath, "w") as f:
448                json.dump(map_data, f)
449
450            logging.info(f"Successfully saved map: {filepath}")
451            return True, None
452        except PermissionError:
453            err = "Permission denied when saving. Check Docker volume permissions."
454            logging.error(err)
455            return False, err
456        except Exception as e:
457            err = f"Unexpected error saving map: {str(e)}"
458            logging.error(err)
459            return False, err

Save map data to a JSON file.

Args: filename (str): Name of the file to save. map_data (Dict[str, Any]): Map data to save.

Returns: Tuple[bool, Optional[str]]: (success, error_message)