From 64609148e74ead73b572c1ce77ba5771cc768ba5 Mon Sep 17 00:00:00 2001 From: intelava Date: Wed, 7 Jan 2026 19:11:38 +0100 Subject: [PATCH 1/2] initial commit --- src/jaxatari/core.py | 2 + src/jaxatari/games/jax_jamesbond.py | 404 ++++++++++++++++++++++++++++ 2 files changed, 406 insertions(+) create mode 100644 src/jaxatari/games/jax_jamesbond.py diff --git a/src/jaxatari/core.py b/src/jaxatari/core.py index 8be1e3aea..0a55a21a0 100644 --- a/src/jaxatari/core.py +++ b/src/jaxatari/core.py @@ -50,6 +50,8 @@ "videocube": "jaxatari.games.jax_videocube", "videopinball": "jaxatari.games.jax_videopinball", "wordzapper": "jaxatari.games.jax_wordzapper", + "Jamesbond": "jaxatari.games.jax_jamesbond", + "jamesbond": "jaxatari.games.jax_jamesbond", # Add new games here } diff --git a/src/jaxatari/games/jax_jamesbond.py b/src/jaxatari/games/jax_jamesbond.py new file mode 100644 index 000000000..4f2422d3a --- /dev/null +++ b/src/jaxatari/games/jax_jamesbond.py @@ -0,0 +1,404 @@ +from __future__ import annotations + +from functools import partial +from typing import NamedTuple, Tuple + +import chex +import jax +import jax.numpy as jnp + +import jaxatari.spaces as spaces +from jaxatari.environment import JaxEnvironment, JAXAtariAction as Action +from jaxatari.renderers import JAXGameRenderer +from jaxatari.rendering import jax_rendering_utils as render_utils + + +class JamesBondConstants(NamedTuple): + WIDTH: int = 160 + HEIGHT: int = 210 + + ROAD_LEFT: int = 40 + ROAD_RIGHT: int = 120 + ROAD_EDGE_LINE_W: int = 2 + + PLAYER_Y: int = 170 + PLAYER_SIZE: Tuple[int, int] = (8, 12) + PLAYER_STEP_X: int = 4 + + MAX_ENEMIES: int = 10 + ENEMY_SIZE: Tuple[int, int] = (8, 12) + ENEMY_SPEED_Y: int = 2 + + NUM_LANES: int = 5 + SPAWN_PROB: float = 0.12 + SPAWN_Y: int = 0 + + SCROLL_SPEED: int = 2 + DASH_W: int = 2 + DASH_H: int = 8 + DASH_SPACING: int = 22 + + SURVIVE_REWARD: float = 0.01 + PASS_REWARD: float = 0.10 + + GRASS_COLOR: Tuple[int, int, int] = (34, 139, 34) + ROAD_COLOR: Tuple[int, int, int] = (80, 80, 80) + LINE_COLOR: Tuple[int, int, int] = (236, 236, 236) + PLAYER_COLOR: Tuple[int, int, int] = (236, 236, 0) + ENEMY_COLOR: Tuple[int, int, int] = (200, 0, 0) + + +class EntityPosition(NamedTuple): + x: jnp.ndarray + y: jnp.ndarray + width: jnp.ndarray + height: jnp.ndarray + + +class EntityPositionBatch(NamedTuple): + x: jnp.ndarray + y: jnp.ndarray + width: jnp.ndarray + height: jnp.ndarray + + +class JamesBondObservation(NamedTuple): + player: EntityPosition + enemies: EntityPositionBatch + + +class JamesBondInfo(NamedTuple): + time: jnp.ndarray + passed: jnp.ndarray + collisions: jnp.ndarray + + +class JamesBondState(NamedTuple): + player_x: chex.Array + enemy_x: chex.Array + enemy_y: chex.Array + enemy_active: chex.Array + scroll_offset: chex.Array + step_counter: chex.Array + passed_counter: chex.Array + collided: chex.Array + key: chex.PRNGKey + + +def _rects_overlap(px: jnp.ndarray, py: jnp.ndarray, pw: int, ph: int, + ex: jnp.ndarray, ey: jnp.ndarray, ew: int, eh: int) -> jnp.ndarray: + overlap_x = (px < (ex + ew)) & ((px + pw) > ex) + overlap_y = (py < (ey + eh)) & ((py + ph) > ey) + return overlap_x & overlap_y + + +class JaxJamesbond(JaxEnvironment[JamesBondState, JamesBondObservation, JamesBondInfo, JamesBondConstants]): + def __init__(self, consts: JamesBondConstants | None = None): + consts = consts or JamesBondConstants() + super().__init__(consts) + self.renderer = JamesBondRenderer(self.consts) + self.action_set = [ + Action.NOOP, + Action.FIRE, + Action.RIGHT, + Action.LEFT, + Action.RIGHTFIRE, + Action.LEFTFIRE, + ] + + def _player_step(self, state: JamesBondState, action: chex.Array) -> JamesBondState: + left = jnp.logical_or(action == Action.LEFT, action == Action.LEFTFIRE) + right = jnp.logical_or(action == Action.RIGHT, action == Action.RIGHTFIRE) + + dx = jnp.where(left, -self.consts.PLAYER_STEP_X, 0) + dx = jnp.where(right, self.consts.PLAYER_STEP_X, dx) + + min_x = jnp.array(self.consts.ROAD_LEFT, dtype=jnp.int32) + max_x = jnp.array(self.consts.ROAD_RIGHT - self.consts.PLAYER_SIZE[0], dtype=jnp.int32) + new_player_x = jnp.clip(state.player_x + dx.astype(jnp.int32), min_x, max_x).astype(jnp.int32) + + return state._replace(player_x=new_player_x) + + def _enemy_step_and_spawn(self, state: JamesBondState) -> Tuple[JamesBondState, jnp.ndarray]: + active_i32 = state.enemy_active.astype(jnp.int32) + new_enemy_y = (state.enemy_y + active_i32 * jnp.array(self.consts.ENEMY_SPEED_Y, dtype=jnp.int32)).astype(jnp.int32) + + passed = state.enemy_active & (new_enemy_y > jnp.array(self.consts.HEIGHT, dtype=jnp.int32)) + passed_count = passed.astype(jnp.int32).sum() + + enemy_active = state.enemy_active & (~passed) + enemy_y = jnp.where(enemy_active, new_enemy_y, jnp.array(0, dtype=jnp.int32)).astype(jnp.int32) + enemy_x = jnp.where(enemy_active, state.enemy_x, jnp.array(0, dtype=jnp.int32)).astype(jnp.int32) + + spawn_key, lane_key = jax.random.split(state.key) + roll = jax.random.uniform(spawn_key, ()) + should_spawn = roll < jnp.array(self.consts.SPAWN_PROB) + + free_mask = ~enemy_active + any_free = jnp.any(free_mask) + spawn_idx = jnp.argmax(free_mask.astype(jnp.int32)) + + road_w = self.consts.ROAD_RIGHT - self.consts.ROAD_LEFT + lane_w = jnp.array(road_w // self.consts.NUM_LANES, dtype=jnp.int32) + lane_margin = jnp.array((road_w - lane_w * self.consts.NUM_LANES) // 2, dtype=jnp.int32) + lane_centers = ( + jnp.arange(self.consts.NUM_LANES, dtype=jnp.int32) * lane_w + + jnp.array(self.consts.ROAD_LEFT, dtype=jnp.int32) + + lane_margin + + (lane_w // 2) + ) + lane_idx = jax.random.randint(lane_key, (), 0, self.consts.NUM_LANES, dtype=jnp.int32) + spawn_x = (lane_centers[lane_idx] - (self.consts.ENEMY_SIZE[0] // 2)).astype(jnp.int32) + spawn_y = jnp.array(self.consts.SPAWN_Y, dtype=jnp.int32) + + def do_spawn(carry): + ex, ey, ea = carry + ex = ex.at[spawn_idx].set(spawn_x) + ey = ey.at[spawn_idx].set(spawn_y) + ea = ea.at[spawn_idx].set(True) + return ex, ey, ea + + enemy_x, enemy_y, enemy_active = jax.lax.cond( + should_spawn & any_free, + do_spawn, + lambda carry: carry, + operand=(enemy_x, enemy_y, enemy_active), + ) + + new_scroll = (state.scroll_offset + jnp.array(self.consts.SCROLL_SPEED, dtype=jnp.int32)) % jnp.array( + self.consts.DASH_SPACING, dtype=jnp.int32 + ) + + new_state = state._replace( + enemy_x=enemy_x.astype(jnp.int32), + enemy_y=enemy_y.astype(jnp.int32), + enemy_active=enemy_active, + scroll_offset=new_scroll.astype(jnp.int32), + ) + return new_state, passed_count + + @partial(jax.jit, static_argnums=(0,)) + def step( + self, state: JamesBondState, action: chex.Array + ) -> Tuple[JamesBondObservation, JamesBondState, float, bool, JamesBondInfo]: + new_state_key, step_key = jax.random.split(state.key) + + state = state._replace(key=step_key) + + prev_state = state + + state = self._player_step(state, action) + state, passed_count = self._enemy_step_and_spawn(state) + + px = state.player_x + py = jnp.array(self.consts.PLAYER_Y, dtype=jnp.int32) + collisions = _rects_overlap( + px, py, + self.consts.PLAYER_SIZE[0], self.consts.PLAYER_SIZE[1], + state.enemy_x, state.enemy_y, + self.consts.ENEMY_SIZE[0], self.consts.ENEMY_SIZE[1], + ) + collided = jnp.any(state.enemy_active & collisions) + + state = state._replace( + step_counter=(state.step_counter + jnp.array(1, dtype=jnp.int32)).astype(jnp.int32), + passed_counter=(state.passed_counter + passed_count).astype(jnp.int32), + collided=collided, + key=new_state_key, + ) + + done = self._get_done(state) + env_reward = self._get_reward(prev_state, state) + info = self._get_info(state) + obs = self._get_observation(state) + return obs, state, env_reward, done, info + + def reset(self, key: chex.PRNGKey = jax.random.PRNGKey(42)) -> Tuple[JamesBondObservation, JamesBondState]: + state_key, _ = jax.random.split(key) + + start_x = jnp.array((self.consts.ROAD_LEFT + self.consts.ROAD_RIGHT) // 2, dtype=jnp.int32) + start_x = (start_x - (self.consts.PLAYER_SIZE[0] // 2)).astype(jnp.int32) + + n = self.consts.MAX_ENEMIES + state = JamesBondState( + player_x=start_x, + enemy_x=jnp.zeros((n,), dtype=jnp.int32), + enemy_y=jnp.zeros((n,), dtype=jnp.int32), + enemy_active=jnp.zeros((n,), dtype=bool), + scroll_offset=jnp.array(0, dtype=jnp.int32), + step_counter=jnp.array(0, dtype=jnp.int32), + passed_counter=jnp.array(0, dtype=jnp.int32), + collided=jnp.array(False), + key=state_key, + ) + return self._get_observation(state), state + + def render(self, state: JamesBondState) -> jnp.ndarray: + return self.renderer.render(state) + + def _get_observation(self, state: JamesBondState) -> JamesBondObservation: + player = EntityPosition( + x=state.player_x.astype(jnp.int32), + y=jnp.array(self.consts.PLAYER_Y, dtype=jnp.int32), + width=jnp.array(self.consts.PLAYER_SIZE[0], dtype=jnp.int32), + height=jnp.array(self.consts.PLAYER_SIZE[1], dtype=jnp.int32), + ) + n = self.consts.MAX_ENEMIES + enemies = EntityPositionBatch( + x=state.enemy_x.astype(jnp.int32).reshape((n,)), + y=state.enemy_y.astype(jnp.int32).reshape((n,)), + width=jnp.full((n,), self.consts.ENEMY_SIZE[0], dtype=jnp.int32), + height=jnp.full((n,), self.consts.ENEMY_SIZE[1], dtype=jnp.int32), + ) + return JamesBondObservation(player=player, enemies=enemies) + + @partial(jax.jit, static_argnums=(0,)) + def obs_to_flat_array(self, obs: JamesBondObservation) -> jnp.ndarray: + return jnp.concatenate( + [ + obs.player.x.flatten(), + obs.player.y.flatten(), + obs.player.width.flatten(), + obs.player.height.flatten(), + obs.enemies.x.flatten(), + obs.enemies.y.flatten(), + obs.enemies.width.flatten(), + obs.enemies.height.flatten(), + ], + axis=0, + ) + + def action_space(self) -> spaces.Discrete: + return spaces.Discrete(6) + + def observation_space(self) -> spaces.Space: + n = self.consts.MAX_ENEMIES + return spaces.Dict( + { + "player": spaces.Dict( + { + "x": spaces.Box(low=0, high=self.consts.WIDTH, shape=(), dtype=jnp.int32), + "y": spaces.Box(low=0, high=self.consts.HEIGHT, shape=(), dtype=jnp.int32), + "width": spaces.Box(low=0, high=self.consts.WIDTH, shape=(), dtype=jnp.int32), + "height": spaces.Box(low=0, high=self.consts.HEIGHT, shape=(), dtype=jnp.int32), + } + ), + "enemies": spaces.Dict( + { + "x": spaces.Box(low=0, high=self.consts.WIDTH, shape=(n,), dtype=jnp.int32), + "y": spaces.Box(low=0, high=self.consts.HEIGHT, shape=(n,), dtype=jnp.int32), + "width": spaces.Box(low=0, high=self.consts.WIDTH, shape=(n,), dtype=jnp.int32), + "height": spaces.Box(low=0, high=self.consts.HEIGHT, shape=(n,), dtype=jnp.int32), + } + ), + } + ) + + def image_space(self) -> spaces.Box: + return spaces.Box(low=0, high=255, shape=(self.consts.HEIGHT, self.consts.WIDTH, 3), dtype=jnp.uint8) + + @partial(jax.jit, static_argnums=(0,)) + def _get_info(self, state: JamesBondState) -> JamesBondInfo: + return JamesBondInfo( + time=state.step_counter, + passed=state.passed_counter, + collisions=state.collided.astype(jnp.int32), + ) + + @partial(jax.jit, static_argnums=(0,)) + def _get_reward(self, previous_state: JamesBondState, state: JamesBondState) -> float: + passed_delta = (state.passed_counter - previous_state.passed_counter).astype(jnp.int32) + pass_reward = passed_delta.astype(jnp.float32) * jnp.array(self.consts.PASS_REWARD, dtype=jnp.float32) + survive_reward = jnp.array(self.consts.SURVIVE_REWARD, dtype=jnp.float32) * (~state.collided) + return (pass_reward + survive_reward).astype(jnp.float32) + + @partial(jax.jit, static_argnums=(0,)) + def _get_done(self, state: JamesBondState) -> bool: + return state.collided + + +class JamesBondRenderer(JAXGameRenderer): + _ID_GRASS = 0 + _ID_ROAD = 1 + _ID_LINE = 2 + _ID_PLAYER = 3 + _ID_ENEMY = 4 + + def __init__(self, consts: JamesBondConstants | None = None): + super().__init__(consts) + self.consts = consts or JamesBondConstants() + + self.config = render_utils.RendererConfig(game_dimensions=(self.consts.HEIGHT, self.consts.WIDTH), channels=3) + self.jr = render_utils.JaxRenderingUtils(self.config) + + self.PALETTE = jnp.array( + [ + self.consts.GRASS_COLOR, + self.consts.ROAD_COLOR, + self.consts.LINE_COLOR, + self.consts.PLAYER_COLOR, + self.consts.ENEMY_COLOR, + ], + dtype=jnp.uint8, + ) + + self.BACKGROUND = jnp.full( + (self.consts.HEIGHT, self.consts.WIDTH), + jnp.array(self._ID_GRASS, dtype=jnp.uint8), + dtype=jnp.uint8, + ) + + @partial(jax.jit, static_argnums=(0,)) + def render(self, state: JamesBondState) -> jnp.ndarray: + raster = self.jr.create_object_raster(self.BACKGROUND) + + road_pos = jnp.array([[self.consts.ROAD_LEFT, 0]], dtype=jnp.int32) + road_size = jnp.array([[self.consts.ROAD_RIGHT - self.consts.ROAD_LEFT, self.consts.HEIGHT]], dtype=jnp.int32) + raster = self.jr.draw_rects(raster, road_pos, road_size, self._ID_ROAD) + + edge_pos = jnp.array( + [ + [self.consts.ROAD_LEFT - self.consts.ROAD_EDGE_LINE_W, 0], + [self.consts.ROAD_RIGHT, 0], + ], + dtype=jnp.int32, + ) + edge_size = jnp.array( + [ + [self.consts.ROAD_EDGE_LINE_W, self.consts.HEIGHT], + [self.consts.ROAD_EDGE_LINE_W, self.consts.HEIGHT], + ], + dtype=jnp.int32, + ) + raster = self.jr.draw_rects(raster, edge_pos, edge_size, self._ID_LINE) + + num_dashes = (self.consts.HEIGHT // self.consts.DASH_SPACING) + 3 + dash_x = jnp.array((self.consts.ROAD_LEFT + self.consts.ROAD_RIGHT) // 2 - (self.consts.DASH_W // 2), dtype=jnp.int32) + ys = ( + (jnp.arange(num_dashes, dtype=jnp.int32) * jnp.array(self.consts.DASH_SPACING, dtype=jnp.int32) + state.scroll_offset) + % jnp.array(self.consts.HEIGHT + self.consts.DASH_SPACING, dtype=jnp.int32) + ) - jnp.array(self.consts.DASH_H, dtype=jnp.int32) + visible = (ys >= 0) & (ys < jnp.array(self.consts.HEIGHT, dtype=jnp.int32)) + dash_xs = jnp.where(visible, dash_x, jnp.array(-1, dtype=jnp.int32)) + dash_pos = jnp.stack([dash_xs, ys], axis=1).astype(jnp.int32) + dash_size = jnp.tile( + jnp.array([[self.consts.DASH_W, self.consts.DASH_H]], dtype=jnp.int32), + (num_dashes, 1), + ) + raster = self.jr.draw_rects(raster, dash_pos, dash_size, self._ID_LINE) + + player_pos = jnp.array([[state.player_x, self.consts.PLAYER_Y]], dtype=jnp.int32) + player_size = jnp.array([[self.consts.PLAYER_SIZE[0], self.consts.PLAYER_SIZE[1]]], dtype=jnp.int32) + raster = self.jr.draw_rects(raster, player_pos, player_size, self._ID_PLAYER) + + enemy_x = jnp.where(state.enemy_active, state.enemy_x, jnp.array(-1, dtype=jnp.int32)) + enemy_pos = jnp.stack([enemy_x, state.enemy_y], axis=1).astype(jnp.int32) + enemy_size = jnp.tile( + jnp.array([[self.consts.ENEMY_SIZE[0], self.consts.ENEMY_SIZE[1]]], dtype=jnp.int32), + (self.consts.MAX_ENEMIES, 1), + ) + raster = self.jr.draw_rects(raster, enemy_pos, enemy_size, self._ID_ENEMY) + + return self.jr.render_from_palette(raster, self.PALETTE) + + From dbeb0d83aad17eca6aae98b64499cadbe89b5266 Mon Sep 17 00:00:00 2001 From: Atak Date: Tue, 31 Mar 2026 21:04:00 +0200 Subject: [PATCH 2/2] full implementation of jamesbond --- src/jaxatari/games/jax_jamesbond.py | 1024 +++++++++++++++++++-------- 1 file changed, 740 insertions(+), 284 deletions(-) diff --git a/src/jaxatari/games/jax_jamesbond.py b/src/jaxatari/games/jax_jamesbond.py index 4f2422d3a..7ae9f3b77 100644 --- a/src/jaxatari/games/jax_jamesbond.py +++ b/src/jaxatari/games/jax_jamesbond.py @@ -1,3 +1,19 @@ +"""JAXAtari implementation of James Bond 007 (Atari 2600 / ALE). + +The ALE version of James Bond 007 is a horizontal side-scroller in which the +player's vehicle (a car/boat/submarine hybrid loosely inspired by Bond's Lotus +Esprit) appears on the left side of the screen. The world scrolls from right +to left and enemies come from the right in three distinct vertical zones: + + * Sky zone (top) — helicopter/jet enemies + * Sea zone (middle) — speedboat enemies on the water surface + * Sub zone (bottom) — submarine enemies underwater + +The player can move up and down to change zones and fires bullets to the right. +Enemy vehicles also fire back. The episode ends when the player is hit and all +lives are lost. +""" + from __future__ import annotations from functools import partial @@ -13,392 +29,832 @@ from jaxatari.rendering import jax_rendering_utils as render_utils -class JamesBondConstants(NamedTuple): - WIDTH: int = 160 - HEIGHT: int = 210 - - ROAD_LEFT: int = 40 - ROAD_RIGHT: int = 120 - ROAD_EDGE_LINE_W: int = 2 - - PLAYER_Y: int = 170 - PLAYER_SIZE: Tuple[int, int] = (8, 12) - PLAYER_STEP_X: int = 4 +# --------------------------------------------------------------------------- +# Enemy type constants (integer tags stored in state arrays) +# --------------------------------------------------------------------------- +ENEMY_JET = 0 # sky zone +ENEMY_BOAT = 1 # sea zone +ENEMY_SUB = 2 # underwater zone - MAX_ENEMIES: int = 10 - ENEMY_SIZE: Tuple[int, int] = (8, 12) - ENEMY_SPEED_Y: int = 2 - NUM_LANES: int = 5 - SPAWN_PROB: float = 0.12 - SPAWN_Y: int = 0 +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- - SCROLL_SPEED: int = 2 - DASH_W: int = 2 - DASH_H: int = 8 - DASH_SPACING: int = 22 +class JamesBondConstants(NamedTuple): + # Screen + WIDTH: int = 160 + HEIGHT: int = 210 - SURVIVE_REWARD: float = 0.01 - PASS_REWARD: float = 0.10 + # HUD / status bars + HUD_TOP_H: int = 16 # height of top score bar + HUD_BOTTOM_H: int = 18 # height of bottom lives bar + + # Zone y-centres (vertical centre of each gameplay zone) + SKY_Y: int = 48 + SEA_Y: int = 100 + SUB_Y: int = 155 + + # Zone boundary lines (used for background drawing) + HORIZON_Y: int = 76 # top of sea surface (sky / sea boundary) + SEABED_Y: int = 130 # top of underwater zone (sea / sub boundary) + + # Player + PLAYER_START_X: int = 20 + PLAYER_W: int = 16 + PLAYER_H: int = 8 + PLAYER_SPEED_Y: int = 3 # pixels per frame (up / down) + PLAYER_SPEED_X: int = 2 # small horizontal drift + PLAYER_MIN_X: int = 0 + PLAYER_MAX_X: int = 36 # player stays on left side + + # Player y bounds (top of sprite must stay inside playfield) + PLAYER_MIN_Y: int = 16 # just below HUD + PLAYER_MAX_Y: int = 182 # just above bottom HUD + + # Player bullets + MAX_BULLETS: int = 3 + BULLET_W: int = 4 + BULLET_H: int = 2 + BULLET_SPEED: int = 6 + + # Enemy bullets + MAX_E_BULLETS: int = 4 + E_BULLET_W: int = 2 + E_BULLET_H: int = 2 + E_BULLET_SPEED: int = 2 + ENEMY_FIRE_PROB: float = 0.015 + + # Enemies + MAX_ENEMIES: int = 6 + ENEMY_W: int = 16 + ENEMY_H: int = 8 + ENEMY_SPEED: int = 1 + ENEMY_SPAWN_INTERVAL: int = 45 # frames between enemy spawns + ENEMY_SPAWN_X: int = 152 # x at which enemies first appear + + # Scroll + SCROLL_SPEED: int = 1 # pixels per frame + BG_STRIPE_W: int = 8 # width of alternating cloud / wave stripe + + # Lives + NUM_LIVES: int = 3 + + # Rewards + SCORE_PER_KILL: int = 100 + + + +# --------------------------------------------------------------------------- +# State / Observation / Info +# --------------------------------------------------------------------------- - GRASS_COLOR: Tuple[int, int, int] = (34, 139, 34) - ROAD_COLOR: Tuple[int, int, int] = (80, 80, 80) - LINE_COLOR: Tuple[int, int, int] = (236, 236, 236) - PLAYER_COLOR: Tuple[int, int, int] = (236, 236, 0) - ENEMY_COLOR: Tuple[int, int, int] = (200, 0, 0) +class JamesBondState(NamedTuple): + player_x: chex.Array # scalar int32 + player_y: chex.Array # scalar int32 + # Player bullets — shape (MAX_BULLETS,) + bullet_x: chex.Array + bullet_y: chex.Array + bullet_active: chex.Array # bool -class EntityPosition(NamedTuple): - x: jnp.ndarray - y: jnp.ndarray - width: jnp.ndarray - height: jnp.ndarray + # Enemies — shape (MAX_ENEMIES,) + enemy_x: chex.Array + enemy_y: chex.Array + enemy_type: chex.Array # 0=jet, 1=boat, 2=sub + enemy_active: chex.Array # bool + # Enemy bullets — shape (MAX_E_BULLETS,) + e_bullet_x: chex.Array + e_bullet_y: chex.Array + e_bullet_active: chex.Array # bool -class EntityPositionBatch(NamedTuple): - x: jnp.ndarray - y: jnp.ndarray - width: jnp.ndarray - height: jnp.ndarray + scroll_offset: chex.Array # for background animation + spawn_timer: chex.Array # counts down to next enemy spawn + score: chex.Array + lives: chex.Array + step_count: chex.Array + key: chex.PRNGKey class JamesBondObservation(NamedTuple): - player: EntityPosition - enemies: EntityPositionBatch + player_x: chex.Array + player_y: chex.Array + enemy_x: chex.Array + enemy_y: chex.Array + enemy_type: chex.Array + enemy_active: chex.Array + bullet_x: chex.Array + bullet_y: chex.Array + bullet_active: chex.Array + score: chex.Array + lives: chex.Array class JamesBondInfo(NamedTuple): - time: jnp.ndarray - passed: jnp.ndarray - collisions: jnp.ndarray + score: chex.Array + lives: chex.Array + step_count: chex.Array -class JamesBondState(NamedTuple): - player_x: chex.Array - enemy_x: chex.Array - enemy_y: chex.Array - enemy_active: chex.Array - scroll_offset: chex.Array - step_counter: chex.Array - passed_counter: chex.Array - collided: chex.Array - key: chex.PRNGKey +# --------------------------------------------------------------------------- +# Helper: axis-aligned rectangle overlap test (vectorised over one axis) +# --------------------------------------------------------------------------- +def _rect_overlap(ax, ay, aw, ah, bx, by, bw, bh): + """Returns True if rectangle A overlaps rectangle B (broadcast-safe).""" + return (ax < bx + bw) & (ax + aw > bx) & (ay < by + bh) & (ay + ah > by) -def _rects_overlap(px: jnp.ndarray, py: jnp.ndarray, pw: int, ph: int, - ex: jnp.ndarray, ey: jnp.ndarray, ew: int, eh: int) -> jnp.ndarray: - overlap_x = (px < (ex + ew)) & ((px + pw) > ex) - overlap_y = (py < (ey + eh)) & ((py + ph) > ey) - return overlap_x & overlap_y +# --------------------------------------------------------------------------- +# Main environment +# --------------------------------------------------------------------------- class JaxJamesbond(JaxEnvironment[JamesBondState, JamesBondObservation, JamesBondInfo, JamesBondConstants]): + """ + JAX implementation of the ALE James Bond 007 game. + + Gameplay (matching the Atari 2600 / ALE version): + - Horizontal side-scroller, world moves right → left. + - Player vehicle on the left; can move up/down (and slightly left/right). + - Three vertical zones: sky (jets), sea surface (boats), underwater (subs). + - Player fires bullets to the right. + - Enemy vehicles fire back towards the left. + - Destroying an enemy scores SCORE_PER_KILL points. + - Being hit by an enemy or enemy bullet costs a life. + - Episode ends when all lives are gone. + """ + def __init__(self, consts: JamesBondConstants | None = None): consts = consts or JamesBondConstants() super().__init__(consts) self.renderer = JamesBondRenderer(self.consts) + # Full 18-action set, matching ALE JamesBond-v5 self.action_set = [ Action.NOOP, Action.FIRE, + Action.UP, Action.RIGHT, Action.LEFT, + Action.DOWN, + Action.UPRIGHT, + Action.UPLEFT, + Action.DOWNRIGHT, + Action.DOWNLEFT, + Action.UPFIRE, Action.RIGHTFIRE, Action.LEFTFIRE, + Action.DOWNFIRE, + Action.UPRIGHTFIRE, + Action.UPLEFTFIRE, + Action.DOWNRIGHTFIRE, + Action.DOWNLEFTFIRE, ] - def _player_step(self, state: JamesBondState, action: chex.Array) -> JamesBondState: - left = jnp.logical_or(action == Action.LEFT, action == Action.LEFTFIRE) - right = jnp.logical_or(action == Action.RIGHT, action == Action.RIGHTFIRE) + # ------------------------------------------------------------------ + # reset + # ------------------------------------------------------------------ - dx = jnp.where(left, -self.consts.PLAYER_STEP_X, 0) - dx = jnp.where(right, self.consts.PLAYER_STEP_X, dx) - - min_x = jnp.array(self.consts.ROAD_LEFT, dtype=jnp.int32) - max_x = jnp.array(self.consts.ROAD_RIGHT - self.consts.PLAYER_SIZE[0], dtype=jnp.int32) - new_player_x = jnp.clip(state.player_x + dx.astype(jnp.int32), min_x, max_x).astype(jnp.int32) - - return state._replace(player_x=new_player_x) - - def _enemy_step_and_spawn(self, state: JamesBondState) -> Tuple[JamesBondState, jnp.ndarray]: - active_i32 = state.enemy_active.astype(jnp.int32) - new_enemy_y = (state.enemy_y + active_i32 * jnp.array(self.consts.ENEMY_SPEED_Y, dtype=jnp.int32)).astype(jnp.int32) + def reset(self, key: chex.PRNGKey = jax.random.PRNGKey(42)) -> Tuple[JamesBondObservation, JamesBondState]: + n_e = self.consts.MAX_ENEMIES + n_b = self.consts.MAX_BULLETS + n_eb = self.consts.MAX_E_BULLETS - passed = state.enemy_active & (new_enemy_y > jnp.array(self.consts.HEIGHT, dtype=jnp.int32)) - passed_count = passed.astype(jnp.int32).sum() + state = JamesBondState( + player_x=jnp.array(self.consts.PLAYER_START_X, dtype=jnp.int32), + player_y=jnp.array(self.consts.SEA_Y - self.consts.PLAYER_H // 2, dtype=jnp.int32), - enemy_active = state.enemy_active & (~passed) - enemy_y = jnp.where(enemy_active, new_enemy_y, jnp.array(0, dtype=jnp.int32)).astype(jnp.int32) - enemy_x = jnp.where(enemy_active, state.enemy_x, jnp.array(0, dtype=jnp.int32)).astype(jnp.int32) + bullet_x=jnp.zeros((n_b,), dtype=jnp.int32), + bullet_y=jnp.zeros((n_b,), dtype=jnp.int32), + bullet_active=jnp.zeros((n_b,), dtype=bool), - spawn_key, lane_key = jax.random.split(state.key) - roll = jax.random.uniform(spawn_key, ()) - should_spawn = roll < jnp.array(self.consts.SPAWN_PROB) + enemy_x=jnp.zeros((n_e,), dtype=jnp.int32), + enemy_y=jnp.zeros((n_e,), dtype=jnp.int32), + enemy_type=jnp.zeros((n_e,), dtype=jnp.int32), + enemy_active=jnp.zeros((n_e,), dtype=bool), - free_mask = ~enemy_active - any_free = jnp.any(free_mask) - spawn_idx = jnp.argmax(free_mask.astype(jnp.int32)) - - road_w = self.consts.ROAD_RIGHT - self.consts.ROAD_LEFT - lane_w = jnp.array(road_w // self.consts.NUM_LANES, dtype=jnp.int32) - lane_margin = jnp.array((road_w - lane_w * self.consts.NUM_LANES) // 2, dtype=jnp.int32) - lane_centers = ( - jnp.arange(self.consts.NUM_LANES, dtype=jnp.int32) * lane_w - + jnp.array(self.consts.ROAD_LEFT, dtype=jnp.int32) - + lane_margin - + (lane_w // 2) - ) - lane_idx = jax.random.randint(lane_key, (), 0, self.consts.NUM_LANES, dtype=jnp.int32) - spawn_x = (lane_centers[lane_idx] - (self.consts.ENEMY_SIZE[0] // 2)).astype(jnp.int32) - spawn_y = jnp.array(self.consts.SPAWN_Y, dtype=jnp.int32) - - def do_spawn(carry): - ex, ey, ea = carry - ex = ex.at[spawn_idx].set(spawn_x) - ey = ey.at[spawn_idx].set(spawn_y) - ea = ea.at[spawn_idx].set(True) - return ex, ey, ea - - enemy_x, enemy_y, enemy_active = jax.lax.cond( - should_spawn & any_free, - do_spawn, - lambda carry: carry, - operand=(enemy_x, enemy_y, enemy_active), - ) + e_bullet_x=jnp.zeros((n_eb,), dtype=jnp.int32), + e_bullet_y=jnp.zeros((n_eb,), dtype=jnp.int32), + e_bullet_active=jnp.zeros((n_eb,), dtype=bool), - new_scroll = (state.scroll_offset + jnp.array(self.consts.SCROLL_SPEED, dtype=jnp.int32)) % jnp.array( - self.consts.DASH_SPACING, dtype=jnp.int32 + scroll_offset=jnp.array(0, dtype=jnp.int32), + spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL, dtype=jnp.int32), + score=jnp.array(0, dtype=jnp.int32), + lives=jnp.array(self.consts.NUM_LIVES, dtype=jnp.int32), + step_count=jnp.array(0, dtype=jnp.int32), + key=key, ) + return self._get_observation(state), state - new_state = state._replace( - enemy_x=enemy_x.astype(jnp.int32), - enemy_y=enemy_y.astype(jnp.int32), - enemy_active=enemy_active, - scroll_offset=new_scroll.astype(jnp.int32), - ) - return new_state, passed_count + # ------------------------------------------------------------------ + # step + # ------------------------------------------------------------------ @partial(jax.jit, static_argnums=(0,)) def step( - self, state: JamesBondState, action: chex.Array + self, + state: JamesBondState, + action: chex.Array, ) -> Tuple[JamesBondObservation, JamesBondState, float, bool, JamesBondInfo]: - new_state_key, step_key = jax.random.split(state.key) - state = state._replace(key=step_key) + # 1. Move player + state = self._step_player(state, action) - prev_state = state + # 2. Fire player bullet + state = self._step_player_fire(state, action) - state = self._player_step(state, action) - state, passed_count = self._enemy_step_and_spawn(state) + # 3. Move existing player bullets + state = self._step_bullets(state) - px = state.player_x - py = jnp.array(self.consts.PLAYER_Y, dtype=jnp.int32) - collisions = _rects_overlap( - px, py, - self.consts.PLAYER_SIZE[0], self.consts.PLAYER_SIZE[1], - state.enemy_x, state.enemy_y, - self.consts.ENEMY_SIZE[0], self.consts.ENEMY_SIZE[1], - ) - collided = jnp.any(state.enemy_active & collisions) + # 4. Spawn + move enemies + state = self._step_enemies(state) + + # 5. Enemy fire + state = self._step_enemy_fire(state) + + # 6. Move enemy bullets + state = self._step_enemy_bullets(state) + + # 7. Collision: player bullets vs enemies + state, kills = self._resolve_bullet_enemy_collisions(state) + + # 8. Collision: player vs enemies / enemy bullets + state, hit = self._resolve_player_hit(state) + + # 9. Score, lives, done + score = (state.score + kills * self.consts.SCORE_PER_KILL).astype(jnp.int32) + lives = (state.lives - hit.astype(jnp.int32)).astype(jnp.int32) + lives = jnp.maximum(lives, 0) state = state._replace( - step_counter=(state.step_counter + jnp.array(1, dtype=jnp.int32)).astype(jnp.int32), - passed_counter=(state.passed_counter + passed_count).astype(jnp.int32), - collided=collided, - key=new_state_key, + score=score, + lives=lives, + step_count=(state.step_count + 1).astype(jnp.int32), + scroll_offset=( + (state.scroll_offset + self.consts.SCROLL_SPEED) % self.consts.WIDTH + ).astype(jnp.int32), ) - done = self._get_done(state) - env_reward = self._get_reward(prev_state, state) - info = self._get_info(state) - obs = self._get_observation(state) - return obs, state, env_reward, done, info + reward = (kills * self.consts.SCORE_PER_KILL).astype(jnp.float32) + done = state.lives <= 0 + obs = self._get_observation(state) + info = self._get_info(state) + return obs, state, reward, done, info + + # ------------------------------------------------------------------ + # Player movement + # ------------------------------------------------------------------ + + def _step_player(self, state: JamesBondState, action: chex.Array) -> JamesBondState: + # Decode directional bits from action + move_up = ( + (action == Action.UP) | (action == Action.UPRIGHT) | + (action == Action.UPLEFT) | (action == Action.UPFIRE) | + (action == Action.UPRIGHTFIRE) | (action == Action.UPLEFTFIRE) + ) + move_down = ( + (action == Action.DOWN) | (action == Action.DOWNRIGHT) | + (action == Action.DOWNLEFT) | (action == Action.DOWNFIRE) | + (action == Action.DOWNRIGHTFIRE) | (action == Action.DOWNLEFTFIRE) + ) + move_right = ( + (action == Action.RIGHT) | (action == Action.UPRIGHT) | + (action == Action.DOWNRIGHT) | (action == Action.RIGHTFIRE) | + (action == Action.UPRIGHTFIRE) | (action == Action.DOWNRIGHTFIRE) + ) + move_left = ( + (action == Action.LEFT) | (action == Action.UPLEFT) | + (action == Action.DOWNLEFT) | (action == Action.LEFTFIRE) | + (action == Action.UPLEFTFIRE) | (action == Action.DOWNLEFTFIRE) + ) - def reset(self, key: chex.PRNGKey = jax.random.PRNGKey(42)) -> Tuple[JamesBondObservation, JamesBondState]: - state_key, _ = jax.random.split(key) + dy = jnp.where(move_up, -self.consts.PLAYER_SPEED_Y, 0) + dy = jnp.where(move_down, self.consts.PLAYER_SPEED_Y, dy) + dx = jnp.where(move_right, self.consts.PLAYER_SPEED_X, 0) + dx = jnp.where(move_left, -self.consts.PLAYER_SPEED_X, dx) + + new_y = jnp.clip( + state.player_y + dy.astype(jnp.int32), + self.consts.PLAYER_MIN_Y, + self.consts.PLAYER_MAX_Y, + ).astype(jnp.int32) + new_x = jnp.clip( + state.player_x + dx.astype(jnp.int32), + self.consts.PLAYER_MIN_X, + self.consts.PLAYER_MAX_X, + ).astype(jnp.int32) + + return state._replace(player_x=new_x, player_y=new_y) + + # ------------------------------------------------------------------ + # Player firing + # ------------------------------------------------------------------ + + def _step_player_fire(self, state: JamesBondState, action: chex.Array) -> JamesBondState: + fired = ( + (action == Action.FIRE) | (action == Action.UPFIRE) | + (action == Action.DOWNFIRE) | (action == Action.RIGHTFIRE) | + (action == Action.LEFTFIRE) | (action == Action.UPRIGHTFIRE) | + (action == Action.UPLEFTFIRE) | (action == Action.DOWNRIGHTFIRE) | + (action == Action.DOWNLEFTFIRE) + ) - start_x = jnp.array((self.consts.ROAD_LEFT + self.consts.ROAD_RIGHT) // 2, dtype=jnp.int32) - start_x = (start_x - (self.consts.PLAYER_SIZE[0] // 2)).astype(jnp.int32) + # Find a free bullet slot + free_mask = ~state.bullet_active + any_free = jnp.any(free_mask) + slot = jnp.argmax(free_mask.astype(jnp.int32)) - n = self.consts.MAX_ENEMIES - state = JamesBondState( - player_x=start_x, - enemy_x=jnp.zeros((n,), dtype=jnp.int32), - enemy_y=jnp.zeros((n,), dtype=jnp.int32), - enemy_active=jnp.zeros((n,), dtype=bool), - scroll_offset=jnp.array(0, dtype=jnp.int32), - step_counter=jnp.array(0, dtype=jnp.int32), - passed_counter=jnp.array(0, dtype=jnp.int32), - collided=jnp.array(False), - key=state_key, - ) - return self._get_observation(state), state + spawn_x = (state.player_x + self.consts.PLAYER_W).astype(jnp.int32) + spawn_y = (state.player_y + self.consts.PLAYER_H // 2 - self.consts.BULLET_H // 2).astype(jnp.int32) - def render(self, state: JamesBondState) -> jnp.ndarray: - return self.renderer.render(state) + should_spawn = fired & any_free - def _get_observation(self, state: JamesBondState) -> JamesBondObservation: - player = EntityPosition( - x=state.player_x.astype(jnp.int32), - y=jnp.array(self.consts.PLAYER_Y, dtype=jnp.int32), - width=jnp.array(self.consts.PLAYER_SIZE[0], dtype=jnp.int32), - height=jnp.array(self.consts.PLAYER_SIZE[1], dtype=jnp.int32), + new_bx = jnp.where( + jnp.arange(self.consts.MAX_BULLETS) == slot, + jnp.where(should_spawn, spawn_x, state.bullet_x), + state.bullet_x, ) - n = self.consts.MAX_ENEMIES - enemies = EntityPositionBatch( - x=state.enemy_x.astype(jnp.int32).reshape((n,)), - y=state.enemy_y.astype(jnp.int32).reshape((n,)), - width=jnp.full((n,), self.consts.ENEMY_SIZE[0], dtype=jnp.int32), - height=jnp.full((n,), self.consts.ENEMY_SIZE[1], dtype=jnp.int32), + new_by = jnp.where( + jnp.arange(self.consts.MAX_BULLETS) == slot, + jnp.where(should_spawn, spawn_y, state.bullet_y), + state.bullet_y, + ) + new_ba = jnp.where( + jnp.arange(self.consts.MAX_BULLETS) == slot, + jnp.where(should_spawn, True, state.bullet_active), + state.bullet_active, ) - return JamesBondObservation(player=player, enemies=enemies) - @partial(jax.jit, static_argnums=(0,)) - def obs_to_flat_array(self, obs: JamesBondObservation) -> jnp.ndarray: - return jnp.concatenate( + return state._replace(bullet_x=new_bx, bullet_y=new_by, bullet_active=new_ba) + + # ------------------------------------------------------------------ + # Move player bullets (to the right) + # ------------------------------------------------------------------ + + def _step_bullets(self, state: JamesBondState) -> JamesBondState: + new_bx = (state.bullet_x + state.bullet_active.astype(jnp.int32) * self.consts.BULLET_SPEED).astype(jnp.int32) + # Deactivate bullets that left the screen + still_on = new_bx < self.consts.WIDTH + new_ba = state.bullet_active & still_on + new_bx = jnp.where(new_ba, new_bx, 0) + new_by = jnp.where(new_ba, state.bullet_y, 0) + return state._replace(bullet_x=new_bx, bullet_y=new_by, bullet_active=new_ba) + + # ------------------------------------------------------------------ + # Spawn and move enemies + # ------------------------------------------------------------------ + + def _step_enemies(self, state: JamesBondState) -> JamesBondState: + # Move active enemies to the left + new_ex = (state.enemy_x - state.enemy_active.astype(jnp.int32) * self.consts.ENEMY_SPEED).astype(jnp.int32) + + # Deactivate enemies that scrolled off the left edge + still_on = new_ex + self.consts.ENEMY_W > 0 + enemy_active = state.enemy_active & still_on + new_ex = jnp.where(enemy_active, new_ex, 0) + new_ey = jnp.where(enemy_active, state.enemy_y, 0) + new_et = jnp.where(enemy_active, state.enemy_type, 0) + + # Countdown spawn timer + new_timer = (state.spawn_timer - 1).astype(jnp.int32) + + # Spawn a new enemy when timer reaches zero + key, type_key, y_key = jax.random.split(state.key, 3) + + # Randomly choose enemy type (0=jet, 1=boat, 2=sub) + etype = jax.random.randint(type_key, (), 0, 3, dtype=jnp.int32) + ey_table = jnp.array( [ - obs.player.x.flatten(), - obs.player.y.flatten(), - obs.player.width.flatten(), - obs.player.height.flatten(), - obs.enemies.x.flatten(), - obs.enemies.y.flatten(), - obs.enemies.width.flatten(), - obs.enemies.height.flatten(), + self.consts.SKY_Y - self.consts.ENEMY_H // 2, + self.consts.SEA_Y - self.consts.ENEMY_H // 2, + self.consts.SUB_Y - self.consts.ENEMY_H // 2, ], - axis=0, + dtype=jnp.int32, ) + ey = ey_table[etype] - def action_space(self) -> spaces.Discrete: - return spaces.Discrete(6) + # Find free enemy slot + free_mask = ~enemy_active + any_free = jnp.any(free_mask) + slot = jnp.argmax(free_mask.astype(jnp.int32)) - def observation_space(self) -> spaces.Space: - n = self.consts.MAX_ENEMIES - return spaces.Dict( - { - "player": spaces.Dict( - { - "x": spaces.Box(low=0, high=self.consts.WIDTH, shape=(), dtype=jnp.int32), - "y": spaces.Box(low=0, high=self.consts.HEIGHT, shape=(), dtype=jnp.int32), - "width": spaces.Box(low=0, high=self.consts.WIDTH, shape=(), dtype=jnp.int32), - "height": spaces.Box(low=0, high=self.consts.HEIGHT, shape=(), dtype=jnp.int32), - } - ), - "enemies": spaces.Dict( - { - "x": spaces.Box(low=0, high=self.consts.WIDTH, shape=(n,), dtype=jnp.int32), - "y": spaces.Box(low=0, high=self.consts.HEIGHT, shape=(n,), dtype=jnp.int32), - "width": spaces.Box(low=0, high=self.consts.WIDTH, shape=(n,), dtype=jnp.int32), - "height": spaces.Box(low=0, high=self.consts.HEIGHT, shape=(n,), dtype=jnp.int32), - } - ), - } + should_spawn = (new_timer <= 0) & any_free + new_timer = jnp.where(new_timer <= 0, self.consts.ENEMY_SPAWN_INTERVAL, new_timer).astype(jnp.int32) + + indices = jnp.arange(self.consts.MAX_ENEMIES) + is_slot = indices == slot + + new_ex = jnp.where(is_slot & should_spawn, self.consts.ENEMY_SPAWN_X, new_ex) + new_ey = jnp.where(is_slot & should_spawn, ey, new_ey) + new_et = jnp.where(is_slot & should_spawn, etype, new_et) + enemy_active = jnp.where(is_slot & should_spawn, True, enemy_active) + + return state._replace( + enemy_x=new_ex.astype(jnp.int32), + enemy_y=new_ey.astype(jnp.int32), + enemy_type=new_et.astype(jnp.int32), + enemy_active=enemy_active, + spawn_timer=new_timer, + key=key, ) - def image_space(self) -> spaces.Box: - return spaces.Box(low=0, high=255, shape=(self.consts.HEIGHT, self.consts.WIDTH, 3), dtype=jnp.uint8) + # ------------------------------------------------------------------ + # Enemies fire at the player + # ------------------------------------------------------------------ + + def _step_enemy_fire(self, state: JamesBondState) -> JamesBondState: + key, roll_key = jax.random.split(state.key) + + rolls = jax.random.uniform(roll_key, (self.consts.MAX_ENEMIES,)) + wants_fire = state.enemy_active & (rolls < self.consts.ENEMY_FIRE_PROB) + + # For each enemy that fires, try to place a bullet in a free slot + free_mask = ~state.e_bullet_active + + def fire_one(carry, i): + bx, by, ba = carry + free_slots = ~ba + any_free = jnp.any(free_slots) + slot = jnp.argmax(free_slots.astype(jnp.int32)) + should = wants_fire[i] & any_free + spawn_bx = state.enemy_x[i] + spawn_by = (state.enemy_y[i] + self.consts.ENEMY_H // 2 - self.consts.E_BULLET_H // 2).astype(jnp.int32) + indices = jnp.arange(self.consts.MAX_E_BULLETS) + bx = jnp.where((indices == slot) & should, spawn_bx, bx) + by = jnp.where((indices == slot) & should, spawn_by, by) + ba = jnp.where((indices == slot) & should, True, ba) + return (bx, by, ba), None + + (new_ebx, new_eby, new_eba), _ = jax.lax.scan( + fire_one, + (state.e_bullet_x, state.e_bullet_y, state.e_bullet_active), + jnp.arange(self.consts.MAX_ENEMIES), + ) + + return state._replace( + e_bullet_x=new_ebx.astype(jnp.int32), + e_bullet_y=new_eby.astype(jnp.int32), + e_bullet_active=new_eba, + key=key, + ) + + # ------------------------------------------------------------------ + # Move enemy bullets (to the left) + # ------------------------------------------------------------------ + + def _step_enemy_bullets(self, state: JamesBondState) -> JamesBondState: + new_ebx = (state.e_bullet_x - state.e_bullet_active.astype(jnp.int32) * self.consts.E_BULLET_SPEED).astype(jnp.int32) + still_on = new_ebx + self.consts.E_BULLET_W > 0 + new_eba = state.e_bullet_active & still_on + new_ebx = jnp.where(new_eba, new_ebx, 0) + new_eby = jnp.where(new_eba, state.e_bullet_y, 0) + return state._replace(e_bullet_x=new_ebx, e_bullet_y=new_eby, e_bullet_active=new_eba) + + # ------------------------------------------------------------------ + # Player bullet vs enemy collisions + # ------------------------------------------------------------------ + + def _resolve_bullet_enemy_collisions( + self, state: JamesBondState + ) -> Tuple[JamesBondState, chex.Array]: + """Returns updated state and the number of kills this frame.""" + + # Build hit matrix: (MAX_BULLETS, MAX_ENEMIES) + bx = state.bullet_x[:, None] # (B, 1) + by = state.bullet_y[:, None] + ba = state.bullet_active[:, None] + + ex = state.enemy_x[None, :] # (1, E) + ey = state.enemy_y[None, :] + ea = state.enemy_active[None, :] + + hit = ( + ba & ea & + _rect_overlap(bx, by, self.consts.BULLET_W, self.consts.BULLET_H, + ex, ey, self.consts.ENEMY_W, self.consts.ENEMY_H) + ) # (B, E) + + # A bullet is consumed if it hit any enemy + bullet_hit = jnp.any(hit, axis=1) # (B,) + # An enemy is killed if hit by any bullet + enemy_hit = jnp.any(hit, axis=0) # (E,) + + kills = enemy_hit.astype(jnp.int32).sum() + + new_ba = state.bullet_active & ~bullet_hit + new_bx = jnp.where(new_ba, state.bullet_x, 0) + new_by = jnp.where(new_ba, state.bullet_y, 0) + + new_ea = state.enemy_active & ~enemy_hit + new_ex = jnp.where(new_ea, state.enemy_x, 0) + new_ey = jnp.where(new_ea, state.enemy_y, 0) + new_et = jnp.where(new_ea, state.enemy_type, 0) + + state = state._replace( + bullet_active=new_ba, bullet_x=new_bx, bullet_y=new_by, + enemy_active=new_ea, enemy_x=new_ex, enemy_y=new_ey, enemy_type=new_et, + ) + return state, kills + + # ------------------------------------------------------------------ + # Player hit detection + # ------------------------------------------------------------------ + + def _resolve_player_hit( + self, state: JamesBondState + ) -> Tuple[JamesBondState, chex.Array]: + """Returns updated state and bool (was player hit this frame?).""" + px, py = state.player_x, state.player_y + pw, ph = self.consts.PLAYER_W, self.consts.PLAYER_H + + # Hit by enemy body + enemy_hit = _rect_overlap( + px, py, pw, ph, + state.enemy_x, state.enemy_y, self.consts.ENEMY_W, self.consts.ENEMY_H, + ) & state.enemy_active # (E,) + + # Hit by enemy bullet + ebullet_hit = _rect_overlap( + px, py, pw, ph, + state.e_bullet_x, state.e_bullet_y, self.consts.E_BULLET_W, self.consts.E_BULLET_H, + ) & state.e_bullet_active # (EB,) + + hit = jnp.any(enemy_hit) | jnp.any(ebullet_hit) + + # Clear the bullet that caused the hit + new_eba = jnp.where(ebullet_hit, False, state.e_bullet_active) + state = state._replace(e_bullet_active=new_eba) + + return state, hit + + # ------------------------------------------------------------------ + # Observations / info + # ------------------------------------------------------------------ + + def _get_observation(self, state: JamesBondState) -> JamesBondObservation: + return JamesBondObservation( + player_x=state.player_x, + player_y=state.player_y, + enemy_x=state.enemy_x, + enemy_y=state.enemy_y, + enemy_type=state.enemy_type, + enemy_active=state.enemy_active, + bullet_x=state.bullet_x, + bullet_y=state.bullet_y, + bullet_active=state.bullet_active, + score=state.score, + lives=state.lives, + ) @partial(jax.jit, static_argnums=(0,)) def _get_info(self, state: JamesBondState) -> JamesBondInfo: return JamesBondInfo( - time=state.step_counter, - passed=state.passed_counter, - collisions=state.collided.astype(jnp.int32), + score=state.score, + lives=state.lives, + step_count=state.step_count, ) @partial(jax.jit, static_argnums=(0,)) - def _get_reward(self, previous_state: JamesBondState, state: JamesBondState) -> float: - passed_delta = (state.passed_counter - previous_state.passed_counter).astype(jnp.int32) - pass_reward = passed_delta.astype(jnp.float32) * jnp.array(self.consts.PASS_REWARD, dtype=jnp.float32) - survive_reward = jnp.array(self.consts.SURVIVE_REWARD, dtype=jnp.float32) * (~state.collided) - return (pass_reward + survive_reward).astype(jnp.float32) + def _get_reward(self, prev: JamesBondState, state: JamesBondState) -> float: + return ((state.score - prev.score)).astype(jnp.float32) @partial(jax.jit, static_argnums=(0,)) def _get_done(self, state: JamesBondState) -> bool: - return state.collided + return state.lives <= 0 + + @partial(jax.jit, static_argnums=(0,)) + def obs_to_flat_array(self, obs: JamesBondObservation) -> jnp.ndarray: + return jnp.concatenate([ + obs.player_x.flatten().astype(jnp.float32), + obs.player_y.flatten().astype(jnp.float32), + obs.enemy_x.flatten().astype(jnp.float32), + obs.enemy_y.flatten().astype(jnp.float32), + obs.enemy_type.flatten().astype(jnp.float32), + obs.enemy_active.flatten().astype(jnp.float32), + obs.bullet_x.flatten().astype(jnp.float32), + obs.bullet_y.flatten().astype(jnp.float32), + obs.bullet_active.flatten().astype(jnp.float32), + obs.score.flatten().astype(jnp.float32), + obs.lives.flatten().astype(jnp.float32), + ]) + + def render(self, state: JamesBondState) -> jnp.ndarray: + return self.renderer.render(state) + + # ------------------------------------------------------------------ + # Spaces + # ------------------------------------------------------------------ + + def action_space(self) -> spaces.Discrete: + return spaces.Discrete(18) + + def observation_space(self) -> spaces.Space: + n_e = self.consts.MAX_ENEMIES + n_b = self.consts.MAX_BULLETS + W, H = self.consts.WIDTH, self.consts.HEIGHT + return spaces.Dict({ + "player_x": spaces.Box(0, W, shape=(), dtype=jnp.int32), + "player_y": spaces.Box(0, H, shape=(), dtype=jnp.int32), + "enemy_x": spaces.Box(0, W, shape=(n_e,), dtype=jnp.int32), + "enemy_y": spaces.Box(0, H, shape=(n_e,), dtype=jnp.int32), + "enemy_type": spaces.Box(0, 3, shape=(n_e,), dtype=jnp.int32), + "enemy_active": spaces.Box(0, 1, shape=(n_e,), dtype=bool), + "bullet_x": spaces.Box(0, W, shape=(n_b,), dtype=jnp.int32), + "bullet_y": spaces.Box(0, H, shape=(n_b,), dtype=jnp.int32), + "bullet_active": spaces.Box(0, 1, shape=(n_b,), dtype=bool), + "score": spaces.Box(0, 999999, shape=(), dtype=jnp.int32), + "lives": spaces.Box(0, self.consts.NUM_LIVES, shape=(), dtype=jnp.int32), + }) + + def image_space(self) -> spaces.Box: + return spaces.Box(0, 255, shape=(self.consts.HEIGHT, self.consts.WIDTH, 3), dtype=jnp.uint8) +# --------------------------------------------------------------------------- +# Renderer +# --------------------------------------------------------------------------- + class JamesBondRenderer(JAXGameRenderer): - _ID_GRASS = 0 - _ID_ROAD = 1 - _ID_LINE = 2 - _ID_PLAYER = 3 - _ID_ENEMY = 4 + """ + Simple rectangle-based renderer approximating the ALE visual layout. + + Background zones: + - Black HUD strip at top and bottom + - Sky (light blue) from HUD to horizon + - Sea surface (dark blue) from horizon to seabed + - Underwater (deep blue/green) below seabed + + Sprites are coloured rectangles: + - Player → yellow + - Jet → red + - Boat → orange + - Sub → green + - Player bullet → white + - Enemy bullet → light orange + """ + + # Palette index constants (kept here so NamedTuple fields stay valid identifiers) + _ID_HUD = 0 + _ID_SKY = 1 + _ID_SEA_SURF = 2 + _ID_SEA_DEEP = 3 + _ID_HORIZON = 4 + _ID_SEABED = 5 + _ID_PLAYER = 6 + _ID_JET = 7 + _ID_BOAT = 8 + _ID_SUB = 9 + _ID_BULLET = 10 + _ID_E_BULLET = 11 + + # RGB colours (approximate ALE palette) + _COLOR_HUD = (0, 0, 0) + _COLOR_SKY = (80, 128, 192) + _COLOR_SEA_SURF = (0, 72, 148) + _COLOR_SEA_DEEP = (0, 48, 108) + _COLOR_HORIZON = (200, 220, 180) + _COLOR_SEABED = (120, 100, 60) + _COLOR_PLAYER = (210, 210, 64) + _COLOR_JET = (200, 72, 72) + _COLOR_BOAT = (210, 140, 60) + _COLOR_SUB = (84, 184, 84) + _COLOR_BULLET = (255, 255, 255) + _COLOR_E_BULLET = (255, 180, 100) def __init__(self, consts: JamesBondConstants | None = None): - super().__init__(consts) self.consts = consts or JamesBondConstants() + super().__init__(self.consts) - self.config = render_utils.RendererConfig(game_dimensions=(self.consts.HEIGHT, self.consts.WIDTH), channels=3) + self.config = render_utils.RendererConfig( + game_dimensions=(self.consts.HEIGHT, self.consts.WIDTH), + channels=3, + ) self.jr = render_utils.JaxRenderingUtils(self.config) + # Palette: index → RGB (order must match _ID_* constants above) self.PALETTE = jnp.array( [ - self.consts.GRASS_COLOR, - self.consts.ROAD_COLOR, - self.consts.LINE_COLOR, - self.consts.PLAYER_COLOR, - self.consts.ENEMY_COLOR, + self._COLOR_HUD, # 0 + self._COLOR_SKY, # 1 + self._COLOR_SEA_SURF, # 2 + self._COLOR_SEA_DEEP, # 3 + self._COLOR_HORIZON, # 4 + self._COLOR_SEABED, # 5 + self._COLOR_PLAYER, # 6 + self._COLOR_JET, # 7 + self._COLOR_BOAT, # 8 + self._COLOR_SUB, # 9 + self._COLOR_BULLET, # 10 + self._COLOR_E_BULLET, # 11 ], dtype=jnp.uint8, ) - self.BACKGROUND = jnp.full( - (self.consts.HEIGHT, self.consts.WIDTH), - jnp.array(self._ID_GRASS, dtype=jnp.uint8), - dtype=jnp.uint8, - ) + # Pre-build the static background raster (does not change with game state) + self._STATIC_BG = self._build_background() + + def _build_background(self) -> jnp.ndarray: + """Builds the static zone background as a palette-index raster (H, W).""" + c = self.consts + H, W = c.HEIGHT, c.WIDTH + + # Start with sky colour everywhere + bg = jnp.full((H, W), self._ID_SKY, dtype=jnp.uint8) + + rows = jnp.arange(H)[:, None] # (H, 1) broadcasts over width + + # Top HUD bar + bg = jnp.where(rows < c.HUD_TOP_H, self._ID_HUD, bg) + + # Horizon line (2 px thick) + bg = jnp.where((rows >= c.HORIZON_Y) & (rows < c.HORIZON_Y + 2), self._ID_HORIZON, bg) + + # Sea surface zone (below horizon) + bg = jnp.where(rows >= c.HORIZON_Y + 2, self._ID_SEA_SURF, bg) + + # Seabed transition line (2 px) + bg = jnp.where((rows >= c.SEABED_Y) & (rows < c.SEABED_Y + 2), self._ID_SEABED, bg) + + # Deep water / underwater zone + bg = jnp.where(rows >= c.SEABED_Y + 2, self._ID_SEA_DEEP, bg) + + # Bottom HUD bar + bg = jnp.where(rows >= H - c.HUD_BOTTOM_H, self._ID_HUD, bg) + + return bg.astype(jnp.uint8) @partial(jax.jit, static_argnums=(0,)) def render(self, state: JamesBondState) -> jnp.ndarray: - raster = self.jr.create_object_raster(self.BACKGROUND) + c = self.consts - road_pos = jnp.array([[self.consts.ROAD_LEFT, 0]], dtype=jnp.int32) - road_size = jnp.array([[self.consts.ROAD_RIGHT - self.consts.ROAD_LEFT, self.consts.HEIGHT]], dtype=jnp.int32) - raster = self.jr.draw_rects(raster, road_pos, road_size, self._ID_ROAD) + # --- Background (with animated scroll stripes) --- + raster = self._draw_scroll_stripes(self._STATIC_BG, state.scroll_offset) - edge_pos = jnp.array( - [ - [self.consts.ROAD_LEFT - self.consts.ROAD_EDGE_LINE_W, 0], - [self.consts.ROAD_RIGHT, 0], - ], - dtype=jnp.int32, - ) - edge_size = jnp.array( - [ - [self.consts.ROAD_EDGE_LINE_W, self.consts.HEIGHT], - [self.consts.ROAD_EDGE_LINE_W, self.consts.HEIGHT], - ], - dtype=jnp.int32, + # --- Player --- + raster = self.jr.draw_rects( + raster, + jnp.stack([state.player_x, state.player_y])[None, :], + jnp.array([[c.PLAYER_W, c.PLAYER_H]], dtype=jnp.int32), + self._ID_PLAYER, ) - raster = self.jr.draw_rects(raster, edge_pos, edge_size, self._ID_LINE) - - num_dashes = (self.consts.HEIGHT // self.consts.DASH_SPACING) + 3 - dash_x = jnp.array((self.consts.ROAD_LEFT + self.consts.ROAD_RIGHT) // 2 - (self.consts.DASH_W // 2), dtype=jnp.int32) - ys = ( - (jnp.arange(num_dashes, dtype=jnp.int32) * jnp.array(self.consts.DASH_SPACING, dtype=jnp.int32) + state.scroll_offset) - % jnp.array(self.consts.HEIGHT + self.consts.DASH_SPACING, dtype=jnp.int32) - ) - jnp.array(self.consts.DASH_H, dtype=jnp.int32) - visible = (ys >= 0) & (ys < jnp.array(self.consts.HEIGHT, dtype=jnp.int32)) - dash_xs = jnp.where(visible, dash_x, jnp.array(-1, dtype=jnp.int32)) - dash_pos = jnp.stack([dash_xs, ys], axis=1).astype(jnp.int32) - dash_size = jnp.tile( - jnp.array([[self.consts.DASH_W, self.consts.DASH_H]], dtype=jnp.int32), - (num_dashes, 1), - ) - raster = self.jr.draw_rects(raster, dash_pos, dash_size, self._ID_LINE) - player_pos = jnp.array([[state.player_x, self.consts.PLAYER_Y]], dtype=jnp.int32) - player_size = jnp.array([[self.consts.PLAYER_SIZE[0], self.consts.PLAYER_SIZE[1]]], dtype=jnp.int32) - raster = self.jr.draw_rects(raster, player_pos, player_size, self._ID_PLAYER) - - enemy_x = jnp.where(state.enemy_active, state.enemy_x, jnp.array(-1, dtype=jnp.int32)) - enemy_pos = jnp.stack([enemy_x, state.enemy_y], axis=1).astype(jnp.int32) - enemy_size = jnp.tile( - jnp.array([[self.consts.ENEMY_SIZE[0], self.consts.ENEMY_SIZE[1]]], dtype=jnp.int32), - (self.consts.MAX_ENEMIES, 1), - ) - raster = self.jr.draw_rects(raster, enemy_pos, enemy_size, self._ID_ENEMY) + # --- Enemies (grouped by type to keep color_id a Python int) --- + for etype, color_id in ( + (ENEMY_JET, self._ID_JET), + (ENEMY_BOAT, self._ID_BOAT), + (ENEMY_SUB, self._ID_SUB), + ): + # Collect positions for enemies of this type; hide others off-screen + mask = state.enemy_active & (state.enemy_type == etype) + ex = jnp.where(mask, state.enemy_x, jnp.full_like(state.enemy_x, -c.ENEMY_W)) + ey = state.enemy_y + positions = jnp.stack([ex, ey], axis=1) # (MAX_ENEMIES, 2) + sizes = jnp.tile(jnp.array([[c.ENEMY_W, c.ENEMY_H]], dtype=jnp.int32), + (c.MAX_ENEMIES, 1)) + raster = self.jr.draw_rects(raster, positions, sizes, color_id) + + # --- Player bullets --- + bx = jnp.where(state.bullet_active, state.bullet_x, + jnp.full_like(state.bullet_x, -c.BULLET_W)) + b_positions = jnp.stack([bx, state.bullet_y], axis=1) + b_sizes = jnp.tile(jnp.array([[c.BULLET_W, c.BULLET_H]], dtype=jnp.int32), + (c.MAX_BULLETS, 1)) + raster = self.jr.draw_rects(raster, b_positions, b_sizes, self._ID_BULLET) + + # --- Enemy bullets --- + ebx = jnp.where(state.e_bullet_active, state.e_bullet_x, + jnp.full_like(state.e_bullet_x, -c.E_BULLET_W)) + eb_positions = jnp.stack([ebx, state.e_bullet_y], axis=1) + eb_sizes = jnp.tile(jnp.array([[c.E_BULLET_W, c.E_BULLET_H]], dtype=jnp.int32), + (c.MAX_E_BULLETS, 1)) + raster = self.jr.draw_rects(raster, eb_positions, eb_sizes, self._ID_E_BULLET) return self.jr.render_from_palette(raster, self.PALETTE) - + @partial(jax.jit, static_argnums=(0,)) + def _draw_scroll_stripes( + self, bg: jnp.ndarray, scroll_offset: chex.Array + ) -> jnp.ndarray: + """ + Adds animated scrolling stripes: + - sparse horizontal dashes in the sky zone (simulate clouds/distance marks) + - sparse ripple lines in the sea zone (simulate waves) + """ + c = self.consts + H, W = c.HEIGHT, c.WIDTH + cols = jnp.arange(W)[None, :] # (1, W) + rows = jnp.arange(H)[:, None] # (H, 1) + + shifted_col = (cols + scroll_offset) % (c.BG_STRIPE_W * 2) + + # Sky dashes + in_sky = (rows >= c.HUD_TOP_H) & (rows < c.HORIZON_Y) + cloud_stripe = in_sky & (shifted_col < c.BG_STRIPE_W) & (rows % 10 < 2) + raster = jnp.where(cloud_stripe, jnp.uint8(self._ID_HORIZON), bg) + + # Sea ripples + in_sea = (rows >= c.HORIZON_Y + 2) & (rows < c.SEABED_Y) + wave_stripe = in_sea & (shifted_col < 2) & (rows % 8 == 0) + raster = jnp.where(wave_stripe, jnp.uint8(self._ID_HORIZON), raster) + + return raster.astype(jnp.uint8)